Source code for mridc.collections.reconstruction.parts.utils

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np
import torch

__all__ = [
    "apply_mask",
    "mask_center",
    "batched_mask_center",
    "center_crop",
    "complex_center_crop",
    "center_crop_to_smallest",
]

from mridc.collections.reconstruction.data.subsample import MaskFunc


[docs]def apply_mask( data: torch.Tensor, mask_func: MaskFunc, seed: Optional[Union[int, Tuple[int, ...]]] = None, padding: Optional[Sequence[int]] = None, shift: bool = False, half_scan_percentage: Optional[float] = 0.0, center_scale: Optional[float] = 0.02, existing_mask: Optional[torch.Tensor] = None, ) -> Tuple[Any, Any, int]: """ Subsample given k-space by multiplying with a mask. Parameters ---------- data: The input k-space data. This should have at least 3 dimensions, where dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 2 (for complex values). mask_func: A function that takes a shape (tuple of ints) and a random number seed and returns a mask. seed: Seed for the random number generator. padding: Padding value to apply for mask. shift: Toggle to shift mask when subsampling. Applicable on 2D data. half_scan_percentage: Percentage of kspace to be dropped. center_scale: Scale of the center of the mask. Applicable on Gaussian masks. existing_mask: When given, use this mask instead of generating a new one. Returns ------- Tuple of subsampled k-space, mask, and mask indices. """ shape = np.array(data.shape) shape[:-3] = 1 if existing_mask is None: mask, acc = mask_func(shape, seed, half_scan_percentage=half_scan_percentage, scale=center_scale) else: mask = existing_mask acc = mask.size / mask.sum() mask = mask.to(data.device) if padding is not None and padding[0] != 0: mask[:, :, : padding[0]] = 0 mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros if shift: mask = torch.fft.fftshift(mask, dim=(1, 2)) masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros return masked_data, mask, acc
[docs]def mask_center( x: torch.Tensor, mask_from: Optional[int], mask_to: Optional[int], mask_type: str = "2D" ) -> torch.Tensor: """ Apply a center crop to the input real image or batch of real images. Parameters ---------- x: The input real image or batch of real images. mask_from: Part of center to start filling. mask_to: Part of center to end filling. mask_type: Type of mask to apply. Can be either "1D" or "2D". Returns ------- A mask with the center filled. """ mask = torch.zeros_like(x) if isinstance(mask_from, list): mask_from = mask_from[0] if isinstance(mask_to, list): mask_to = mask_to[0] if mask_type == "1D": mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] elif mask_type == "2D": mask[:, :, mask_from:mask_to] = x[:, :, mask_from:mask_to] return mask
[docs]def batched_mask_center( x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor, mask_type: str = "2D" ) -> torch.Tensor: """ Initializes a mask with the center filled in. Can operate with different masks for each batch element. Parameters ---------- x: The input real image or batch of real images. mask_from: Part of center to start filling. mask_to: Part of center to end filling. mask_type: Type of mask to apply. Can be either "1D" or "2D". Returns ------- A mask with the center filled. """ if mask_from.shape != mask_to.shape: raise ValueError("mask_from and mask_to must match shapes.") if mask_from.ndim != 1: raise ValueError("mask_from and mask_to must have 1 dimension.") if mask_from.shape[0] not in (1, x.shape[0]) or x.shape[0] != mask_to.shape[0]: raise ValueError("mask_from and mask_to must have batch_size length.") if mask_from.shape[0] == 1: mask = mask_center(x, int(mask_from), int(mask_to), mask_type=mask_type) else: mask = torch.zeros_like(x) for i, (start, end) in enumerate(zip(mask_from, mask_to)): mask[i, :, :, start:end] = x[i, :, :, start:end] return mask
[docs]def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: """ Apply a center crop to the input real image or batch of real images. Parameters ---------- data: The input tensor to be center cropped. It should have at least 2 dimensions and the cropping is applied along the last two dimensions. shape: The output shape. The shape should be smaller than the corresponding dimensions of data. Returns ------- The center cropped image. """ if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]): raise ValueError("Invalid shapes.") w_from = torch.div((data.shape[-2] - shape[0]), 2, rounding_mode="trunc") h_from = torch.div((data.shape[-1] - shape[1]), 2, rounding_mode="trunc") w_to = w_from + shape[0] h_to = h_from + shape[1] return data[..., w_from:w_to, h_from:h_to] # type: ignore
[docs]def complex_center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: """ Apply a center crop to the input image or batch of complex images. Parameters ---------- data: The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions -3 and -2 and the last dimensions should have a size of 2. shape: The output shape. The shape should be smaller than the corresponding dimensions of data. Returns ------- The center cropped image. """ if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]): raise ValueError("Invalid shapes.") w_from = torch.div((data.shape[-3] - shape[0]), 2, rounding_mode="trunc") h_from = torch.div((data.shape[-2] - shape[1]), 2, rounding_mode="trunc") w_to = w_from + shape[0] h_to = h_from + shape[1] return data[..., w_from:w_to, h_from:h_to, :] # type: ignore
[docs]def center_crop_to_smallest( x: Union[torch.Tensor, np.ndarray], y: Union[torch.Tensor, np.ndarray] ) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray]]: """ Apply a center crop on the larger image to the size of the smaller. The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at dim=-1 and y is smaller than x at dim=-2, then the returned dimension will be a mixture of the two. Parameters ---------- x: The first image. y: The second image. Returns ------- Tuple of tensors x and y, each cropped to the minimum size. """ smallest_width = min(x.shape[-1], y.shape[-1]) smallest_height = min(x.shape[-2], y.shape[-2]) x = center_crop(x, (smallest_height, smallest_width)) y = center_crop(y, (smallest_height, smallest_width)) return x, y