mridc.utils package

Subpackages

Submodules

mridc.utils.app_state module

class mridc.utils.app_state.AppState(*args, **kwargs)[source]

Bases: object

A singleton class that holds the state of the application.

property checkpoint_callback_params

Returns the version set by exp_manager.

property checkpoint_name

Returns the name set by exp_manager.

property create_checkpoint_callback

Returns the create_checkpoint_callback set by exp_manager.

property data_parallel_group

Property returns the data parallel group.

property data_parallel_rank

Property returns the data parallel rank.

property data_parallel_size

Property returns the number of GPUs in each data parallel group.

property device_id

Property returns the device_id.

property exp_dir

Returns the exp_dir set by exp_manager.

get_model_metadata_from_guid(guid) ModelMetadataRegistry[source]

Returns the global model idx and restoration path.

property global_rank

Property returns the global rank.

property is_model_being_restored: bool

Returns whether a model is being restored.

property local_rank

Property returns the local rank.

property log_dir

Returns the log_dir set by exp_manager.

property model_parallel_size

Property returns the number of GPUs in each model parallel group.

property model_restore_path

Returns the model_restore_path set by exp_manager.

property mridc_file_folder: str

Returns the mridc_file_folder set by exp_manager.

property name

Returns the name set by exp_manager.

property pipeline_model_parallel_group

Property returns the model parallel group.

property pipeline_model_parallel_rank

Property returns the model parallel rank.

property pipeline_model_parallel_size

Property returns the number of GPUs in each model parallel group.

property pipeline_model_parallel_split_rank

Property returns the model parallel split rank.

property random_seed

Property returns the random seed.

register_model_guid(guid: str, restoration_path: Optional[str] = None)[source]

Maps a guid to its restore path (None or last absolute path).

reset_model_guid_registry()[source]

Resets the model guid registry.

property tensor_model_parallel_group

Property returns the model parallel group.

property tensor_model_parallel_rank

Property returns the model parallel rank.

property tensor_model_parallel_size

Property returns the number of GPUs in each model parallel group.

property version

Returns the version set by exp_manager.

property world_size

Property returns the total number of GPUs.

class mridc.utils.app_state.ModelMetadataRegistry(guid: str, gidx: int, restoration_path: Optional[str] = None)[source]

Bases: object

A registry for model metadata.

gidx: int
guid: str
restoration_path: Optional[str] = None

mridc.utils.arguments module

mridc.utils.arguments.add_optimizer_args(parent_parser: ArgumentParser, optimizer: str = 'adam', default_lr: Optional[float] = None, default_opt_args: Optional[Union[Dict[str, Any], List[str]]] = None) ArgumentParser[source]

Extends existing argparse with default optimizer args.

# Example of adding optimizer args to command line: python train_script.py … –optimizer “novograd” –lr 0.01 –opt_args betas=0.95,0.5 weight_decay=0.001

Parameters
  • parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

  • optimizer (Default optimizer required.) – str, default “adam”

  • default_lr (Default learning rate.) – float, default None

  • default_opt_args (Default optimizer arguments.) – Optional[Union[Dict[str, Any], List[str]]], default None

Returns

ArgumentParser

Return type

Parser extended by Optimizers arguments.

mridc.utils.arguments.add_recon_args(parent_parser: ArgumentParser) ArgumentParser[source]

Extends existing argparse with default reconstruction args.

Parameters

parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

Returns

ArgumentParser

Return type

Parser extended by Reconstruction arguments.

mridc.utils.arguments.add_scheduler_args(parent_parser: ArgumentParser) ArgumentParser[source]

Extends existing argparse with default scheduler args.

Parameters

parent_parser (Custom CLI parser that will be extended.) – ArgumentParser

Returns

ArgumentParser

Return type

Parser extended by Schedulers arguments.

mridc.utils.cloud module

mridc.utils.cloud.maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) str[source]

Download a file from a URL if it does not exist in the cache.

Parameters
  • url (URL to download the file from.) – str

  • filename (What to download. The request will be issued to url/filename) – str

  • subfolder (Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty.) – str

  • cache_dir (A cache directory where to download. If not present, this function will attempt to create it.) – str, If None (default), then it will be $HOME/.cache/torch/mridc

  • refresh_cache (If True and cached file is present, it will delete it and re-fetch) – bool

Return type

If successful - absolute local path to the downloaded file else empty string.

mridc.utils.config_utils module

mridc.utils.config_utils.assert_dataclass_signature_match(cls: class_type, datacls: dataclass, ignore_args: Optional[List[str]] = None, remap_args: Optional[Dict[str, str]] = None)[source]

Analyses the signature of a provided class and its respective data class, asserting that the dataclass signature matches the class __init__ signature. .. note:

