Source code for devo_ml.modelmanager.engines

"""Engine code literals to identify the ML engines and infer their model
file extensions.
"""

from typing import Sequence


#: | Constant denoting `Open Neural Network Exchange <https://onnx.ai/>`_
# model.
#: | Extensions: ``.onnx``
ONNX = "ONNX"

#: | Constant denoting `H2O <https://h2o.ai/>`_ model.
#: | Extensions: ``.zip``
H2O = "H2O"

#: | Constant denoting `BigML <https://bigml.com/>`_ model.
#: | Extensions: ``.json``
BIGML = "BIGML"

#: | Constant denoting `CatBoost <https://catboost.ai/>`_ model.
#: | Extensions: ``.cmb``
CATBOOST = "CATBOOST"

#: | Constant denoting Decision Tree model.
#: | Extensions: ``.zip``
DT = "DT"

#: | Constant denoting Iterated Distillation and Amplification model.
#: | Extensions: ``.zip``
IDA = "IDA"

#: | Constant denoting ML STATS model.
#: | Extensions: ``-``
MLSTATS = "MLSTATS"

#: | Constant denoting MUA model.
#: | Extensions: ``.zip``
MUA = "MUA"

#: | Constant denoting UNICODE model.
#: | Extensions: ``.zip``
UNICODE = "UNICODE"

#: | Constant denoting WORKFLOWS model.
#: | Extensions: ``.json``
WORKFLOWS = "WORKFLOWS"

_dot_json = ".json"
_dot_zip = ".zip"
_dot_onnx = ".onnx"
_dot_h5 = ".h5"
_dot_cbm = ".cmb"

_aware_engines = {
    ONNX: {
        "code": ONNX,
        "extensions": [_dot_onnx],
    },
    H2O: {
        "code": H2O,
        "extensions": [_dot_zip],
    },
    BIGML: {
        "code": BIGML,
        "extensions": [_dot_json],
    },
    CATBOOST: {
        "code": CATBOOST,
        "extensions": [_dot_cbm],
    },
    DT: {
        "code": DT,
        "extensions": [_dot_json],
    },
    IDA: {
        "code": IDA,
        "extensions": [_dot_json],
    },
    MLSTATS: {
        "code": MLSTATS,
        "extensions": [],
    },
    MUA: {
        "code": MUA,
        "extensions": [_dot_zip],
    },
    UNICODE: {
        "code": UNICODE,
        "extensions": [_dot_zip],
    },
    WORKFLOWS: {
        "code": WORKFLOWS,
        "extensions": [_dot_json],
    },
}


[docs]def get_engine_extensions(engine_code: str) -> Sequence[str]: """Returns file extensions associated with an engine represented by its code. An empty list will be returned if it is an unknown engine code. :param engine_code: The code of the engine :return: The file extensions or an empty list. Extensions include the dot, e.g. ``.json``, ``.zip`` ... """ engine = _aware_engines.get(engine_code.upper()) return engine.get("extensions", []) if engine else []
[docs]def get_default_engine_extension(engine_code: str) -> str: """Returns the file extension associated with an engine represented by its code. An empty string will be returned if it is an unknown engine code. :param engine_code: The code of the engine :return: The extension associate with an engine or empty string. Extensions include the dot, e.g. ``.json``, ``.zip`` ... """ extensions = get_engine_extensions(engine_code) return extensions[0] if extensions else ""