mridc.collections.common.metrics package
Submodules
mridc.collections.common.metrics.global_average_loss_metric module
- class mridc.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]
Bases:
Metric
This class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.
Note
If
take_avg_loss
isTrue
, theupdate()
methodloss
argument has to be a mean loss. Iftake_avg_loss
isFalse
then theupdate()
methodloss
argument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.- Parameters
compute_on_step (The method
forward()
only callsupdate()
and returnsNone
if this is set toFalse
. Default:True
) –dist_sync_on_step (Synchronize metric state across processes at each method
forward()
call before returning the value at the step) –process_group (Specify the process group on which synchronization is called. default:
None
(which selects the entire world)) –take_avg_loss (If
True
values ofupdate()
methodloss
argument has to be a mean loss. IfFalse
values ofupdate()
methodloss
argument has to be a sum of losses. default:True
) –
- update(loss, num_measurements)[source]
Updates
loss_sum
andnum_measurements
.- Parameters
loss (A float zero dimensional
torch.Tensor
which is either sum or average of losses for processed examples. Seetake_avg_loss
parameter of__init__()
.) –num_measurements (An integer zero dimensional
torch.Tensor
which contains a number of loss measurements. The sum or mean of the results of these measurements are in theloss
parameter.) –