This is not a value based check. This function only checks if all argument
names exist on both class and dataclass and logs mismatches.
Parameters
  • cls (Any class type - but not an instance of a class. Pass type(x) where x is an instance) – if class type is not easily available.

  • datacls (A corresponding dataclass for the above class.) –

  • ignore_args ((Optional) A list of string argument names which are forcibly ignored,) – even if mismatched in the signature. Useful when a dataclass is a superset of the arguments of a class.

  • remap_args ((Optional) A dictionary, mapping an argument name that exists (in either the) – class or its dataclass), to another name. Useful when argument names are mismatched between a class and its dataclass due to indirect instantiation via a helper method.

Returns

  1. A bool value which is True if the signatures matched exactly / after ignoring values.

    False otherwise.

  2. A set of arguments names that exist in the class, but do not exist in the dataclass.

    If exact signature match occurs, this will be None instead.

  3. A set of argument names that exist in the data class, but do not exist in the class itself.

    If exact signature match occurs, this will be None instead.

Return type

A tuple containing information about the analysis

mridc.utils.config_utils.update_model_config(model_cls: MRIDCConfig, update_cfg: DictConfig, drop_missing_subconfigs: bool = True)[source]
Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that mirrors the structure of the config class. Assumes the update_cfg is a DictConfig (either generated manually, via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values preset inside the ModelPT config class. If drop_missing_subconfigs is set, the certain sub-configs of the ModelPT config class will be removed, if they are not found in the mirrored update_cfg. The following sub-configs are subject to potential removal:
  • train_ds

  • validation_ds

  • test_ds

  • optim + nested sched

Parameters
  • model_cls (A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model.) –

  • update_cfg (A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default values of the config class.) –

  • drop_missing_subconfigs (Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, if the corresponding sub-config is missing from update_cfg.) –

Return type

A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting infrastructure.

mridc.utils.debug_hook module

mridc.utils.debug_hook.get_backward_hook(name, trainer, rank, logger, dump_to_file=False)[source]

A backward hook to dump all the module input and output grad norms. The hook will be called every time the gradients with respect to module inputs are computed. Only float type input/output grad tensor norms are computed.

For more details about the backward hook, check: https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_full_backward_hook.html

Parameters
  • name (str) – tensor name

  • trainer (PTL trainer) – PTL trainer

  • rank (int) – worker rank

  • logger (PTL log function) – PTL log function

  • dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False

Return type

backward_hook

mridc.utils.debug_hook.get_forward_hook(name, trainer, rank, logger, dump_to_file=False)[source]

A forward hook to dump all the module input and output norms. It is called at every time after forward() has computed an output. Only float type input/output tensor norms are computed.

For more details about the forward hook, check: https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html

Parameters
  • name (str) – tensor name

  • trainer (PTL trainer) – PTL trainer

  • rank (int) – worker rank

  • logger (PTL log function) – PTL log function

  • dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False

Return type

forward_hook

mridc.utils.debug_hook.get_tensor_hook(module, name, trainer, rank, logger, dump_to_file=False)[source]

A tensor hook to dump all of the tensor weight norms and grad norms at the end of each of the backward steps.

For more details about the tensor hook, check: https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html

Parameters
  • module (torch.nn.Module) – module to register the hook

  • name (str) – tensor name

  • trainer (PTL trainer) – PTL trainer

  • rank (int) – worker rank

  • logger (PTL log function) – PTL log function

  • dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False

Return type

tensor_hook

mridc.utils.debug_hook.register_debug_hooks(module, trainer, logger, dump_to_file=False)[source]

Register debug hooks. It can 1. track the module forward step input/output norm 2. track the module backward step input/output grad norm 3. track the parameter weight norm and grad norm.

mridc.utils.distributed module

mridc.utils.distributed.initialize_distributed(args, backend='nccl')[source]

Initialize distributed training.

Parameters
  • args (The arguments object.) –

  • backend (The backend to use.) – default: “nccl”

Returns

  • local_rank (The local rank of the process.)

  • rank (The rank of the process.)

  • world_size (The number of processes.)

mridc.utils.env_var_parsing module

exception mridc.utils.env_var_parsing.CoercionError(key, value, func)[source]

Bases: Exception

Custom error raised when a value cannot be coerced.

exception mridc.utils.env_var_parsing.RequiredSettingMissingError(key)[source]

Bases: Exception

Custom error raised when a required env var is missing.

mridc.utils.env_var_parsing.get_env(key, *default, **kwargs)[source]

Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ).

Parameters
  • key (string, the env var name to look up.) –

  • default ((optional) the value to use if the env var does not exist. If this value is not supplied, then the env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist.) –

  • kwargs – coerce: a func that may be supplied to coerce the value into something else. This is used by the default get_foo functions to cast strings to builtin types, but could be a function that returns a custom class.

