1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #include <unordered_map> 10 11 #include <executorch/runtime/core/exec_aten/exec_aten.h> 12 13 namespace executorch::extension { 14 15 constexpr static int kTensorDTypeUInt8 = 0; 16 constexpr static int kTensorDTypeInt8 = 1; 17 constexpr static int kTensorDTypeInt16 = 2; 18 constexpr static int kTensorDTypeInt32 = 3; 19 constexpr static int kTensorDTypeInt64 = 4; 20 constexpr static int kTensorDTypeHalf = 5; 21 constexpr static int kTensorDTypeFloat = 6; 22 constexpr static int kTensorDTypeDouble = 7; 23 // These types are not supported yet 24 // constexpr static int kTensorDTypeComplexHalf = 8; 25 // constexpr static int kTensorDTypeComplexFloat = 9; 26 // constexpr static int kTensorDTypeComplexDouble = 10; 27 constexpr static int kTensorDTypeBool = 11; 28 constexpr static int kTensorDTypeQint8 = 12; 29 constexpr static int kTensorDTypeQuint8 = 13; 30 constexpr static int kTensorDTypeQint32 = 14; 31 constexpr static int kTensorDTypeBFloat16 = 15; 32 constexpr static int kTensorDTypeQuint4x2 = 16; 33 constexpr static int kTensorDTypeQuint2x4 = 17; 34 constexpr static int kTensorDTypeBits1x8 = 18; 35 constexpr static int kTensorDTypeBits2x4 = 19; 36 constexpr static int kTensorDTypeBits4x2 = 20; 37 constexpr static int kTensorDTypeBits8 = 21; 38 constexpr static int kTensorDTypeBits16 = 22; 39 40 using executorch::aten::ScalarType; 41 42 const std::unordered_map<ScalarType, int> scalar_type_to_java_dtype = { 43 {ScalarType::Byte, kTensorDTypeUInt8}, 44 {ScalarType::Char, kTensorDTypeInt8}, 45 {ScalarType::Short, kTensorDTypeInt16}, 46 {ScalarType::Int, kTensorDTypeInt32}, 47 {ScalarType::Long, kTensorDTypeInt64}, 48 {ScalarType::Half, kTensorDTypeHalf}, 49 {ScalarType::Float, kTensorDTypeFloat}, 50 {ScalarType::Double, kTensorDTypeDouble}, 51 // These types are not supported yet 52 // {ScalarType::ComplexHalf, kTensorDTypeComplexHalf}, 53 // {ScalarType::ComplexFloat, kTensorDTypeComplexFloat}, 54 // {ScalarType::ComplexDouble, kTensorDTypeComplexDouble}, 55 {ScalarType::Bool, kTensorDTypeBool}, 56 {ScalarType::QInt8, kTensorDTypeQint8}, 57 {ScalarType::QUInt8, kTensorDTypeQuint8}, 58 {ScalarType::QInt32, kTensorDTypeQint32}, 59 {ScalarType::BFloat16, kTensorDTypeBFloat16}, 60 {ScalarType::QUInt4x2, kTensorDTypeQuint4x2}, 61 {ScalarType::QUInt2x4, kTensorDTypeQuint2x4}, 62 {ScalarType::Bits1x8, kTensorDTypeBits1x8}, 63 {ScalarType::Bits2x4, kTensorDTypeBits2x4}, 64 {ScalarType::Bits4x2, kTensorDTypeBits4x2}, 65 {ScalarType::Bits8, kTensorDTypeBits8}, 66 {ScalarType::Bits16, kTensorDTypeBits16}, 67 }; 68 69 const std::unordered_map<int, ScalarType> java_dtype_to_scalar_type = { 70 {kTensorDTypeUInt8, ScalarType::Byte}, 71 {kTensorDTypeInt8, ScalarType::Char}, 72 {kTensorDTypeInt16, ScalarType::Short}, 73 {kTensorDTypeInt32, ScalarType::Int}, 74 {kTensorDTypeInt64, ScalarType::Long}, 75 {kTensorDTypeHalf, ScalarType::Half}, 76 {kTensorDTypeFloat, ScalarType::Float}, 77 {kTensorDTypeDouble, ScalarType::Double}, 78 // These types are not supported yet 79 // {kTensorDTypeComplexHalf, ScalarType::ComplexHalf}, 80 // {kTensorDTypeComplexFloat, ScalarType::ComplexFloat}, 81 // {kTensorDTypeComplexDouble, ScalarType::ComplexDouble}, 82 {kTensorDTypeBool, ScalarType::Bool}, 83 {kTensorDTypeQint8, ScalarType::QInt8}, 84 {kTensorDTypeQuint8, ScalarType::QUInt8}, 85 {kTensorDTypeQint32, ScalarType::QInt32}, 86 {kTensorDTypeBFloat16, ScalarType::BFloat16}, 87 {kTensorDTypeQuint4x2, ScalarType::QUInt4x2}, 88 {kTensorDTypeQuint2x4, ScalarType::QUInt2x4}, 89 {kTensorDTypeBits1x8, ScalarType::Bits1x8}, 90 {kTensorDTypeBits2x4, ScalarType::Bits2x4}, 91 {kTensorDTypeBits4x2, ScalarType::Bits4x2}, 92 {kTensorDTypeBits8, ScalarType::Bits8}, 93 {kTensorDTypeBits16, ScalarType::Bits16}, 94 }; 95 96 } // namespace executorch::extension 97