Source code for mridc.collections.reconstruction.models.xpdnet

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from abc import ABC

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss

from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
from mridc.collections.reconstruction.models.crossdomain.crossdomain import CrossDomainNetwork
from mridc.collections.reconstruction.models.crossdomain.multicoil import MultiCoil
from mridc.collections.reconstruction.models.didn.didn import DIDN
from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck

__all__ = ["XPDNet"]


[docs]class XPDNet(BaseMRIReconstructionModel, ABC): """ Implementation of the XPDNet, as presented in Ramzi, Zaccharie, et al. References ---------- .. Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge. \ ” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # init superclass super().__init__(cfg=cfg, trainer=trainer) cfg_dict = OmegaConf.to_container(cfg, resolve=True) num_primal = cfg_dict.get("num_primal") num_dual = cfg_dict.get("num_dual") num_iter = cfg_dict.get("num_iter") kspace_model_architecture = cfg_dict.get("kspace_model_architecture") dual_conv_hidden_channels = cfg_dict.get("dual_conv_hidden_channels") dual_conv_num_dubs = cfg_dict.get("dual_conv_num_dubs") dual_conv_batchnorm = cfg_dict.get("dual_conv_batchnorm") dual_didn_hidden_channels = cfg_dict.get("dual_didn_hidden_channels") dual_didn_num_dubs = cfg_dict.get("dual_didn_num_dubs") dual_didn_num_convs_recon = cfg_dict.get("dual_didn_num_convs_recon") if cfg_dict.get("use_primal_only"): kspace_model_list = None num_dual = 1 elif kspace_model_architecture == "CONV": kspace_model_list = torch.nn.ModuleList( [ MultiCoil( Conv2d( 2 * (num_dual + num_primal + 1), 2 * num_dual, dual_conv_hidden_channels, dual_conv_num_dubs, batchnorm=dual_conv_batchnorm, ) ) for _ in range(num_iter) ] ) elif kspace_model_architecture == "DIDN": kspace_model_list = torch.nn.ModuleList( [ MultiCoil( DIDN( in_channels=2 * (num_dual + num_primal + 1), out_channels=2 * num_dual, hidden_channels=dual_didn_hidden_channels, num_dubs=dual_didn_num_dubs, num_convs_recon=dual_didn_num_convs_recon, ) ) for _ in range(num_iter) ] ) elif kspace_model_architecture in ["UNET", "NORMUNET"]: kspace_model_list = torch.nn.ModuleList( [ MultiCoil( NormUnet( cfg_dict.get("kspace_unet_num_filters"), cfg_dict.get("kspace_unet_num_pool_layers"), in_chans=2 * (num_dual + num_primal + 1), out_chans=2 * num_dual, drop_prob=cfg_dict.get("kspace_unet_dropout_probability"), padding_size=cfg_dict.get("kspace_unet_padding_size"), normalize=cfg_dict.get("kspace_unet_normalize"), ), coil_to_batch=True, ) for _ in range(num_iter) ] ) else: raise NotImplementedError( "XPDNet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN'." f"Got kspace_model_architecture == {kspace_model_architecture}." ) image_model_architecture = cfg_dict.get("image_model_architecture") mwcnn_hidden_channels = cfg_dict.get("mwcnn_hidden_channels") mwcnn_num_scales = cfg_dict.get("mwcnn_num_scales") mwcnn_bias = cfg_dict.get("mwcnn_bias") mwcnn_batchnorm = cfg_dict.get("mwcnn_batchnorm") if image_model_architecture == "MWCNN": image_model_list = torch.nn.ModuleList( [ torch.nn.Sequential( MWCNN( input_channels=2 * (num_primal + num_dual), first_conv_hidden_channels=mwcnn_hidden_channels, num_scales=mwcnn_num_scales, bias=mwcnn_bias, batchnorm=mwcnn_batchnorm, ), torch.nn.Conv2d(2 * (num_primal + num_dual), 2 * num_primal, kernel_size=3, padding=1), ) for _ in range(num_iter) ] ) elif image_model_architecture in ["UNET", "NORMUNET"]: image_model_list = torch.nn.ModuleList( [ NormUnet( cfg_dict.get("imspace_unet_num_filters"), cfg_dict.get("imspace_unet_num_pool_layers"), in_chans=2 * (num_primal + num_dual), out_chans=2 * num_primal, drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), padding_size=cfg_dict.get("imspace_unet_padding_size"), normalize=cfg_dict.get("imspace_unet_normalize"), ) for _ in range(num_iter) ] ) else: raise NotImplementedError(f"Image model architecture {image_model_architecture} not found for XPDNet.") self.fft_normalization = cfg_dict.get("fft_normalization") self.spatial_dims = cfg_dict.get("spatial_dims") self.coil_dim = cfg_dict.get("coil_dim") self.num_cascades = cfg_dict.get("num_cascades") self.xpdnet = CrossDomainNetwork( image_model_list=image_model_list, kspace_model_list=kspace_model_list, domain_sequence="KI" * num_iter, image_buffer_size=num_primal, kspace_buffer_size=num_dual, normalize_image=cfg_dict.get("normalize_image"), fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ) self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() self.accumulate_estimates = False
[docs] @typecheck() def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ eta = self.xpdnet(y, sensitivity_maps, mask) eta = (eta**2).sqrt().sum(-1) _, eta = center_crop_to_smallest(target, eta) return eta