Return type

The env var, coerced if required, and a default if supplied.

mridc.utils.env_var_parsing.get_envbool(key, *default)[source]

Return env var cast as boolean.

mridc.utils.env_var_parsing.get_envdate(key, *default)[source]

Return env var as a date.

mridc.utils.env_var_parsing.get_envdatetime(key, *default)[source]

Return env var as a datetime.

mridc.utils.env_var_parsing.get_envdecimal(key, *default)[source]

Return env var cast as Decimal.

mridc.utils.env_var_parsing.get_envdict(key, *default)[source]

Return env var as a dict.

mridc.utils.env_var_parsing.get_envfloat(key, *default)[source]

Return env var cast as float.

mridc.utils.env_var_parsing.get_envint(key, *default)[source]

Return env var cast as integer.

mridc.utils.env_var_parsing.get_envlist(key, *default, **kwargs)[source]

Return env var as a list.

mridc.utils.exceptions module

class mridc.utils.exceptions.CheckInstall(*args, **kwargs)[source]

Bases: object

Class to check if a package is installed.

exception mridc.utils.exceptions.LightningNotInstalledException(obj)[source]

Bases: MRIDCBaseException

Exception for when lightning is not installed

exception mridc.utils.exceptions.MRIDCBaseException[source]

Bases: Exception

MRIDC Base Exception. All exceptions created in MRIDC should inherit from this class

mridc.utils.exp_manager module

class mridc.utils.exp_manager.CallbackParams(filepath: Optional[str] = None, dirpath: Optional[str] = None, filename: Optional[str] = None, monitor: Optional[str] = 'val_loss', verbose: Optional[bool] = True, save_last: Optional[bool] = True, save_top_k: Optional[int] = 3, save_weights_only: Optional[bool] = False, mode: Optional[str] = 'min', every_n_epochs: Optional[int] = 1, prefix: Optional[str] = None, postfix: str = '.mridc', save_best_model: bool = False, always_save_mridc: bool = False, save_mridc_on_train_end: Optional[bool] = True, model_parallel_size: Optional[int] = None)[source]

Bases: object

Parameters for a callback

always_save_mridc: bool = False
dirpath: Optional[str] = None
every_n_epochs: Optional[int] = 1
filename: Optional[str] = None
filepath: Optional[str] = None
mode: Optional[str] = 'min'
model_parallel_size: Optional[int] = None
monitor: Optional[str] = 'val_loss'
postfix: str = '.mridc'
prefix: Optional[str] = None
save_best_model: bool = False
save_last: Optional[bool] = True
save_mridc_on_train_end: Optional[bool] = True
save_top_k: Optional[int] = 3
save_weights_only: Optional[bool] = False
verbose: Optional[bool] = True
exception mridc.utils.exp_manager.CheckpointMisconfigurationError[source]

Bases: MRIDCBaseException

Raised when a mismatch between trainer.callbacks and exp_manager occurs

class mridc.utils.exp_manager.ExpManagerConfig(explicit_log_dir: Optional[str] = None, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, use_datetime_version: Optional[bool] = True, resume_if_exists: Optional[bool] = False, resume_past_end: Optional[bool] = False, resume_ignore_no_checkpoint: Optional[bool] = False, create_tensorboard_logger: Optional[bool] = True, summary_writer_kwargs: Optional[Dict[Any, Any]] = None, create_wandb_logger: Optional[bool] = False, wandb_logger_kwargs: Optional[Dict[Any, Any]] = None, create_checkpoint_callback: Optional[bool] = True, checkpoint_callback_params: Optional[CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None), files_to_copy: Optional[List[str]] = None, log_step_timing: Optional[bool] = True, step_timing_kwargs: Optional[StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1), log_local_rank_0_only: Optional[bool] = False, log_global_rank_0_only: Optional[bool] = False, model_parallel_size: Optional[int] = None)[source]

Bases: object

Configuration for the experiment manager.

checkpoint_callback_params: Optional[CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None)
create_checkpoint_callback: Optional[bool] = True
create_tensorboard_logger: Optional[bool] = True
create_wandb_logger: Optional[bool] = False
exp_dir: Optional[str] = None
explicit_log_dir: Optional[str] = None
files_to_copy: Optional[List[str]] = None
log_global_rank_0_only: Optional[bool] = False
log_local_rank_0_only: Optional[bool] = False
log_step_timing: Optional[bool] = True
model_parallel_size: Optional[int] = None
name: Optional[str] = None
resume_if_exists: Optional[bool] = False
resume_ignore_no_checkpoint: Optional[bool] = False
resume_past_end: Optional[bool] = False
step_timing_kwargs: Optional[StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1)
summary_writer_kwargs: Optional[Dict[Any, Any]] = None
use_datetime_version: Optional[bool] = True
version: Optional[str] = None
wandb_logger_kwargs: Optional[Dict[Any, Any]] = None
class mridc.utils.exp_manager.LoggerList(_logger_iterable, mridc_name=None, mridc_version='')[source]

