# coding=utf-8
__author__ = "Dimitrios Karkalousos"
import math
from abc import ABC
from typing import Optional
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss
from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.rnn_utils import rnn_weights_init
from mridc.collections.common.parts.utils import coil_combination, complex_conj, complex_mul
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet import RecurrentInit, RecurrentVarNetBlock
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck
__all__ = ["RecurrentVarNet"]
[docs]class RecurrentVarNet(BaseMRIReconstructionModel, ABC):
"""
Implementation of the Recurrent Variational Network implementation, as presented in Yiasemis, George, et al.
References
----------
..
Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \
the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \
http://arxiv.org/abs/2111.09639.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# init superclass
super().__init__(cfg=cfg, trainer=trainer)
# Cascades of RIM blocks
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
self.in_channels = cfg_dict.get("in_channels")
self.recurrent_hidden_channels = cfg_dict.get("recurrent_hidden_channels")
self.recurrent_num_layers = cfg_dict.get("recurrent_num_layers")
self.no_parameter_sharing = cfg_dict.get("no_parameter_sharing")
# make time-steps size divisible by 8 for fast fp16 training
self.num_steps = 8 * math.ceil(cfg_dict.get("num_steps") / 8)
self.learned_initializer = cfg_dict.get("learned_initializer")
self.initializer_initialization = cfg_dict.get("initializer_initialization")
self.initializer_channels = cfg_dict.get("initializer_channels")
self.initializer_dilations = cfg_dict.get("initializer_dilations")
if (
self.learned_initializer
and self.initializer_initialization is not None
and self.initializer_channels is not None
and self.initializer_dilations is not None
):
if self.initializer_initialization not in [
"sense",
"input_image",
"zero_filled",
]:
raise ValueError(
"Unknown initializer_initialization. Expected `sense`, `'input_image` or `zero_filled`."
f"Got {self.initializer_initialization}."
)
self.initializer = RecurrentInit(
self.in_channels,
self.recurrent_hidden_channels,
channels=self.initializer_channels,
dilations=self.initializer_dilations,
depth=self.recurrent_num_layers,
multiscale_depth=cfg_dict.get("initializer_multiscale"),
)
else:
self.initializer = None # type: ignore
self.fft_centered = cfg_dict.get("fft_centered")
self.fft_normalization = cfg_dict.get("fft_normalization")
self.spatial_dims = cfg_dict.get("spatial_dims")
self.coil_dim = cfg_dict.get("coil_dim")
self.coil_combination_method = cfg_dict.get("coil_combination_method")
self.block_list: torch.nn.Module = torch.nn.ModuleList()
for _ in range(self.num_steps if self.no_parameter_sharing else 1):
self.block_list.append(
RecurrentVarNetBlock(
in_channels=self.in_channels,
hidden_channels=self.recurrent_hidden_channels,
num_layers=self.recurrent_num_layers,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
coil_dim=self.coil_dim,
)
)
std_init_range = 1 / self.recurrent_hidden_channels**0.5
# initialize weights if not using pretrained cirim
if not cfg_dict.get("pretrained", False):
self.block_list.apply(lambda module: rnn_weights_init(module, std_init_range))
self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss()
self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss()
self.accumulate_estimates = False
[docs] @typecheck()
def forward(
self,
y: torch.Tensor,
sensitivity_maps: torch.Tensor,
mask: torch.Tensor,
init_pred: torch.Tensor,
target: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Forward pass of the network.
Parameters
----------
y: Subsampled k-space data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
mask: Sampling mask.
torch.Tensor, shape [1, 1, n_x, n_y, 1]
init_pred: Initial prediction.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
target: Target data to compute the loss.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
Returns
-------
pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]
If self.accumulate_loss is True, returns a list of all intermediate estimates.
If False, returns the final estimate.
"""
previous_state: Optional[torch.Tensor] = None
if self.initializer is not None:
if self.initializer_initialization == "sense":
initializer_input_image = (
complex_mul(
ifft2(
y,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
complex_conj(sensitivity_maps),
)
.sum(self.coil_dim)
.unsqueeze(self.coil_dim)
)
elif self.initializer_initialization == "input_image":
if "initial_image" not in kwargs:
raise ValueError(
"`'initial_image` is required as input if initializer_initialization "
f"is {self.initializer_initialization}."
)
initializer_input_image = kwargs["initial_image"].unsqueeze(self.coil_dim)
elif self.initializer_initialization == "zero_filled":
initializer_input_image = ifft2(
y,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
previous_state = self.initializer(
fft2(
initializer_input_image,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
.sum(1)
.permute(0, 3, 1, 2)
)
kspace_prediction = y.clone()
for step in range(self.num_steps):
block = self.block_list[step] if self.no_parameter_sharing else self.block_list[0]
kspace_prediction, previous_state = block(
kspace_prediction,
y,
mask,
sensitivity_maps,
previous_state,
)
eta = ifft2(
kspace_prediction,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
eta = coil_combination(eta, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim)
eta = torch.view_as_complex(eta)
_, eta = center_crop_to_smallest(target, eta)
return eta