mridc.utils package
Subpackages
- mridc.utils.decorators package
- mridc.utils.formaters package
- Submodules
- mridc.utils.formaters.base module
- mridc.utils.formaters.colors module
AnsiBack
AnsiBack.BLACK
AnsiBack.BLUE
AnsiBack.CYAN
AnsiBack.GREEN
AnsiBack.LIGHTBLACK_EX
AnsiBack.LIGHTBLUE_EX
AnsiBack.LIGHTCYAN_EX
AnsiBack.LIGHTGREEN_EX
AnsiBack.LIGHTMAGENTA_EX
AnsiBack.LIGHTRED_EX
AnsiBack.LIGHTWHITE_EX
AnsiBack.LIGHTYELLOW_EX
AnsiBack.MAGENTA
AnsiBack.RED
AnsiBack.RESET
AnsiBack.WHITE
AnsiBack.YELLOW
AnsiCodes
AnsiCursor
AnsiFore
AnsiFore.BLACK
AnsiFore.BLUE
AnsiFore.CYAN
AnsiFore.GREEN
AnsiFore.LIGHTBLACK_EX
AnsiFore.LIGHTBLUE_EX
AnsiFore.LIGHTCYAN_EX
AnsiFore.LIGHTGREEN_EX
AnsiFore.LIGHTMAGENTA_EX
AnsiFore.LIGHTRED_EX
AnsiFore.LIGHTWHITE_EX
AnsiFore.LIGHTYELLOW_EX
AnsiFore.MAGENTA
AnsiFore.RED
AnsiFore.RESET
AnsiFore.WHITE
AnsiFore.YELLOW
AnsiStyle
clear_line()
clear_screen()
code_to_chars()
set_title()
- mridc.utils.formaters.utils module
- Module contents
Submodules
mridc.utils.app_state module
- class mridc.utils.app_state.AppState(*args, **kwargs)[source]
Bases:
object
A singleton class that holds the state of the application.
- property checkpoint_callback_params
Returns the version set by exp_manager.
- property checkpoint_name
Returns the name set by exp_manager.
- property create_checkpoint_callback
Returns the create_checkpoint_callback set by exp_manager.
- property data_parallel_group
Property returns the data parallel group.
- property data_parallel_rank
Property returns the data parallel rank.
- property data_parallel_size
Property returns the number of GPUs in each data parallel group.
- property device_id
Property returns the device_id.
- property exp_dir
Returns the exp_dir set by exp_manager.
- get_model_metadata_from_guid(guid) ModelMetadataRegistry [source]
Returns the global model idx and restoration path.
- property global_rank
Property returns the global rank.
- property is_model_being_restored: bool
Returns whether a model is being restored.
- property local_rank
Property returns the local rank.
- property log_dir
Returns the log_dir set by exp_manager.
- property model_parallel_size
Property returns the number of GPUs in each model parallel group.
- property model_restore_path
Returns the model_restore_path set by exp_manager.
- property mridc_file_folder: str
Returns the mridc_file_folder set by exp_manager.
- property name
Returns the name set by exp_manager.
- property pipeline_model_parallel_group
Property returns the model parallel group.
- property pipeline_model_parallel_rank
Property returns the model parallel rank.
- property pipeline_model_parallel_size
Property returns the number of GPUs in each model parallel group.
- property pipeline_model_parallel_split_rank
Property returns the model parallel split rank.
- property random_seed
Property returns the random seed.
- register_model_guid(guid: str, restoration_path: Optional[str] = None)[source]
Maps a guid to its restore path (None or last absolute path).
- property tensor_model_parallel_group
Property returns the model parallel group.
- property tensor_model_parallel_rank
Property returns the model parallel rank.
- property tensor_model_parallel_size
Property returns the number of GPUs in each model parallel group.
- property version
Returns the version set by exp_manager.
- property world_size
Property returns the total number of GPUs.
mridc.utils.arguments module
- mridc.utils.arguments.add_optimizer_args(parent_parser: ArgumentParser, optimizer: str = 'adam', default_lr: Optional[float] = None, default_opt_args: Optional[Union[Dict[str, Any], List[str]]] = None) ArgumentParser [source]
Extends existing argparse with default optimizer args.
# Example of adding optimizer args to command line: python train_script.py … –optimizer “novograd” –lr 0.01 –opt_args betas=0.95,0.5 weight_decay=0.001
- Parameters
parent_parser (Custom CLI parser that will be extended.) – ArgumentParser
optimizer (Default optimizer required.) – str, default “adam”
default_lr (Default learning rate.) – float, default None
default_opt_args (Default optimizer arguments.) – Optional[Union[Dict[str, Any], List[str]]], default None
- Returns
ArgumentParser
- Return type
Parser extended by Optimizers arguments.
- mridc.utils.arguments.add_recon_args(parent_parser: ArgumentParser) ArgumentParser [source]
Extends existing argparse with default reconstruction args.
- Parameters
parent_parser (Custom CLI parser that will be extended.) – ArgumentParser
- Returns
ArgumentParser
- Return type
Parser extended by Reconstruction arguments.
- mridc.utils.arguments.add_scheduler_args(parent_parser: ArgumentParser) ArgumentParser [source]
Extends existing argparse with default scheduler args.
- Parameters
parent_parser (Custom CLI parser that will be extended.) – ArgumentParser
- Returns
ArgumentParser
- Return type
Parser extended by Schedulers arguments.
mridc.utils.cloud module
- mridc.utils.cloud.maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) str [source]
Download a file from a URL if it does not exist in the cache.
- Parameters
url (URL to download the file from.) – str
filename (What to download. The request will be issued to url/filename) – str
subfolder (Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty.) – str
cache_dir (A cache directory where to download. If not present, this function will attempt to create it.) – str, If None (default), then it will be $HOME/.cache/torch/mridc
refresh_cache (If True and cached file is present, it will delete it and re-fetch) – bool
- Return type
If successful - absolute local path to the downloaded file else empty string.
mridc.utils.config_utils module
- mridc.utils.config_utils.assert_dataclass_signature_match(cls: class_type, datacls: dataclass, ignore_args: Optional[List[str]] = None, remap_args: Optional[Dict[str, str]] = None)[source]
Analyses the signature of a provided class and its respective data class, asserting that the dataclass signature matches the class __init__ signature. .. note:
This is not a value based check. This function only checks if all argument names exist on both class and dataclass and logs mismatches.
- Parameters
cls (Any class type - but not an instance of a class. Pass type(x) where x is an instance) – if class type is not easily available.
datacls (A corresponding dataclass for the above class.) –
ignore_args ((Optional) A list of string argument names which are forcibly ignored,) – even if mismatched in the signature. Useful when a dataclass is a superset of the arguments of a class.
remap_args ((Optional) A dictionary, mapping an argument name that exists (in either the) – class or its dataclass), to another name. Useful when argument names are mismatched between a class and its dataclass due to indirect instantiation via a helper method.
- Returns
- A bool value which is True if the signatures matched exactly / after ignoring values.
False otherwise.
- A set of arguments names that exist in the class, but do not exist in the dataclass.
If exact signature match occurs, this will be None instead.
- A set of argument names that exist in the data class, but do not exist in the class itself.
If exact signature match occurs, this will be None instead.
- Return type
A tuple containing information about the analysis
- mridc.utils.config_utils.update_model_config(model_cls: MRIDCConfig, update_cfg: DictConfig, drop_missing_subconfigs: bool = True)[source]
- Helper class that updates the default values of a ModelPT config class with the values in a DictConfig that mirrors the structure of the config class. Assumes the update_cfg is a DictConfig (either generated manually, via hydra or instantiated via yaml/model.cfg). This update_cfg is then used to override the default values preset inside the ModelPT config class. If drop_missing_subconfigs is set, the certain sub-configs of the ModelPT config class will be removed, if they are not found in the mirrored update_cfg. The following sub-configs are subject to potential removal:
train_ds
validation_ds
test_ds
optim + nested sched
- Parameters
model_cls (A subclass of MRIDC, that details in entirety all the parameters that constitute the MRIDC Model.) –
update_cfg (A DictConfig that mirrors the structure of the MRIDCConfig data class. Used to update the default values of the config class.) –
drop_missing_subconfigs (Bool which determines whether to drop certain sub-configs from the MRIDCConfig class, if the corresponding sub-config is missing from update_cfg.) –
- Return type
A DictConfig with updated values that can be used to instantiate the MRIDC Model along with supporting infrastructure.
mridc.utils.debug_hook module
- mridc.utils.debug_hook.get_backward_hook(name, trainer, rank, logger, dump_to_file=False)[source]
A backward hook to dump all the module input and output grad norms. The hook will be called every time the gradients with respect to module inputs are computed. Only float type input/output grad tensor norms are computed.
For more details about the backward hook, check: https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_full_backward_hook.html
- Parameters
name (str) – tensor name
trainer (PTL trainer) – PTL trainer
rank (int) – worker rank
logger (PTL log function) – PTL log function
dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False
- Return type
backward_hook
- mridc.utils.debug_hook.get_forward_hook(name, trainer, rank, logger, dump_to_file=False)[source]
A forward hook to dump all the module input and output norms. It is called at every time after forward() has computed an output. Only float type input/output tensor norms are computed.
For more details about the forward hook, check: https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html
- Parameters
name (str) – tensor name
trainer (PTL trainer) – PTL trainer
rank (int) – worker rank
logger (PTL log function) – PTL log function
dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False
- Return type
forward_hook
- mridc.utils.debug_hook.get_tensor_hook(module, name, trainer, rank, logger, dump_to_file=False)[source]
A tensor hook to dump all of the tensor weight norms and grad norms at the end of each of the backward steps.
For more details about the tensor hook, check: https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html
- Parameters
module (torch.nn.Module) – module to register the hook
name (str) – tensor name
trainer (PTL trainer) – PTL trainer
rank (int) – worker rank
logger (PTL log function) – PTL log function
dump_to_file (bool, optional) – wether dump the csv file to the disk, by default False
- Return type
tensor_hook
mridc.utils.distributed module
- mridc.utils.distributed.initialize_distributed(args, backend='nccl')[source]
Initialize distributed training.
- Parameters
args (The arguments object.) –
backend (The backend to use.) – default: “nccl”
- Returns
local_rank (The local rank of the process.)
rank (The rank of the process.)
world_size (The number of processes.)
mridc.utils.env_var_parsing module
- exception mridc.utils.env_var_parsing.CoercionError(key, value, func)[source]
Bases:
Exception
Custom error raised when a value cannot be coerced.
- exception mridc.utils.env_var_parsing.RequiredSettingMissingError(key)[source]
Bases:
Exception
Custom error raised when a required env var is missing.
- mridc.utils.env_var_parsing.get_env(key, *default, **kwargs)[source]
Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ).
- Parameters
key (string, the env var name to look up.) –
default ((optional) the value to use if the env var does not exist. If this value is not supplied, then the env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist.) –
kwargs – coerce: a func that may be supplied to coerce the value into something else. This is used by the default get_foo functions to cast strings to builtin types, but could be a function that returns a custom class.
- Return type
The env var, coerced if required, and a default if supplied.
mridc.utils.exceptions module
- class mridc.utils.exceptions.CheckInstall(*args, **kwargs)[source]
Bases:
object
Class to check if a package is installed.
- exception mridc.utils.exceptions.LightningNotInstalledException(obj)[source]
Bases:
MRIDCBaseException
Exception for when lightning is not installed
mridc.utils.exp_manager module
- class mridc.utils.exp_manager.CallbackParams(filepath: Optional[str] = None, dirpath: Optional[str] = None, filename: Optional[str] = None, monitor: Optional[str] = 'val_loss', verbose: Optional[bool] = True, save_last: Optional[bool] = True, save_top_k: Optional[int] = 3, save_weights_only: Optional[bool] = False, mode: Optional[str] = 'min', every_n_epochs: Optional[int] = 1, prefix: Optional[str] = None, postfix: str = '.mridc', save_best_model: bool = False, always_save_mridc: bool = False, save_mridc_on_train_end: Optional[bool] = True, model_parallel_size: Optional[int] = None)[source]
Bases:
object
Parameters for a callback
- always_save_mridc: bool = False
- dirpath: Optional[str] = None
- every_n_epochs: Optional[int] = 1
- filename: Optional[str] = None
- filepath: Optional[str] = None
- mode: Optional[str] = 'min'
- model_parallel_size: Optional[int] = None
- monitor: Optional[str] = 'val_loss'
- postfix: str = '.mridc'
- prefix: Optional[str] = None
- save_best_model: bool = False
- save_last: Optional[bool] = True
- save_mridc_on_train_end: Optional[bool] = True
- save_top_k: Optional[int] = 3
- save_weights_only: Optional[bool] = False
- verbose: Optional[bool] = True
- exception mridc.utils.exp_manager.CheckpointMisconfigurationError[source]
Bases:
MRIDCBaseException
Raised when a mismatch between trainer.callbacks and exp_manager occurs
- class mridc.utils.exp_manager.ExpManagerConfig(explicit_log_dir: Optional[str] = None, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, use_datetime_version: Optional[bool] = True, resume_if_exists: Optional[bool] = False, resume_past_end: Optional[bool] = False, resume_ignore_no_checkpoint: Optional[bool] = False, create_tensorboard_logger: Optional[bool] = True, summary_writer_kwargs: Optional[Dict[Any, Any]] = None, create_wandb_logger: Optional[bool] = False, wandb_logger_kwargs: Optional[Dict[Any, Any]] = None, create_checkpoint_callback: Optional[bool] = True, checkpoint_callback_params: Optional[CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None), files_to_copy: Optional[List[str]] = None, log_step_timing: Optional[bool] = True, step_timing_kwargs: Optional[StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1), log_local_rank_0_only: Optional[bool] = False, log_global_rank_0_only: Optional[bool] = False, model_parallel_size: Optional[int] = None)[source]
Bases:
object
Configuration for the experiment manager.
- checkpoint_callback_params: Optional[CallbackParams] = CallbackParams(filepath=None, dirpath=None, filename=None, monitor='val_loss', verbose=True, save_last=True, save_top_k=3, save_weights_only=False, mode='min', every_n_epochs=1, prefix=None, postfix='.mridc', save_best_model=False, always_save_mridc=False, save_mridc_on_train_end=True, model_parallel_size=None)
- create_checkpoint_callback: Optional[bool] = True
- create_tensorboard_logger: Optional[bool] = True
- create_wandb_logger: Optional[bool] = False
- exp_dir: Optional[str] = None
- explicit_log_dir: Optional[str] = None
- files_to_copy: Optional[List[str]] = None
- log_global_rank_0_only: Optional[bool] = False
- log_local_rank_0_only: Optional[bool] = False
- log_step_timing: Optional[bool] = True
- model_parallel_size: Optional[int] = None
- name: Optional[str] = None
- resume_if_exists: Optional[bool] = False
- resume_ignore_no_checkpoint: Optional[bool] = False
- resume_past_end: Optional[bool] = False
- step_timing_kwargs: Optional[StepTimingParams] = StepTimingParams(reduction='mean', sync_cuda=False, buffer_size=1)
- summary_writer_kwargs: Optional[Dict[Any, Any]] = None
- use_datetime_version: Optional[bool] = True
- version: Optional[str] = None
- wandb_logger_kwargs: Optional[Dict[Any, Any]] = None
- class mridc.utils.exp_manager.LoggerList(_logger_iterable, mridc_name=None, mridc_version='')[source]
Bases:
LoggerCollection
A thin wrapper on Lightning’s LoggerCollection such that name and version are better aligned with exp_manager
- property name: str
The name of the experiment.
- property version: str
The version of the experiment. If the logger was created with a version, this will be the version.
- exception mridc.utils.exp_manager.LoggerMisconfigurationError(message)[source]
Bases:
MRIDCBaseException
Raised when a mismatch between trainer.logger and exp_manager occurs
- class mridc.utils.exp_manager.MRIDCModelCheckpoint(always_save_mridc=False, save_mridc_on_train_end=True, save_best_model=False, postfix='.mridc', n_resume=False, model_parallel_size=None, **kwargs)[source]
Bases:
ModelCheckpoint
Light wrapper around Lightning’s ModelCheckpoint to force a saved checkpoint on train_end
- exception mridc.utils.exp_manager.NotFoundError[source]
Bases:
MRIDCBaseException
Raised when a file or folder is not found
- class mridc.utils.exp_manager.StatelessTimer(duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True)[source]
Bases:
Timer
Extension of PTL timers to be per run.
- class mridc.utils.exp_manager.StepTimingParams(reduction: Optional[str] = 'mean', sync_cuda: Optional[bool] = False, buffer_size: Optional[int] = 1)[source]
Bases:
object
Parameters for the step timing callback.
- buffer_size: Optional[int] = 1
- reduction: Optional[str] = 'mean'
- sync_cuda: Optional[bool] = False
- class mridc.utils.exp_manager.TimingCallback(timer_kwargs=None)[source]
Bases:
Callback
Logs execution time of train/val/test steps
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Logs execution time of test steps
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Logs execution time of test steps
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, **kwargs)[source]
Logs the time taken by the training batch
- on_train_batch_start(trainer, pl_module, batch, batch_idx, **kwargs)[source]
Called at the beginning of each training batch
- mridc.utils.exp_manager.check_explicit_log_dir(trainer: Trainer, explicit_log_dir: List[Union[Path, str]], exp_dir: str, name: str, version: str) Tuple[Path, str, str, str] [source]
Checks that the passed arguments are compatible with explicit_log_dir.
- Parameters
trainer (The trainer to check.) –
explicit_log_dir (The explicit log dir to check.) –
exp_dir (The experiment directory to check.) –
name (The experiment name to check.) –
version (The experiment version to check.) –
- Return type
The log_dir, exp_dir, name, and version that should be used.
- Raises
- mridc.utils.exp_manager.check_resume(trainer: Trainer, log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False)[source]
Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.
- Parameters
trainer (The trainer that is being used.) –
log_dir (The directory where the logs are being saved.) –
resume_past_end (Whether to resume from the end of the experiment.) –
resume_ignore_no_checkpoint (Whether to ignore if there is no checkpoint to resume from.) –
- Returns
NotFoundError (If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.)
ValueError (If resume is True, and there were more than 1 checkpoint could found.)
- mridc.utils.exp_manager.check_slurm(trainer)[source]
Checks if the trainer is running on a slurm cluster. If so, it will check if the trainer is running on the master node. If it is not, it will exit.
- Parameters
trainer (The trainer to check.) –
- Return type
True if the trainer is running on the master node, False otherwise.
- mridc.utils.exp_manager.configure_checkpointing(trainer: Trainer, log_dir: Path, name: str, resume: bool, params: DictConfig)[source]
Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer.
- mridc.utils.exp_manager.configure_loggers(trainer: Trainer, exp_dir: List[Union[Path, str]], name: str, version: str, create_tensorboard_logger: bool, summary_writer_kwargs: dict, create_wandb_logger: bool, wandb_kwargs: dict)[source]
Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs or wandb_kwargs are miss configured.
- Parameters
trainer (The trainer to attach the loggers to.) –
exp_dir (The experiment directory.) –
name (The name of the experiment.) –
version (The version of the experiment.) –
create_tensorboard_logger (Whether to create a TensorboardLogger.) –
summary_writer_kwargs (The kwargs to pass to the TensorboardLogger.) –
create_wandb_logger (Whether to create a Weights & Biases logger.) –
wandb_kwargs (The kwargs to pass to the Weights & Biases logger.) –
- Returns
LoggerList
- Return type
A list of loggers.
- mridc.utils.exp_manager.error_checks(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None)[source]
- Checks that the passed trainer is compliant with MRIDC and exp_manager’s passed configuration. Checks that:
Throws error when hydra has changed the working directory. This causes issues with lightning’s DDP
Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True
Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP
- mridc.utils.exp_manager.exp_manager(trainer: Trainer, cfg: Optional[Union[DictConfig, Dict]] = None) Optional[Path] [source]
exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.
The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file for each process to log their output into.
exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when resume_if_exists is set to True, creating the version folders is ignored.
- Parameters
trainer (The lightning trainer object.) –
cfg (Can have the following keys:) –
explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which will use exp_dir, name, and version to construct the logging directory.
exp_dir: The base directory to create the logging directory. Defaults to None, which logs to ./mridc_experiments.
name: The name of the experiment. Defaults to None which turns into “default” via name = name or “default”.
version: The version of the experiment. Defaults to None which uses either a datetime string or lightning’s TensorboardLogger system of using version_{int}.
use_datetime_version: Whether to use a datetime string for version. Defaults to True.
resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when resume_if_exists is True, we would not create version folders to make it easier to find the log folder for next runs.
resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning trainer. Defaults to True.
summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning’s TensorboardLogger class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning trainer. Defaults to False.
wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning’s WandBLogger class. Note that name and project are required parameters if create_wandb_logger is True. Defaults to None.
create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best “val_loss”, the most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt. Defaults to True.
files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies no files.
log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.
log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.
- Return type
The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.
- mridc.utils.exp_manager.get_git_diff()[source]
Helper function that tries to get the git diff if running inside a git folder.
- Returns
Bool (Whether the git subprocess ran without error.)
String (git subprocess output or error message)
- mridc.utils.exp_manager.get_git_hash()[source]
Helper function that tries to get the commit hash if running inside a git folder.
- Returns
Bool (Whether the git subprocess ran without error.)
String (git subprocess output or error message)
- mridc.utils.exp_manager.get_log_dir(trainer: Trainer, exp_dir: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, explicit_log_dir: Optional[str] = None, use_datetime_version: bool = True, resume_if_exists: bool = False) Tuple[Path, str, str, str] [source]
Obtains the log_dir used for exp_manager.
- Parameters
trainer (The trainer to check.) –
exp_dir (The experiment directory to check.) –
name (The experiment name to check.) –
version (The experiment version to check.) –
explicit_log_dir (The explicit log dir to check.) –
use_datetime_version (Whether to use datetime versioning.) –
resume_if_exists (Whether to resume if the log_dir already exists.) –
- Raises
LoggerMisconfigurationError – If trainer is incompatible with arguments:
NotFoundError – If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.:
ValueError – If resume is True, and there were more than 1 checkpoint could found.:
mridc.utils.export_utils module
- class mridc.utils.export_utils.CastToFloat(mod)[source]
Bases:
Module
Cast input to float
- training: bool
- class mridc.utils.export_utils.ExportFormat(value)[source]
Bases:
Enum
Which format to use when exporting a Neural Module for deployment
- ONNX = (1,)
- TORCHSCRIPT = (2,)
- mridc.utils.export_utils.augment_filename(output: str, prepend: str)[source]
Augment output filename with prepend
- mridc.utils.export_utils.cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32)[source]
Cast all tensors in x from from_dtype to to_dtype
- mridc.utils.export_utils.cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32)[source]
Cast tensor from from_dtype to to_dtype
- mridc.utils.export_utils.parse_input_example(input_example)[source]
Parse input example to onnxrt input format
- mridc.utils.export_utils.replace_for_export(model: Module) Module [source]
Top-level function to replace default set of modules in model NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
- Parameters
model (Top-level model to replace modules in.) –
- Return type
The model with replaced modules.
- mridc.utils.export_utils.replace_modules(model: Module, expansions: Optional[Dict[str, Callable[[Module], Optional[Module]]]] = None) Module [source]
Top-level function to replace modules in model, specified by class name with a desired replacement. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
- Parameters
model (Top-level model to replace modules in.) –
expansions (A dictionary of module class names to functions to replace them with.) –
- Return type
The model with replaced modules.
- mridc.utils.export_utils.run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01)[source]
Run onnxrt and compare with output example
- mridc.utils.export_utils.simple_replace(BaseT: Type[Module], DestT: Type[Module]) Callable[[Module], Optional[Module]] [source]
Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. No weights are copied.
- Parameters
BaseT (The base type of the module.) –
DestT (The destination type of the module.) –
- Return type
A function to replace BaseT with DestT.
- mridc.utils.export_utils.swap_modules(model: Module, mapping: Dict[str, Module])[source]
This function swaps nested modules as specified by “dot paths” in mod with a desired replacement. This allows for swapping nested modules through arbitrary levels if children NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
- mridc.utils.export_utils.to_onnxrt_input(ort_input_names, input_names, input_dict, input_list)[source]
Convert input to onnxrt input
- mridc.utils.export_utils.verify_runtime(model, output, input_examples, input_names, check_tolerance=0.01)[source]
Verify runtime output with onnxrt.
- mridc.utils.export_utils.wrap_forward_method(self)[source]
Wraps the forward method of the module with a function that returns the output of the forward method
- mridc.utils.export_utils.wrap_module(BaseT: Type[Module], DestT: Type[Module]) Callable[[Module], Optional[Module]] [source]
Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same attributes. No weights are copied.
- Parameters
BaseT (The base type of the module.) –
DestT (The destination type of the module.) –
- Return type
A function to replace BaseT with DestT.
mridc.utils.get_rank module
mridc.utils.lightning_logger_patch module
- mridc.utils.lightning_logger_patch.add_filehandlers_to_pl_logger(all_log_file, err_log_file)[source]
Adds two filehandlers to pytorch_lightning’s logger. Called in mridc.utils.exp_manager(). The first filehandler logs all messages to all_log_file while the second filehandler logs all WARNING and higher messages to err_log_file. If “memory_err” and “memory_all” exist in HANDLERS, then those buffers are flushed to err_log_file and all_log_file respectively, and then closed.
- mridc.utils.lightning_logger_patch.add_memory_handlers_to_pl_logger()[source]
Adds two MemoryHandlers to pytorch_lightning’s logger. These two handlers are essentially message buffers. This function is called in mridc.utils.__init__.py. These handlers are used in add_filehandlers_to_pl_logger to flush buffered messages to files.
mridc.utils.metaclasses module
mridc.utils.model_utils module
- class mridc.utils.model_utils.ArtifactItem[source]
Bases:
object
ArtifactItem is a dataclass that holds the information of an artifact.
- hashed_path: Optional[str] = None
- path: str
- path_type: ArtifactPathType
- class mridc.utils.model_utils.ArtifactPathType(value)[source]
Bases:
Enum
ArtifactPathType refers to the type of the path that the artifact is located at. LOCAL_PATH: A user local filepath that exists on the file system. TAR_PATH: A (generally flattened) filepath that exists inside of an archive (that may have its own full path).
- LOCAL_PATH = 0
- TAR_PATH = 1
- mridc.utils.model_utils.check_lib_version(lib_name: str, checked_version: str, operator) Tuple[Optional[bool], str] [source]
Checks if a library is installed, and if it is, checks the operator(lib.__version__, checked_version) as a result. This bool result along with a string analysis of result is returned. If the library is not installed at all, then returns None instead, along with a string explaining that the library is not installed
- Parameters
lib_name (lower case str name of the library that must be imported.) –
checked_version (semver string that is compared against lib.__version__.) –
operator (binary callable function func(a, b) -> bool; that compares lib.__version__ against version in some) –
boolean. (manner. Must return a) –
- Returns
Bool or None. Bool if the library could be imported, and the result of operator(lib.__version__, checked_version) or False if __version__ is not implemented in lib. None is passed if the library is not installed at all.
A string analysis of the check.
- Return type
A tuple of results
- mridc.utils.model_utils.convert_model_config_to_dict_config(cfg: Union[DictConfig, MRIDCConfig]) DictConfig [source]
Converts its input into a standard DictConfig.
- Possible input values are:
DictConfig
A dataclass which is a subclass of MRIDCConfig
- Parameters
cfg (A dict-like object.) –
- Return type
The equivalent DictConfig.
- mridc.utils.model_utils.import_class_by_path(path: str)[source]
Recursive import of class by path string.
- mridc.utils.model_utils.inject_model_parallel_rank(filepath)[source]
Injects tensor/pipeline model parallel ranks into the filepath. Does nothing if not using model parallelism.
- mridc.utils.model_utils.maybe_update_config_version(cfg: DictConfig)[source]
Recursively convert Hydra 0.x configs to Hydra 1.x configs. Changes include: - cls -> _target_. - params -> drop params and shift all arguments to parent. - target -> _target_ cannot be performed due to ModelPT injecting target inside class.
- Parameters
cfg (Any Hydra compatible DictConfig) –
- Return type
An updated DictConfig that conforms to Hydra 1.x format.
- mridc.utils.model_utils.parse_dataset_as_name(name: str) str [source]
Constructs a valid prefix-name from a provided file path.
- Parameters
name (Path to some valid data/manifest file or a python object that will be used as a name for the data loader (via) –
cast). (str()) –
- Return type
A valid prefix-name for the data loader.
- mridc.utils.model_utils.resolve_cache_dir() Path [source]
Utility method to resolve a cache directory for MRIDC that can be overridden by an environment variable. .. rubric:: Example
MRIDC_CACHE_DIR=”~/mridc_cache_dir/” python mridc_example_script.py
- Returns
A Path object, resolved to the absolute path of the cache directory. If no override is provided, uses an inbuilt
default which adapts to mridc versions strings.
- mridc.utils.model_utils.resolve_dataset_name_from_cfg(cfg: DictConfig) Union[str, int, Enum, float, bool, None, Any] [source]
Parses items of the provided sub-config to find the first potential key that resolves to an existing file or directory.
# Fast-path Resolution In order to handle cases where we need to resolve items that are not paths, a fastpath key can be provided as defined in the global _VAL_TEST_FASTPATH_KEY.
This key can be used in two ways : ## _VAL_TEST_FASTPATH_KEY points to another key in the config If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself, then we assume we want to loop through the values of that key. This allows for any key in the config to become a fastpath key.
Example
validation_ds:
splits: "val" ... <_VAL_TEST_FASTPATH_KEY>: "splits" <-- this points to the key name "splits"
Then we can write the following when overriding in hydra:
`python python train_file.py ... model.validation_ds.splits=[val1, val2, dev1, dev2] ... `
## _VAL_TEST_FASTPATH_KEY itself acts as the resolved key If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then it is assumed that the items of this key itself are used for resolution.Example
validation_ds:
<_VAL_TEST_FASTPATH_KEY>: "val" <-- this points to the key name "splits"
Then we can write the following when overriding in hydra:
`python python train_file.py ... model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ... `
# IMPORTANT NOTE: It <can> potentially mismatch if there exist more than 2 valid paths, and the first path does not resolve the path of the data file (but does resolve to some other valid path). To avoid this side effect, place the data path as the first item on the config file.- Parameters
cfg (Sub-config of the config file.) –
- Return type
A str representing the key of the config which hosts the filepath(s), or None in case path could not be resolved.
- mridc.utils.model_utils.resolve_subclass_pretrained_model_info(base_class) Union[List[PretrainedModelInfo], Set[Any]] [source]
Recursively traverses the inheritance graph of subclasses to extract all pretrained model info. First constructs a set of unique pretrained model info by performing DFS over the inheritance graph. All model info belonging to the same class is added together.
- Parameters
base_class (The root class, whose subclass graph will be traversed.) –
- Return type
A list of unique pretrained model infos belonging to all the inherited subclasses of this baseclass.
- mridc.utils.model_utils.resolve_validation_dataloaders(model: ModelPT)[source]
Helper method that operates on the ModelPT class to automatically support multiple dataloaders for the validation set. It does so by first resolving the path to one/more data files via resolve_dataset_name_from_cfg(). If this resolution fails, it assumes the data loader is prepared to manually support / not support multiple data loaders and simply calls the appropriate setup method. If resolution succeeds: - Checks if provided path is to a single file or a list of files. If a single file is provided, simply tags that file as such and loads it via the setup method. If multiple files are provided: - Inject a new manifest path at index “i” into the resolved key. - Calls the appropriate setup method to set the data loader. - Collects the initialized data loader in a list and preserves it. - Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT. - Finally, assigns a list of unique names resolved from the file paths to the ModelPT.
- Parameters
model (ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.) –
- mridc.utils.model_utils.uninject_model_parallel_rank(filepath)[source]
Uninjects tensor/pipeline model parallel ranks from the filepath.
- mridc.utils.model_utils.unique_names_check(name_list: Optional[List[str]])[source]
Performs a uniqueness check on the name list resolved, so that it can warn users about non-unique keys.
- Parameters
name_list (List of strings resolved for data loaders.) –
- mridc.utils.model_utils.wrap_training_step(wrapper=None, enabled=None, adapter=None, proxy=<class 'FunctionWrapper'>)[source]
Wraps the training step of the LightningModule.
- Parameters
wrapped (The wrapped function.) –
instance (The LightningModule instance.) –
args (The arguments passed to the wrapped function.) –
kwargs (The keyword arguments passed to the wrapped function.) –
- Return type
The return value of the wrapped function.
mridc.utils.mridc_logging module
- class mridc.utils.mridc_logging.LogMode(value)[source]
Bases:
IntEnum
Enum for the different logging modes.
- EACH = 0
- ONCE = 1
- class mridc.utils.mridc_logging.Logger(*args, **kwargs)[source]
Bases:
object
Singleton class for logging.
- CRITICAL = 50
- DEBUG = 10
- ERROR = 40
- INFO = 20
- NOTSET = 0
- WARNING = 30
- add_err_file_handler(log_file)[source]
Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a MemoryHandler at self._handlers[“memory_err”], those buffered messages are flushed to the new file, and the MemoryHandler is closed.
- add_file_handler(log_file)[source]
Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at self._handlers[“memory_all”], those buffered messages are flushed to the new file, and the MemoryHandler is closed.
- add_stream_handlers(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]
Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout while WARNING and higher are streamed to stderr. If the MRIDC_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment variable is set, all logs are sent to stderr instead.
- captureWarnings(capture)[source]
If capture is true, redirect all warnings to the logging package. If capture is False, ensure that warnings are not redirected to logging but to their original destinations.
- critical(msg, *args, mode=LogMode.EACH, **kwargs) None [source]
Log ‘msg % args’ with severity ‘CRITICAL’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.critical(“Houston, we have %s”, “major disaster”, exc_info=1)
- Parameters
msg (the message to log) –
*args (the arguments to the message) –
mode (the mode to log the message in) –
**kwargs (the keyword arguments to the message) –
- debug(msg, *args, mode=LogMode.EACH, **kwargs)[source]
Log ‘msg % args’ with severity ‘DEBUG’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.debug(“Houston, we have %s”, “thorny problem”, exc_info=1)
- error(msg, *args, mode=LogMode.EACH, **kwargs)[source]
Log ‘msg % args’ with severity ‘ERROR’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.error(“Houston, we have %s”, “major problem”, exc_info=1)
- info(msg, *args, mode=LogMode.EACH, **kwargs)[source]
Log ‘msg % args’ with severity ‘INFO’. To pass exception information, use the keyword argument exc_info with a true value, e.g. logger.info(“Houston, we have %s”, “interesting problem”, exc_info=1)
- patch_stderr_handler(stream)[source]
Sends messages that should log to stderr to stream instead. Useful for unittests
- patch_stdout_handler(stream)[source]
Sends messages that should log to stdout to stream instead. Useful for unittests
- remove_stream_handlers()[source]
Removes StreamHandler that log to stdout and stderr from the logger.
- reset_stream_handler(formatter=<class 'mridc.utils.formaters.base.BaseMRIDCFormatter'>)[source]
Removes then adds stream handlers.
mridc.utils.timers module
- class mridc.utils.timers.NamedTimer(reduction='mean', sync_cuda=False, buffer_size=-1)[source]
Bases:
object
A timer class that supports multiple named timers. A named timer can be used multiple times, in which case the average dt will be returned. A named timer cannot be started if it is already currently running. Use case: measuring execution of multiple code blocks.
- property buffer_size
Returns the buffer size of the timer.