Bases: LoggerCollection

A thin wrapper on Lightning’s LoggerCollection such that name and version are better aligned with exp_manager

property name: str

The name of the experiment.

property version: str

The version of the experiment. If the logger was created with a version, this will be the version.

exception mridc.utils.exp_manager.LoggerMisconfigurationError(message)[source]

Bases: MRIDCBaseException

Raised when a mismatch between trainer.logger and exp_manager occurs

class mridc.utils.exp_manager.MRIDCModelCheckpoint(always_save_mridc=False, save_mridc_on_train_end=True, save_best_model=False, postfix='.mridc', n_resume=False, model_parallel_size=None, **kwargs)[source]

Bases: ModelCheckpoint

Light wrapper around Lightning’s ModelCheckpoint to force a saved checkpoint on train_end

mridc_topk_check_previous_run()[source]

Check if there are previous runs with the same topk value.

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Override the default on_save_checkpoint to save the best model if needed.

Parameters
  • trainer (The trainer object.) –

  • pl_module (The PyTorch-Lightning module.) –

  • checkpoint (The checkpoint object.) –

on_train_end(trainer, pl_module)[source]

This is called at the end of training.

Parameters
  • trainer (The trainer object.) –

  • pl_module (The PyTorch-Lightning module.) –

exception mridc.utils.exp_manager.NotFoundError[source]

Bases: MRIDCBaseException

Raised when a file or folder is not found

class mridc.utils.exp_manager.StatelessTimer(duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True)[source]

Bases: Timer

Extension of PTL timers to be per run.

load_state_dict(state_dict: Dict[str, Any]) None[source]

Loads the state of the timer.

state_dict() Dict[str, Any][source]

Saves the state of the timer.

class mridc.utils.exp_manager.StepTimingParams(reduction: Optional[str] = 'mean', sync_cuda: Optional[bool] = False, buffer_size: Optional[int] = 1)[source]

Bases: object

Parameters for the step timing callback.

buffer_size: Optional[int] = 1
reduction: Optional[str] = 'mean'
sync_cuda: Optional[bool] = False
class mridc.utils.exp_manager.TimingCallback(timer_kwargs=None)[source]

Bases: Callback

Logs execution time of train/val/test steps

on_after_backward(trainer, pl_module)[source]

Note: this is called after the optimizer step

on_before_backward(trainer, pl_module, loss)[source]

Logs the time taken for backward pass

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Logs execution time of test steps

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Logs execution time of test steps

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, **kwargs)[source]

Logs the time taken by the training batch

on_train_batch_start(trainer, pl_module, batch, batch_idx, **kwargs)[source]

Called at the beginning of each training batch

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Logs the time taken by the validation step

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Logs the time taken by the validation batch

mridc.utils.exp_manager.check_explicit_log_dir(trainer: Trainer, explicit_log_dir: List[Union[Path, str]], exp_dir: str, name: str, version: str) Tuple[Path, str, str, str][source]

Checks that the passed arguments are compatible with explicit_log_dir.

Parameters
  • trainer (The trainer to check.) –

  • explicit_log_dir (The explicit log dir to check.) –

  • exp_dir (The experiment directory to check.) –

  • name (The experiment name to check.) –

  • version (The experiment version to check.) –

Return type

The log_dir, exp_dir, name, and version that should be used.

Raises

LoggerMisconfigurationError

mridc.utils.exp_manager.check_resume(trainer: Trainer, log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False)[source]

Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.

Parameters
  • trainer (The trainer that is being used.) –

  • log_dir (The directory where the logs are being saved.) –

  • resume_past_end (Whether to resume from the end of the experiment.) –

  • resume_ignore_no_checkpoint (Whether to ignore if there is no checkpoint to resume from.) –

Returns

  • NotFoundError (If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.)

  • ValueError (If resume is True, and there were more than 1 checkpoint could found.)

mridc.utils.exp_manager.check_slurm(trainer)[source]

Checks if the trainer is running on a slurm cluster. If so, it will check if the trainer is running on the master node. If it is not, it will exit.

Parameters

trainer (The trainer to check.) –

Return type

True if the trainer is running on the master node, False otherwise.

mridc.utils.exp_manager.configure_checkpointing(trainer: Trainer, log_dir: Path, name: str, resume: bool, params: DictConfig)[source]

Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer.

