Source code for mridc.utils.timers

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/timers.py

import time

import numpy as np
import torch

__all__ = ["NamedTimer"]


[docs]class NamedTimer: """ A timer class that supports multiple named timers. A named timer can be used multiple times, in which case the average dt will be returned. A named timer cannot be started if it is already currently running. Use case: measuring execution of multiple code blocks. """ _REDUCTION_TYPE = ["mean", "sum", "min", "max", "none"] def __init__(self, reduction="mean", sync_cuda=False, buffer_size=-1): """ Parameters ---------- reduction: Reduction over multiple timings of the same timer (none - returns the list instead of a scalar). sync_cuda: If True torch.cuda.synchronize() is called for start/stop buffer_size: If positive, limits the number of stored measures per name """ if reduction not in self._REDUCTION_TYPE: raise ValueError(f"Unknown reduction={reduction} please use one of {self._REDUCTION_TYPE}") self._reduction = reduction self._sync_cuda = sync_cuda self._buffer_size = buffer_size self.reset() def __getitem__(self, k): return self.get(k) @property def buffer_size(self): """Returns the buffer size of the timer.""" return self._buffer_size @property def _reduction_fn(self): """Returns the reduction function for the timer.""" if self._reduction == "none": def fn(x): return x else: fn = getattr(np, self._reduction) return fn
[docs] def reset(self, name=None): """ Resents all / specific timer Parameters ---------- name: Timer name to reset (if None all timers are reset) """ if name is None: self.timers = {} else: self.timers[name] = {}
[docs] def start(self, name=""): """ Starts measuring a named timer. Parameters ---------- name: timer name to start """ timer_data = self.timers.get(name, {}) if "start" in timer_data: raise RuntimeError(f"Cannot start timer = '{name}' since it is already active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): torch.cuda.synchronize() timer_data["start"] = time.time() self.timers[name] = timer_data
[docs] def stop(self, name=""): """ Stops measuring a named timer. Parameters ---------- name: timer name to stop """ timer_data = self.timers.get(name) if (timer_data is None) or ("start" not in timer_data): raise RuntimeError(f"Cannot end timer = '{name}' since it is not active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): torch.cuda.synchronize() # compute dt and make timer inactive dt = time.time() - timer_data.pop("start") # store dt timer_data["dt"] = timer_data.get("dt", []) + [dt] # enforce buffer_size if positive if self._buffer_size > 0: timer_data["dt"] = timer_data["dt"][-self._buffer_size :] self.timers[name] = timer_data
[docs] def active_timers(self): """Return list of all active named timers""" return [k for k, v in self.timers.items() if "start" in v]
[docs] def get(self, name=""): """ Returns the value of a named timer Parameters ---------- name: timer name to return """ dt_list = self.timers[name].get("dt", []) return self._reduction_fn(dt_list)
[docs] def export(self): """Exports a dictionary with average/all dt per named timer""" fn = self._reduction_fn return {k: fn(v["dt"]) for k, v in self.timers.items() if "dt" in v}