Source code for mridc.core.neural_types.axes

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/axes.py

from enum import Enum
from typing import Optional

__all__ = ["AxisKindAbstract", "AxisKind", "AxisType"]


[docs]class AxisKindAbstract(Enum): """ This is an abstract Enum to represents what does varying axis dimension mean. In practice, you will almost always use AxisKind Enum. This Enum should be inherited by your OWN Enum if you aren't satisfied with AxisKind. Then your own Enum can be used instead of AxisKind. """
[docs]class AxisKind(AxisKindAbstract): """ This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, \ batch, time, etc. The "Dimension" and "Channel" kinds are the same and used to represent a general axis. "Any" \ axis will accept any axis kind fed to it. """ # TODO (wdika): change names of the enums Batch = 0 Time = 1 Dimension = 2 Channel = 2 Width = 3 Height = 4 Any = 5 Sequence = 6 FlowGroup = 7 Singleton = 8 # Used to represent a axis that has size 1
[docs] def __repr__(self): """Returns short string representation of the AxisKind""" return self.__str__()
[docs] def __str__(self): """Returns short string representation of the AxisKind""" return str(self.name).lower()
[docs] def t_with_string(self, text): """It checks if text is 't_<any string>'""" return text.startswith("t_") and text.endswith("_") and text[2:-1] == self.__str__()
[docs] @staticmethod def from_str(label): """Returns AxisKind instance based on short string representation""" _label = label.lower().strip() if _label in ("b", "n", "batch"): return AxisKind.Batch if _label == "t" or _label == "time" or (len(_label) > 2 and _label.startswith("t_")): return AxisKind.Time if _label in ("d", "c", "channel"): return AxisKind.Dimension if _label in ("w", "width"): return AxisKind.Width if _label in ("h", "height"): return AxisKind.Height if _label in ("s", "singleton"): return AxisKind.Singleton if _label in ("seq", "sequence"): return AxisKind.Sequence if _label == "flowgroup": return AxisKind.FlowGroup if _label == "any": return AxisKind.Any raise ValueError(f"Can't create AxisKind from {label}")
[docs]class AxisType: """This class represents axis semantics and (optionally) it's dimensionality Parameters ---------- kind: what kind of axis it is? For example Batch, Height, etc. AxisKindAbstract size: specify if the axis should have a fixed size. By default, it is set to None and you typically do not want to set it for Batch and Time. (int, optional) is_list: whether this is a list or a tensor axis. (bool, default=False) """ def __init__(self, kind: AxisKindAbstract, size: Optional[int] = None, is_list=False): if size is not None and is_list: raise ValueError("The axis can't be list and have a fixed size") self.kind = kind self.size = size self.is_list = is_list
[docs] def __repr__(self): """Returns short string representation of the AxisType""" if self.size is None: representation = str(self.kind) else: representation = f"{str(self.kind)}:{self.size}" if self.is_list: representation += "_listdim" return representation