mridc.utils.exp_manager.configure_loggers(trainer: Trainer, exp_dir: List[Union[Path, str]], name: str, version: str, create_tensorboard_logger: bool, summary_writer_kwargs: dict, create_wandb_logger: bool, wandb_kwargs: dict)[source]

Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs or wandb_kwargs are miss configured.

Parameters
  • trainer (The trainer to attach the loggers to.) –

  • exp_dir (The experiment directory.) –

  • name (The name of the experiment.) –

  • version (The version of the experiment.) –

  • create_tensorboard_logger (Whether to create a TensorboardLogger.) –

  • summary_writer_kwargs (The kwargs to pass to the TensorboardLogger.) –

  • create_wandb_logger (Whether to create a Weights & Biases logger.) –

  • wandb_kwargs (The kwargs to pass to the Weights & Biases logger.) –

Returns

LoggerList

Return type

A list of loggers.

mridc.utils.exp_manager.error_checks(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None)[source]
Checks that the passed trainer is compliant with MRIDC and exp_manager’s passed configuration. Checks that:
  • Throws error when hydra has changed the working directory. This causes issues with lightning’s DDP

  • Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True

  • Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP

mridc.utils.exp_manager.exp_manager(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) Optional[Path][source]

exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file for each process to log their output into.

exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when resume_if_exists is set to True, creating the version folders is ignored.

Parameters
  • trainer (The lightning trainer object.) –

  • cfg (Can have the following keys:) –

    • explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which will use exp_dir, name, and version to construct the logging directory.

    • exp_dir: The base directory to create the logging directory. Defaults to None, which logs to ./mridc_experiments.

    • name: The name of the experiment. Defaults to None which turns into “default” via name = name or “default”.

    • version: The version of the experiment. Defaults to None which uses either a datetime string or lightning’s TensorboardLogger system of using version_{int}.

    • use_datetime_version: Whether to use a datetime string for version. Defaults to True.

    • resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when resume_if_exists is True, we would not create version folders to make it easier to find the log folder for next runs.

    • resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.

    • resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.

    • create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning trainer. Defaults to True.

    • summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning’s TensorboardLogger class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.

    • create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning trainer. Defaults to False.

    • wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning’s WandBLogger class. Note that name and project are required parameters if create_wandb_logger is True. Defaults to None.

    • create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best “val_loss”, the most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt. Defaults to True.

    • files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies no files.

    • log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    • log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.

Return type

The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.

mridc.utils.exp_manager.get_git_diff()[source]

Helper function that tries to get the git diff if running inside a git folder.

Returns

  • Bool (Whether the git subprocess ran without error.)

  • String (git subprocess output or error message)

mridc.utils.exp_manager.get_git_hash()[source]

Helper function that tries to get the commit hash if running inside a git folder.

Returns

  • Bool (Whether the git subprocess ran without error.)

  • String (git subprocess output or error message)

mridc.utils.exp_manager.get_log_dir(trainer: Trainer, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, explicit_log_dir: Optional[str] = None, use_datetime_version: bool = True, resume_if_exists: bool = False) Tuple[Path, str, str, str][source]

Obtains the log_dir used for exp_manager.

Parameters
  • trainer (The trainer to check.) –

  • exp_dir (The experiment directory to check.) –

  • name (The experiment name to check.) –

  • version (The experiment version to check.) –

  • explicit_log_dir (The explicit log dir to check.) –

  • use_datetime_version (Whether to use datetime versioning.) –

  • resume_if_exists (Whether to resume if the log_dir already exists.) –

Raises
  • LoggerMisconfigurationError – If trainer is incompatible with arguments:

  • NotFoundError – If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.:

  • ValueError – If resume is True, and there were more than 1 checkpoint could found.:

mridc.utils.export_utils module

class mridc.utils.export_utils.CastToFloat(mod)[source]

Bases: Module

Cast input to float

forward(x)[source]

Forward pass

training: bool
class mridc.utils.export_utils.ExportFormat(value)[source]

Bases: Enum

Which format to use when exporting a Neural Module for deployment

ONNX = (1,)
TORCHSCRIPT = (2,)
mridc.utils.export_utils.augment_filename(output: str, prepend: str)[source]

Augment output filename with prepend

mridc.utils.export_utils.cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32)[source]

Cast all tensors in x from from_dtype to to_dtype

mridc.utils.export_utils.cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32)[source]

Cast tensor from from_dtype to to_dtype

mridc.utils.export_utils.forward_method(self)[source]

Forward method for export

mridc.utils.export_utils.get_export_format(filename: str)[source]

Get export format from filename

mridc.utils.export_utils.parse_input_example(input_example)[source]

Parse input example to onnxrt input format

