Source code for mridc.utils.env_var_parsing

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/env_var_parsing.py

import decimal
import json
import os

from dateutil import parser  # type: ignore

__all__ = [
    "get_env",
    "get_envbool",
    "get_envint",
    "get_envfloat",
    "get_envdecimal",
    "get_envdate",
    "get_envdatetime",
    "get_envlist",
    "get_envdict",
    "CoercionError",
    "RequiredSettingMissingError",
]


[docs]class CoercionError(Exception): """Custom error raised when a value cannot be coerced.""" def __init__(self, key, value, func): msg = f"Unable to coerce '{key}={value}' using {func.__name__}." super(CoercionError, self).__init__(msg)
[docs]class RequiredSettingMissingError(Exception): """Custom error raised when a required env var is missing.""" def __init__(self, key): msg = f"Required env var '{key}' is missing." super(RequiredSettingMissingError, self).__init__(msg)
def _get_env(key, default=None, coerce=lambda x: x, required=False): """ Return env var coerced into a type other than string. This function extends the standard os.getenv function to \ enable the coercion of values into data types other than string (all env vars are strings by default). Parameters ---------- key: The name of the env var to retrieve. default: The default value to return if the env var is not set. NB the default value is **not** coerced, and is \ assumed to be of the correct type. coerce: A function that takes a string and returns a value of the desired type. required: If True, raises a RequiredSettingMissingError if the env var is not set. Returns ------- The value of the env var coerced into the desired type. """ try: value = os.environ[key] except KeyError as e: if required is True: raise RequiredSettingMissingError(key) from e return default try: return coerce(value) except Exception as exc: raise CoercionError(key, value, coerce) from exc # standard type coercion functions def _bool(value): """Return env var cast as boolean.""" if isinstance(value, bool): return value return value is not None and value.lower() not in ( "false", "0", "no", "n", "f", "none", ) def _int(value): """Return env var cast as integer.""" return int(value) def _float(value): """Return env var cast as float.""" return float(value) def _decimal(value): """Return env var cast as Decimal.""" return decimal.Decimal(value) def _dict(value): """Return env var as a dict.""" return json.loads(value) def _datetime(value): """Return env var as a datetime.""" return parser.parse(value) def _date(value): """Return env var as a date.""" return parser.parse(value).date()
[docs]def get_env(key, *default, **kwargs): """ Return env var. This is the parent function of all other get_foo functions, and is responsible for unpacking \ args/kwargs into the values that _get_env expects (it is the root function that actually interacts with environ). Parameters ---------- key: string, the env var name to look up. default: (optional) the value to use if the env var does not exist. If this value is not supplied, then the \ env var is considered to be required, and a RequiredSettingMissingError error will be raised if it does not exist. kwargs: coerce: a func that may be supplied to coerce the value into something else. This is used by the default \ get_foo functions to cast strings to builtin types, but could be a function that returns a custom class. Returns ------- The env var, coerced if required, and a default if supplied. """ if len(default) not in (0, 1): raise AssertionError("Too many args supplied.") func = kwargs.get("coerce", lambda x: x) required = len(default) == 0 default = None if required else default[0] return _get_env(key, default=default, coerce=func, required=required)
[docs]def get_envbool(key, *default): """Return env var cast as boolean.""" return get_env(key, *default, coerce=_bool)
[docs]def get_envint(key, *default): """Return env var cast as integer.""" return get_env(key, *default, coerce=_int)
[docs]def get_envfloat(key, *default): """Return env var cast as float.""" return get_env(key, *default, coerce=_float)
[docs]def get_envdecimal(key, *default): """Return env var cast as Decimal.""" return get_env(key, *default, coerce=_decimal)
[docs]def get_envdate(key, *default): """Return env var as a date.""" return get_env(key, *default, coerce=_date)
[docs]def get_envdatetime(key, *default): """Return env var as a datetime.""" return get_env(key, *default, coerce=_datetime)
[docs]def get_envlist(key, *default, **kwargs): """Return env var as a list.""" separator = kwargs.get("separator", " ") return get_env(key, *default, coerce=lambda x: x.split(separator))
[docs]def get_envdict(key, *default): """Return env var as a dict.""" return get_env(key, *default, coerce=_dict)