# coding=utf-8
__author__ = "Dimitrios Karkalousos"
from typing import Any, Optional, Tuple, Union
import torch
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.rim.conv_layers import ConvNonlinear, ConvRNNStack
from mridc.collections.reconstruction.models.rim.rnn_cells import ConvGRUCell, ConvMGUCell, IndRNNCell
from mridc.collections.reconstruction.models.rim.utils import log_likelihood_gradient
[docs]class RIMBlock(torch.nn.Module):
"""RIMBlock is a block of Recurrent Inference Machines (RIMs)."""
def __init__(
self,
recurrent_layer=None,
conv_filters=None,
conv_kernels=None,
conv_dilations=None,
conv_bias=None,
recurrent_filters=None,
recurrent_kernels=None,
recurrent_dilations=None,
recurrent_bias=None,
depth: int = 2,
time_steps: int = 8,
conv_dim: int = 2,
no_dc: bool = False,
fft_centered: bool = True,
fft_normalization: str = "ortho",
spatial_dims: Optional[Tuple[int, int]] = None,
coil_dim: int = 1,
dimensionality: int = 2,
):
"""
Initialize the RIMBlock.
Parameters
----------
recurrent_layer: Type of recurrent layer.
conv_filters: Number of filters in the convolutional layers.
conv_kernels: Kernel size of the convolutional layers.
conv_dilations: Dilation of the convolutional layers.
conv_bias: Bias of the convolutional layers.
recurrent_filters: Number of filters in the recurrent layers.
recurrent_kernels: Kernel size of the recurrent layers.
recurrent_dilations: Dilation of the recurrent layers.
recurrent_bias: Bias of the recurrent layers.
depth: Number of layers in the block.
time_steps: Number of time steps in the block.
conv_dim: Dimension of the convolutional layers.
no_dc: If True, the DC component is removed from the input.
fft_centered: If True, the FFT is centered.
fft_normalization: Normalization of the FFT.
spatial_dims: Spatial dimensions of the input.
coil_dim: Coils dimension of the input.
dimensionality: Dimensionality of the input.
"""
super(RIMBlock, self).__init__()
self.input_size = depth * 2
self.time_steps = time_steps
self.layers = torch.nn.ModuleList()
for (
(conv_features, conv_k_size, conv_dilation, l_conv_bias, nonlinear),
(rnn_features, rnn_k_size, rnn_dilation, rnn_bias, rnn_type),
) in zip(
zip(conv_filters, conv_kernels, conv_dilations, conv_bias, ["relu", "relu", None]),
zip(
recurrent_filters,
recurrent_kernels,
recurrent_dilations,
recurrent_bias,
[recurrent_layer, recurrent_layer, None],
),
):
conv_layer = None
if conv_features != 0:
conv_layer = ConvNonlinear(
self.input_size,
conv_features,
conv_dim=conv_dim,
kernel_size=conv_k_size,
dilation=conv_dilation,
bias=l_conv_bias,
nonlinear=nonlinear,
)
self.input_size = conv_features
if rnn_features != 0 and rnn_type is not None:
if rnn_type.upper() == "GRU":
rnn_type = ConvGRUCell
elif rnn_type.upper() == "MGU":
rnn_type = ConvMGUCell
elif rnn_type.upper() == "INDRNN":
rnn_type = IndRNNCell
else:
raise ValueError("Please specify a proper recurrent layer type.")
rnn_layer = rnn_type(
self.input_size,
rnn_features,
conv_dim=conv_dim,
kernel_size=rnn_k_size,
dilation=rnn_dilation,
bias=rnn_bias,
)
self.input_size = rnn_features
self.layers.append(ConvRNNStack(conv_layer, rnn_layer))
self.final_layer = torch.nn.Sequential(conv_layer)
self.recurrent_filters = recurrent_filters
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1]
self.coil_dim = coil_dim
self.no_dc = no_dc
if not self.no_dc:
self.dc_weight = torch.nn.Parameter(torch.ones(1))
self.zero = torch.zeros(1, 1, 1, 1, 1)
self.dimensionality = dimensionality
[docs] def forward(
self,
pred: torch.Tensor,
masked_kspace: torch.Tensor,
sense: torch.Tensor,
mask: torch.Tensor,
eta: torch.Tensor = None,
hx: torch.Tensor = None,
sigma: float = 1.0,
keep_eta: bool = False,
) -> Tuple[Any, Union[list, torch.Tensor, None]]:
"""
Forward pass of the RIMBlock.
Parameters
----------
pred: Predicted k-space.
masked_kspace: Subsampled k-space.
sense: Coil sensitivity maps.
mask: Sample mask.
eta: Initial guess for the eta.
hx: Initial guess for the hidden state.
sigma: Noise level.
keep_eta: Whether to keep the eta.
Returns
-------
Reconstructed image and hidden states.
"""
if self.dimensionality == 3:
# 2D pred.shape = [batch, coils, height, width, 2]
# 3D pred.shape = [batch, slices, coils, height, width, 2] -> [batch * slices, coils, height, width, 2]
batch, slices = masked_kspace.shape[0], masked_kspace.shape[1]
if isinstance(pred, (tuple, list)):
pred = pred[-1].detach()
else:
pred = pred.reshape(
[pred.shape[0] * pred.shape[1], pred.shape[2], pred.shape[3], pred.shape[4], pred.shape[5]]
)
masked_kspace = masked_kspace.reshape(
[
masked_kspace.shape[0] * masked_kspace.shape[1],
masked_kspace.shape[2],
masked_kspace.shape[3],
masked_kspace.shape[4],
masked_kspace.shape[5],
]
)
mask = mask.reshape(
[mask.shape[0] * mask.shape[1], mask.shape[2], mask.shape[3], mask.shape[4], mask.shape[5]]
)
sense = sense.reshape(
[sense.shape[0] * sense.shape[1], sense.shape[2], sense.shape[3], sense.shape[4], sense.shape[5]]
)
else:
batch = masked_kspace.shape[0]
slices = masked_kspace.shape[1]
if isinstance(pred, list):
pred = pred[-1].detach()
if hx is None:
hx = [
masked_kspace.new_zeros((masked_kspace.size(0), f, *masked_kspace.size()[2:-1]))
for f in self.recurrent_filters
if f != 0
]
if eta is None or eta.ndim < 3:
eta = (
pred
if keep_eta
else torch.sum(
complex_mul(
ifft2(
pred,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
complex_conj(sense),
),
self.coil_dim,
)
)
etas = []
for _ in range(self.time_steps):
grad_eta = log_likelihood_gradient(
eta,
masked_kspace,
sense,
mask,
sigma=sigma,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
coil_dim=self.coil_dim,
).contiguous()
if self.dimensionality == 3:
grad_eta = grad_eta.view([slices * batch, 4, grad_eta.shape[2], grad_eta.shape[3]]).permute(1, 0, 2, 3)
for h, convrnn in enumerate(self.layers):
hx[h] = convrnn(grad_eta, hx[h])
if self.dimensionality == 3:
hx[h] = hx[h].squeeze(0)
grad_eta = hx[h]
grad_eta = self.final_layer(grad_eta)
if self.dimensionality == 2:
grad_eta = grad_eta.permute(0, 2, 3, 1)
elif self.dimensionality == 3:
grad_eta = grad_eta.permute(1, 2, 3, 0)
for h in range(len(hx)):
hx[h] = hx[h].permute(1, 0, 2, 3)
eta = eta + grad_eta
etas.append(eta)
eta = etas
if self.no_dc:
return eta, None
soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight
current_kspace = [
masked_kspace
- soft_dc
- fft2(
complex_mul(e.unsqueeze(self.coil_dim), sense),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
for e in eta
]
return current_kspace, None