mridc.utils.export_utils.replace_for_export(model: Module) Module[source]

Top-level function to replace default set of modules in model NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

Parameters

model (Top-level model to replace modules in.) –

Return type

The model with replaced modules.

mridc.utils.export_utils.replace_modules(model: Module, expansions: Optional[Dict[str, Callable[[Module], Optional[Module]]]] = None) Module[source]

Top-level function to replace modules in model, specified by class name with a desired replacement. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

Parameters
  • model (Top-level model to replace modules in.) –

  • expansions (A dictionary of module class names to functions to replace them with.) –

Return type

The model with replaced modules.

mridc.utils.export_utils.run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01)[source]

Run onnxrt and compare with output example

mridc.utils.export_utils.simple_replace(BaseT: Type[Module], DestT: Type[Module]) Callable[[Module], Optional[Module]][source]

Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. No weights are copied.

Parameters
  • BaseT (The base type of the module.) –

  • DestT (The destination type of the module.) –

Return type

A function to replace BaseT with DestT.

mridc.utils.export_utils.swap_modules(model: Module, mapping: Dict[str, Module])[source]

This function swaps nested modules as specified by “dot paths” in mod with a desired replacement. This allows for swapping nested modules through arbitrary levels if children NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

mridc.utils.export_utils.to_onnxrt_input(ort_input_names, input_names, input_dict, input_list)[source]

Convert input to onnxrt input

mridc.utils.export_utils.verify_runtime(model, output, input_examples, input_names, check_tolerance=0.01)[source]

Verify runtime output with onnxrt.

mridc.utils.export_utils.wrap_forward_method(self)[source]

Wraps the forward method of the module with a function that returns the output of the forward method

mridc.utils.export_utils.wrap_module(BaseT: Type[Module], DestT: Type[Module]) Callable[[Module], Optional[Module]][source]

Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. No weights are copied.

Parameters
  • BaseT (The base type of the module.) –

  • DestT (The destination type of the module.) –

Return type

A function to replace BaseT with DestT.

mridc.utils.get_rank module

mridc.utils.get_rank.get_rank()[source]

Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise returns 0.

mridc.utils.get_rank.is_global_rank_zero()[source]

Helper function to determine if the current process is global_rank 0 (the main process).

mridc.utils.lightning_logger_patch module

mridc.utils.lightning_logger_patch.add_filehandlers_to_pl_logger(all_log_file, err_log_file)[source]

Adds two filehandlers to pytorch_lightning’s logger. Called in mridc.utils.exp_manager(). The first filehandler logs all messages to all_log_file while the second filehandler logs all WARNING and higher messages to err_log_file. If “memory_err” and “memory_all” exist in HANDLERS, then those buffers are flushed to err_log_file and all_log_file respectively, and then closed.

mridc.utils.lightning_logger_patch.add_memory_handlers_to_pl_logger()[source]

Adds two MemoryHandlers to pytorch_lightning’s logger. These two handlers are essentially message buffers. This function is called in mridc.utils.__init__.py. These handlers are used in add_filehandlers_to_pl_logger to flush buffered messages to files.

mridc.utils.metaclasses module

class mridc.utils.metaclasses.Singleton[source]

Bases: type

Implementation of a generic, tread-safe singleton meta-class. Can be used as meta-class, i.e. will create.

__call__(*args, **kwargs)[source]

Returns singleton instance. A thread safe implementation.

mridc.utils.model_utils module

class mridc.utils.model_utils.ArtifactItem[source]

Bases: object

ArtifactItem is a dataclass that holds the information of an artifact.

hashed_path: Optional[str] = None
path: str
path_type: ArtifactPathType
class mridc.utils.model_utils.ArtifactPathType(value)[source]

Bases: Enum

ArtifactPathType refers to the type of the path that the artifact is located at. LOCAL_PATH: A user local filepath that exists on the file system. TAR_PATH: A (generally flattened) filepath that exists inside of an archive (that may have its own full path).

LOCAL_PATH = 0
TAR_PATH = 1
mridc.utils.model_utils.check_lib_version(lib_name: str, checked_version: str, operator) Tuple[Optional[bool], str][source]

Checks if a library is installed, and if it is, checks the operator(lib.__version__, checked_version) as a result. This bool result along with a string analysis of result is returned. If the library is not installed at all, then returns None instead, along with a string explaining that the library is not installed

Parameters
  • lib_name (lower case str name of the library that must be imported.) –

  • checked_version (semver string that is compared against lib.__version__.) –

  • operator (binary callable function func(a, b) -> bool; that compares lib.__version__ against version in some) –

  • boolean. (manner. Must return a) –

