# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/connectors/save_restore_connector.py
import os
import shutil
import tarfile
import tempfile
import uuid
from typing import Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
import mridc.utils
from mridc.utils import logging
from mridc.utils.app_state import AppState
from mridc.utils.get_rank import is_global_rank_zero
[docs]class SaveRestoreConnector:
"""This class is used to save and restore the model state."""
def __init__(self) -> None:
self._model_config_yaml = "model_config.yaml"
self._model_weights_ckpt = "model_weights.ckpt"
self._model_extracted_dir = None
[docs] def save_to(self, model, save_path: str):
"""
Saves model instance (weights and configuration) into .mridc file.
You can use "restore_from" method to fully restore instance from .mridc file.
.mridc file is an archive (tar.gz) with the following:
- model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for \
model's constructor
- model_wights.chpt - model checkpoint
Parameters
----------
model: ModelPT object to be saved.
save_path: Path to .mridc file where model instance should be saved
"""
if is_global_rank_zero():
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, "artifacts") and model.artifacts is not None:
self._handle_artifacts(model, mridc_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
self._save_state_dict_to_disk(model.state_dict(), model_weights)
self._make_mridc_file_from_folder(filename=save_path, source_dir=tmpdir)
else:
return
[docs] def load_config_and_state_dict(
self,
calling_cls,
restore_path: str,
override_config_path: Optional[Union[OmegaConf, str]] = None,
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
):
"""
Restores model instance (weights and configuration) into .mridc file
Parameters
----------
calling_cls: Class of the model to be restored.
restore_path: path to .mridc file from which model should be instantiated
override_config_path: path to a yaml config that will override the internal config file or an
OmegaConf/DictConfig object representing the model config.
map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
select a GPU if available, falling back to CPU otherwise.
strict: Passed to load_state_dict. By default, True.
return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf
DictConfig object without instantiating the model.
trainer: Optional trainer object to be used for model parallelism.
Example
-------
```
model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc')
assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel)
```
Returns
-------
An instance of type cls or its underlying config (if return_config is set).
"""
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .mridc behavior)
cwd = os.getcwd()
if map_location is None:
if torch.cuda.is_available():
map_location = torch.device("cuda")
else:
map_location = torch.device("cpu")
app_state = AppState()
with tempfile.TemporaryDirectory() as tmpdir:
try:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that MRIDC will use the provided `model_extracted_dir`
logging.info(
"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`."
)
# Override `tmpdir` above with the pre-extracted `model_extracted_dir`
tmpdir = self.model_extracted_dir
else:
# Extract the nemo file into the temporary directory
self._unpack_mridc_file(path2file=restore_path, out_folder=tmpdir)
# Change current working directory to the temporary directory
os.chdir(tmpdir)
if override_config_path is None:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
else:
# can be str path or OmegaConf / DictConfig object
config_yaml = override_config_path
if not isinstance(config_yaml, (OmegaConf, DictConfig)):
conf = OmegaConf.load(config_yaml)
else:
conf = config_yaml
if override_config_path is not None:
# Resolve the override config
conf = OmegaConf.to_container(conf, resolve=True)
conf = OmegaConf.create(conf)
# If override is top level config, extract just `model` from it
if "model" in conf:
conf = conf.model
if return_config:
instance = conf
return instance
if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1:
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
else:
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
OmegaConf.set_struct(conf, True)
os.chdir(cwd)
# get the class
calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir) # type: ignore
instance = calling_cls.from_config_dict(config=conf, trainer=trainer)
instance = instance.to(map_location)
# add load_state_dict override
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
instance.load_state_dict(
self._load_state_dict_from_disk(model_weights, map_location=map_location), strict=strict
)
logging.info(f"Model {instance.__class__.__name__} was successfully restored from {restore_path}.")
instance._set_model_restore_state(is_being_restored=False) # type: ignore
finally:
os.chdir(cwd)
return instance
[docs] @staticmethod
def load_instance_with_state_dict(instance, state_dict, strict):
"""Loads the state dict into the instance."""
instance.load_state_dict(state_dict, strict=strict)
instance._set_model_restore_state(is_being_restored=False) # type: ignore
[docs] def restore_from(
self,
calling_cls,
restore_path: str,
override_config_path: Optional[Union[OmegaConf, str]] = None,
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
):
"""
Restores model instance (weights and configuration) into .mridc file
Parameters
----------
calling_cls: The class of the model to be restored.
restore_path: path to .mridc file from which model should be instantiated
override_config_path: path to a yaml config that will override the internal config file or an
OmegaConf/DictConfig object representing the model config.
map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
select a GPU if available, falling back to CPU otherwise.
strict: Passed to load_state_dict. By default, True.
return_config: If set to true, will return just the underlying config of the restored model as an
OmegaConf/DictConfig object without instantiating the model.
trainer: Optional trainer object to be used for restoring the model.
Returns
-------
An instance of type cls or its underlying config (if return_config is set).
"""
# Get path where the command is executed - the artifacts will be "retrieved" there (original .mridc behavior)
loaded_params = self.load_config_and_state_dict(
calling_cls,
restore_path,
override_config_path,
map_location,
strict,
return_config,
trainer,
)
if not isinstance(loaded_params, tuple):
return loaded_params
_, instance, state_dict = loaded_params
self.load_instance_with_state_dict(instance, state_dict, strict)
logging.info(f"Model {instance.__class__.__name__} was successfully restored from {restore_path}.")
return instance
[docs] @staticmethod
def register_artifact(model, config_path: str, src: str, verify_src_exists: bool = True):
"""
Register model artifacts with this function. These artifacts (files) will be included inside .mridc file
when model.save_to("mymodel.mridc") is called.
How it works:
1. It always returns existing absolute path which can be used during Model constructor call. EXCEPTION: src is
None or "" in which case nothing will be done and src will be returned
2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts. If "src" is local existing
path, then it will be returned in absolute path form. elif "src" starts with "mridc_file:unique_artifact_name":
.mridc will be untarred to a temporary folder location and an actual existing path will be returned else an
error will be raised.
WARNING: use .register_artifact calls in your models' constructors.
The returned path is not guaranteed to exist after you have exited your model's constructor.
Parameters
----------
model: ModelPT object to register artifact for.
config_path: Artifact key. Usually corresponds to the model config.
src: Path to artifact.
verify_src_exists: If set to False, then the artifact is optional and register_artifact will return None
even if src is not found. Defaults to True.
Returns
--------
If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance
life.
"""
app_state = AppState()
artifact_item = mridc.utils.model_utils.ArtifactItem() # type: ignore
# This is for backward compatibility, if the src objects exists simply inside the tarfile
# without its key having been overridden, this pathway will be used.
src_obj_name = os.path.basename(src)
if app_state.mridc_file_folder is not None:
src_obj_path = os.path.abspath(os.path.join(app_state.mridc_file_folder, src_obj_name))
else:
src_obj_path = src_obj_name
# src is a local existing path - register artifact and return exact same path for usage by the model
if os.path.exists(os.path.abspath(src)):
return_path = os.path.abspath(src)
artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH # type: ignore
elif src.startswith("mridc:"):
return_path = os.path.abspath(os.path.join(app_state.mridc_file_folder, src[5:]))
artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore
elif os.path.exists(src_obj_path):
return_path = src_obj_path
artifact_item.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH # type: ignore
elif verify_src_exists:
raise FileNotFoundError(
f"src path does not exist or it is not a path in mridc file. src value I got was: {src}. "
f"Absolute: {os.path.abspath(src)}"
)
else:
# artifact is optional and we simply return None
return None
if not os.path.exists(return_path):
raise AssertionError
artifact_item.path = os.path.abspath(src)
model.artifacts[config_path] = artifact_item
# we were called by ModelPT
if hasattr(model, "cfg"):
with open_dict(model._cfg):
OmegaConf.update(model.cfg, config_path, return_path)
return return_path
def _handle_artifacts(self, model, mridc_file_folder):
"""
This method is called by ModelPT.save_to() and ModelPT.load_from(). It will handle all artifacts and save them
to the mridc file.
Parameters
----------
model: ModelPT object to register artifact for.
mridc_file_folder: Path to the mridc file.
"""
tarfile_artifacts = []
app_state = AppState()
for conf_path, artiitem in model.artifacts.items():
if artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")
# Generate new uniq artifact name and copy it to mridc_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(mridc_file_folder, artifact_uniq_name))
# Update artifacts registry
artiitem.hashed_path = f"mridc:{artifact_uniq_name}"
model.artifacts[conf_path] = artiitem
elif artiitem.path_type == mridc.utils.model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))
else:
raise ValueError("Directly referencing artifacts from other mridc files isn't supported yet")
# Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
# that are currently required.
model_metadata = app_state.get_model_metadata_from_guid(model.model_guid)
if tarfile_artifacts and model_metadata.restoration_path is not None:
# Need to step into mridc archive to extract file
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .mridc behavior)
cwd = os.getcwd()
try:
# Step into the mridc archive to try and find the file
with tempfile.TemporaryDirectory() as archive_dir:
self._unpack_mridc_file(path2file=model_metadata.restoration_path, out_folder=archive_dir)
os.chdir(archive_dir)
for conf_path, artiitem in tarfile_artifacts:
# Get basename and copy it to mridc_file_folder
if "mridc:" in artiitem.path:
artifact_base_name = artiitem.path.split("mridc:")[1]
else:
artifact_base_name = os.path.basename(artiitem.path)
# no need to hash here as we are in tarfile_artifacts which are already hashed
artifact_uniq_name = artifact_base_name
shutil.copy2(artifact_base_name, os.path.join(mridc_file_folder, artifact_uniq_name))
# Update artifacts registry
new_artiitem = mridc.utils.model_utils.ArtifactItem()
new_artiitem.path = f"mridc:{artifact_uniq_name}"
new_artiitem.path_type = mridc.utils.model_utils.ArtifactPathType.TAR_PATH
model.artifacts[conf_path] = new_artiitem
finally:
# change back working directory
os.chdir(cwd)
@staticmethod
def _update_artifact_paths(model, path2yaml_file):
"""
This method is called by ModelPT.save_to() and ModelPT.load_from() to update the artifact paths in the
model.
"""
if model.artifacts is not None and len(model.artifacts) > 0:
conf = OmegaConf.load(path2yaml_file)
for conf_path, item in model.artifacts.items():
if item.hashed_path is None:
OmegaConf.update(conf, conf_path, item.path)
else:
OmegaConf.update(conf, conf_path, item.hashed_path)
with open(path2yaml_file, "w", encoding="utf-8") as fout:
OmegaConf.save(config=conf, f=fout, resolve=True)
@staticmethod
def _inject_model_parallel_rank_for_ckpt(dirname, basename):
"""
This method is called by ModelPT.save_to() and ModelPT.load_from() to inject the parallel rank of the process
into the checkpoint file name.
"""
model_weights = os.path.join(dirname, basename)
model_weights = mridc.utils.model_utils.inject_model_parallel_rank(model_weights)
return model_weights
@staticmethod
def _make_mridc_file_from_folder(filename, source_dir):
"""This method is called by ModelPT.save_to() and ModelPT.load_from() to create a mridc file from a folder."""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
with tarfile.open(filename, "w") as tar:
tar.add(source_dir, arcname=".")
@staticmethod
def _unpack_mridc_file(path2file: str, out_folder: str) -> str:
"""This method is called by ModelPT.save_to() and ModelPT.load_from() to unpack a mridc file."""
if not os.path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")
# we start with an assumption of uncompressed tar, which should be true for versions 1.7.0 and above
tar_header = "r:"
try:
tar_test = tarfile.open(path2file, tar_header)
tar_test.close()
except tarfile.ReadError:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(path2file, tar_header)
tar.extractall(path=out_folder)
tar.close()
return out_folder
@staticmethod
def _save_state_dict_to_disk(state_dict, filepath):
"""This method is called by ModelPT.save_to() and ModelPT.load_from() to save the state dict to disk."""
torch.save(state_dict, filepath)
@staticmethod
def _load_state_dict_from_disk(model_weights, map_location=None):
"""This method is called by ModelPT.save_to() and ModelPT.load_from() to load the state dict from disk."""
return torch.load(model_weights, map_location=map_location)
@property
def model_config_yaml(self) -> str:
"""This property is used to get the path to the model config yaml file."""
return self._model_config_yaml
@model_config_yaml.setter
def model_config_yaml(self, path: str):
"""This property is used to set the path to the model config yaml file."""
self._model_config_yaml = path
@property
def model_weights_ckpt(self) -> str:
"""This property is used to get the path to the model weights ckpt file."""
return self._model_weights_ckpt
@model_weights_ckpt.setter
def model_weights_ckpt(self, path: str):
"""This property is used to set the path to the model weights ckpt file."""
self._model_weights_ckpt = path
@property
def model_extracted_dir(self) -> Optional[str]:
return self._model_extracted_dir
@model_extracted_dir.setter
def model_extracted_dir(self, path: None):
self._model_extracted_dir = path