# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/neural_types/elements.py
from abc import ABC, ABCMeta
from typing import Dict, Optional, Tuple
__all__ = [
"ElementType",
"VoidType",
"ChannelType",
"MRISignal",
"RecurrentsType",
"LabelsType",
"LogprobsType",
"ProbsType",
"LossType",
"RegressionValuesType",
"CategoricalValuesType",
"PredictionsType",
"LengthsType",
"MaskType",
"Target",
"ReconstructionTarget",
"ImageFeatureValue",
"Index",
"ImageValue",
"NormalizedImageValue",
"StringLabel",
"StringType",
"Length",
"IntType",
"FloatType",
"NormalDistributionSamplesType",
"NormalDistributionMeanType",
"NormalDistributionLogVarianceType",
"LogDeterminantType",
"SequenceToSequenceAlignmentType",
]
from mridc.core.neural_types.comparison import NeuralTypeComparisonResult
[docs]class ElementType(ABC):
"""Abstract class defining semantics of the tensor elements. We are relying on Python for inheritance checking"""
[docs] def __str__(self):
"""Override this method to provide a human readable representation of the type"""
return self.__doc__
[docs] def __repr__(self):
"""Override this method to provide a human readable representation of the type"""
return self.__class__.__name__
@property
def type_parameters(self) -> Dict:
"""
Override this property to parametrize your type. For example, you can specify 'storage' type such as float,
int, bool with 'dtype' keyword. Another example, is if you want to represent a signal with a particular
property (say, sample frequency), then you can put sample_freq->value in there. When two types are compared
their type_parameters must match."
"""
return {}
@property
def fields(self) -> Optional[Tuple]:
"""
This should be used to logically represent tuples/structures. For example, if you want to represent a \
bounding box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here. Under the \
hood this should be converted to the last tensor dimension of fixed size = len(fields). When two types are \
compared their fields must match.
"""
return None
[docs] def compare(self, second) -> NeuralTypeComparisonResult:
"""Override this method to provide a comparison between two types."""
# First, check general compatibility
first_t = type(self)
second_t = type(second)
if first_t == second_t:
result = NeuralTypeComparisonResult.SAME
elif issubclass(first_t, second_t):
result = NeuralTypeComparisonResult.LESS
elif issubclass(second_t, first_t):
result = NeuralTypeComparisonResult.GREATER
else:
result = NeuralTypeComparisonResult.INCOMPATIBLE
if result != NeuralTypeComparisonResult.SAME:
return result
# now check that all parameters match
check_params = set(self.type_parameters.keys()) == set(second.type_parameters.keys())
if not check_params:
return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
for k1, v1 in self.type_parameters.items():
if v1 is None or second.type_parameters[k1] is None:
# Treat None as Void
continue
if v1 != second.type_parameters[k1]:
return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
# check that all fields match
if self.fields == second.fields:
return NeuralTypeComparisonResult.SAME
return NeuralTypeComparisonResult.INCOMPATIBLE
[docs]class VoidType(ElementType):
"""
Void-like type which is compatible with everything. It is a good practice to use this type only as necessary.
For example, when you need template-like functionality.
"""
[docs] def compare(cls, second: ABCMeta) -> NeuralTypeComparisonResult:
"""Void type is compatible with everything."""
return NeuralTypeComparisonResult.SAME
# TODO: Consider moving these files elsewhere
[docs]class ChannelType(ElementType):
"""Element to represent convolutional input/output channel."""
[docs]class RecurrentsType(ElementType):
"""Element type to represent recurrent layers"""
[docs]class LengthsType(ElementType):
"""Element type representing lengths of something"""
[docs]class ProbsType(ElementType):
"""Element type to represent probabilities. For example, outputs of softmax layers."""
[docs]class LogprobsType(ElementType):
"""Element type to represent log-probabilities. For example, outputs of log softmax layers."""
[docs]class LossType(ElementType):
"""Element type to represent outputs of Loss modules"""
[docs]class MRISignal(ElementType):
"""
Element type to represent encoded representation returned by the mri model
Parameters
----------
freq: sampling frequency of a signal. Note that two signals will only be the same if their freq is the same.
"""
def __init__(self, freq: int = None):
self._params = {"freq": freq}
@property
def type_parameters(self):
"""Returns the type parameters of the element type."""
return self._params
[docs]class LabelsType(ElementType):
"""Element type to represent labels of something. For example, labels of a dataset."""
[docs]class PredictionsType(LabelsType):
"""Element type to represent some sort of predictions returned by model"""
[docs]class RegressionValuesType(PredictionsType):
"""Element type to represent labels for regression task"""
[docs]class CategoricalValuesType(PredictionsType):
"""Element type to represent labels for categorical classification task"""
[docs]class MaskType(PredictionsType):
"""Element type to represent a boolean mask"""
[docs]class Index(ElementType):
"""Type representing an element being an index of the sample."""
[docs]class Target(ElementType):
"""Type representing an element being a target value."""
[docs]class ReconstructionTarget(Target):
"""
Type representing an element being target value in the reconstruction task, i.e. identifier of a desired
class.
"""
[docs]class ImageValue(ElementType):
"""Type representing an element/value of a single image channel,"""
[docs]class NormalizedImageValue(ImageValue):
"""Type representing an element/value of a single image channel normalized to <0-1> range."""
[docs]class ImageFeatureValue(ImageValue):
"""Type representing an element (single value) of a (image) feature maps."""
[docs]class StringType(ElementType):
"""Element type representing a single string"""
[docs]class StringLabel(StringType):
"""Type representing a label being a string with class name (e.g. the "hamster" class in CIFAR100)."""
class BoolType(ElementType):
"""Element type representing a single integer"""
[docs]class IntType(ElementType):
"""Element type representing a single integer"""
[docs]class FloatType(ElementType):
"""Element type representing a single float"""
[docs]class Length(IntType):
"""Type representing an element storing a "length" (e.g. length of a list)."""
class ProbabilityDistributionSamplesType(ElementType):
"""Element to represent tensors that meant to be sampled from a valid probability distribution"""
[docs]class NormalDistributionSamplesType(ProbabilityDistributionSamplesType):
"""Element to represent tensors that meant to be sampled from a valid normal distribution"""
[docs]class SequenceToSequenceAlignmentType(ElementType):
"""
Class to represent the alignment from seq-to-seq attention outputs. Generally a mapping from encoder time steps
to decoder time steps.
"""
[docs]class NormalDistributionMeanType(ElementType):
"""Element to represent the mean of a normal distribution"""
[docs]class NormalDistributionLogVarianceType(ElementType):
"""Element to represent the log variance of a normal distribution"""
[docs]class LogDeterminantType(ElementType):
"""Element for representing log determinants usually used in flow models"""