xref: /aosp_15_r20/external/pytorch/c10/core/ScalarType.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ScalarType.h>
2 #include <c10/util/Array.h>
3 #include <array>
4 
5 namespace c10 {
6 
7 namespace {
8 
9 constexpr auto u1 = ScalarType::Byte;
10 constexpr auto i1 = ScalarType::Char;
11 constexpr auto i2 = ScalarType::Short;
12 constexpr auto i4 = ScalarType::Int;
13 constexpr auto i8 = ScalarType::Long;
14 constexpr auto f2 = ScalarType::Half;
15 constexpr auto f4 = ScalarType::Float;
16 constexpr auto f8 = ScalarType::Double;
17 constexpr auto c2 = ScalarType::ComplexHalf;
18 constexpr auto c4 = ScalarType::ComplexFloat;
19 constexpr auto c8 = ScalarType::ComplexDouble;
20 constexpr auto b1 = ScalarType::Bool;
21 constexpr auto bf = ScalarType::BFloat16;
22 constexpr auto ud = ScalarType::Undefined;
23 
24 constexpr auto index2dtype = array_of<
25     c10::ScalarType>(u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf);
26 
27 constexpr std::array<int64_t, static_cast<size_t>(ScalarType::NumOptions)>
calculate_dtype2index()28 calculate_dtype2index() {
29   std::array<int64_t, static_cast<size_t>(ScalarType::NumOptions)> inverse = {};
30   for (int64_t i = 0; i < static_cast<int64_t>(ScalarType::NumOptions); i++) {
31     inverse[i] = -1;
32   }
33   for (int64_t i = 0; i < static_cast<int64_t>(index2dtype.size()); i++) {
34     inverse[static_cast<int64_t>(index2dtype[i])] = i;
35   }
36   return inverse;
37 }
38 
39 constexpr auto dtype2index = calculate_dtype2index();
40 
41 } // anonymous namespace
42 
promoteTypes(ScalarType a,ScalarType b)43 ScalarType promoteTypes(ScalarType a, ScalarType b) {
44   // This is generated according to NumPy's promote_types
45   if (a == ud || b == ud) {
46     return ScalarType::Undefined;
47   }
48 
49   // If the two types are equal, return that type
50   if (a == b) {
51     return a;
52   }
53 
54   // Handle identically equal types
55   if (isQIntType(a) || isQIntType(b)) {
56     TORCH_CHECK(
57         false,
58         "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
59         toString(a),
60         " ",
61         toString(b));
62   }
63 
64   if (isBitsType(a) || isBitsType(b)) {
65     return ScalarType::Undefined;
66   }
67 
68   if (isFloat8Type(a) || isFloat8Type(b)) {
69     TORCH_CHECK(
70         false,
71         "Promotion for Float8 Types is not supported, attempted to promote ",
72         toString(a),
73         " and ",
74         toString(b));
75   }
76 
77   if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
78     // There are two problems with promotion here:
79     //
80     // - Our promotion rule for uint8 is inconsistent with Numpy; Numpy
81     //   promotes to uint64, but since we never had uint64 for the longest
82     //   time, we promote to int64.  Changing this is BC-breaking
83     //
84     // - We must not promote uint64 to int64 because this will overflow.
85     //
86     // It'll be a bit of work to fix it, so we're punting on it for now.
87     // However, float promotion is fine, so we handle that.
88     if (isFloatingType(a)) {
89       return a;
90     }
91     if (isFloatingType(b)) {
92       return b;
93     }
94     TORCH_CHECK(
95         false,
96         "Promotion for uint16, uint32, uint64 types is not supported, attempted to promote ",
97         toString(a),
98         " and ",
99         toString(b));
100   }
101 
102   auto ix_a = dtype2index[static_cast<int64_t>(a)];
103   TORCH_INTERNAL_ASSERT(ix_a != -1);
104   auto ix_b = dtype2index[static_cast<int64_t>(b)];
105   TORCH_INTERNAL_ASSERT(ix_b != -1);
106 
107   // This table axes must be consistent with index2dtype
108   // clang-format off
109   static constexpr std::
110   array<std::array<ScalarType, index2dtype.size()>, index2dtype.size()>
111       _promoteTypesLookup = {{
112       /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  bf*/
113       /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf},
114       /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf},
115       /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf},
116       /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf},
117       /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf},
118       /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4},
119       /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4},
120       /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8},
121       /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4},
122       /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4},
123       /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
124       /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf},
125       /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf},
126   }};
127   // clang-format on
128   return _promoteTypesLookup[ix_a][ix_b];
129 }
130 
getDtypeNames(c10::ScalarType scalarType)131 std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
132   switch (scalarType) {
133     case c10::ScalarType::UInt1:
134       return std::make_pair("uint1", "bit");
135     case c10::ScalarType::UInt2:
136       return std::make_pair("uint2", "");
137     case c10::ScalarType::UInt3:
138       return std::make_pair("uint3", "");
139     case c10::ScalarType::UInt4:
140       return std::make_pair("uint4", "");
141     case c10::ScalarType::UInt5:
142       return std::make_pair("uint5", "");
143     case c10::ScalarType::UInt6:
144       return std::make_pair("uint6", "");
145     case c10::ScalarType::UInt7:
146       return std::make_pair("uint7", "");
147     case c10::ScalarType::Byte:
148       // no "byte" because byte is signed in numpy and we overload
149       // byte to mean bool often
150       return std::make_pair("uint8", "");
151     case c10::ScalarType::UInt16:
152       return std::make_pair("uint16", "");
153     case c10::ScalarType::UInt32:
154       return std::make_pair("uint32", "");
155     case c10::ScalarType::UInt64:
156       return std::make_pair("uint64", "");
157     case c10::ScalarType::Char:
158       // no "char" because it is not consistently signed or unsigned; we want
159       // to move to int8
160       return std::make_pair("int8", "");
161     case c10::ScalarType::Double:
162       return std::make_pair("float64", "double");
163     case c10::ScalarType::Float:
164       return std::make_pair("float32", "float");
165     case c10::ScalarType::Int:
166       return std::make_pair("int32", "int");
167     case c10::ScalarType::Long:
168       return std::make_pair("int64", "long");
169     case c10::ScalarType::Short:
170       return std::make_pair("int16", "short");
171     case c10::ScalarType::Half:
172       return std::make_pair("float16", "half");
173     case c10::ScalarType::ComplexHalf:
174       return std::make_pair("complex32", "chalf");
175     case c10::ScalarType::ComplexFloat:
176       return std::make_pair("complex64", "cfloat");
177     case c10::ScalarType::ComplexDouble:
178       return std::make_pair("complex128", "cdouble");
179     case c10::ScalarType::Bool:
180       return std::make_pair("bool", "");
181     case c10::ScalarType::QInt8:
182       return std::make_pair("qint8", "");
183     case c10::ScalarType::QUInt8:
184       return std::make_pair("quint8", "");
185     case c10::ScalarType::QInt32:
186       return std::make_pair("qint32", "");
187     case c10::ScalarType::BFloat16:
188       return std::make_pair("bfloat16", "");
189     case c10::ScalarType::QUInt4x2:
190       return std::make_pair("quint4x2", "");
191     case c10::ScalarType::QUInt2x4:
192       return std::make_pair("quint2x4", "");
193     case c10::ScalarType::Bits1x8:
194       return std::make_pair("bits1x8", "");
195     case c10::ScalarType::Bits2x4:
196       return std::make_pair("bits2x4", "");
197     case c10::ScalarType::Bits4x2:
198       return std::make_pair("bits4x2", "");
199     case c10::ScalarType::Bits8:
200       return std::make_pair("bits8", "");
201     case c10::ScalarType::Bits16:
202       return std::make_pair("bits16", "");
203     case c10::ScalarType::Float8_e5m2:
204       return std::make_pair("float8_e5m2", "");
205     case c10::ScalarType::Float8_e4m3fn:
206       return std::make_pair("float8_e4m3fn", "");
207     case c10::ScalarType::Float8_e5m2fnuz:
208       return std::make_pair("float8_e5m2fnuz", "");
209     case c10::ScalarType::Float8_e4m3fnuz:
210       return std::make_pair("float8_e4m3fnuz", "");
211     default:
212       throw std::runtime_error("Unimplemented scalar type");
213   }
214 }
215 
getStringToDtypeMap()216 const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap() {
217   static std::unordered_map<std::string, ScalarType> result;
218   if (!result.empty()) {
219     return result;
220   }
221 
222 #define DEFINE_SCALAR_TYPE(_1, n) c10::ScalarType::n,
223 
224   auto all_scalar_types = {
225       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
226 
227 #undef DEFINE_SCALAR_TYPE
228 
229   for (auto scalar_type : all_scalar_types) {
230     auto names = getDtypeNames(scalar_type);
231     result[std::get<0>(names)] = scalar_type;
232     if (!std::get<1>(names).empty()) {
233       result[std::get<1>(names)] = scalar_type;
234     }
235   }
236   return result;
237 }
238 
239 } // namespace c10
240