mridc.collections.common.metrics package

Submodules

mridc.collections.common.metrics.global_average_loss_metric module

class mridc.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]

Bases: Metric

This class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.

Note

If take_avg_loss is True, the update() method loss argument has to be a mean loss. If take_avg_loss is False then the update() method loss argument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.

Parameters
  • compute_on_step (The method forward() only calls update() and returns None if this is set to False. Default: True) –

  • dist_sync_on_step (Synchronize metric state across processes at each method forward() call before returning the value at the step) –

  • process_group (Specify the process group on which synchronization is called. default: None (which selects the entire world)) –

  • take_avg_loss (If True values of update() method loss argument has to be a mean loss. If False values of update() method loss argument has to be a sum of losses. default: True) –

compute()[source]

Returns mean loss.

update(loss, num_measurements)[source]

Updates loss_sum and num_measurements.

Parameters
  • loss (A float zero dimensional torch.Tensor which is either sum or average of losses for processed examples. See take_avg_loss parameter of __init__().) –

  • num_measurements (An integer zero dimensional torch.Tensor which contains a number of loss measurements. The sum or mean of the results of these measurements are in the loss parameter.) –

Module contents