# coding=utf-8
__author__ = "Dimitrios Karkalousos"
from abc import ABC
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss
from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import ifft2
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.didn.didn import DIDN
from mridc.collections.reconstruction.models.sigmanet.dc_layers import (
DataGDLayer,
DataIDLayer,
DataProxCGLayer,
DataVSLayer,
)
from mridc.collections.reconstruction.models.sigmanet.sensitivity_net import SensitivityNetwork
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck
__all__ = ["DUNet"]
[docs]class DUNet(BaseMRIReconstructionModel, ABC):
"""
Implementation of the Down-Up NET, inspired by Hammernik, K, Schlemper, J, Qin, C, et al.
References
----------
..
Hammernik, K, Schlemper, J, Qin, C, et al. Systematic evaluation of iterative deep neural networks for fast \
parallel MRI reconstruction with sensitivity-weighted coil combination. Magn Reson Med. 2021; 86: 1859– 1872. \
https://doi.org/10.1002/mrm.28827
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# init superclass
super().__init__(cfg=cfg, trainer=trainer)
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
self.fft_centered = cfg_dict.get("fft_centered")
self.fft_normalization = cfg_dict.get("fft_normalization")
self.spatial_dims = cfg_dict.get("spatial_dims")
self.coil_dim = cfg_dict.get("coil_dim")
reg_model_architecture = cfg_dict.get("reg_model_architecture")
if reg_model_architecture == "DIDN":
reg_model = DIDN(
in_channels=2,
out_channels=2,
hidden_channels=cfg_dict.get("didn_hidden_channels"),
num_dubs=cfg_dict.get("didn_num_dubs"),
num_convs_recon=cfg_dict.get("didn_num_convs_recon"),
)
elif reg_model_architecture in ["UNET", "NORMUNET"]:
reg_model = NormUnet(
cfg_dict.get("unet_num_filters"),
cfg_dict.get("unet_num_pool_layers"),
in_chans=2,
out_chans=2,
drop_prob=cfg_dict.get("unet_dropout_probability"),
padding_size=cfg_dict.get("unet_padding_size"),
normalize=cfg_dict.get("unet_normalize"),
)
else:
raise NotImplementedError(
"DUNET is currently implemented for reg_model_architecture == 'DIDN' or 'UNet'."
f"Got reg_model_architecture == {reg_model_architecture}."
)
data_consistency_term = cfg_dict.get("data_consistency_term")
if data_consistency_term == "GD":
dc_layer = DataGDLayer(
lambda_init=cfg_dict.get("data_consistency_lambda_init"),
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif data_consistency_term == "PROX":
dc_layer = DataProxCGLayer(
lambda_init=cfg_dict.get("data_consistency_lambda_init"),
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif data_consistency_term == "VS":
dc_layer = DataVSLayer(
alpha_init=cfg_dict.get("data_consistency_alpha_init"),
beta_init=cfg_dict.get("data_consistency_beta_init"),
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
else:
dc_layer = DataIDLayer()
self.model = SensitivityNetwork(
cfg_dict.get("num_iter"),
reg_model,
dc_layer,
shared_params=cfg_dict.get("shared_params"),
save_space=False,
reset_cache=False,
)
self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss()
self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss()
self.dc_weight = torch.nn.Parameter(torch.ones(1))
self.accumulate_estimates = False
[docs] @typecheck()
def forward(
self,
y: torch.Tensor,
sensitivity_maps: torch.Tensor,
mask: torch.Tensor,
init_pred: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the network.
Parameters
----------
y: Subsampled k-space data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
mask: Sampling mask.
torch.Tensor, shape [1, 1, n_x, n_y, 1]
init_pred: Initial prediction.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
target: Target data to compute the loss.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
Returns
-------
pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]
If self.accumulate_loss is True, returns a list of all intermediate estimates.
If False, returns the final estimate.
"""
init_pred = torch.sum(
complex_mul(
ifft2(
y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims
),
complex_conj(sensitivity_maps),
),
self.coil_dim,
)
image = self.model(init_pred, y, sensitivity_maps, mask)
image = torch.sum(complex_mul(image, complex_conj(sensitivity_maps)), self.coil_dim)
image = torch.view_as_complex(image)
_, image = center_crop_to_smallest(target, image)
return image