# coding=utf-8
__author__ = "Dimitrios Karkalousos, Chaoping Zhang"
from typing import List, Sequence, Union
import torch
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import coil_combination, complex_mul
[docs]class RescaleByMax(object):
def __init__(self, slack=1e-6):
self.slack = slack
[docs] def forward(self, data):
"""Apply scaling."""
gamma = torch.max(torch.max(torch.abs(data), 3, keepdim=True)[0], 2, keepdim=True)[0] + self.slack
data = data / gamma
return data, gamma
[docs] @staticmethod
def reverse(data, gamma):
"""Reverse scaling."""
return torch.stack([data[i] * gamma[i] for i in range(data.shape[0])], 0)
[docs]class SignalForwardModel(object):
"""Defines a signal forward model"""
def __init__(self, sequence: Union[str, None] = None):
super(SignalForwardModel, self).__init__()
self.sequence = sequence.lower() if isinstance(sequence, str) else None
self.scaling = 1e-3
[docs] def __call__(
self,
R2star_map: torch.Tensor,
S0_map: torch.Tensor,
B0_map: torch.Tensor,
phi_map: torch.Tensor,
TEs=None,
):
"""
Defines forward model based on sequence.
Parameters
----------
R2star_map: R2* map.
torch.Tensor, shape [batch_size, n_x, n_y]
S0_map: S0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
B0_map: B0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
phi_map: phi map.
torch.Tensor, shape [batch_size, n_x, n_y]
TEs: List of echo times.
List of float, shape [n_echoes]
"""
if TEs is None:
TEs = torch.Tensor([3.0, 11.5, 20.0, 28.5])
if self.sequence == "megre":
return self.MEGRESignalModel(R2star_map, S0_map, B0_map, phi_map, TEs)
elif self.sequence == "megre_no_phase":
return self.MEGRENoPhaseSignalModel(R2star_map, S0_map, TEs)
else:
raise ValueError(
"Only MEGRE and MEGRE no phase are supported are signal forward model at the moment. "
f"Found {self.sequence}"
)
[docs] def MEGRESignalModel(
self,
R2star_map: torch.Tensor,
S0_map: torch.Tensor,
B0_map: torch.Tensor,
phi_map: torch.Tensor,
TEs: List,
):
"""
MEGRE forward model.
Parameters
----------
R2star_map: R2* map.
torch.Tensor, shape [batch_size, n_x, n_y]
S0_map: S0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
B0_map: B0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
phi_map: phi map.
torch.Tensor, shape [batch_size, n_x, n_y]
TEs: List of echo times.
List of float, shape [n_echoes]
"""
S0_map_real = S0_map
S0_map_imag = phi_map
first_term = lambda i: torch.exp(-TEs[i] * self.scaling * R2star_map)
second_term = lambda i: torch.cos(B0_map * self.scaling * -TEs[i])
third_term = lambda i: torch.sin(B0_map * self.scaling * -TEs[i])
pred = torch.stack(
[
torch.stack(
(
S0_map_real * first_term(i) * second_term(i) - S0_map_imag * first_term(i) * third_term(i),
S0_map_real * first_term(i) * third_term(i) + S0_map_imag * first_term(i) * second_term(i),
),
-1,
)
for i in range(len(TEs))
],
1,
)
pred[pred != pred] = 0.0
return torch.view_as_real(pred[..., 0] + 1j * pred[..., 1])
[docs] def MEGRENoPhaseSignalModel(
self,
R2star_map: torch.Tensor,
S0_map: torch.Tensor,
TEs: List,
):
"""
MEGRE no phase forward model.
Parameters
----------
R2star_map: R2* map.
torch.Tensor, shape [batch_size, n_x, n_y]
S0_map: S0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
TEs: List of echo times.
List of float, shape [n_echoes]
"""
pred = torch.stack(
[
torch.stack(
(
S0_map * torch.exp(-TEs[i] * self.scaling * R2star_map),
S0_map * torch.exp(-TEs[i] * self.scaling * R2star_map),
),
-1,
)
for i in range(len(TEs))
],
1,
)
pred[pred != pred] = 0.0
return torch.view_as_real(pred[..., 0] + 1j * pred[..., 1])
[docs]def expand_op(x, sensitivity_maps):
"""Expand a coil-combined image to multicoil."""
x = complex_mul(x, sensitivity_maps)
if torch.isnan(x).any():
x[x != x] = 0
return x
[docs]def analytical_log_likelihood_gradient(
linear_forward_model: SignalForwardModel,
R2star_map: torch.Tensor,
S0_map: torch.Tensor,
B0_map: torch.Tensor,
phi_map: torch.Tensor,
TEs: List,
sensitivity_maps: torch.Tensor,
masked_kspace: torch.Tensor,
sampling_mask: torch.Tensor,
fft_centered: bool,
fft_normalization: str,
spatial_dims: Sequence[int],
coil_dim: int,
coil_combination_method: str = "SENSE",
scaling: float = 1e-3,
) -> torch.Tensor:
"""
Computes the analytical gradient of the log-likelihood function.
Parameters
----------
linear_forward_model: SignalForwardModel
Signal forward model to use.
R2star_map: R2* map.
torch.Tensor, shape [batch_size, n_x, n_y]
S0_map: S0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
B0_map: B0 map.
torch.Tensor, shape [batch_size, n_x, n_y]
phi_map: phi map.
torch.Tensor, shape [batch_size, n_x, n_y]
TEs: List of echo times.
List of float, shape [n_echoes]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
masked_kspace: Data.
torch.Tensor, shape [batch_size, n_echoes, n_coils, n_x, n_y, 2]
sampling_mask: Mask of the sampling.
torch.Tensor, shape [batch_size, 1, n_x, n_y, 1]
fft_centered: If True, the FFT is centered.
bool
fft_normalization: Normalization of the FFT.
str, one of "ortho", "forward", "backward", None
spatial_dims: Spatial dimensions of the input.
Sequence of int, shape [n_dims]
coil_dim: Coils dimension of the input.
int
coil_combination_method: Method to use for coil combination.
str, one of "SENSE", "RSS"
scaling: Scaling factor.
float
Returns
-------
Analytical gradient of the log-likelihood function.
"""
nr_TEs = len(TEs)
R2star_map = R2star_map.unsqueeze(0)
S0_map = S0_map.unsqueeze(0)
B0_map = B0_map.unsqueeze(0)
phi_map = phi_map.unsqueeze(0)
pred = linear_forward_model(R2star_map, S0_map, B0_map, phi_map, TEs)
S0_map_real = S0_map
S0_map_imag = phi_map
pred_kspace = fft2(
expand_op(pred.unsqueeze(coil_dim), sensitivity_maps.unsqueeze(0).unsqueeze(coil_dim - 1)),
fft_centered,
fft_normalization,
spatial_dims,
)
diff_data = (pred_kspace - masked_kspace) * sampling_mask
diff_data_inverse = coil_combination(
ifft2(diff_data, fft_centered, fft_normalization, spatial_dims),
sensitivity_maps.unsqueeze(0).unsqueeze(coil_dim - 1),
method=coil_combination_method,
dim=coil_dim,
)
first_term = lambda i: torch.exp(-TEs[i] * scaling * R2star_map)
second_term = lambda i: torch.cos(B0_map * scaling * -TEs[i])
third_term = lambda i: torch.sin(B0_map * scaling * -TEs[i])
S0_part_der = torch.stack(
[torch.stack((first_term(i) * second_term(i), -first_term(i) * third_term(i)), -1) for i in range(nr_TEs)], 1
)
R2str_part_der = torch.stack(
[
torch.stack(
(
-TEs[i] * scaling * first_term(i) * (S0_map_real * second_term(i) - S0_map_imag * third_term(i)),
-TEs[i] * scaling * first_term(i) * (-S0_map_real * third_term(i) - S0_map_imag * second_term(i)),
),
-1,
)
for i in range(nr_TEs)
],
1,
)
S0_map_real_grad = (
diff_data_inverse[..., 0] * S0_part_der[..., 0] - diff_data_inverse[..., 1] * S0_part_der[..., 1]
)
S0_map_imag_grad = (
diff_data_inverse[..., 0] * S0_part_der[..., 1] + diff_data_inverse[..., 1] * S0_part_der[..., 0]
)
R2star_map_real_grad = (
diff_data_inverse[..., 0] * R2str_part_der[..., 0] - diff_data_inverse[..., 1] * R2str_part_der[..., 1]
)
R2star_map_imag_grad = (
diff_data_inverse[..., 0] * R2str_part_der[..., 1] + diff_data_inverse[..., 1] * R2str_part_der[..., 0]
)
S0_map_grad = torch.stack([S0_map_real_grad, S0_map_imag_grad], -1).squeeze()
S0_map_grad = torch.mean(S0_map_grad, 0)
R2star_map_grad = torch.stack([R2star_map_real_grad, R2star_map_imag_grad], -1).squeeze()
R2star_map_grad = torch.mean(R2star_map_grad, 0)
return torch.stack([R2star_map_grad[..., 0], S0_map_grad[..., 0], R2star_map_grad[..., 1], S0_map_grad[..., 1]], 0)