Source code for mridc.collections.common.metrics.global_average_loss_metric

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/global_average_loss_metric.py

import torch
from torchmetrics import Metric

__all__ = ["GlobalAverageLossMetric"]


[docs]class GlobalAverageLossMetric(Metric): """ This class is for averaging loss across multiple processes if a distributed backend is used. True average is \ computed not running average. It does not accumulate gradients so the averaged loss cannot be used for \ optimization. .. note:: If ``take_avg_loss`` is ``True``, the :meth:`update` method ``loss`` argument has to be a mean loss. If \ ``take_avg_loss`` is ``False`` then the :meth:`update` method ``loss`` argument has to be a sum of losses. \ See PyTorch Lightning Metrics for the metric usage instruction. Parameters ---------- compute_on_step: The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to \ ``False``. Default: ``True`` dist_sync_on_step: Synchronize metric state across processes at each method :meth:`forward` call before \ returning the value at the step process_group: Specify the process group on which synchronization is called. default: ``None`` (which selects \ the entire world) take_avg_loss: If ``True`` values of :meth:`update` method ``loss`` argument has to be a mean loss. If ``False`` \ values of :meth:`update` method ``loss`` argument has to be a sum of losses. default: ``True`` """ def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group ) self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") self.take_avg_loss = take_avg_loss
[docs] def update(self, loss, num_measurements): """ Updates :attr:`loss_sum` and :attr:`num_measurements`. Parameters ---------- loss: A float zero dimensional ``torch.Tensor`` which is either sum or average of losses for processed \ examples. See ``take_avg_loss`` parameter of :meth:`__init__`. num_measurements: An integer zero dimensional ``torch.Tensor`` which contains a number of loss measurements. \ The sum or mean of the results of these measurements are in the ``loss`` parameter. """ if self.take_avg_loss: self.loss_sum += loss.detach() * num_measurements else: self.loss_sum += loss.detach() self.num_measurements += num_measurements
[docs] def compute(self): """Returns mean loss.""" if self.num_measurements.eq(0): return torch.tensor(float("nan")) return self.loss_sum / self.num_measurements