xref: /aosp_15_r20/external/executorch/extension/android/jni/jni_layer_constants.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <unordered_map>
10 
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 
13 namespace executorch::extension {
14 
15 constexpr static int kTensorDTypeUInt8 = 0;
16 constexpr static int kTensorDTypeInt8 = 1;
17 constexpr static int kTensorDTypeInt16 = 2;
18 constexpr static int kTensorDTypeInt32 = 3;
19 constexpr static int kTensorDTypeInt64 = 4;
20 constexpr static int kTensorDTypeHalf = 5;
21 constexpr static int kTensorDTypeFloat = 6;
22 constexpr static int kTensorDTypeDouble = 7;
23 // These types are not supported yet
24 // constexpr static int kTensorDTypeComplexHalf = 8;
25 // constexpr static int kTensorDTypeComplexFloat = 9;
26 // constexpr static int kTensorDTypeComplexDouble = 10;
27 constexpr static int kTensorDTypeBool = 11;
28 constexpr static int kTensorDTypeQint8 = 12;
29 constexpr static int kTensorDTypeQuint8 = 13;
30 constexpr static int kTensorDTypeQint32 = 14;
31 constexpr static int kTensorDTypeBFloat16 = 15;
32 constexpr static int kTensorDTypeQuint4x2 = 16;
33 constexpr static int kTensorDTypeQuint2x4 = 17;
34 constexpr static int kTensorDTypeBits1x8 = 18;
35 constexpr static int kTensorDTypeBits2x4 = 19;
36 constexpr static int kTensorDTypeBits4x2 = 20;
37 constexpr static int kTensorDTypeBits8 = 21;
38 constexpr static int kTensorDTypeBits16 = 22;
39 
40 using executorch::aten::ScalarType;
41 
42 const std::unordered_map<ScalarType, int> scalar_type_to_java_dtype = {
43     {ScalarType::Byte, kTensorDTypeUInt8},
44     {ScalarType::Char, kTensorDTypeInt8},
45     {ScalarType::Short, kTensorDTypeInt16},
46     {ScalarType::Int, kTensorDTypeInt32},
47     {ScalarType::Long, kTensorDTypeInt64},
48     {ScalarType::Half, kTensorDTypeHalf},
49     {ScalarType::Float, kTensorDTypeFloat},
50     {ScalarType::Double, kTensorDTypeDouble},
51     // These types are not supported yet
52     // {ScalarType::ComplexHalf, kTensorDTypeComplexHalf},
53     // {ScalarType::ComplexFloat, kTensorDTypeComplexFloat},
54     // {ScalarType::ComplexDouble, kTensorDTypeComplexDouble},
55     {ScalarType::Bool, kTensorDTypeBool},
56     {ScalarType::QInt8, kTensorDTypeQint8},
57     {ScalarType::QUInt8, kTensorDTypeQuint8},
58     {ScalarType::QInt32, kTensorDTypeQint32},
59     {ScalarType::BFloat16, kTensorDTypeBFloat16},
60     {ScalarType::QUInt4x2, kTensorDTypeQuint4x2},
61     {ScalarType::QUInt2x4, kTensorDTypeQuint2x4},
62     {ScalarType::Bits1x8, kTensorDTypeBits1x8},
63     {ScalarType::Bits2x4, kTensorDTypeBits2x4},
64     {ScalarType::Bits4x2, kTensorDTypeBits4x2},
65     {ScalarType::Bits8, kTensorDTypeBits8},
66     {ScalarType::Bits16, kTensorDTypeBits16},
67 };
68 
69 const std::unordered_map<int, ScalarType> java_dtype_to_scalar_type = {
70     {kTensorDTypeUInt8, ScalarType::Byte},
71     {kTensorDTypeInt8, ScalarType::Char},
72     {kTensorDTypeInt16, ScalarType::Short},
73     {kTensorDTypeInt32, ScalarType::Int},
74     {kTensorDTypeInt64, ScalarType::Long},
75     {kTensorDTypeHalf, ScalarType::Half},
76     {kTensorDTypeFloat, ScalarType::Float},
77     {kTensorDTypeDouble, ScalarType::Double},
78     // These types are not supported yet
79     // {kTensorDTypeComplexHalf, ScalarType::ComplexHalf},
80     // {kTensorDTypeComplexFloat, ScalarType::ComplexFloat},
81     // {kTensorDTypeComplexDouble, ScalarType::ComplexDouble},
82     {kTensorDTypeBool, ScalarType::Bool},
83     {kTensorDTypeQint8, ScalarType::QInt8},
84     {kTensorDTypeQuint8, ScalarType::QUInt8},
85     {kTensorDTypeQint32, ScalarType::QInt32},
86     {kTensorDTypeBFloat16, ScalarType::BFloat16},
87     {kTensorDTypeQuint4x2, ScalarType::QUInt4x2},
88     {kTensorDTypeQuint2x4, ScalarType::QUInt2x4},
89     {kTensorDTypeBits1x8, ScalarType::Bits1x8},
90     {kTensorDTypeBits2x4, ScalarType::Bits2x4},
91     {kTensorDTypeBits4x2, ScalarType::Bits4x2},
92     {kTensorDTypeBits8, ScalarType::Bits8},
93     {kTensorDTypeBits16, ScalarType::Bits16},
94 };
95 
96 } // namespace executorch::extension
97