Source code for mridc.core.utils.neural_type_utils

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/utils/neural_type_utils.py

from collections import defaultdict

from mridc.core.neural_types.axes import AxisKind
from mridc.core.neural_types.neural_type import NeuralType


[docs]def get_io_names(types, disabled_names): """ This method will return a list of input and output names for a given NeuralType. Parameters ---------- types: The NeuralType of the module or model to be inspected. disabled_names: A list of names that should be excluded from the result. Returns ------- A list of input and output names. """ names = list(types.keys()) for name in disabled_names: if name in names: names.remove(name) return names
[docs]def extract_dynamic_axes(name: str, ntype: NeuralType): """ This method will extract BATCH and TIME dimension ids from each provided input/output name argument. For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes as they can change from call to call during inference. Parameters ---------- name: Name of input or output parameter ntype: Corresponding Neural Type Returns ------- A dictionary with input/output name as key and a list of dynamic axes as value. """ def unpack_nested_neural_type(neural_type): """ This method will unpack nested NeuralTypes. Parameters ---------- neural_type: The NeuralType to be unpacked. Returns ------- A list of all the nested NeuralTypes. """ if type(neural_type) in (list, tuple): return unpack_nested_neural_type(neural_type[0]) return neural_type dynamic_axes = defaultdict(list) if type(ntype) in (list, tuple): ntype = unpack_nested_neural_type(ntype) if ntype.axes: for ind, axis in enumerate(ntype.axes): if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]: dynamic_axes[name].append(ind) return dynamic_axes
[docs]def get_dynamic_axes(types, names): """ This method will return a dictionary with input/output names as keys and a list of dynamic axes as values. Parameters ---------- types: The NeuralType of the module or model to be inspected. names: A list of names that should be inspected. Returns ------- A dictionary with input/output names as keys and a list of dynamic axes as values. """ dynamic_axes = defaultdict(list) for name in names: if name in types: dynamic_axes |= extract_dynamic_axes(name, types[name]) return dynamic_axes