mridc.collections.reconstruction.models.sigmanet package

Submodules

mridc.collections.reconstruction.models.sigmanet.dc_layers module

class mridc.collections.reconstruction.models.sigmanet.dc_layers.ConjugateGradient(*args, **kwargs)[source]

Bases: Function

Conjugate Gradient solver for the prox of the data term.

static backward(ctx, grad_x)[source]

Backward pass of the conjugate gradient solver.

Parameters
  • ctx (Context object.) –

  • grad_x (Gradient of the output image.) –

Returns

grad_z

Return type

Gradient of the input image.

static complexDot(data1, data2)[source]

Complex dot product of two tensors.

static forward(ctx, z, lambdaa, y, smaps, mask, tol, max_iter, fft_centered, fft_normalization, spatial_dims)[source]

Forward pass of the conjugate gradient solver.

Parameters
  • ctx (Context object.) –

  • z (Input image.) –

  • lambdaa (Regularization parameter.) –

  • y (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

  • tol (Tolerance for the stopping criterion.) –

  • max_iter (Maximum number of iterations.) –

  • fft_centered (Boolean flag for centering the FFT.) –

  • fft_normalization (Boolean flag for normalizing the FFT.) –

  • spatial_dims (Spatial dimensions.) –

Returns

z

Return type

Output image.

static solve(x0, M, tol, max_iter)[source]

Solve the linear system Mx=b using conjugate gradient.

class mridc.collections.reconstruction.models.sigmanet.dc_layers.DCLayer(lambda_init=0.0, learnable=True, fft_centered: bool = True, fft_normalization: str = 'ortho', spatial_dims: Optional[Tuple[int, int]] = None)[source]

Bases: Module

Data Consistency layer from DC-CNN, apply for single coil mainly

forward(x, y, mask)[source]

Forward pass of the data-consistency block.

Parameters
  • x (Input image.) –

  • y (Subsampled k-space data.) –

  • mask (Sampling mask.) –

Return type

Output image.

set_learnable(flag)[source]

Set the learnable flag of the parameters.

Parameters

flag (If True, the parameters of the model are learnable.) –

training: bool
class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataGDLayer(lambda_init, learnable=True, fft_centered: bool = True, fft_normalization: str = 'ortho', spatial_dims: Optional[Tuple[int, int]] = None)[source]

Bases: Module

DataLayer computing the gradient on the L2 dataterm.

forward(x, y, smaps, mask)[source]
Parameters
  • x (Input image.) –

  • y (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

Returns

data_loss

Return type

Data term loss.

training: bool
class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataIDLayer(*args, **kwargs)[source]

Bases: Module

Placeholder for the data layer.

training: bool
class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataProxCGLayer(lambda_init, tol=1e-06, iter=10, learnable=True, fft_centered: bool = True, fft_normalization: str = 'ortho', spatial_dims: Optional[Tuple[int, int]] = None)[source]

Bases: Module

Solving the prox wrt. dataterm using Conjugate Gradient as proposed by Aggarwal et al.

forward(x, f, smaps, mask)[source]
Parameters
  • x (Input image.) –

  • f (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

Returns

data_loss

Return type

Data term loss.

set_learnable(flag)[source]
training: bool
class mridc.collections.reconstruction.models.sigmanet.dc_layers.DataVSLayer(alpha_init, beta_init, learnable=True, fft_centered: bool = True, fft_normalization: str = 'ortho', spatial_dims: Optional[Tuple[int, int]] = None)[source]

Bases: Module

DataLayer using variable splitting formulation

forward(x, y, smaps, mask)[source]

Forward pass of the data-consistency block.

Parameters
  • x (Input image.) –

  • y (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

Return type

Output image.

set_learnable(flag)[source]

Set the learnable flag of the parameters.

Parameters

flag (If True, the parameters of the model are learnable.) –

training: bool

mridc.collections.reconstruction.models.sigmanet.sensitivity_net module

class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexInstanceNorm[source]

Bases: Module

Motivated by ‘Deep Complex Networks’ (https://arxiv.org/pdf/1705.09792.pdf)

complex_instance_norm(x, eps=1e-05)[source]

Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

complex_pseudocovariance(data)[source]

Data variable hast to be already mean-free! Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

forward(input)[source]

Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]

normalize(x)[source]

Normalize the input x.

set_normalization(input)[source]

Set the normalization parameters for a given input.

training: bool
unnormalize(x)[source]

Unnormalize the input x.

class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.ComplexNormWrapper(model)[source]

Bases: Module

Wrapper for complex normalization.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class mridc.collections.reconstruction.models.sigmanet.sensitivity_net.SensitivityNetwork(num_iter, model, datalayer, shared_params=True, save_space=False, reset_cache=False)[source]

Bases: Module

Sensitivity network with data term based on forward and adjoint containing the sensitivity maps

copy_params(src_i, trg_j)[source]

copy i-th cascade net parameters to j-th cascade net parameters

forward(x, y, smaps, mask)[source]
Parameters
  • x (Input data.) –

  • y (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

Return type

Output data.

forward_save_space(x, y, smaps, mask)[source]
Parameters
  • x (Input data.) –

  • y (Subsampled k-space data.) –

  • smaps (Coil sensitivity maps.) –

  • mask (Sampling mask.) –

Return type

Output data.

freeze(i)[source]

freeze parameter of cascade i

freeze_all()[source]

freeze parameter of cascade i

stage_training_init()[source]

set stage training flag to True

stage_training_transition_i(copy=False)[source]

set stage training flag to True

training: bool
unfreeze(i)[source]

freeze parameter of cascade i

unfreeze_all()[source]

freeze parameter of cascade i

mridc.collections.reconstruction.models.sigmanet.sensitivity_net.matrix_invert(xx, xy, yx, yy)[source]

Invert a 2x2 matrix.

Module contents