Returns

  • Bool or None. Bool if the library could be imported, and the result of operator(lib.__version__, checked_version) or False if __version__ is not implemented in lib. None is passed if the library is not installed at all.

  • A string analysis of the check.

Return type

A tuple of results

mridc.utils.model_utils.convert_model_config_to_dict_config(cfg: Union[DictConfig, MRIDCConfig]) DictConfig[source]

Converts its input into a standard DictConfig.

Possible input values are:
  • DictConfig

  • A dataclass which is a subclass of MRIDCConfig

Parameters

cfg (A dict-like object.) –

Return type

The equivalent DictConfig.

mridc.utils.model_utils.import_class_by_path(path: str)[source]

Recursive import of class by path string.

mridc.utils.model_utils.inject_model_parallel_rank(filepath)[source]

Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism.

mridc.utils.model_utils.maybe_update_config_version(cfg: DictConfig)[source]

Recursively convert Hydra 0.x configs to Hydra 1.x configs. Changes include: - cls -> _target_. - params -> drop params and shift all arguments to parent. - target -> _target_ cannot be performed due to ModelPT injecting target inside class.

Parameters

cfg (Any Hydra compatible DictConfig) –

Return type

An updated DictConfig that conforms to Hydra 1.x format.

mridc.utils.model_utils.parse_dataset_as_name(name: str) str[source]

Constructs a valid prefix-name from a provided file path.

Parameters
  • name (Path to some valid data/manifest file or a python object that will be used as a name for the data loader (via) –

  • cast). (str()) –

Return type

A valid prefix-name for the data loader.

mridc.utils.model_utils.resolve_cache_dir() Path[source]

Utility method to resolve a cache directory for MRIDC that can be overridden by an environment variable. .. rubric:: Example

MRIDC_CACHE_DIR=”~/mridc_cache_dir/” python mridc_example_script.py

Returns

  • A Path object, resolved to the absolute path of the cache directory. If no override is provided, uses an inbuilt

  • default which adapts to mridc versions strings.

mridc.utils.model_utils.resolve_dataset_name_from_cfg(cfg: DictConfig) Union[str, int, Enum, float, bool, None, Any][source]

Parses items of the provided sub-config to find the first potential key that resolves to an existing file or directory.

# Fast-path Resolution In order to handle cases where we need to resolve items that are not paths, a fastpath key can be provided as defined in the global _VAL_TEST_FASTPATH_KEY.

This key can be used in two ways : ## _VAL_TEST_FASTPATH_KEY points to another key in the config If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself, then we assume we want to loop through the values of that key. This allows for any key in the config to become a fastpath key.

Example

validation_ds:

splits: "val"
...
<_VAL_TEST_FASTPATH_KEY>: "splits"  <-- this points to the key name "splits"

Then we can write the following when overriding in hydra: `python python train_file.py ... model.validation_ds.splits=[val1, val2, dev1, dev2] ... ` ## _VAL_TEST_FASTPATH_KEY itself acts as the resolved key If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then it is assumed that the items of this key itself are used for resolution.

Example

validation_ds:

<_VAL_TEST_FASTPATH_KEY>: "val"  <-- this points to the key name "splits"

Then we can write the following when overriding in hydra: `python python train_file.py ... model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ... ` # IMPORTANT NOTE: It <can> potentially mismatch if there exist more than 2 valid paths, and the first path does not resolve the path of the data file (but does resolve to some other valid path). To avoid this side effect, place the data path as the first item on the config file.

Parameters

cfg (Sub-config of the config file.) –

Return type

A str representing the key of the config which hosts the filepath(s), or None in case path could not be resolved.

mridc.utils.model_utils.resolve_subclass_pretrained_model_info(base_class) Union[List[PretrainedModelInfo], Set[Any]][source]

Recursively traverses the inheritance graph of subclasses to extract all pretrained model info. First constructs a set of unique pretrained model info by performing DFS over the inheritance graph. All model info belonging to the same class is added together.

Parameters

base_class (The root class, whose subclass graph will be traversed.) –

Return type

A list of unique pretrained model infos belonging to all the inherited subclasses of this baseclass.

mridc.utils.model_utils.resolve_validation_dataloaders(model: ModelPT)[source]

Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the validation set. It does so by first resolving the path to one/more data files via resolve_dataset_name_from_cfg(). If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data loaders and simply calls the appropriate setup method. If resolution succeeds: - Checks if provided path is to a single file or a list of files. If a single file is provided, simply tags that file as such and loads it via the setup method. If multiple files are provided: - Inject a new manifest path at index “i” into the resolved key. - Calls the appropriate setup method to set the data loader. - Collects the initialized data loader in a list and preserves it. - Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. - Finally, assigns a list of unique names resolved from the file paths to the ModelPT.

Parameters

model (ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.) –

