Source code for

# encoding: utf-8
# Parts of the code have been taken from
__author__ = "Dimitrios Karkalousos"

import contextlib
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch

[docs]@contextlib.contextmanager def temp_seed(rng: np.random, seed: Optional[Union[int, Tuple[int, ...]]]): """ Temporarily sets the seed of the given random number generator. Parameters ---------- rng: The random number generator. seed: The seed to set. Returns ------- A context manager. """ if seed is None: try: yield finally: pass else: state = rng.get_state() rng.seed(seed) try: yield finally: rng.set_state(state)
[docs]class MaskFunc: """A class that defines a mask function.""" def __init__(self, center_fractions: Sequence[float], accelerations: Sequence[int]): """ Initialize the mask function. Parameters ---------- center_fractions: Fraction of low-frequency columns to be retained. If multiple values are provided, then \ one of these numbers is chosen uniformly each time. For 2D setting this value corresponds to setting the \ Full-Width-Half-Maximum. accelerations: Amount of under-sampling. This should have the same length as center_fractions. If multiple \ values are provided, then one of these is chosen uniformly each time. """ if len(center_fractions) != len(accelerations): raise ValueError("Number of center fractions should match number of accelerations") self.center_fractions = center_fractions self.accelerations = accelerations self.rng = np.random.RandomState() # pylint: disable=no-member
[docs] def __call__( self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: Shape of the input tensor. seed: Seed for the random number generator. half_scan_percentage: Percentage of the low-frequency columns to be retained. scale: Scale of the mask. Returns ------- A tuple of the mask and the number of low-frequency columns retained. """ raise NotImplementedError
[docs] def choose_acceleration(self): """Choose acceleration.""" choice = self.rng.randint(0, len(self.accelerations)) center_fraction = self.center_fractions[choice] acceleration = self.accelerations[choice] return center_fraction, acceleration
[docs]class RandomMaskFunc(MaskFunc): """ RandomMaskFunc creates a sub-sampling mask of a given shape. The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ picks out: 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. 2. The other columns are selected uniformly at random with a probability equal to: \ prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). This ensures that the expected number of \ columns selected is equal to (N / acceleration). It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ acceleration) is chosen uniformly at random each time the RandomMaskFunc object is called. For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there is a 50% probability that \ 4-fold acceleration with 8% center fraction is selected and a 50% probability that 8-fold acceleration with 4% \ center fraction is selected. """
[docs] def __call__( self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: Optional; Defines the scale of the center of the mask. Returns ------- A tuple of the mask and the number of columns selected. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") with temp_seed(self.rng, seed): num_cols = shape[-2] center_fraction, acceleration = self.choose_acceleration() # create the mask num_low_freqs = int(round(num_cols * center_fraction)) prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) mask = self.rng.uniform(size=num_cols) < prob # type: ignore pad = torch.div((num_cols - num_low_freqs + 1), 2, rounding_mode="trunc").item() mask[pad : pad + num_low_freqs] = True # reshape the mask mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) return mask, acceleration
[docs]class Equispaced1DMaskFunc(MaskFunc): """ Equispaced1DMaskFunc creates a sub-sampling mask of a given shape. The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask \ picks out: 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. 2. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration \ rate taking into consideration the number of low frequencies. This ensures that the expected number of \ columns selected is equal to (N / acceleration) It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, \ acceleration) is chosen uniformly at random each time the Equispaced1DMaskFunc object is called. Note that this function may not give equispaced samples (documented in \, which will require modifications to standard GRAPPA \ approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. """
[docs] def __call__( self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: Optional; Defines the scale of the center of the mask. Returns ------- A tuple of the mask and the number of columns selected. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") with temp_seed(self.rng, seed): center_fraction, acceleration = self.choose_acceleration() num_cols = shape[-2] num_low_freqs = int(round(num_cols * center_fraction)) # create the mask mask = np.zeros(num_cols, dtype=np.float32) pad = torch.div((num_cols - num_low_freqs + 1), 2, rounding_mode="trunc").item() mask[pad : pad + num_low_freqs] = True # type: ignore # determine acceleration rate by adjusting for the number of low frequencies adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) offset = self.rng.randint(0, round(adjusted_accel)) accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) accel_samples = np.around(accel_samples).astype(np.uint) mask[accel_samples] = True # reshape the mask mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) return mask, acceleration
[docs]class Equispaced2DMaskFunc(MaskFunc): """Same as Equispaced1DMaskFunc, but for 2D k-space data."""
[docs] def __call__( self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: Optional; Defines the scale of the center of the mask. Returns ------- A tuple of the mask and the number of columns selected. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") with temp_seed(self.rng, seed): center_fraction, acceleration = self.choose_acceleration() acceleration = acceleration / 2 center_fraction = center_fraction / 2 num_cols = shape[-2] num_low_freqs = int(round(num_cols * center_fraction)) num_rows = shape[-3] num_high_freqs = int(round(num_rows * center_fraction)) # create the mask mask = np.zeros([num_rows, num_cols], dtype=np.float32) pad_cols = torch.div((num_cols - num_low_freqs + 1), 2, rounding_mode="trunc").item() pad_rows = torch.div((num_rows - num_high_freqs + 1), 2, rounding_mode="trunc").item() mask[pad_rows : pad_rows + num_high_freqs, pad_cols : pad_cols + num_low_freqs] = True # type: ignore for i in np.arange(0, num_rows, acceleration): for j in np.arange(0, num_cols, acceleration): mask[int(i), int(j)] = True # reshape the mask mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols mask_shape[-3] = num_rows mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) return mask, acceleration * 2
[docs]class Gaussian1DMaskFunc(MaskFunc): """ Creates a 1D sub-sampling mask of a given shape. For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ according to a Gaussian distribution. The center fractions here act as Full-Width at Half-Maximum (FWHM) values. """
[docs] def __call__( self, shape: Union[Sequence[int], np.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- A tuple of the mask and the number of columns selected. """ dims = [1 for _ in shape] self.shape = tuple(shape[-3:-1]) dims[-2] = self.shape[-1] full_width_half_maximum, acceleration = self.choose_acceleration() if not isinstance(full_width_half_maximum, list): full_width_half_maximum = [full_width_half_maximum] * 2 self.full_width_half_maximum = full_width_half_maximum self.acceleration = acceleration self.scale = scale mask = self.gaussian_kspace() mask[tuple(self.gaussian_coordinates())] = 1.0 mask = np.fft.ifftshift(np.fft.ifftshift(np.fft.ifftshift(mask, axes=0), axes=0), axes=(0, 1)) if half_scan_percentage != 0: mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 return torch.from_numpy(mask[0].reshape(dims).astype(np.float32)), acceleration
[docs] def gaussian_kspace(self): """Creates a Gaussian sampled k-space center.""" scaled = int(self.shape[0] * self.scale) center = np.ones((scaled, self.shape[1])) top_scaled = torch.div((self.shape[0] - scaled), 2, rounding_mode="trunc").item() bottom_scaled = self.shape[0] - scaled - top_scaled top = np.zeros((top_scaled, self.shape[1])) btm = np.zeros((bottom_scaled, self.shape[1])) return np.concatenate((top, center, btm))
[docs] def gaussian_coordinates(self): """Creates a Gaussian sampled k-space coordinates.""" n_sample = int(self.shape[0] / self.acceleration) kernel = self.gaussian_kernel() idxs = np.random.choice(range(self.shape[0]), size=n_sample, replace=False, p=kernel) xsamples = np.concatenate([np.tile(i, self.shape[1]) for i in idxs]) ysamples = np.concatenate([range(self.shape[1]) for _ in idxs]) return xsamples, ysamples
[docs] def gaussian_kernel(self): """Creates a Gaussian sampled k-space kernel.""" kernel = 1 for fwhm, kern_len in zip(self.full_width_half_maximum, self.shape): sigma = fwhm / np.sqrt(8 * np.log(2)) x = np.linspace(-1.0, 1.0, kern_len) g = np.exp(-(x**2 / (2 * sigma**2))) kernel = g break kernel = kernel / kernel.sum() return kernel
[docs]class Gaussian2DMaskFunc(MaskFunc): """ Creates a 2D sub-sampling mask of a given shape. For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ according to a Gaussian distribution. The center fractions here act as Full-Width at Half-Maximum (FWHM) values. """
[docs] def __call__( self, shape: Union[Sequence[int], np.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time for \ the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- A tuple of the mask and the number of columns selected. """ dims = [1 for _ in shape] self.shape = tuple(shape[-3:-1]) dims[-3:-1] = self.shape full_width_half_maximum, acceleration = self.choose_acceleration() if not isinstance(full_width_half_maximum, list): full_width_half_maximum = [full_width_half_maximum] * 2 self.full_width_half_maximum = full_width_half_maximum self.acceleration = acceleration self.scale = scale mask = self.gaussian_kspace() mask[tuple(self.gaussian_coordinates())] = 1.0 if half_scan_percentage != 0: mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 return torch.from_numpy(mask.reshape(dims).astype(np.float32)), acceleration
[docs] def gaussian_kspace(self): """Creates a Gaussian sampled k-space center.""" a, b = self.scale * self.shape[0], self.scale * self.shape[1] afocal, bfocal = self.shape[0] / 2, self.shape[1] / 2 xx, yy = np.mgrid[: self.shape[0], : self.shape[1]] ellipse = np.power((xx - afocal) / a, 2) + np.power((yy - bfocal) / b, 2) return (ellipse < 1).astype(float)
[docs] def gaussian_coordinates(self): """Creates a Gaussian sampled k-space coordinates.""" n_sample = int(self.shape[0] * self.shape[1] / self.acceleration) cartesian_prod = list(np.ndindex(self.shape)) # type: ignore kernel = self.gaussian_kernel() idxs = np.random.choice(range(len(cartesian_prod)), size=n_sample, replace=False, p=kernel.flatten()) return list(zip(*list(map(cartesian_prod.__getitem__, idxs))))
[docs] def gaussian_kernel(self): """Creates a Gaussian kernel.""" kernels = [] for fwhm, kern_len in zip(self.full_width_half_maximum, self.shape): sigma = fwhm / np.sqrt(8 * np.log(2)) x = np.linspace(-1.0, 1.0, kern_len) g = np.exp(-(x**2 / (2 * sigma**2))) kernels.append(g) kernel = np.sqrt(np.outer(kernels[0], kernels[1])) kernel = kernel / kernel.sum() return kernel
[docs]class Poisson2DMaskFunc(MaskFunc): """ Creates a 2D sub-sampling mask of a given shape. For autocalibration purposes, data points near the k-space center will be fully sampled within an ellipse of \ which the half-axes will set to the set scale % of the fully sampled region. The remaining points will be sampled \ according to a (variable density) Poisson distribution. For a given acceleration factor to be accurate, the scale for the fully sampled center should remain at the \ default 0.02. A predefined list is used to convert the acceleration factor to the appropriate r parameter needed \ for the variable density calculation. This list has been made to accommodate acceleration factors of 4 up to 21, \ rounding off to the nearest one available. As such, acceleration factors outside this range cannot be used. """
[docs] def __call__( self, shape: Union[Sequence[int], np.ndarray], seed: Optional[Union[int, Tuple[int, ...]]] = None, half_scan_percentage: Optional[float] = 0.0, scale: Optional[float] = 0.02, ) -> Tuple[torch.Tensor, int]: """ Parameters ---------- shape: The shape of the mask to be created. The shape should have at least 3 dimensions. Samples are drawn \ along the second last dimension. seed: Seed for the random number generator. Setting the seed ensures the same mask is generated each time \ for the same shape. The random state is reset afterwards. half_scan_percentage: Optional; Defines a fraction of the k-space data that is not sampled. scale: For autocalibration purposes, data points near the k-space center will be fully sampled within an \ ellipse of which the half-axes will set to the set scale % of the fully sampled region Returns ------- A tuple of the mask and the number of columns selected. """ dims = [1 for _ in shape] self.shape = tuple(shape[-3:-1]) dims[-3:-1] = self.shape _, acceleration = self.choose_acceleration() if acceleration > 21.5 or acceleration < 3.5: raise ValueError(f"Acceleration {acceleration} is not supported for Poisson 2D masking.") self.acceleration = acceleration self.scale = scale # TODO: consider moving this to a yaml file rfactor = [ 21.22, 20.32, 19.06, 18.22, 17.41, 16.56, 15.86, 15.12, 14.42, 13.88, 13.17, 12.76, 12.21, 11.72, 11.09, 10.68, 10.35, 10.02, 9.61, 9.22, 9.03, 8.66, 8.28, 8.1, 7.74, 7.62, 7.32, 7.04, 6.94, 6.61, 6.5, 6.27, 6.15, 5.96, 5.83, 5.59, 5.46, 5.38, 5.15, 5.05, 4.9, 4.86, 4.67, 4.56, 4.52, 4.41, 4.31, 4.21, 4.11, 3.99, ] self.r = min(range(len(rfactor)), key=lambda i: abs(rfactor[i] - self.acceleration)) + 40 pattern1 = self.poisson_disc2d() pattern2 = self.centered_circle() mask = np.logical_or(pattern1, pattern2) if half_scan_percentage != 0: mask[: int(np.round(mask.shape[0] * half_scan_percentage)), :] = 0.0 return (torch.from_numpy(mask.reshape(dims).astype(np.float32)), acceleration)
[docs] def poisson_disc2d(self): """Creates a 2D Poisson disc pattern.""" # Amount of tries before discarding a reference point for new samples k = 10 # Amount of samples to be drawn pattern_shape = (self.shape[0] - 1, self.shape[1] - 1) # Initialize the pattern center = np.array([1.0 * pattern_shape[0] / 2, 1.0 * pattern_shape[1] / 2]) width, height = pattern_shape # Cell side length (equal to r_min) a = 1 # Number of cells in the x- and y-directions of the grid nx, ny = int(width / a), int(height / a) # A list of coordinates in the grid of cells coords_list = [(ix, iy) for ix in range(nx + 1) for iy in range(ny + 1)] # Initialize the dictionary of cells: each key is a cell's coordinates, the corresponding value is the index # of that cell's point's that might cause conflict when adding a new point. cells = {coords: [] for coords in coords_list} centernorm = np.linalg.norm(center) def calc_r(coords): """Calculate r for the given coordinates.""" return ((np.linalg.norm(np.asarray(coords) - center) / centernorm) * 240 + 50) / self.r def get_cell_coords(pt): """Get the coordinates of the cell that pt = (x,y) falls in.""" return int(np.floor_divide(pt[0], a)), int(np.floor_divide(pt[1], a)) def mark_neighbours(idx): """Add sample index to the cells within r(point) range of the point.""" coords = samples[idx] if idx in cells[get_cell_coords(coords)]: # This point is already marked on the grid, so we can skip return # Mark the point on the grid rx = calc_r(coords) xvals = np.arange(coords[0] - rx, coords[0] + rx) yvals = np.arange(coords[1] - rx, coords[1] + rx) # Get the coordinates of the cells that the point falls in xvals = xvals[(xvals >= 0) & (xvals <= width)] yvals = yvals[(yvals >= 0) & (yvals <= height)] def dist(x, y): """Calculate the distance between the point and the cell.""" return np.sqrt((coords[0] - x) ** 2 + (coords[1] - y) ** 2) < rx xx, yy = np.meshgrid(xvals, yvals, sparse=False) # Mark the points in the grid pts = np.vstack((xx.ravel(), yy.ravel())).T pts = pts[dist(pts[:, 0], pts[:, 1])] return [cells[get_cell_coords(pt)].append(idx) for pt in pts] def point_valid(pt): """Check if the point is valid.""" rx = calc_r(pt) if rx < 1: if np.linalg.norm(pt - center) < self.scale * width: return False rx = 1 # Get the coordinates of the cells that the point falls in neighbour_idxs = cells[get_cell_coords(pt)] for n in neighbour_idxs: n_coords = samples[n] # Squared distance between or candidate point, pt, and this nearby_pt. distance = np.sqrt((n_coords[0] - pt[0]) ** 2 + (n_coords[1] - pt[1]) ** 2) if distance < rx: # The points are too close, so pt is not a candidate. return False # All points tested: if we're here, pt is return True def get_point(k, refpt): """ Try to find a candidate point relative to refpt to emit in the sample. We draw up to k points from the annulus of inner radius r, outer radius 2r around the reference point, refpt. If none of them are suitable return False. Otherwise, return the pt. """ i = 0 rx = calc_r(refpt) while i < k: rho, theta = np.random.uniform(rx, 2 * rx), np.random.uniform(0, 2 * np.pi) pt = refpt[0] + rho * np.cos(theta), refpt[1] + rho * np.sin(theta) if not (0 < pt[0] < width and 0 < pt[1] < height): # Off the grid, try again. continue if point_valid(pt): return pt i += 1 # We failed to find a suitable point in the vicinity of refpt. return False # Pick a random point to start with. pt = (np.random.uniform(0, width), np.random.uniform(0, height)) samples = [pt] cursample = 0 mark_neighbours(0) # Set active, in the sense that we're going to look for more points in its neighbourhood. active = [0] # As long as there are points in the active list, keep trying to find samples. while active: # choose a random "reference" point from the active list. idx = np.random.choice(active) refpt = samples[idx] # Try to pick a new point relative to the reference point. pt = get_point(k, refpt) if pt: # Point pt is valid: add it to the samples list and mark it as active samples.append(pt) cursample += 1 active.append(cursample) mark_neighbours(cursample) else: # We had to give up looking for valid points near refpt, so remove it from the list of "active" points. active.remove(idx) samples = np.rint(np.array(samples)).astype(int) samples = np.unique(samples[:, 0] + 1j * samples[:, 1]) samples = np.column_stack((samples.real, samples.imag)).astype(int) poisson_pattern = np.zeros((pattern_shape[0] + 1, pattern_shape[1] + 1), dtype=bool) poisson_pattern[samples[:, 0], samples[:, 1]] = True return poisson_pattern
[docs] def centered_circle(self): """Creates a boolean centered circle image using the scale as a radius.""" center_x = int((self.shape[0] - 1) / 2) center_y = int((self.shape[1] - 1) / 2) X, Y = np.indices(self.shape) radius = int(self.shape[0] * self.scale) return ((X - center_x) ** 2 + (Y - center_y) ** 2) < radius**2
[docs]def create_mask_for_mask_type( mask_type_str: str, center_fractions: Sequence[float], accelerations: Sequence[int] ) -> MaskFunc: """ Creates a MaskFunc object for the given mask type. Parameters ---------- mask_type_str: The string representation of the mask type. center_fractions: The center fractions for the mask. accelerations: The accelerations for the mask. Returns ------- A MaskFunc object. """ if mask_type_str == "random1d": return RandomMaskFunc(center_fractions, accelerations) if mask_type_str == "equispaced1d": return Equispaced1DMaskFunc(center_fractions, accelerations) if mask_type_str == "equispaced2d": return Equispaced2DMaskFunc(center_fractions, accelerations) if mask_type_str == "gaussian1d": return Gaussian1DMaskFunc(center_fractions, accelerations) if mask_type_str == "gaussian2d": return Gaussian2DMaskFunc(center_fractions, accelerations) if mask_type_str == "poisson2d": return Poisson2DMaskFunc(center_fractions, accelerations) raise NotImplementedError(f"{mask_type_str} not supported")