Source code for mridc.collections.reconstruction.models.cirim

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

import math
from abc import ABC
from typing import Generator, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss

from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import ifft2
from mridc.collections.common.parts.rnn_utils import rnn_weights_init
from mridc.collections.common.parts.utils import coil_combination
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel
from mridc.collections.reconstruction.models.rim.rim_block import RIMBlock
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck

__all__ = ["CIRIM"]


[docs]class CIRIM(BaseMRIReconstructionModel, ABC): """ Implementation of the quantitative Recurrent Inference Machines (qRIM), as presented in Zhang, C. et al. References ---------- .. Zhang, C. et al. (2022) ‘A unified model for reconstruction and R2 mapping of accelerated 7T data using \ quantitative Recurrent Inference Machine’. In review. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # init superclass super().__init__(cfg=cfg, trainer=trainer) # Cascades of RIM blocks cfg_dict = OmegaConf.to_container(cfg, resolve=True) self.recurrent_filters = cfg_dict.get("recurrent_filters") # make time-steps size divisible by 8 for fast fp16 training self.time_steps = 8 * math.ceil(cfg_dict.get("time_steps") / 8) self.no_dc = cfg_dict.get("no_dc") self.fft_centered = cfg_dict.get("fft_centered") self.fft_normalization = cfg_dict.get("fft_normalization") self.spatial_dims = cfg_dict.get("spatial_dims") self.coil_dim = cfg_dict.get("coil_dim") self.num_cascades = cfg_dict.get("num_cascades") self.cirim = torch.nn.ModuleList( [ RIMBlock( recurrent_layer=cfg_dict.get("recurrent_layer"), conv_filters=cfg_dict.get("conv_filters"), conv_kernels=cfg_dict.get("conv_kernels"), conv_dilations=cfg_dict.get("conv_dilations"), conv_bias=cfg_dict.get("conv_bias"), recurrent_filters=self.recurrent_filters, recurrent_kernels=cfg_dict.get("recurrent_kernels"), recurrent_dilations=cfg_dict.get("recurrent_dilations"), recurrent_bias=cfg_dict.get("recurrent_bias"), depth=cfg_dict.get("depth"), time_steps=self.time_steps, conv_dim=cfg_dict.get("conv_dim"), no_dc=self.no_dc, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, dimensionality=cfg_dict.get("dimensionality"), ) for _ in range(self.num_cascades) ] ) # Keep estimation through the cascades if keep_eta is True or re-estimate it if False. self.keep_eta = cfg_dict.get("keep_eta") self.coil_combination_method = cfg_dict.get("coil_combination_method") # initialize weights if not using pretrained cirim if not cfg_dict.get("pretrained", False): std_init_range = 1 / self.recurrent_filters[0] ** 0.5 self.cirim.apply(lambda module: rnn_weights_init(module, std_init_range)) self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss() self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss() self.dc_weight = torch.nn.Parameter(torch.ones(1)) self.accumulate_estimates = True
[docs] @typecheck() def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> Union[Generator, torch.Tensor]: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ prediction = y.clone() init_pred = None if init_pred is None or init_pred.dim() < 4 else init_pred hx = None sigma = 1.0 cascades_etas = [] for i, cascade in enumerate(self.cirim): # Forward pass through the cascades prediction, hx = cascade( prediction, y, sensitivity_maps, mask, init_pred, hx, sigma, keep_eta=False if i == 0 else self.keep_eta, ) time_steps_etas = [self.process_intermediate_pred(pred, sensitivity_maps, target) for pred in prediction] cascades_etas.append(time_steps_etas) yield cascades_etas
[docs] def process_intermediate_pred(self, pred, sensitivity_maps, target, do_coil_combination=False): """ Process the intermediate prediction. Parameters ---------- pred: Intermediate prediction. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] target: Target data to crop to size. torch.Tensor, shape [batch_size, n_x, n_y, 2] do_coil_combination: Whether to do coil combination. bool, default False Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Processed prediction. """ # Take the last time step of the eta if not self.no_dc or do_coil_combination: pred = ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims ) pred = coil_combination(pred, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) pred = torch.view_as_complex(pred) _, pred = center_crop_to_smallest(target, pred) return pred
[docs] def process_loss(self, target, pred, _loss_fn=None, mask=None): """ Process the loss. Parameters ---------- target: Target data. torch.Tensor, shape [batch_size, n_x, n_y, 2] pred: Final prediction(s). list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] _loss_fn: Loss function. torch.nn.Module, default torch.nn.L1Loss() Returns ------- loss: torch.FloatTensor, shape [1] If self.accumulate_loss is True, returns an accumulative result of all intermediate losses. """ target = torch.abs(target / torch.max(torch.abs(target))) if "ssim" in str(_loss_fn).lower(): max_value = np.array(torch.max(torch.abs(target)).item()).astype(np.float32) def loss_fn(x, y, m): """Calculate the ssim loss.""" y = torch.abs(y / torch.max(torch.abs(y))) return _loss_fn( x.unsqueeze(dim=self.coil_dim), y.unsqueeze(dim=self.coil_dim), data_range=torch.tensor(max_value).unsqueeze(dim=0).to(x.device), ) else: def loss_fn(x, y, m): """Calculate other loss.""" y = torch.abs(y / torch.max(torch.abs(y))) return _loss_fn(x, y) if self.accumulate_estimates: cascades_loss = [] for cascade_pred in pred: time_steps_loss = [loss_fn(target, time_step_pred, mask) for time_step_pred in cascade_pred] _loss = [ x * torch.logspace(-1, 0, steps=self.time_steps).to(time_steps_loss[0]) for x in time_steps_loss ] cascades_loss.append(sum(sum(_loss) / self.time_steps)) yield sum(list(cascades_loss)) / len(self.cirim) else: return loss_fn(target, pred, mask)