Source code for mridc.collections.reconstruction.models.lpd

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from abc import ABC

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss

from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
from mridc.collections.reconstruction.models.didn.didn import DIDN
from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
from mridc.collections.reconstruction.models.primaldual.pd import DualNet, PrimalNet
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck

__all__ = ["LPDNet"]


[docs]class LPDNet(BaseMRIReconstructionModel, ABC): """ Implementation of the Learned Primal Dual network, inspired by Adler, Jonas, and Ozan Öktem. References ---------- .. Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, \ vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # init superclass super().__init__(cfg=cfg, trainer=trainer) cfg_dict = OmegaConf.to_container(cfg, resolve=True) self.num_iter = cfg_dict.get("num_iter") self.num_primal = cfg_dict.get("num_primal") self.num_dual = cfg_dict.get("num_dual") primal_model_architecture = cfg_dict.get("primal_model_architecture") if primal_model_architecture == "MWCNN": primal_model = torch.nn.Sequential( *[ MWCNN( input_channels=2 * (self.num_primal + 1), first_conv_hidden_channels=cfg_dict.get("primal_mwcnn_hidden_channels"), num_scales=cfg_dict.get("primal_mwcnn_num_scales"), bias=cfg_dict.get("primal_mwcnn_bias"), batchnorm=cfg_dict.get("primal_mwcnn_batchnorm"), ), torch.nn.Conv2d(2 * (self.num_primal + 1), 2 * self.num_primal, kernel_size=1), ] ) elif primal_model_architecture in ["UNET", "NORMUNET"]: primal_model = NormUnet( cfg_dict.get("primal_unet_num_filters"), cfg_dict.get("primal_unet_num_pool_layers"), in_chans=2 * (self.num_primal + 1), out_chans=2 * self.num_primal, drop_prob=cfg_dict.get("primal_unet_dropout_probability"), padding_size=cfg_dict.get("primal_unet_padding_size"), normalize=cfg_dict.get("primal_unet_normalize"), ) else: raise NotImplementedError( "LPDNet is currently implemented for primal_model_architecture == 'CONV' or 'UNet'." f"Got primal_model_architecture == {primal_model_architecture}." ) dual_model_architecture = cfg_dict.get("dual_model_architecture") if dual_model_architecture == "CONV": dual_model = Conv2d( in_channels=2 * (self.num_dual + 2), out_channels=2 * self.num_dual, hidden_channels=cfg_dict.get("kspace_conv_hidden_channels"), n_convs=cfg_dict.get("kspace_conv_n_convs"), batchnorm=cfg_dict.get("kspace_conv_batchnorm"), ) elif dual_model_architecture == "DIDN": dual_model = DIDN( in_channels=2 * (self.num_dual + 2), out_channels=2 * self.num_dual, hidden_channels=cfg_dict.get("kspace_didn_hidden_channels"), num_dubs=cfg_dict.get("kspace_didn_num_dubs"), num_convs_recon=cfg_dict.get("kspace_didn_num_convs_recon"), ) elif dual_model_architecture in ["UNET", "NORMUNET"]: dual_model = NormUnet( cfg_dict.get("dual_unet_num_filters"), cfg_dict.get("dual_unet_num_pool_layers"), in_chans=2 * (self.num_dual + 2), out_chans=2 * self.num_dual, drop_prob=cfg_dict.get("dual_unet_dropout_probability"), padding_size=cfg_dict.get("dual_unet_padding_size"), normalize=cfg_dict.get("dual_unet_normalize"), ) else: raise NotImplementedError( "LPDNet is currently implemented for dual_model_architecture == 'CONV' or 'DIDN' or 'UNet'." f"Got dual_model_architecture == {dual_model_architecture}." ) self.primal_net = torch.nn.ModuleList( [PrimalNet(self.num_primal, primal_architecture=primal_model) for _ in range(self.num_iter)] ) self.dual_net = torch.nn.ModuleList( [DualNet(self.num_dual, dual_architecture=dual_model) for _ in range(self.num_iter)] ) self.fft_centered = cfg_dict.get("fft_centered") self.fft_normalization = cfg_dict.get("fft_normalization") self.spatial_dims = cfg_dict.get("spatial_dims") self.coil_dim = cfg_dict.get("coil_dim") self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() self.accumulate_estimates = False
[docs] @typecheck() def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ input_image = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) for idx in range(self.num_iter): # Dual f_2 = primal_buffer[..., 2:4].clone() f_2 = torch.where( mask == 0, torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), fft2( complex_mul(f_2.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(f_2.type()), ) dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) # Primal h_1 = dual_buffer[..., 0:2].clone() h_1 = torch.view_as_real(h_1[..., 0] + 1j * h_1[..., 1]) # needed for python3.9 h_1 = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) primal_buffer = self.primal_net[idx](primal_buffer, h_1) output = primal_buffer[..., 0:2] output = (output**2).sum(-1).sqrt() _, output = center_crop_to_smallest(target, output) return output