Source code for mridc.collections.reconstruction.models.vsnet

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from abc import ABC

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss

from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import ifft2
from mridc.collections.common.parts.utils import coil_combination
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block import (
    DataConsistencyLayer,
    VSNetBlock,
    WeightedAverageTerm,
)
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck

__all__ = ["VSNet"]


[docs]class VSNet(BaseMRIReconstructionModel, ABC): """ Implementation of the Variable-Splitting Net, as presented in Duan, J. et al. References ---------- .. Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, \ Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture \ Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # init superclass super().__init__(cfg=cfg, trainer=trainer) cfg_dict = OmegaConf.to_container(cfg, resolve=True) num_cascades = cfg_dict.get("num_cascades") self.fft_normalization = cfg_dict.get("fft_normalization") self.spatial_dims = cfg_dict.get("spatial_dims") self.coil_dim = cfg_dict.get("coil_dim") self.num_cascades = cfg_dict.get("num_cascades") image_model_architecture = cfg_dict.get("imspace_model_architecture") if image_model_architecture == "CONV": image_model = Conv2d( in_channels=2, out_channels=2, hidden_channels=cfg_dict.get("imspace_conv_hidden_channels"), n_convs=cfg_dict.get("imspace_conv_n_convs"), batchnorm=cfg_dict.get("imspace_conv_batchnorm"), ) elif image_model_architecture == "MWCNN": image_model = MWCNN( input_channels=2, first_conv_hidden_channels=cfg_dict.get("image_mwcnn_hidden_channels"), num_scales=cfg_dict.get("image_mwcnn_num_scales"), bias=cfg_dict.get("image_mwcnn_bias"), batchnorm=cfg_dict.get("image_mwcnn_batchnorm"), ) elif image_model_architecture in ["UNET", "NORMUNET"]: image_model = NormUnet( cfg_dict.get("imspace_unet_num_filters"), cfg_dict.get("imspace_unet_num_pool_layers"), in_chans=2, out_chans=2, drop_prob=cfg_dict.get("imspace_unet_dropout_probability"), padding_size=cfg_dict.get("imspace_unet_padding_size"), normalize=cfg_dict.get("imspace_unet_normalize"), ) else: raise NotImplementedError( "VSNet is currently implemented only with image_model_architecture == 'MWCNN' or 'UNet'." f"Got {image_model_architecture}." ) image_model = torch.nn.ModuleList([image_model] * num_cascades) data_consistency_model = torch.nn.ModuleList([DataConsistencyLayer()] * num_cascades) weighted_average_model = torch.nn.ModuleList([WeightedAverageTerm()] * num_cascades) self.model = VSNetBlock( denoiser_block=image_model, data_consistency_block=data_consistency_model, weighted_average_block=weighted_average_model, num_cascades=num_cascades, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ) self.coil_combination_method = cfg_dict.get("coil_combination_method") self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() self.accumulate_estimates = False
[docs] @typecheck() def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps image = self.model(y, sensitivity_maps, mask) image = torch.view_as_complex( coil_combination( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim, ) ) _, image = center_crop_to_smallest(target, image) return image