mridc.collections.common.metrics package
Submodules
mridc.collections.common.metrics.global_average_loss_metric module
- class mridc.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]
Bases:
MetricThis class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.
Note
If
take_avg_lossisTrue, theupdate()methodlossargument has to be a mean loss. Iftake_avg_lossisFalsethen theupdate()methodlossargument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.- Parameters
compute_on_step (The method
forward()only callsupdate()and returnsNoneif this is set toFalse. Default:True) –dist_sync_on_step (Synchronize metric state across processes at each method
forward()call before returning the value at the step) –process_group (Specify the process group on which synchronization is called. default:
None(which selects the entire world)) –take_avg_loss (If
Truevalues ofupdate()methodlossargument has to be a mean loss. IfFalsevalues ofupdate()methodlossargument has to be a sum of losses. default:True) –
- update(loss, num_measurements)[source]
Updates
loss_sumandnum_measurements.- Parameters
loss (A float zero dimensional
torch.Tensorwhich is either sum or average of losses for processed examples. Seetake_avg_lossparameter of__init__().) –num_measurements (An integer zero dimensional
torch.Tensorwhich contains a number of loss measurements. The sum or mean of the results of these measurements are in thelossparameter.) –