Source code for mridc.collections.reconstruction.parts.transforms
# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
from math import sqrt
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import (
is_none,
reshape_fortran,
rss,
sense,
to_tensor,
)
from mridc.collections.reconstruction.data.subsample import MaskFunc
from mridc.collections.reconstruction.parts.utils import apply_mask, center_crop, complex_center_crop
__all__ = ["MRIDataTransforms"]
[docs]class MRIDataTransforms:
"""MRI preprocessing data transforms."""
def __init__(
self,
apply_prewhitening: bool = False,
prewhitening_scale_factor: float = 1.0,
prewhitening_patch_start: int = 10,
prewhitening_patch_length: int = 30,
apply_gcc: bool = False,
gcc_virtual_coils: int = 10,
gcc_calib_lines: int = 24,
gcc_align_data: bool = True,
coil_combination_method: str = "SENSE",
dimensionality: int = 2,
mask_func: Optional[List[MaskFunc]] = None,
shift_mask: bool = False,
mask_center_scale: Optional[float] = 0.02,
half_scan_percentage: float = 0.0,
remask: bool = False,
crop_size: Optional[Tuple[int, int]] = None,
kspace_crop: bool = False,
crop_before_masking: bool = True,
kspace_zero_filling_size: Optional[Tuple] = None,
normalize_inputs: bool = False,
fft_centered: bool = True,
fft_normalization: str = "ortho",
max_norm: bool = True,
spatial_dims: Sequence[int] = None,
coil_dim: int = 0,
use_seed: bool = True,
):
"""
Initialize the data transform.
Parameters
----------
apply_prewhitening : bool
Whether to apply prewhitening.
prewhitening_scale_factor : float
The scale factor for the prewhitening.
prewhitening_patch_start : int
The start index for the prewhitening patch.
prewhitening_patch_length : int
The length of the prewhitening patch.
apply_gcc : bool
Whether to apply GCC.
gcc_virtual_coils : int
The number of virtual coils.
gcc_calib_lines : int
The number of calibration lines.
gcc_align_data : bool
Whether to align the data.
coil_combination_method : str
The coil combination method.
dimensionality : int
The dimensionality of the data.
mask_func : Optional[List[MaskFunc]]
The mask functions.
shift_mask : bool
Whether to shift the mask.
mask_center_scale : Optional[float]
The scale for the mask center.
half_scan_percentage : float
The percentage of the scan to use.
remask : bool
Whether to remask the data.
crop_size : Optional[Tuple[int, int]]
The crop size.
kspace_crop : bool
Whether to crop the kspace.
crop_before_masking : bool
Whether to crop before masking.
kspace_zero_filling_size : Optional[Tuple]
The zero filling size.
normalize_inputs : bool
Whether to normalize the inputs.
fft_centered : bool
Whether to center the FFT.
fft_normalization : str
The FFT normalization.
max_norm : bool
Whether to apply max norm.
spatial_dims : Sequence[int]
The spatial dimensions.
coil_dim : int
The coil dimension.
use_seed : bool
Whether to use a seed.
"""
self.coil_combination_method = coil_combination_method
self.dimensionality = dimensionality
self.mask_func = mask_func
self.shift_mask = shift_mask
self.mask_center_scale = mask_center_scale
self.half_scan_percentage = half_scan_percentage
self.remask = remask
self.crop_size = crop_size
self.kspace_crop = kspace_crop
self.crop_before_masking = crop_before_masking
self.kspace_zero_filling_size = kspace_zero_filling_size
self.normalize_inputs = normalize_inputs
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.max_norm = max_norm
self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1]
self.coil_dim = coil_dim - 1 if self.dimensionality == 2 else coil_dim
self.apply_prewhitening = apply_prewhitening
self.prewhitening = (
NoisePreWhitening(
patch_size=[
prewhitening_patch_start,
prewhitening_patch_length + prewhitening_patch_start,
prewhitening_patch_start,
prewhitening_patch_length + prewhitening_patch_start,
],
scale_factor=prewhitening_scale_factor,
)
if apply_prewhitening
else None
)
self.gcc = (
GeometricDecompositionCoilCompression(
virtual_coils=gcc_virtual_coils,
calib_lines=gcc_calib_lines,
align_data=gcc_align_data,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if apply_gcc
else None
)
self.use_seed = use_seed
[docs] def __call__(
self,
kspace: np.ndarray,
sensitivity_map: np.ndarray,
mask: np.ndarray,
eta: np.ndarray,
target: np.ndarray,
attrs: Dict,
fname: str,
slice_idx: int,
) -> Tuple[
torch.Tensor,
Union[Union[List, torch.Tensor], torch.Tensor],
Union[Optional[torch.Tensor], Any],
Union[List, Any],
Union[Optional[torch.Tensor], Any],
Union[torch.Tensor, Any],
str,
int,
Union[List, Any],
]:
"""
Apply the data transform.
Parameters
----------
kspace: The kspace.
sensitivity_map: The sensitivity map.
mask: The mask.
eta: The initial estimation.
target: The target.
attrs: The attributes.
fname: The file name.
slice_idx: The slice number.
Returns
-------
The transformed data.
"""
kspace = to_tensor(kspace)
# This condition is necessary in case of auto estimation of sense maps.
if sensitivity_map is not None and sensitivity_map.size != 0:
sensitivity_map = to_tensor(sensitivity_map)
if self.apply_prewhitening:
kspace = self.prewhitening(kspace) # type: ignore
if self.gcc is not None:
kspace = self.gcc(kspace)
if isinstance(sensitivity_map, torch.Tensor):
sensitivity_map = ifft2(
self.gcc(
fft2(
sensitivity_map,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
# Apply zero-filling on kspace
if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ("", "None"):
padding_top = np.floor_divide(abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]), 2)
padding_bottom = padding_top
padding_left = np.floor_divide(abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]), 2)
padding_right = padding_left
kspace = torch.view_as_complex(kspace)
kspace = torch.nn.functional.pad(
kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0
)
kspace = torch.view_as_real(kspace)
sensitivity_map = fft2(
sensitivity_map,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
sensitivity_map = torch.view_as_complex(sensitivity_map)
sensitivity_map = torch.nn.functional.pad(
sensitivity_map,
pad=(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
sensitivity_map = torch.view_as_real(sensitivity_map)
sensitivity_map = ifft2(
sensitivity_map,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
# Initial estimation
eta = to_tensor(eta) if eta is not None and eta.size != 0 else torch.tensor([])
# If the target is not given, we need to compute it.
if self.coil_combination_method.upper() == "RSS":
target = rss(
ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
dim=self.coil_dim,
)
elif self.coil_combination_method.upper() == "SENSE":
if sensitivity_map is not None and sensitivity_map.size != 0:
target = sense(
ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
sensitivity_map,
dim=self.coil_dim,
)
elif target is not None and target.size != 0:
target = to_tensor(target)
elif "target" in attrs or "target_rss" in attrs:
target = torch.tensor(attrs["target"])
else:
raise ValueError("No target found")
target = torch.view_as_complex(target)
target = torch.abs(target / torch.max(torch.abs(target)))
seed = tuple(map(ord, fname)) if self.use_seed else None
acq_start = attrs["padding_left"] if "padding_left" in attrs else 0
acq_end = attrs["padding_right"] if "padding_left" in attrs else 0
# This should be outside the condition because it needs to be returned in the end, even if cropping is off.
# crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]])
crop_size = target.shape
if self.crop_size is not None and self.crop_size not in ("", "None"):
# Check for smallest size against the target shape.
h = min(int(self.crop_size[0]), target.shape[0])
w = min(int(self.crop_size[1]), target.shape[1])
# Check for smallest size against the stored recon shape in metadata.
if crop_size[0] != 0:
h = h if h <= crop_size[0] else crop_size[0]
if crop_size[1] != 0:
w = w if w <= crop_size[1] else crop_size[1]
self.crop_size = (int(h), int(w))
target = center_crop(target, self.crop_size)
if sensitivity_map is not None and sensitivity_map.size != 0:
sensitivity_map = (
ifft2(
complex_center_crop(
fft2(
sensitivity_map,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
self.crop_size,
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.kspace_crop
else complex_center_crop(sensitivity_map, self.crop_size)
)
if eta is not None and eta.ndim > 2:
eta = (
ifft2(
complex_center_crop(
fft2(
eta,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
self.crop_size,
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.kspace_crop
else complex_center_crop(eta, self.crop_size)
)
# Cropping before masking will maintain the shape of original kspace intact for masking.
if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking:
kspace = (
complex_center_crop(kspace, self.crop_size)
if self.kspace_crop
else fft2(
complex_center_crop(
ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
self.crop_size,
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
)
if not is_none(mask): # and not is_none(self.mask_func):
for _mask in mask:
if list(_mask.shape) == [kspace.shape[-3], kspace.shape[-2]]:
mask = torch.from_numpy(_mask).unsqueeze(0).unsqueeze(-1)
break
padding = (acq_start, acq_end)
if (not is_none(padding[0]) and not is_none(padding[1])) and padding[0] != 0:
mask[:, :, : padding[0]] = 0
mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1)
if self.shift_mask:
mask = torch.fft.fftshift(mask, dim=(self.spatial_dims[0] - 1, self.spatial_dims[1] - 1))
if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking:
mask = complex_center_crop(mask, self.crop_size)
masked_kspace = kspace * mask + 0.0 # the + 0.0 removes the sign of the zeros
acc = 1
elif is_none(self.mask_func):
masked_kspace = kspace.clone()
acc = torch.tensor([1])
if mask is None:
mask = torch.ones(masked_kspace.shape[-3], masked_kspace.shape[-2]).type(torch.float32)
else:
mask = torch.from_numpy(mask)
if mask.dim() == 1:
mask = mask.unsqueeze(0)
if mask.shape[0] == masked_kspace.shape[2]: # type: ignore
mask = mask.permute(1, 0)
elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore
mask = torch.ones(
[masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore
)
if mask.shape[-2] == 1: # 1D mask
mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1)
else: # 2D mask
# Crop loaded mask.
if self.crop_size is not None and self.crop_size not in ("", "None"):
mask = center_crop(mask, self.crop_size)
mask = mask.unsqueeze(0).unsqueeze(-1)
if self.shift_mask:
mask = torch.fft.fftshift(mask, dim=(1, 2))
masked_kspace = masked_kspace * mask
mask = mask.byte()
elif isinstance(self.mask_func, list):
masked_kspaces = []
masks = []
accs = []
for m in self.mask_func:
if self.dimensionality == 2:
_masked_kspace, _mask, _acc = apply_mask(
kspace,
m,
seed,
(acq_start, acq_end),
shift=self.shift_mask,
half_scan_percentage=self.half_scan_percentage,
center_scale=self.mask_center_scale,
)
elif self.dimensionality == 3:
_masked_kspace = []
_mask = None
for i in range(kspace.shape[0]):
_i_masked_kspace, _i_mask, _i_acc = apply_mask(
kspace[i],
m,
seed,
(acq_start, acq_end),
shift=self.shift_mask,
half_scan_percentage=self.half_scan_percentage,
center_scale=self.mask_center_scale,
existing_mask=_mask,
)
if self.remask:
_mask = _i_mask
if i == 0:
_acc = _i_acc
_masked_kspace.append(_i_masked_kspace)
_masked_kspace = torch.stack(_masked_kspace, dim=0)
_mask = _i_mask.unsqueeze(0)
else:
raise ValueError(f"Unsupported data dimensionality {self.dimensionality}D.")
masked_kspaces.append(_masked_kspace)
masks.append(_mask.byte())
accs.append(_acc)
masked_kspace = masked_kspaces
mask = masks
acc = accs # type: ignore
else:
masked_kspace, mask, acc = apply_mask(
kspace,
self.mask_func[0], # type: ignore
seed,
(acq_start, acq_end),
shift=self.shift_mask,
half_scan_percentage=self.half_scan_percentage,
center_scale=self.mask_center_scale,
)
mask = mask.byte()
# Cropping after masking.
if self.crop_size is not None and self.crop_size not in ("", "None") and not self.crop_before_masking:
kspace = (
complex_center_crop(kspace, self.crop_size)
if self.kspace_crop
else fft2(
complex_center_crop(
ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
self.crop_size,
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
)
masked_kspace = (
complex_center_crop(masked_kspace, self.crop_size)
if self.kspace_crop
else fft2(
complex_center_crop(
ifft2(
masked_kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
),
self.crop_size,
),
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
)
mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1)
# Normalize by the max value.
if self.normalize_inputs:
if isinstance(self.mask_func, list):
if self.fft_normalization in ("backward", "ortho", "forward"):
imspace = ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.max_norm:
imspace = imspace / torch.max(torch.abs(imspace))
kspace = fft2(
imspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif self.fft_normalization in ("none", None) and self.max_norm:
imspace = torch.fft.ifftn(torch.view_as_complex(kspace), dim=list(self.spatial_dims), norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None))
masked_kspaces = []
for y in masked_kspace:
if self.fft_normalization in ("backward", "ortho", "forward"):
imspace = ifft2(
y,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.max_norm:
imspace = imspace / torch.max(torch.abs(imspace))
y = fft2(
imspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif self.fft_normalization in ("none", None) and self.max_norm:
imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=list(self.spatial_dims), norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
y = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None))
masked_kspaces.append(y)
masked_kspace = masked_kspaces
elif self.fft_normalization in ("backward", "ortho", "forward"):
imspace = ifft2(
kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.max_norm:
imspace = imspace / torch.max(torch.abs(imspace))
kspace = fft2(
imspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
imspace = ifft2(
masked_kspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.max_norm:
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspace = fft2(
imspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif self.fft_normalization in ("none", None) and self.max_norm:
imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=list(self.spatial_dims), norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None))
imspace = torch.fft.ifftn(torch.view_as_complex(kspace), dim=list(self.spatial_dims), norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None))
if self.max_norm:
if sensitivity_map.size != 0:
sensitivity_map = sensitivity_map / torch.max(torch.abs(sensitivity_map))
if eta.size != 0 and eta.ndim > 2:
eta = eta / torch.max(torch.abs(eta))
target = target / torch.max(torch.abs(target))
return kspace, masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc
class NoisePreWhitening:
"""Apply noise pre-whitening / coil decorrelation."""
def __init__(
self,
patch_size: List[int],
scale_factor: float = 1.0,
):
"""
Parameters
----------
patch_size : list of ints
Define patch size to calculate psi.
x_start, x_end, y_start, y_end
scale_factor : float
Applied on the noise covariance matrix. Used to adjust for effective noise bandwidth and difference in
sampling rate between noise calibration and actual measurement.
scale_factor = (T_acq_dwell/T_noise_dwell)*NoiseReceiverBandwidthRatio
"""
super().__init__()
self.patch_size = patch_size
self.scale_factor = scale_factor
def __call__(self, data):
if not self.patch_size:
raise ValueError("Patch size must be defined for noise prewhitening.")
if data.shape[-1] != 2:
data = torch.view_as_real(data)
noise = data[:, self.patch_size[0] : self.patch_size[1], self.patch_size[-2] : self.patch_size[-1]]
noise_int = torch.reshape(noise, (noise.shape[0], int(torch.numel(noise) / noise.shape[0])))
deformation_matrix = (1 / (float(noise_int.shape[1]) - 1)) * torch.mm(noise_int, torch.conj(noise_int).t())
psi = torch.linalg.inv(torch.linalg.cholesky(deformation_matrix)) * sqrt(2) * sqrt(self.scale_factor)
return torch.reshape(
torch.mm(psi, torch.reshape(data, (data.shape[0], int(torch.numel(data) / data.shape[0])))), data.shape
)
class GeometricDecompositionCoilCompression:
"""
Geometric Decomposition Coil Compression Based on:
Zhang, T., Pauly, J. M., Vasanawala, S. S., & Lustig, M. (2013). Coil compression for accelerated imaging with
Cartesian sampling. Magnetic Resonance in Medicine, 69(2), 571–582. https://doi.org/10.1002/mrm.24267
"""
def __init__(
self,
virtual_coils: int = None,
calib_lines: int = None,
align_data: bool = True,
fft_centered: bool = True,
fft_normalization: str = "ortho",
spatial_dims: Sequence[int] = None,
):
"""
Parameters
----------
virtual_coils : number of final-"virtual" coils
calib_lines : calibration lines to sample data points
align_data : align data to the first calibration line
fft_centered: Whether to center the fft.
fft_normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None.
spatial_dims: dimensions to apply the FFT
"""
super().__init__()
self.virtual_coils = virtual_coils
self.calib_lines = calib_lines
self.align_data = align_data
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(self, data):
if not self.virtual_coils:
raise ValueError("Number of virtual coils must be defined for geometric decomposition coil compression.")
self.data = data
if self.data.shape[-1] == 2:
self.data = torch.view_as_complex(self.data)
curr_num_coils = self.data.shape[0]
if curr_num_coils < self.virtual_coils:
raise ValueError(
f"Tried to compress from {curr_num_coils} to {self.virtual_coils} coils, please select less coils for."
)
# TODO: think about if needed to handle slices here or not
# if "slice" in self.data.names or self.data.dim() <= 3:
# raise ValueError(f"Currently supported only 2D data, you have given as input {self.data.dim()}D data.")
self.data = self.data.permute(1, 2, 0)
self.init_data: torch.Tensor = self.data
self.fft_dim = [0, 1]
_, self.width, self.coils = self.data.shape
# TODO: figure out why this is happening for singlecoil data
# For singlecoil data, use no calibration lines equal to the no of coils.
if self.virtual_coils == 1:
self.calib_lines = self.data.shape[-1]
self.crop()
self.calculate_gcc()
if self.align_data:
self.align_compressed_coils()
rotated_compressed_data = self.rotate_and_compress(data_to_cc=self.aligned_data)
else:
rotated_compressed_data = self.rotate_and_compress(data_to_cc=self.unaligned_data)
rotated_compressed_data = torch.flip(rotated_compressed_data, dims=[1])
return fft2(
torch.view_as_real(rotated_compressed_data.permute(2, 0, 1)),
self.fft_centered,
self.fft_normalization,
self.spatial_dims,
)
def crop(
self,
):
"""Crop to the size of the calibration lines."""
s = torch.as_tensor([self.calib_lines, self.width, self.coils])
idx = [
torch.arange(
abs(int(self.data.shape[n] // 2 + torch.ceil(-s[n] / 2))),
abs(int(self.data.shape[n] // 2 + torch.ceil(s[n] / 2) + 1)),
)
for n in range(len(s))
]
self.data = (
self.data[idx[0][0] : idx[0][-1], idx[1][0] : idx[1][-1], idx[2][0] : idx[2][-1]]
.unsqueeze(-2)
.permute(1, 0, 2, 3)
)
def calculate_gcc(self):
"""Calculates Geometric Coil-Compression."""
ws = (self.virtual_coils // 2) * 2 + 1
Nx, Ny, Nz, Nc = self.data.shape
im = torch.view_as_complex(
ifft2(torch.view_as_real(self.data), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
s = torch.as_tensor([Nx + ws - 1, Ny, Nz, Nc])
idx = [
torch.arange(
abs(int(im.shape[n] // 2 + torch.ceil((-s[n] / 2).clone().detach()))),
abs(int(im.shape[n] // 2 + torch.ceil((s[n] / 2).clone().detach())) + 1),
)
for n in range(len(s))
]
zpim = torch.zeros((Nx + ws - 1, Ny, Nz, Nc)).type(im.dtype)
zpim[idx[0][0] : idx[0][-1], idx[1][0] : idx[1][-1], idx[2][0] : idx[2][-1], idx[3][0] : idx[3][-1]] = im
self.unaligned_data = torch.zeros((Nc, min(Nc, ws * Ny * Nz), Nx)).type(im.dtype)
for n in range(Nx):
tmpc = reshape_fortran(zpim[n : n + ws, :, :, :], (ws * Ny * Nz, Nc))
_, _, v = torch.svd(tmpc, some=False)
self.unaligned_data[:, :, n] = v
self.unaligned_data = self.unaligned_data[:, : self.virtual_coils, :]
def align_compressed_coils(self):
"""Virtual Coil Alignment."""
self.aligned_data = self.unaligned_data
_, sy, nc = self.aligned_data.shape
ncc = sy
n0 = nc // 2
A00 = self.aligned_data[:, :ncc, n0 - 1]
A0 = A00
for n in range(n0, 0, -1):
A1 = self.aligned_data[:, :ncc, n - 1]
C = torch.conj(A1).T @ A0
u, _, v = torch.svd(C, some=False)
P = v @ torch.conj(u).T
self.aligned_data[:, :ncc, n - 1] = A1 @ torch.conj(P).T
A0 = self.aligned_data[:, :ncc, n - 1]
A0 = A00
for n in range(n0, nc):
A1 = self.aligned_data[:, :ncc, n]
C = torch.conj(A1).T @ A0
u, _, v = torch.svd(C, some=False)
P = v @ torch.conj(u).T
self.aligned_data[:, :ncc, n] = A1 @ torch.conj(P).T
A0 = self.aligned_data[:, :ncc, n]
def rotate_and_compress(self, data_to_cc):
"""Uses compression matrices to project the data onto them -> rotate to the compressed space."""
_data = self.init_data.permute(1, 0, 2).unsqueeze(-2)
_ncc = data_to_cc.shape[1]
data_to_cc = data_to_cc.to(_data.device)
Nx, Ny, Nz, Nc = _data.shape
im = torch.view_as_complex(
ifft2(torch.view_as_real(_data), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
ccdata = torch.zeros((Nx, Ny, Nz, _ncc)).type(_data.dtype).to(_data.device)
for n in range(Nx):
tmpc = im[n, :, :, :].squeeze().reshape(Ny * Nz, Nc)
ccdata[n, :, :, :] = (tmpc @ data_to_cc[:, :, n]).reshape(Ny, Nz, _ncc).unsqueeze(0)
ccdata = (
torch.view_as_complex(
fft2(torch.view_as_real(ccdata), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
.permute(1, 0, 2, 3)
.squeeze()
)
# Singlecoil
if ccdata.dim() == 2:
ccdata = ccdata.unsqueeze(-1)
gcc = torch.zeros(ccdata.shape).type(ccdata.dtype)
for n in range(ccdata.shape[-1]):
gcc[:, :, n] = torch.view_as_complex(
ifft2(
torch.view_as_real(ccdata[:, :, n]), self.fft_centered, self.fft_normalization, self.spatial_dims
)
)
return gcc