# coding=utf-8
__author__ = "Dimitrios Karkalousos"
import math
from abc import ABC
from typing import Generator, Union
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss
from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import ifft2
from mridc.collections.common.parts.rnn_utils import rnn_weights_init
from mridc.collections.common.parts.utils import coil_combination
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.rim.rim_block import RIMBlock
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck
__all__ = ["CIRIM"]
[docs]class CIRIM(BaseMRIReconstructionModel, ABC):
"""
Implementation of the quantitative Recurrent Inference Machines (qRIM), as presented in Zhang, C. et al.
References
----------
..
Zhang, C. et al. (2022) ‘A unified model for reconstruction and R2 mapping of accelerated 7T data using \
quantitative Recurrent Inference Machine’. In review.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# init superclass
super().__init__(cfg=cfg, trainer=trainer)
# Cascades of RIM blocks
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
self.recurrent_filters = cfg_dict.get("recurrent_filters")
# make time-steps size divisible by 8 for fast fp16 training
self.time_steps = 8 * math.ceil(cfg_dict.get("time_steps") / 8)
self.no_dc = cfg_dict.get("no_dc")
self.fft_centered = cfg_dict.get("fft_centered")
self.fft_normalization = cfg_dict.get("fft_normalization")
self.spatial_dims = cfg_dict.get("spatial_dims")
self.coil_dim = cfg_dict.get("coil_dim")
self.num_cascades = cfg_dict.get("num_cascades")
self.cirim = torch.nn.ModuleList(
[
RIMBlock(
recurrent_layer=cfg_dict.get("recurrent_layer"),
conv_filters=cfg_dict.get("conv_filters"),
conv_kernels=cfg_dict.get("conv_kernels"),
conv_dilations=cfg_dict.get("conv_dilations"),
conv_bias=cfg_dict.get("conv_bias"),
recurrent_filters=self.recurrent_filters,
recurrent_kernels=cfg_dict.get("recurrent_kernels"),
recurrent_dilations=cfg_dict.get("recurrent_dilations"),
recurrent_bias=cfg_dict.get("recurrent_bias"),
depth=cfg_dict.get("depth"),
time_steps=self.time_steps,
conv_dim=cfg_dict.get("conv_dim"),
no_dc=self.no_dc,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
coil_dim=self.coil_dim,
dimensionality=cfg_dict.get("dimensionality"),
)
for _ in range(self.num_cascades)
]
)
# Keep estimation through the cascades if keep_eta is True or re-estimate it if False.
self.keep_eta = cfg_dict.get("keep_eta")
self.coil_combination_method = cfg_dict.get("coil_combination_method")
# initialize weights if not using pretrained cirim
if not cfg_dict.get("pretrained", False):
std_init_range = 1 / self.recurrent_filters[0] ** 0.5
self.cirim.apply(lambda module: rnn_weights_init(module, std_init_range))
self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss()
self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss()
self.dc_weight = torch.nn.Parameter(torch.ones(1))
self.accumulate_estimates = True
[docs] @typecheck()
def forward(
self,
y: torch.Tensor,
sensitivity_maps: torch.Tensor,
mask: torch.Tensor,
init_pred: torch.Tensor,
target: torch.Tensor,
) -> Union[Generator, torch.Tensor]:
"""
Forward pass of the network.
Parameters
----------
y: Subsampled k-space data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
mask: Sampling mask.
torch.Tensor, shape [1, 1, n_x, n_y, 1]
init_pred: Initial prediction.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
target: Target data to compute the loss.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
Returns
-------
pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]
If self.accumulate_loss is True, returns a list of all intermediate estimates.
If False, returns the final estimate.
"""
prediction = y.clone()
init_pred = None if init_pred is None or init_pred.dim() < 4 else init_pred
hx = None
sigma = 1.0
cascades_etas = []
for i, cascade in enumerate(self.cirim):
# Forward pass through the cascades
prediction, hx = cascade(
prediction,
y,
sensitivity_maps,
mask,
init_pred,
hx,
sigma,
keep_eta=False if i == 0 else self.keep_eta,
)
time_steps_etas = [self.process_intermediate_pred(pred, sensitivity_maps, target) for pred in prediction]
cascades_etas.append(time_steps_etas)
yield cascades_etas
[docs] def process_loss(self, target, pred, _loss_fn=None, mask=None):
"""
Process the loss.
Parameters
----------
target: Target data.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
pred: Final prediction(s).
list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or
torch.Tensor, shape [batch_size, n_x, n_y, 2]
_loss_fn: Loss function.
torch.nn.Module, default torch.nn.L1Loss()
Returns
-------
loss: torch.FloatTensor, shape [1]
If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.
"""
target = torch.abs(target / torch.max(torch.abs(target)))
if "ssim" in str(_loss_fn).lower():
max_value = np.array(torch.max(torch.abs(target)).item()).astype(np.float32)
def loss_fn(x, y, m):
"""Calculate the ssim loss."""
y = torch.abs(y / torch.max(torch.abs(y)))
return _loss_fn(
x.unsqueeze(dim=self.coil_dim),
y.unsqueeze(dim=self.coil_dim),
data_range=torch.tensor(max_value).unsqueeze(dim=0).to(x.device),
)
else:
def loss_fn(x, y, m):
"""Calculate other loss."""
y = torch.abs(y / torch.max(torch.abs(y)))
return _loss_fn(x, y)
if self.accumulate_estimates:
cascades_loss = []
for cascade_pred in pred:
time_steps_loss = [loss_fn(target, time_step_pred, mask) for time_step_pred in cascade_pred]
_loss = [
x * torch.logspace(-1, 0, steps=self.time_steps).to(time_steps_loss[0]) for x in time_steps_loss
]
cascades_loss.append(sum(sum(_loss) / self.time_steps))
yield sum(list(cascades_loss)) / len(self.cirim)
else:
return loss_fn(target, pred, mask)