mridc.collections.reconstruction.models.recurrentvarnet package
Submodules
mridc.collections.reconstruction.models.recurrentvarnet.conv2gru module
- class mridc.collections.reconstruction.models.recurrentvarnet.conv2gru.Conv2dGRU(in_channels: int, hidden_channels: int, out_channels: Optional[int] = None, num_layers: int = 2, gru_kernel_size=1, orthogonal_initialization: bool = True, instance_norm: bool = False, dense_connect: int = 0, replication_padding: bool = True)[source]
Bases:
Module
2D Convolutional GRU Network.
- forward(cell_input: Tensor, previous_state: Tensor) Tuple[Tensor, Tensor] [source]
Computes Conv2dGRU forward pass given tensors cell_input and previous_state.
- Parameters
cell_input (Reconstruction input) –
previous_state (Tensor of previous states.) –
- Return type
Output and new states.
- training: bool
mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet module
- class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentInit(in_channels: int, out_channels: int, channels: Tuple[int, ...], dilations: Tuple[int, ...], depth: int = 2, multiscale_depth: int = 1)[source]
Bases:
Module
Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al. The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.
References
Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- forward(x: Tensor) Tensor [source]
Computes initialization for recurrent unit given input x.
- Parameters
x (Initialization for RecurrentInit.) –
- Return type
Initial recurrent hidden state from input x.
- training: bool
- class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentVarNetBlock(in_channels: int = 2, hidden_channels: int = 64, num_layers: int = 4, fft_centered: bool = True, fft_normalization: str = 'ortho', spatial_dims: Optional[Tuple[int, int]] = None, coil_dim: int = 1)[source]
Bases:
Module
Recurrent Variational Network Block \(\mathcal{H}_{ heta_{t}}\) as presented in Yiasemis, George, et al.
References
Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- forward(current_kspace: Tensor, masked_kspace: Tensor, sampling_mask: Tensor, sensitivity_map: Tensor, hidden_state: Union[None, Tensor]) Tuple[Tensor, Tensor] [source]
Computes forward pass of RecurrentVarNetBlock.
- Parameters
current_kspace (Current k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]
masked_kspace (Subsampled k-space.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]
sampling_mask (Sampling mask.) – torch.Tensor, shape [batch_size, 1, height, width, 1]
sensitivity_map (Coil sensitivities.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]
hidden_state (ConvGRU hidden state.) – None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]
- Returns
new_kspace (New k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]
hidden_state (Next hidden state.) – list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]
- training: bool