Source code for mridc.collections.common.parts.training_utils

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/parts/training_utils.py

from contextlib import nullcontext

import torch

__all__ = ["avoid_bfloat16_autocast_context", "avoid_float16_autocast_context"]


[docs]def avoid_bfloat16_autocast_context(): """If the current autocast context is bfloat16, cast it to float32.""" if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: return torch.cuda.amp.autocast(dtype=torch.float32) else: return nullcontext()
[docs]def avoid_float16_autocast_context(): """If the current autocast context is float16, cast it to bfloat16 if available or float32.""" if not torch.is_autocast_enabled() or torch.get_autocast_gpu_dtype() != torch.float16: return nullcontext() if torch.cuda.is_bf16_supported(): return torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return torch.cuda.amp.autocast(dtype=torch.float32)