1# Copyright © 2020 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import inspect 4import sys 5import logging 6 7from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion 8 9# Parsers 10 11try: 12 from ._generated.pyarmnn_onnxparser import IOnnxParser 13except ImportError as err: 14 logger = logging.getLogger(__name__) 15 message = "Your ArmNN library instance does not support Onnx models parser functionality. " 16 logger.warning("%s Skipped IOnnxParser import.", message) 17 logger.debug(str(err)) 18 19 20 def IOnnxParser(): 21 """In case people try importing without having Arm NN built with this parser.""" 22 raise RuntimeError(message) 23 24try: 25 from ._generated.pyarmnn_tfliteparser import ITfLiteParser, TfLiteParserOptions 26except ImportError as err: 27 logger = logging.getLogger(__name__) 28 message = "Your ArmNN library instance does not support TF lite models parser functionality. " 29 logger.warning("%s Skipped ITfLiteParser import.", message) 30 logger.debug(str(err)) 31 32 33 def ITfLiteParser(): 34 """In case people try importing without having Arm NN built with this parser.""" 35 raise RuntimeError(message) 36 37try: 38 from ._generated.pyarmnn_deserializer import IDeserializer 39except ImportError as err: 40 logger = logging.getLogger(__name__) 41 message = "Your ArmNN library instance does not have ArmNN model (.armnn) parser functionality. " 42 logger.warning("%s Skipped IDeserializer import.", message) 43 logger.debug(str(err)) 44 45 def IDeserializer(): 46 """In case people try importing without having ArmNN built with this parser.""" 47 raise RuntimeError(message) 48 49# Network 50from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \ 51 IOutputSlot, IConnectableLayer, INetwork 52 53# Backend 54from ._generated.pyarmnn import BackendId 55from ._generated.pyarmnn import IDeviceSpec 56from ._generated.pyarmnn import BackendOptions, BackendOption 57 58# Tensors 59from ._generated.pyarmnn import TensorInfo, TensorShape 60 61# Runtime 62from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties 63 64# Profiler 65from ._generated.pyarmnn import IProfiler 66 67# Types 68from ._generated.pyarmnn import DataType_Float16, DataType_Float32, DataType_QAsymmU8, DataType_Signed32, \ 69 DataType_Boolean, DataType_QSymmS16, DataType_QSymmS8, DataType_QAsymmS8, ShapeInferenceMethod_ValidateOnly, \ 70 ShapeInferenceMethod_InferAndValidate 71from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC, DataLayout_NCDHW, DataLayout_NDHWC 72from ._generated.pyarmnn import MemorySource_Malloc, MemorySource_Undefined, MemorySource_DmaBuf, \ 73 MemorySource_DmaBufProtected 74from ._generated.pyarmnn import ProfilingDetailsMethod_Undefined, ProfilingDetailsMethod_DetailsWithEvents, \ 75 ProfilingDetailsMethod_DetailsOnly 76 77from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \ 78 ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \ 79 ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor 80from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor 81from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor 82from ._generated.pyarmnn import ChannelShuffleDescriptor, ComparisonDescriptor, ComparisonOperation_Equal, \ 83 ComparisonOperation_Greater, ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \ 84 ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual 85from ._generated.pyarmnn import UnaryOperation_Abs, UnaryOperation_Exp, UnaryOperation_Sqrt, UnaryOperation_Rsqrt, \ 86 UnaryOperation_Neg, ElementwiseUnaryDescriptor 87from ._generated.pyarmnn import LogicalBinaryOperation_LogicalAnd, LogicalBinaryOperation_LogicalOr, \ 88 LogicalBinaryDescriptor 89from ._generated.pyarmnn import Convolution2dDescriptor, Convolution3dDescriptor, DepthToSpaceDescriptor, \ 90 DepthwiseConvolution2dDescriptor, DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FillDescriptor, \ 91 FullyConnectedDescriptor, GatherDescriptor, InstanceNormalizationDescriptor, LstmDescriptor, \ 92 L2NormalizationDescriptor, MeanDescriptor 93from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \ 94 NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor 95from ._generated.pyarmnn import PaddingMode_Constant, PaddingMode_Reflect, PaddingMode_Symmetric, PadDescriptor 96from ._generated.pyarmnn import PermutationVector, PermuteDescriptor 97from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \ 98 PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \ 99 PoolingAlgorithm_Max, Pooling2dDescriptor, Pooling3dDescriptor 100from ._generated.pyarmnn import ReduceDescriptor, ReduceOperation_Prod, ReduceOperation_Max, ReduceOperation_Mean, \ 101 ReduceOperation_Min, ReduceOperation_Sum 102from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \ 103 ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \ 104 StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \ 105 TransposeDescriptor, SplitterDescriptor 106from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation 107 108from ._generated.pyarmnn import LstmInputParams, QuantizedLstmInputParams 109 110# Public API 111# Quantization 112from ._quantization.quantize_and_dequantize import quantize, dequantize 113 114# Tensor 115from ._tensor.tensor import Tensor 116from ._tensor.const_tensor import ConstTensor 117from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray 118 119# Utilities 120from ._utilities.profiling_helper import ProfilerData, get_profiling_data 121 122from ._version import __version__, __arm_ml_version__ 123 124ARMNN_VERSION = GetVersion() 125 126 127def __check_version(): 128 from ._version import check_armnn_version 129 check_armnn_version(ARMNN_VERSION) 130 131 132__check_version() 133 134__all__ = [] 135 136__private_api_names = ['__check_version'] 137 138for name, obj in inspect.getmembers(sys.modules[__name__]): 139 if inspect.isclass(obj) or inspect.isfunction(obj): 140 if name not in __private_api_names: 141 __all__.append(name) 142