Source code for mridc.collections.common.parts.fft

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

from typing import List, Sequence, Union

import numpy as np
import torch
from omegaconf import ListConfig

__all__ = ["fft2", "ifft2"]


[docs]def fft2( data: torch.Tensor, centered: bool = True, normalization: str = "ortho", spatial_dims: Sequence[int] = None, ) -> torch.Tensor: """ Apply 2 dimensional Fast Fourier Transform. Parameters ---------- data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All other dimensions are assumed to be batch dimensions. centered: Whether to center the fft. normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. spatial_dims: dimensions to apply the FFT Returns ------- The FFT of the input. """ if data.shape[-1] == 2: data = torch.view_as_complex(data) if spatial_dims is None: spatial_dims = [-2, -1] elif isinstance(spatial_dims, ListConfig): spatial_dims = list(spatial_dims) if centered: data = ifftshift(data, dim=spatial_dims) data = torch.fft.fft2( data, dim=spatial_dims, norm=normalization if normalization.lower() != "none" else None, ) if centered: data = fftshift(data, dim=spatial_dims) data = torch.view_as_real(data) return data
[docs]def ifft2( data: torch.Tensor, centered: bool = True, normalization: str = "ortho", spatial_dims: Sequence[int] = None, ) -> torch.Tensor: """ Apply 2 dimensional Inverse Fast Fourier Transform. Parameters ---------- data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All other dimensions are assumed to be batch dimensions. centered: Whether to center the fft. normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. spatial_dims: dimensions to apply the FFT Returns ------- The FFT of the input. """ if data.shape[-1] == 2: data = torch.view_as_complex(data) if spatial_dims is None: spatial_dims = [-2, -1] elif isinstance(spatial_dims, ListConfig): spatial_dims = list(spatial_dims) if centered: data = ifftshift(data, dim=spatial_dims) data = torch.fft.ifft2( data, dim=spatial_dims, norm=normalization if normalization.lower() != "none" else None, ) if centered: data = fftshift(data, dim=spatial_dims) data = torch.view_as_real(data) return data
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: """ Similar to roll but for only one dim. Parameters ---------- x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns ------- Rolled version of x. """ shift %= x.size(dim) if shift == 0: return x left = x.narrow(dim, 0, x.size(dim) - shift) right = x.narrow(dim, x.size(dim) - shift, shift) return torch.cat((right, left), dim=dim) def roll(x: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]) -> torch.Tensor: """ Similar to np.roll but applies to PyTorch Tensors. Parameters ---------- x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns ------- Rolled version of x. """ if len(shift) != len(dim): raise ValueError("len(shift) must match len(dim)") if isinstance(dim, ListConfig): dim = list(dim) for (s, d) in zip(shift, dim): x = roll_one_dim(x, s, d) return x def fftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor: """ Similar to np.fft.fftshift but applies to PyTorch Tensors Parameters ---------- x: A PyTorch tensor. dim: Which dimension to fftshift. Returns ------- fftshifted version of x. """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i elif isinstance(dim, ListConfig): dim = list(dim) # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num], 2) return roll(x, shift, dim) def ifftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor: """ Similar to np.fft.ifftshift but applies to PyTorch Tensors Parameters ---------- x: A PyTorch tensor. dim: Which dimension to ifftshift. Returns ------- ifftshifted version of x. """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i elif isinstance(dim, ListConfig): dim = list(dim) # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num] + 1, 2) return roll(x, shift, dim)