Source code for mridc.collections.reconstruction.models.cascadenet.ccnn_block

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from typing import Optional, Tuple

import torch

from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import complex_conj, complex_mul


[docs]class CascadeNetBlock(torch.nn.Module): """ Model block for CascadeNet & Convolution Recurrent Neural Network. This model applies a combination of soft data consistency with the input model as a regularizer. A series of these blocks can be stacked to form the full variational network. """ def __init__( self, model: torch.nn.Module, fft_centered: bool = True, fft_normalization: str = "ortho", spatial_dims: Optional[Tuple[int, int]] = None, coil_dim: int = 1, no_dc: bool = False, ): """ Initializes the model block. Parameters ---------- model: Model to apply soft data consistency. torch.nn.Module fft_centered: Whether to center the FFT. bool fft_normalization: Whether to normalize the FFT. str spatial_dims: Spatial dimensions of the input. Tuple[int, int] coil_dim: Coil dimension. int no_dc: Flag to disable the soft data consistency. bool """ super().__init__() self.model = model self.fft_centered = fft_centered self.fft_normalization = fft_normalization self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1] self.coil_dim = coil_dim self.no_dc = no_dc self.dc_weight = torch.nn.Parameter(torch.ones(1))
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction expanded to the same size as the input. torch.Tensor, shape [batch_size, n_coils, height, width, 2] """ return fft2( complex_mul(x, sens_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction. torch.Tensor, shape [batch_size, height, width, 2] """ x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim, keepdim=True)
[docs] def forward( self, pred: torch.Tensor, ref_kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the model block. Parameters ---------- pred: Predicted k-space data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] ref_kspace: Reference k-space data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] mask: Mask to apply to the data. torch.Tensor, shape [batch_size, 1, height, width, 1] Returns ------- Reconstructed image. torch.Tensor, shape [batch_size, height, width, 2] """ zero = torch.zeros(1, 1, 1, 1, 1).to(pred) soft_dc = torch.where(mask.bool(), pred - ref_kspace, zero) * self.dc_weight eta = self.sens_reduce(pred, sens_maps) eta = self.model(eta.squeeze(self.coil_dim).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) eta = self.sens_expand(eta, sens_maps) if not self.no_dc: eta = pred - soft_dc - eta return eta