# coding=utf-8
__author__ = "Dimitrios Karkalousos, Chaoping Zhang"
from typing import List, Optional, Tuple
import torch
from matplotlib import pyplot as plt
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import coil_combination, complex_conj, complex_mul
from mridc.collections.quantitative.models.qrim.utils import SignalForwardModel
[docs]class qVarNetBlock(torch.nn.Module):
"""
Implementation of the quantitative End-to-end Variational Network (qVN), as presented in Zhang, C. et al.
References
----------
..
Zhang, C. et al. (2022) ‘A unified model for reconstruction and R2 mapping of accelerated 7T data using \
quantitative Recurrent Inference Machine’. In review.
"""
def __init__(
self,
model: torch.nn.Module,
fft_centered: bool = True,
fft_normalization: str = "ortho",
spatial_dims: Optional[Tuple[int, int]] = None,
coil_dim: int = 1,
no_dc: bool = False,
linear_forward_model=None,
):
"""
Initialize the model block.
Parameters
----------
model: Model to apply soft data consistency.
fft_centered: Whether to center the fft.
fft_normalization: The normalization of the fft.
spatial_dims: The spatial dimensions of the data.
coil_dim: The dimension of the coil dimension.
no_dc: Whether to remove the DC component.
"""
super().__init__()
self.linear_forward_model = (
SignalForwardModel(sequence="MEGRE") if linear_forward_model is None else linear_forward_model
)
self.model = model
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1]
self.coil_dim = coil_dim
self.no_dc = no_dc
self.dc_weight = torch.nn.Parameter(torch.ones(1))
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
"""
Expand the sensitivity maps to the same size as the input.
Parameters
----------
x: Input data.
sens_maps: Coil Sensitivity maps.
Returns
-------
SENSE reconstruction expanded to the same size as the input sens_maps.
"""
return fft2(
complex_mul(x, sens_maps),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
"""
Reduce the sensitivity maps.
Parameters
----------
x: Input data.
sens_maps: Coil Sensitivity maps.
Returns
-------
SENSE coil-combined reconstruction.
"""
x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims)
return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim)
[docs] def forward(
self,
prediction: torch.Tensor,
masked_kspace: torch.Tensor,
R2star_map_init: torch.Tensor,
S0_map_init: torch.Tensor,
B0_map_init: torch.Tensor,
phi_map_init: torch.Tensor,
TEs: List,
sensitivity_maps: torch.Tensor,
sampling_mask: torch.Tensor,
gamma: torch.Tensor = None,
) -> torch.Tensor:
"""
Parameters
----------
prediction: Initial prediction of the subsampled k-space.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
masked_kspace: Data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
R2star_map_init: Initial R2* map.
torch.Tensor, shape [batch_size, n_echoes, n_coils, n_x, n_y]
S0_map_init: Initial S0 map.
torch.Tensor, shape [batch_size, n_echoes, n_coils, n_x, n_y]
B0_map_init: Initial B0 map.
torch.Tensor, shape [batch_size, n_echoes, n_coils, n_x, n_y]
phi_map_init: Initial phi map.
torch.Tensor, shape [batch_size, n_echoes, n_coils, n_x, n_y]
TEs: List of echo times.
List of int, shape [batch_size, n_echoes]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sampling_mask: Mask of the sampling.
torch.Tensor, shape [batch_size, 1, n_x, n_y, 2]
gamma: Scaling normalization factor.
torch.Tensor, shape [batch_size, 1, 1, 1, 1]
Returns
-------
Reconstructed image.
"""
init_eta = torch.stack([R2star_map_init, S0_map_init, B0_map_init, phi_map_init], dim=1)
R2star_map_init = (R2star_map_init * gamma[0]).unsqueeze(0) # type: ignore
S0_map_init = (S0_map_init * gamma[1]).unsqueeze(0) # type: ignore
B0_map_init = (B0_map_init * gamma[2]).unsqueeze(0) # type: ignore
phi_map_init = (phi_map_init * gamma[3]).unsqueeze(0) # type: ignore
init_pred = self.linear_forward_model(R2star_map_init, S0_map_init, B0_map_init, phi_map_init, TEs)
pred_kspace = self.sens_expand(init_pred, sensitivity_maps.unsqueeze(self.coil_dim - 1))
soft_dc = (pred_kspace - masked_kspace) * sampling_mask * self.dc_weight
init_pred = self.sens_reduce(soft_dc, sensitivity_maps.unsqueeze(self.coil_dim - 1)).to(masked_kspace)
eta = torch.view_as_real(init_eta + torch.view_as_complex(self.model(init_pred.to(masked_kspace))))
eta_tmp = eta[:, 0, ...]
eta_tmp[eta_tmp < 0] = 0
eta[:, 0, ...] = eta_tmp
return eta