mridc.utils.model_utils.uninject_model_parallel_rank(filepath)[source]

Uninjects tensor/pipeline model parallel ranks from the filepath.

mridc.utils.model_utils.unique_names_check(name_list: Optional[List[str]])[source]

Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys.

Parameters

name_list (List of strings resolved for data loaders.) –

mridc.utils.model_utils.wrap_training_step(wrapper=None, enabled=None, adapter=None, proxy=<class 'FunctionWrapper'>)[source]

Wraps the training step of the LightningModule.

Parameters
  • wrapped (The wrapped function.) –

  • instance (The LightningModule instance.) –

  • args (The arguments passed to the wrapped function.) –

  • kwargs (The keyword arguments passed to the wrapped function.) –

Return type

The return value of the wrapped function.

mridc.utils.mridc_logging module

class mridc.utils.mridc_logging.LogMode(value)[source]

Bases: IntEnum

Enum for the different logging modes.

EACH = 0
ONCE = 1
class mridc.utils.mridc_logging.Logger(*args, **kwargs)[source]

Bases: object

Singleton class for logging.

CRITICAL = 50
DEBUG = 10
ERROR = 40
INFO = 20
NOTSET = 0
WARNING = 30
add_err_file_handler(log_file)[source]

Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a MemoryHandler at self._handlers[“memory_err”], those buffered messages are flushed to the new file, and the MemoryHandler is closed.

add_file_handler(log_file)[source]

Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at self._handlers[“memory_all”], those buffered messages are flushed to the new file, and the MemoryHandler is closed.

add_stream_handlers(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]

Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout while WARNING and higher are streamed to stderr. If the MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment variable is set, all logs are sent to stderr instead.

captureWarnings(capture)[source]

If capture is true, redirect all warnings to the logging package. If capture is False, ensure that warnings are not redirected to logging but to their original destinations.

critical(msg, *args, mode=LogMode.EACH, **kwargs) None[source]

Log ‘msg % args’ with severity ‘CRITICAL’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.critical(“Houston, we have %s”, “major disaster”, exc_info=1)

Parameters
  • msg (the message to log) –

  • *args (the arguments to the message) –

  • mode (the mode to log the message in) –

  • **kwargs (the keyword arguments to the message) –

debug(msg, *args, mode=LogMode.EACH, **kwargs)[source]

Log ‘msg % args’ with severity ‘DEBUG’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.debug(“Houston, we have %s”, “thorny problem”, exc_info=1)

error(msg, *args, mode=LogMode.EACH, **kwargs)[source]

Log ‘msg % args’ with severity ‘ERROR’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.error(“Houston, we have %s”, “major problem”, exc_info=1)

getEffectiveLevel()[source]

Return how much logging output will be produced.

get_verbosity()[source]

See getEffectiveLevel

info(msg, *args, mode=LogMode.EACH, **kwargs)[source]

Log ‘msg % args’ with severity ‘INFO’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.info(“Houston, we have %s”, “interesting problem”, exc_info=1)

patch_stderr_handler(stream)[source]

Sends messages that should log to stderr to stream instead. Useful for unittests

patch_stdout_handler(stream)[source]

Sends messages that should log to stdout to stream instead. Useful for unittests

remove_stream_handlers()[source]

Removes StreamHandler that log to stdout and stderr from the logger.

reset_stream_handler(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]

Removes then adds stream handlers.

setLevel(verbosity_level)[source]

Sets the threshold for what messages will be logged.

set_verbosity(verbosity_level)[source]

See setLevel

temp_verbosity(verbosity_level)[source]

Sets a temporary threshold for what messages will be logged.

warning(msg, *args, mode=LogMode.EACH, **kwargs)[source]

Log ‘msg % args’ with severity ‘WARNING’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.warning(“Houston, we have %s”, “bit of a problem”, exc_info=1)

mridc.utils.timers module

class mridc.utils.timers.NamedTimer(reduction='mean', sync_cuda=False, buffer_size=-1)[source]

Bases: object

A timer class that supports multiple named timers. A named timer can be used multiple times, in which case the average dt will be returned. A named timer cannot be started if it is already currently running. Use case: measuring execution of multiple code blocks.

active_timers()[source]

Return list of all active named timers

property buffer_size

Returns the buffer size of the timer.

export()[source]

Exports a dictionary with average/all dt per named timer

get(name='')[source]

Returns the value of a named timer

Parameters

name (timer name to return) –

reset(name=None)[source]

Resents all / specific timer

Parameters

name (Timer name to reset (if None all timers are reset)) –

start(name='')[source]

Starts measuring a named timer.

Parameters

name (timer name to start) –

stop(name='')[source]

Stops measuring a named timer.

Parameters

name (timer name to stop) –

Module contents