xref: /aosp_15_r20/external/armnn/python/pyarmnn/src/pyarmnn/__init__.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import inspect
4import sys
5import logging
6
7from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion
8
9# Parsers
10
11try:
12    from ._generated.pyarmnn_onnxparser import IOnnxParser
13except ImportError as err:
14    logger = logging.getLogger(__name__)
15    message = "Your ArmNN library instance does not support Onnx models parser functionality. "
16    logger.warning("%s Skipped IOnnxParser import.", message)
17    logger.debug(str(err))
18
19
20    def IOnnxParser():
21        """In case people try importing without having Arm NN built with this parser."""
22        raise RuntimeError(message)
23
24try:
25    from ._generated.pyarmnn_tfliteparser import ITfLiteParser, TfLiteParserOptions
26except ImportError as err:
27    logger = logging.getLogger(__name__)
28    message = "Your ArmNN library instance does not support TF lite models parser functionality. "
29    logger.warning("%s Skipped ITfLiteParser import.", message)
30    logger.debug(str(err))
31
32
33    def ITfLiteParser():
34        """In case people try importing without having Arm NN built with this parser."""
35        raise RuntimeError(message)
36
37try:
38    from ._generated.pyarmnn_deserializer import IDeserializer
39except ImportError as err:
40    logger = logging.getLogger(__name__)
41    message = "Your ArmNN library instance does not have ArmNN model (.armnn) parser functionality. "
42    logger.warning("%s Skipped IDeserializer import.", message)
43    logger.debug(str(err))
44
45    def IDeserializer():
46        """In case people try importing without having ArmNN built with this parser."""
47        raise RuntimeError(message)
48
49# Network
50from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \
51    IOutputSlot, IConnectableLayer, INetwork
52
53# Backend
54from ._generated.pyarmnn import BackendId
55from ._generated.pyarmnn import IDeviceSpec
56from ._generated.pyarmnn import BackendOptions, BackendOption
57
58# Tensors
59from ._generated.pyarmnn import TensorInfo, TensorShape
60
61# Runtime
62from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties
63
64# Profiler
65from ._generated.pyarmnn import IProfiler
66
67# Types
68from ._generated.pyarmnn import DataType_Float16, DataType_Float32, DataType_QAsymmU8, DataType_Signed32, \
69    DataType_Boolean, DataType_QSymmS16, DataType_QSymmS8, DataType_QAsymmS8, ShapeInferenceMethod_ValidateOnly, \
70    ShapeInferenceMethod_InferAndValidate
71from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC, DataLayout_NCDHW, DataLayout_NDHWC
72from ._generated.pyarmnn import MemorySource_Malloc, MemorySource_Undefined, MemorySource_DmaBuf, \
73    MemorySource_DmaBufProtected
74from ._generated.pyarmnn import ProfilingDetailsMethod_Undefined, ProfilingDetailsMethod_DetailsWithEvents, \
75    ProfilingDetailsMethod_DetailsOnly
76
77from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \
78    ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \
79    ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor
80from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor
81from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor
82from ._generated.pyarmnn import ChannelShuffleDescriptor, ComparisonDescriptor, ComparisonOperation_Equal, \
83    ComparisonOperation_Greater, ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
84    ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual
85from ._generated.pyarmnn import UnaryOperation_Abs, UnaryOperation_Exp, UnaryOperation_Sqrt, UnaryOperation_Rsqrt, \
86    UnaryOperation_Neg, ElementwiseUnaryDescriptor
87from ._generated.pyarmnn import LogicalBinaryOperation_LogicalAnd, LogicalBinaryOperation_LogicalOr, \
88    LogicalBinaryDescriptor
89from ._generated.pyarmnn import Convolution2dDescriptor, Convolution3dDescriptor, DepthToSpaceDescriptor, \
90    DepthwiseConvolution2dDescriptor, DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FillDescriptor, \
91    FullyConnectedDescriptor, GatherDescriptor, InstanceNormalizationDescriptor, LstmDescriptor, \
92    L2NormalizationDescriptor, MeanDescriptor
93from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \
94    NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor
95from ._generated.pyarmnn import PaddingMode_Constant, PaddingMode_Reflect, PaddingMode_Symmetric, PadDescriptor
96from ._generated.pyarmnn import PermutationVector, PermuteDescriptor
97from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \
98    PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \
99    PoolingAlgorithm_Max, Pooling2dDescriptor, Pooling3dDescriptor
100from ._generated.pyarmnn import ReduceDescriptor, ReduceOperation_Prod, ReduceOperation_Max, ReduceOperation_Mean, \
101    ReduceOperation_Min, ReduceOperation_Sum
102from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \
103    ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \
104    StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \
105    TransposeDescriptor, SplitterDescriptor
106from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation
107
108from ._generated.pyarmnn import LstmInputParams, QuantizedLstmInputParams
109
110# Public API
111# Quantization
112from ._quantization.quantize_and_dequantize import quantize, dequantize
113
114# Tensor
115from ._tensor.tensor import Tensor
116from ._tensor.const_tensor import ConstTensor
117from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray
118
119# Utilities
120from ._utilities.profiling_helper import ProfilerData, get_profiling_data
121
122from ._version import __version__, __arm_ml_version__
123
124ARMNN_VERSION = GetVersion()
125
126
127def __check_version():
128    from ._version import check_armnn_version
129    check_armnn_version(ARMNN_VERSION)
130
131
132__check_version()
133
134__all__ = []
135
136__private_api_names = ['__check_version']
137
138for name, obj in inspect.getmembers(sys.modules[__name__]):
139    if inspect.isclass(obj) or inspect.isfunction(obj):
140        if name not in __private_api_names:
141            __all__.append(name)
142