Source code for mridc.collections.quantitative.parts.transforms

# encoding: utf-8
__author__ = "Dimitrios Karkalousos, Chaoping Zhang"

import math
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from skimage.restoration import unwrap_phase
from torch import Tensor
from torch.nn import functional as F

from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.common.parts.utils import (
    is_none,
    rss,
    sense,
    to_tensor,
)
from mridc.collections.reconstruction.data.subsample import MaskFunc
from mridc.collections.reconstruction.parts.transforms import (
    GeometricDecompositionCoilCompression,
    NoisePreWhitening,
)
from mridc.collections.reconstruction.parts.utils import apply_mask, center_crop, complex_center_crop

__all__ = ["qMRIDataTransforms"]


[docs]class qMRIDataTransforms: """qMRI preprocessing data transforms.""" def __init__( self, TEs: Optional[List[float]], precompute_quantitative_maps: bool = True, apply_prewhitening: bool = False, prewhitening_scale_factor: float = 1.0, prewhitening_patch_start: int = 10, prewhitening_patch_length: int = 30, apply_gcc: bool = False, gcc_virtual_coils: int = 10, gcc_calib_lines: int = 24, gcc_align_data: bool = True, coil_combination_method: str = "SENSE", dimensionality: int = 2, mask_func: Optional[List[MaskFunc]] = None, shift_mask: bool = False, mask_center_scale: Optional[float] = 0.02, half_scan_percentage: float = 0.0, remask: bool = False, crop_size: Optional[Tuple[int, int]] = None, kspace_crop: bool = False, crop_before_masking: bool = True, kspace_zero_filling_size: Optional[Tuple] = None, normalize_inputs: bool = False, fft_centered: bool = True, fft_normalization: str = "ortho", max_norm: bool = True, spatial_dims: Sequence[int] = None, coil_dim: int = 0, shift_B0_input: bool = False, use_seed: bool = True, ): """ Initialize the data transform. Parameters ---------- TEs: Echo times. List[float] precompute_quantitative_maps: Precompute quantitative maps. bool apply_prewhitening: Apply prewhitening. bool prewhitening_scale_factor: Prewhitening scale factor. float prewhitening_patch_start: Prewhitening patch start. int prewhitening_patch_length: Prewhitening patch length. int apply_gcc: Apply Geometric Decomposition Coil Compression. bool gcc_virtual_coils: GCC virtual coils. int gcc_calib_lines: GCC calibration lines. int gcc_align_data: GCC align data. bool coil_combination_method: Coil combination method. Default: SENSE. str dimensionality: Dimensionality. int mask_func: Mask function. List[MaskFunc] shift_mask: Shift mask. bool mask_center_scale: Mask center scale. float half_scan_percentage: Half scan percentage. float remask: Use the same mask. Default: False. bool crop_size: Crop size. Tuple[int, int] kspace_crop: K-space crop. bool crop_before_masking: Crop before masking. bool kspace_zero_filling_size: K-space zero filling size. Tuple normalize_inputs: Normalize inputs. bool fft_centered: FFT centered. bool fft_normalization: FFT normalization. str max_norm: Normalization by the maximum value. bool spatial_dims: Spatial dimensions. Sequence[int] coil_dim: Coil dimension. int shift_B0_input: Shift B0 input. bool use_seed: Use seed. bool """ self.TEs = TEs if self.TEs is None: raise ValueError("Please specify echo times (TEs).") self.precompute_quantitative_maps = precompute_quantitative_maps self.coil_combination_method = coil_combination_method self.dimensionality = dimensionality self.mask_func = mask_func self.shift_mask = shift_mask self.mask_center_scale = mask_center_scale self.half_scan_percentage = half_scan_percentage self.remask = remask self.crop_size = crop_size self.kspace_crop = kspace_crop self.crop_before_masking = crop_before_masking self.kspace_zero_filling_size = kspace_zero_filling_size self.normalize_inputs = normalize_inputs self.fft_centered = fft_centered self.fft_normalization = fft_normalization self.max_norm = max_norm self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1] self.coil_dim = coil_dim - 1 if self.dimensionality == 2 else coil_dim self.shift_B0_input = shift_B0_input self.apply_prewhitening = apply_prewhitening self.prewhitening = ( NoisePreWhitening( patch_size=[ prewhitening_patch_start, prewhitening_patch_length + prewhitening_patch_start, prewhitening_patch_start, prewhitening_patch_length + prewhitening_patch_start, ], scale_factor=prewhitening_scale_factor, ) if apply_prewhitening else None ) self.gcc = ( GeometricDecompositionCoilCompression( virtual_coils=gcc_virtual_coils, calib_lines=gcc_calib_lines, align_data=gcc_align_data, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if apply_gcc else None ) self.use_seed = use_seed
[docs] def __call__( self, kspace: np.ndarray, sensitivity_map: np.ndarray, qmaps: np.ndarray, mask: np.ndarray, eta: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_idx: int, ) -> Tuple[ Union[Tensor, List[Any], List[Tensor]], Union[Tensor, Any], Union[Tensor, List[Any], List[Tensor]], Union[Tensor, Any], Union[Tensor, List[Any], List[Tensor]], Union[Tensor, Any], Union[Tensor, List[Any], List[Tensor]], Union[Tensor, Any], Tensor, Tensor, Union[Union[Tensor, List[Union[Union[float, Tensor], Any]], float], Any], Union[Optional[Tensor], Any], Union[Union[List[Union[Tensor, Any]], Tensor, List[Tensor]], Any], Union[Tensor, Any], Union[Optional[Tensor], Any], Union[Tensor, Any], str, int, Union[List[int], int, Tensor], ]: """ Apply the data transform. Parameters ---------- kspace: The kspace. sensitivity_map: The sensitivity map. qmaps: The quantitative maps. mask: List, sampling mask if exists and brain mask and head mask. eta: The initial estimation. target: The target. attrs: The attributes. fname: The file name. slice_idx: The slice number. Returns ------- The transformed data. """ kspace = to_tensor(kspace) # This condition is necessary in case of auto estimation of sense maps. if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = to_tensor(sensitivity_map) mask_head = mask[2] mask_brain = mask[1] mask = mask[0] if mask_brain.ndim != 0: mask_brain = torch.from_numpy(mask_brain) if mask_head.ndim != 0: mask_head = torch.from_numpy(mask_head) if isinstance(mask, list): mask = [torch.from_numpy(m) for m in mask] elif mask.ndim != 0: mask = torch.from_numpy(mask) if self.apply_prewhitening: kspace = torch.stack( [self.prewhitening(kspace[echo]) for echo in range(kspace.shape[0])], dim=0 # type: ignore ) if self.gcc is not None: kspace = torch.stack([self.gcc(kspace[echo]) for echo in range(kspace.shape[0])], dim=0) if isinstance(sensitivity_map, torch.Tensor): sensitivity_map = ifft2( self.gcc( fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # Apply zero-filling on kspace if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ("", "None"): padding_top = np.floor_divide(abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[2]), 2) padding_bottom = padding_top padding_left = np.floor_divide(abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[3]), 2) padding_right = padding_left kspace = torch.view_as_complex(kspace) kspace = torch.nn.functional.pad( kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0 ) kspace = torch.view_as_real(kspace) sensitivity_map = fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sensitivity_map = torch.view_as_complex(sensitivity_map) sensitivity_map = torch.nn.functional.pad( sensitivity_map, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0, ) sensitivity_map = torch.view_as_real(sensitivity_map) sensitivity_map = ifft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # Initial estimation eta = to_tensor(eta) if eta is not None and eta.size != 0 else torch.tensor([]) # If the target is not given, we need to compute it. if self.coil_combination_method.upper() == "RSS": target = rss( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), dim=self.coil_dim, ) elif self.coil_combination_method.upper() == "SENSE": if sensitivity_map is not None and sensitivity_map.size != 0: target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_map.unsqueeze(0), dim=self.coil_dim, ) elif target is not None and target.size != 0: target = to_tensor(target) elif "target" in attrs or "target_rss" in attrs: target = torch.tensor(attrs["target"]) else: raise ValueError("No target found") target = torch.view_as_complex(target) target = torch.abs(target / torch.max(torch.abs(target))) seed = tuple(map(ord, fname)) if self.use_seed else None acq_start = attrs["padding_left"] if "padding_left" in attrs else 0 acq_end = attrs["padding_right"] if "padding_left" in attrs else 0 # This should be outside the condition because it needs to be returned in the end, even if cropping is off. # crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]]) crop_size = target.shape[1:] if self.crop_size is not None and self.crop_size not in ("", "None"): # Check for smallest size against the target shape. h = min(int(self.crop_size[0]), target.shape[1]) w = min(int(self.crop_size[1]), target.shape[2]) # Check for smallest size against the stored recon shape in metadata. if crop_size[0] != 0: h = h if h <= crop_size[0] else crop_size[0] if crop_size[1] != 0: w = w if w <= crop_size[1] else crop_size[1] self.crop_size = (int(h), int(w)) target = center_crop(target, self.crop_size) if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = ( ifft2( complex_center_crop( fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop(sensitivity_map, self.crop_size) ) if eta is not None and eta.ndim > 2: eta = ( ifft2( complex_center_crop( fft2( eta, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop(eta, self.crop_size) ) if mask_brain.dim() != 0: mask_brain = center_crop(mask_brain, self.crop_size) # Cropping before masking will maintain the shape of original kspace intact for masking. if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking: kspace = ( complex_center_crop(kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) ) if isinstance(mask, list): masked_kspaces = [] masks = [] for _mask in mask: if list(_mask.shape) == [kspace.shape[-3], kspace.shape[-2]]: _mask = _mask.unsqueeze(0).unsqueeze(-1) padding = (acq_start, acq_end) if (not is_none(padding[0]) and not is_none(padding[1])) and padding[0] != 0: _mask[:, :, : padding[0]] = 0 _mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros if isinstance(_mask, np.ndarray): _mask = torch.from_numpy(_mask).unsqueeze(0).unsqueeze(-1) if self.shift_mask: _mask = torch.fft.fftshift(_mask, dim=(self.spatial_dims[0] - 1, self.spatial_dims[1] - 1)) if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking: _mask = complex_center_crop(_mask, self.crop_size) masked_kspaces.append(kspace * _mask + 0.0) masks.append(_mask) masked_kspace = masked_kspaces mask = masks acc = 1 elif not is_none(mask) and mask.ndim != 0: # and not is_none(self.mask_func): for _mask in mask: if list(_mask.shape) == [kspace.shape[-3], kspace.shape[-2]]: mask = torch.from_numpy(_mask).unsqueeze(0).unsqueeze(-1) break padding = (acq_start, acq_end) if (not is_none(padding[0]) and not is_none(padding[1])) and padding[0] != 0: mask[:, :, : padding[0]] = 0 mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1) if self.shift_mask: mask = torch.fft.fftshift(mask, dim=(self.spatial_dims[0] - 1, self.spatial_dims[1] - 1)) if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking: mask = complex_center_crop(mask, self.crop_size) masked_kspace = kspace * mask + 0.0 # the + 0.0 removes the sign of the zeros acc = 1 elif is_none(self.mask_func): masked_kspace = kspace.clone() acc = torch.tensor([1]) if mask is None: mask = torch.ones(masked_kspace.shape[-3], masked_kspace.shape[-2]).type(torch.float32) # type: ignore else: mask = torch.from_numpy(mask) if mask.dim() == 1: mask = mask.unsqueeze(0) if mask.shape[0] == masked_kspace.shape[2]: # type: ignore mask = mask.permute(1, 0) elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore mask = torch.ones( [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore ) if mask.shape[-2] == 1: # 1D mask mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1) else: # 2D mask # Crop loaded mask. if self.crop_size is not None and self.crop_size not in ("", "None"): mask = center_crop(mask, self.crop_size) mask = mask.unsqueeze(0).unsqueeze(-1) if self.shift_mask: mask = torch.fft.fftshift(mask, dim=(1, 2)) masked_kspace = masked_kspace * mask mask = mask.byte() elif isinstance(self.mask_func, list): masked_kspaces = [] masks = [] accs = [] for m in self.mask_func: if self.dimensionality == 2: _masked_kspace, _mask, _acc = apply_mask( kspace, m, seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) elif self.dimensionality == 3: _masked_kspace = [] _mask = None for i in range(kspace.shape[0]): _i_masked_kspace, _i_mask, _i_acc = apply_mask( kspace[i], m, seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, existing_mask=_mask, ) if self.remask: _mask = _i_mask if i == 0: _acc = _i_acc _masked_kspace.append(_i_masked_kspace) _masked_kspace = torch.stack(_masked_kspace, dim=0) _mask = _i_mask.unsqueeze(0) else: raise ValueError(f"Unsupported data dimensionality {self.dimensionality}D.") masked_kspaces.append(_masked_kspace) masks.append(_mask.byte()) accs.append(_acc) masked_kspace = masked_kspaces mask = masks acc = accs # type: ignore else: masked_kspace, mask, acc = apply_mask( kspace, self.mask_func[0], # type: ignore seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) mask = mask.byte() # Cropping after masking. if self.crop_size is not None and self.crop_size not in ("", "None") and not self.crop_before_masking: kspace = ( complex_center_crop(kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) ) masked_kspace = ( complex_center_crop(masked_kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) ) mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1) mask_head = torch.ones_like(mask_brain) if self.precompute_quantitative_maps: R2star_maps_init = [] S0_maps_init = [] B0_maps_init = [] phi_maps_init = [] etas = [] for y in masked_kspace: eta = sense( ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_map.unsqueeze(0), dim=self.coil_dim, ) etas.append(eta) R2star_map_init, S0_map_init, B0_map_init, phi_map_init = R2star_B0_real_S0_complex_mapping( eta, self.TEs, mask_brain, mask_head, fully_sampled=True, shift=self.shift_B0_input, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) R2star_maps_init.append(R2star_map_init) S0_maps_init.append(S0_map_init) B0_maps_init.append(B0_map_init) phi_maps_init.append(phi_map_init) R2star_map_init = torch.stack(R2star_maps_init, dim=0) S0_map_init = torch.stack(S0_maps_init, dim=0) B0_map_init = torch.stack(B0_maps_init, dim=0) phi_map_init = torch.stack(phi_maps_init, dim=0) mask_brain_tmp = torch.ones_like(torch.abs(mask_brain)) mask_brain_tmp = mask_brain_tmp.unsqueeze(0) if mask_brain.dim() == 2 else mask_brain_tmp imspace = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) * mask_brain_tmp.unsqueeze(self.coil_dim - 1).unsqueeze(-1), sensitivity_map.unsqueeze(0), dim=self.coil_dim, ) R2star_map_target, S0_map_target, B0_map_target, phi_map_target = R2star_B0_real_S0_complex_mapping( imspace, self.TEs, mask_brain, mask_head, fully_sampled=True, shift=self.shift_B0_input, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) else: if qmaps[0][0].ndim != 0: B0_map, S0_map, R2star_map, phi_map = qmaps B0_map = [torch.from_numpy(x).squeeze(0) for x in B0_map] B0_map_target = B0_map[-1] B0_map_init = B0_map[:-1] S0_map = [torch.from_numpy(x).squeeze(0) for x in S0_map] S0_map_target = S0_map[-1] S0_map_init = S0_map[:-1] R2star_map = [torch.from_numpy(x).squeeze(0) for x in R2star_map] R2star_map_target = R2star_map[-1] R2star_map_init = R2star_map[:-1] phi_map = [torch.from_numpy(x).squeeze(0) for x in phi_map] phi_map_target = phi_map[-1] phi_map_init = phi_map[:-1] else: B0_map_target = torch.tensor([]) B0_map_init = [torch.tensor([])] * len(masked_kspace) S0_map_target = torch.tensor([]) S0_map_init = [torch.tensor([])] * len(masked_kspace) R2star_map_target = torch.tensor([]) R2star_map_init = [torch.tensor([])] * len(masked_kspace) phi_map_target = torch.tensor([]) phi_map_init = [torch.tensor([])] * len(masked_kspace) # Normalize by the max value. if self.normalize_inputs: if isinstance(self.mask_func, list): if self.fft_normalization in ("backward", "ortho", "forward"): imspace = ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.max_norm: imspace = imspace / torch.max(torch.abs(imspace)) kspace = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization in ("none", None) and self.max_norm: imspace = torch.fft.ifftn(torch.view_as_complex(kspace), dim=list(self.spatial_dims), norm=None) imspace = imspace / torch.max(torch.abs(imspace)) kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None)) masked_kspaces = [] for y in masked_kspace: if self.fft_normalization in ("backward", "ortho", "forward"): imspace = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.max_norm: imspace = imspace / torch.max(torch.abs(imspace)) y = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization in ("none", None) and self.max_norm: imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=list(self.spatial_dims), norm=None) imspace = imspace / torch.max(torch.abs(imspace)) y = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None)) masked_kspaces.append(y) masked_kspace = masked_kspaces elif self.fft_normalization in ("backward", "ortho", "forward"): imspace = ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.max_norm: imspace = imspace / torch.max(torch.abs(imspace)) kspace = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) imspace = ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.max_norm: imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization in ("none", None) and self.max_norm: imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=list(self.spatial_dims), norm=None) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None)) imspace = torch.fft.ifftn(torch.view_as_complex(kspace), dim=list(self.spatial_dims), norm=None) imspace = imspace / torch.max(torch.abs(imspace)) kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=list(self.spatial_dims), norm=None)) if self.max_norm: if sensitivity_map.size != 0: sensitivity_map = sensitivity_map / torch.max(torch.abs(sensitivity_map)) if eta.size != 0 and eta.ndim > 2: eta = eta / torch.max(torch.abs(eta)) target = target / torch.max(torch.abs(target)) return ( R2star_map_init, R2star_map_target, S0_map_init, S0_map_target, B0_map_init, B0_map_target, phi_map_init, phi_map_target, torch.tensor(self.TEs), kspace, masked_kspace, sensitivity_map, mask, mask_brain, eta, target, fname, slice_idx, acc, )
class GaussianSmoothing(torch.nn.Module): """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed separately for each channel in the input using a depthwise convolution. """ def __init__( self, channels: int, kernel_size: Union[Optional[List[int]], int], sigma: float, dim: int = 2, shift: bool = False, fft_centered: bool = True, fft_normalization: str = "ortho", spatial_dims: Sequence[int] = None, ): """ Initialize the module with the gaussian kernel size and standard deviation. Parameters ---------- channels : int Number of channels in the input tensor. kernel_size : Union[Optional[List[int]], int] Gaussian kernel size. sigma : float Gaussian kernel standard deviation. dim : int Number of dimensions in the input tensor. shift : bool If True, the gaussian kernel is centered at (kernel_size - 1) / 2. fft_centered : bool Whether to center the FFT for a real- or complex-valued input. fft_normalization : str Whether to normalize the FFT output (None, "ortho", "backward", "forward", "none"). spatial_dims : Sequence[int] Spatial dimensions to keep in the FFT. """ super(GaussianSmoothing, self).__init__() self.shift = shift self.fft_centered = fft_centered self.fft_normalization = fft_normalization self.spatial_dims = spatial_dims if isinstance(kernel_size, int): kernel_size = [kernel_size] * dim if isinstance(sigma, float): sigma = [sigma] * dim # type: ignore # The gaussian kernel is the product of the gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij" # type: ignore ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): # type: ignore mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) # type: ignore kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) # type: ignore self.register_buffer("weight", kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) def forward(self, input): """ Apply gaussian filter to input. Parameters ---------- input : torch.Tensor Input to apply gaussian filter on. Returns ------- torch.Tensor Filtered output. """ if self.shift: input = input.permute(0, 2, 3, 1) input = ifft2( torch.fft.fftshift( fft2( torch.view_as_real(input[..., 0] + 1j * input[..., 1]), self.fft_centered, self.fft_normalization, self.spatial_dims, ), ), self.fft_centered, self.fft_normalization, self.spatial_dims, ).permute(0, 3, 1, 2) x = self.conv(input, weight=self.weight.to(input), groups=self.groups).to(input).detach() if self.shift: x = x.permute(0, 2, 3, 1) x = ifft2( torch.fft.fftshift( fft2( torch.view_as_real(x[..., 0] + 1j * x[..., 1]), self.fft_centered, self.fft_normalization, self.spatial_dims, ), ), self.fft_centered, self.fft_normalization, self.spatial_dims, ).permute(0, 3, 1, 2) return x class LeastSquares: def __init__(self, device): super(LeastSquares, self).__init__() self.device = device def lstq(self, A, Y, lamb=0.0): """Differentiable least square.""" q, r = torch.qr(A) return torch.inverse(r) @ q.permute(0, 2, 1) @ Y def lstq_pinv(self, A, Y, lamb=0.0): """Differentiable inverse least square.""" if Y.dim() == 2: return torch.matmul(torch.pinverse(Y), A) else: return torch.matmul( torch.matmul(torch.inverse(torch.matmul(Y.permute(0, 2, 1), Y)), Y.permute(0, 2, 1)), A ) def lstq_pinv_complex_np(self, A, Y, lamb=0.0): """Differentiable inverse least square for stacked complex inputs.""" if Y.ndim == 2: return np.matmul(np.linalg.pinv(Y), A) else: Y = Y.to(self.device) A = A.to(Y) x = torch.matmul(torch.conj(Y).permute(0, 2, 1), Y) x = torch.matmul(torch.inverse(x), torch.conj(Y).permute(0, 2, 1)) return torch.bmm(x, A)[..., 0] def R2star_B0_real_S0_complex_mapping( prediction: torch.Tensor, TEs: Union[Optional[List[float]], float], brain_mask: torch.Tensor, head_mask: torch.Tensor, fully_sampled: bool = True, shift: bool = False, fft_centered: bool = True, fft_normalization: str = "ortho", spatial_dims: Sequence[int] = None, ): """ Maps the prediction to R2*, B0, and S0 maps. Parameters ---------- prediction : torch.Tensor The prediction of the model. TEs : Union[Optional[List[float]], float] The TEs of the images. brain_mask : torch.Tensor The brain mask of the images. head_mask : torch.Tensor The head mask of the images. fully_sampled : bool Whether the images are fully sampled. shift : bool If True, the gaussian kernel is centered at (kernel_size - 1) / 2. fft_centered : bool Whether to center the FFT for a real- or complex-valued input. fft_normalization : str Whether to normalize the FFT output (None, "ortho", "backward", "forward", "none"). spatial_dims : Sequence[int] Spatial dimensions to keep in the FFT. Returns ------- R2star : torch.Tensor The R2* map. B0 : torch.Tensor The B0 map. S0 : torch.Tensor The S0 map. phi : torch.Tensor The phi map. """ R2star_map = R2star_S0_mapping(prediction, TEs) B0_map = -B0_phi_mapping( prediction, TEs, brain_mask, head_mask, fully_sampled, shift=shift, fft_centered=fft_centered, fft_normalization=fft_normalization, spatial_dims=spatial_dims, )[0] S0_map_real, S0_map_imag = S0_mapping_complex( prediction, TEs, R2star_map, B0_map, shift=shift, fft_centered=fft_centered, fft_normalization=fft_normalization, spatial_dims=spatial_dims, ) return R2star_map, S0_map_real, B0_map, S0_map_imag def R2star_S0_mapping( prediction: torch.Tensor, TEs: Union[Optional[List[float]], float], scaling_factor: float = 1e-3, ): """ R2* map and S0 map estimation for multi-echo GRE from stored magnitude image files acquired at multiple TEs. Parameters ---------- prediction : torch.Tensor The prediction of the model. TEs : Union[Optional[List[float]], float] The TEs of the images. scaling_factor : float The scaling factor. Returns ------- R2star : torch.Tensor The R2* map. S0 : torch.Tensor The S0 map. """ prediction = torch.abs(torch.view_as_complex(prediction)) + 1e-8 prediction_flatten = torch.flatten(prediction, start_dim=1, end_dim=-1).detach().cpu() # .numpy() TEs = np.array(TEs).to(prediction_flatten) # TODO: this part needs a proper implementation in PyTorch R2star_map = torch.zeros([prediction_flatten.shape[1]]) for i in range(prediction_flatten.shape[1]): R2star_map[i], _ = torch.from_numpy( np.polyfit( TEs * scaling_factor, # type:ignore np.log(prediction_flatten[:, i]), 1, w=np.sqrt(prediction_flatten[:, i]), ) ).to(prediction) R2star_map = torch.reshape(-R2star_map, prediction.shape[1:4]) return R2star_map def B0_phi_mapping( prediction: torch.Tensor, TEs: Union[Optional[List[float]], float], brain_mask: torch.Tensor, head_mask: torch.Tensor, fully_sampled: bool = True, scaling_factor: float = 1e-3, shift: bool = False, fft_centered: bool = True, fft_normalization: str = "ortho", spatial_dims: Sequence[int] = None, ): """ B0 map and Phi map estimation for multi-echo GRE from stored magnitude image files acquired at multiple TEs. Parameters ---------- prediction : torch.Tensor The prediction of the model. TEs : Union[Optional[List[float]], float] The TEs of the images. brain_mask : torch.Tensor The brain mask of the images. head_mask : torch.Tensor The head mask of the images. fully_sampled : bool Whether the images are fully sampled. scaling_factor : float The scaling factor. shift : bool If True, the gaussian kernel is centered at (kernel_size - 1) / 2. fft_centered : bool Whether to center the FFT for a real- or complex-valued input. fft_normalization : str Whether to normalize the FFT output (None, "ortho", "backward", "forward", "none"). spatial_dims : Sequence[int] Spatial dimensions to keep in the FFT. Returns ------- B0 : torch.Tensor The B0 map. phi : torch.Tensor The phi map. """ lsq = LeastSquares(device=prediction.device) TEnotused = 3 # if fully_sampled else 3 TEs = torch.tensor(TEs) # brain_mask is used only for descale of phase difference (so that phase_diff is in between -2pi and 2pi) brain_mask_descale = brain_mask shape = prediction.shape # apply gaussian blur with radius r to smoothing = GaussianSmoothing( channels=2, kernel_size=9, sigma=1.0, dim=2, shift=shift, fft_centered=fft_centered, fft_normalization=fft_normalization, spatial_dims=spatial_dims, ) prediction = prediction.unsqueeze(1).permute([0, 1, 4, 2, 3]) # add a dummy batch dimension for i in range(prediction.shape[0]): prediction[i] = smoothing(F.pad(prediction[i], (4, 4, 4, 4), mode="reflect")) prediction = prediction.permute([0, 1, 3, 4, 2]).squeeze(1) if shift: prediction = ifft2( torch.fft.fftshift(fft2(prediction, fft_centered, fft_normalization, spatial_dims), dim=(1, 2)), fft_centered, fft_normalization, spatial_dims, ) phase = torch.angle(torch.view_as_complex(prediction)) # unwrap phases phase_unwrapped = torch.zeros_like(phase) mask_head_np = np.invert(head_mask.cpu().detach().numpy() > 0.5) # loop over echo times for i in range(phase.shape[0]): phase_unwrapped[i] = torch.from_numpy( unwrap_phase(np.ma.array(phase[i].detach().cpu().numpy(), mask=mask_head_np)).data ).to(prediction) phase_diff_set = [] TE_diff = [] # obtain phase differences and TE differences for i in range(phase_unwrapped.shape[0] - TEnotused): phase_diff_set.append(torch.flatten(phase_unwrapped[i + 1] - phase_unwrapped[i])) phase_diff_set[i] = ( phase_diff_set[i] - torch.round( torch.abs( torch.sum(phase_diff_set[i] * torch.flatten(brain_mask_descale)) / torch.sum(brain_mask_descale) / 2 / np.pi ) ) * 2 * np.pi ) TE_diff.append(TEs[i + 1] - TEs[i]) # type: ignore phase_diff_set = torch.stack(phase_diff_set, 0) TE_diff = torch.stack(TE_diff, 0).to(prediction) # least squares fitting to obtain phase map B0_map_tmp = lsq.lstq_pinv( phase_diff_set.unsqueeze(2).permute(1, 0, 2), TE_diff.unsqueeze(1) * scaling_factor # type: ignore ) B0_map = B0_map_tmp.reshape(shape[-3], shape[-2]) B0_map = B0_map * torch.abs(head_mask) # obtain phi map phi_map = (phase_unwrapped[0] - scaling_factor * TEs[0] * B0_map).squeeze(0) # type: ignore return B0_map.to(prediction), phi_map.to(prediction) def S0_mapping_complex( prediction: torch.Tensor, TEs: Union[Optional[List[float]], float], R2star_map: torch.Tensor, B0_map: torch.Tensor, scaling_factor: float = 1e-3, shift: bool = False, fft_centered: bool = True, fft_normalization: str = "ortho", spatial_dims: Sequence[int] = None, ): """ Complex S0 mapping. Parameters ---------- prediction : torch.Tensor The prediction of the model. TEs : Union[Optional[List[float]], float] The TEs of the images. R2star_map : torch.Tensor The R2* map. B0_map : torch.Tensor The B0 map. scaling_factor : float The scaling factor. shift : bool If True, the gaussian kernel is centered at (kernel_size - 1) / 2. fft_centered : bool Whether to center the FFT for a real- or complex-valued input. fft_normalization : str Whether to normalize the FFT output (None, "ortho", "backward", "forward", "none"). spatial_dims : Sequence[int] Spatial dimensions to keep in the FFT. Returns ------- S0 : torch.Tensor The S0 map. """ lsq = LeastSquares(device=prediction.device) prediction = torch.view_as_complex(prediction) prediction_flatten = prediction.reshape(prediction.shape[0], -1) TEs = torch.tensor(TEs).to(prediction) R2star_B0_complex_map = R2star_map.to(prediction) + 1j * B0_map.to(prediction) R2star_B0_complex_map_flatten = R2star_B0_complex_map.flatten() TEs_r2 = TEs[0:4].unsqueeze(1) * -R2star_B0_complex_map_flatten # type: ignore S0_map = lsq.lstq_pinv_complex_np( prediction_flatten.permute(1, 0).unsqueeze(2), torch.exp(scaling_factor * TEs_r2.permute(1, 0).unsqueeze(2)), ) S0_map = torch.view_as_real(S0_map.reshape(prediction.shape[1:])) if shift: S0_map = ifft2( torch.fft.fftshift(fft2(S0_map, fft_centered, fft_normalization, spatial_dims), dim=(0, 1)), fft_centered, fft_normalization, spatial_dims, ) S0_map_real, S0_map_imag = torch.chunk(S0_map, 2, dim=-1) return S0_map_real.squeeze(-1), S0_map_imag.squeeze(-1)