1 #include <c10/core/ScalarType.h>
2 #include <c10/util/Array.h>
3 #include <array>
4
5 namespace c10 {
6
7 namespace {
8
9 constexpr auto u1 = ScalarType::Byte;
10 constexpr auto i1 = ScalarType::Char;
11 constexpr auto i2 = ScalarType::Short;
12 constexpr auto i4 = ScalarType::Int;
13 constexpr auto i8 = ScalarType::Long;
14 constexpr auto f2 = ScalarType::Half;
15 constexpr auto f4 = ScalarType::Float;
16 constexpr auto f8 = ScalarType::Double;
17 constexpr auto c2 = ScalarType::ComplexHalf;
18 constexpr auto c4 = ScalarType::ComplexFloat;
19 constexpr auto c8 = ScalarType::ComplexDouble;
20 constexpr auto b1 = ScalarType::Bool;
21 constexpr auto bf = ScalarType::BFloat16;
22 constexpr auto ud = ScalarType::Undefined;
23
24 constexpr auto index2dtype = array_of<
25 c10::ScalarType>(u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf);
26
27 constexpr std::array<int64_t, static_cast<size_t>(ScalarType::NumOptions)>
calculate_dtype2index()28 calculate_dtype2index() {
29 std::array<int64_t, static_cast<size_t>(ScalarType::NumOptions)> inverse = {};
30 for (int64_t i = 0; i < static_cast<int64_t>(ScalarType::NumOptions); i++) {
31 inverse[i] = -1;
32 }
33 for (int64_t i = 0; i < static_cast<int64_t>(index2dtype.size()); i++) {
34 inverse[static_cast<int64_t>(index2dtype[i])] = i;
35 }
36 return inverse;
37 }
38
39 constexpr auto dtype2index = calculate_dtype2index();
40
41 } // anonymous namespace
42
promoteTypes(ScalarType a,ScalarType b)43 ScalarType promoteTypes(ScalarType a, ScalarType b) {
44 // This is generated according to NumPy's promote_types
45 if (a == ud || b == ud) {
46 return ScalarType::Undefined;
47 }
48
49 // If the two types are equal, return that type
50 if (a == b) {
51 return a;
52 }
53
54 // Handle identically equal types
55 if (isQIntType(a) || isQIntType(b)) {
56 TORCH_CHECK(
57 false,
58 "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
59 toString(a),
60 " ",
61 toString(b));
62 }
63
64 if (isBitsType(a) || isBitsType(b)) {
65 return ScalarType::Undefined;
66 }
67
68 if (isFloat8Type(a) || isFloat8Type(b)) {
69 TORCH_CHECK(
70 false,
71 "Promotion for Float8 Types is not supported, attempted to promote ",
72 toString(a),
73 " and ",
74 toString(b));
75 }
76
77 if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
78 // There are two problems with promotion here:
79 //
80 // - Our promotion rule for uint8 is inconsistent with Numpy; Numpy
81 // promotes to uint64, but since we never had uint64 for the longest
82 // time, we promote to int64. Changing this is BC-breaking
83 //
84 // - We must not promote uint64 to int64 because this will overflow.
85 //
86 // It'll be a bit of work to fix it, so we're punting on it for now.
87 // However, float promotion is fine, so we handle that.
88 if (isFloatingType(a)) {
89 return a;
90 }
91 if (isFloatingType(b)) {
92 return b;
93 }
94 TORCH_CHECK(
95 false,
96 "Promotion for uint16, uint32, uint64 types is not supported, attempted to promote ",
97 toString(a),
98 " and ",
99 toString(b));
100 }
101
102 auto ix_a = dtype2index[static_cast<int64_t>(a)];
103 TORCH_INTERNAL_ASSERT(ix_a != -1);
104 auto ix_b = dtype2index[static_cast<int64_t>(b)];
105 TORCH_INTERNAL_ASSERT(ix_b != -1);
106
107 // This table axes must be consistent with index2dtype
108 // clang-format off
109 static constexpr std::
110 array<std::array<ScalarType, index2dtype.size()>, index2dtype.size()>
111 _promoteTypesLookup = {{
112 /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/
113 /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf},
114 /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf},
115 /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf},
116 /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf},
117 /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf},
118 /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4},
119 /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4},
120 /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8},
121 /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4},
122 /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4},
123 /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
124 /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf},
125 /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf},
126 }};
127 // clang-format on
128 return _promoteTypesLookup[ix_a][ix_b];
129 }
130
getDtypeNames(c10::ScalarType scalarType)131 std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
132 switch (scalarType) {
133 case c10::ScalarType::UInt1:
134 return std::make_pair("uint1", "bit");
135 case c10::ScalarType::UInt2:
136 return std::make_pair("uint2", "");
137 case c10::ScalarType::UInt3:
138 return std::make_pair("uint3", "");
139 case c10::ScalarType::UInt4:
140 return std::make_pair("uint4", "");
141 case c10::ScalarType::UInt5:
142 return std::make_pair("uint5", "");
143 case c10::ScalarType::UInt6:
144 return std::make_pair("uint6", "");
145 case c10::ScalarType::UInt7:
146 return std::make_pair("uint7", "");
147 case c10::ScalarType::Byte:
148 // no "byte" because byte is signed in numpy and we overload
149 // byte to mean bool often
150 return std::make_pair("uint8", "");
151 case c10::ScalarType::UInt16:
152 return std::make_pair("uint16", "");
153 case c10::ScalarType::UInt32:
154 return std::make_pair("uint32", "");
155 case c10::ScalarType::UInt64:
156 return std::make_pair("uint64", "");
157 case c10::ScalarType::Char:
158 // no "char" because it is not consistently signed or unsigned; we want
159 // to move to int8
160 return std::make_pair("int8", "");
161 case c10::ScalarType::Double:
162 return std::make_pair("float64", "double");
163 case c10::ScalarType::Float:
164 return std::make_pair("float32", "float");
165 case c10::ScalarType::Int:
166 return std::make_pair("int32", "int");
167 case c10::ScalarType::Long:
168 return std::make_pair("int64", "long");
169 case c10::ScalarType::Short:
170 return std::make_pair("int16", "short");
171 case c10::ScalarType::Half:
172 return std::make_pair("float16", "half");
173 case c10::ScalarType::ComplexHalf:
174 return std::make_pair("complex32", "chalf");
175 case c10::ScalarType::ComplexFloat:
176 return std::make_pair("complex64", "cfloat");
177 case c10::ScalarType::ComplexDouble:
178 return std::make_pair("complex128", "cdouble");
179 case c10::ScalarType::Bool:
180 return std::make_pair("bool", "");
181 case c10::ScalarType::QInt8:
182 return std::make_pair("qint8", "");
183 case c10::ScalarType::QUInt8:
184 return std::make_pair("quint8", "");
185 case c10::ScalarType::QInt32:
186 return std::make_pair("qint32", "");
187 case c10::ScalarType::BFloat16:
188 return std::make_pair("bfloat16", "");
189 case c10::ScalarType::QUInt4x2:
190 return std::make_pair("quint4x2", "");
191 case c10::ScalarType::QUInt2x4:
192 return std::make_pair("quint2x4", "");
193 case c10::ScalarType::Bits1x8:
194 return std::make_pair("bits1x8", "");
195 case c10::ScalarType::Bits2x4:
196 return std::make_pair("bits2x4", "");
197 case c10::ScalarType::Bits4x2:
198 return std::make_pair("bits4x2", "");
199 case c10::ScalarType::Bits8:
200 return std::make_pair("bits8", "");
201 case c10::ScalarType::Bits16:
202 return std::make_pair("bits16", "");
203 case c10::ScalarType::Float8_e5m2:
204 return std::make_pair("float8_e5m2", "");
205 case c10::ScalarType::Float8_e4m3fn:
206 return std::make_pair("float8_e4m3fn", "");
207 case c10::ScalarType::Float8_e5m2fnuz:
208 return std::make_pair("float8_e5m2fnuz", "");
209 case c10::ScalarType::Float8_e4m3fnuz:
210 return std::make_pair("float8_e4m3fnuz", "");
211 default:
212 throw std::runtime_error("Unimplemented scalar type");
213 }
214 }
215
getStringToDtypeMap()216 const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap() {
217 static std::unordered_map<std::string, ScalarType> result;
218 if (!result.empty()) {
219 return result;
220 }
221
222 #define DEFINE_SCALAR_TYPE(_1, n) c10::ScalarType::n,
223
224 auto all_scalar_types = {
225 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
226
227 #undef DEFINE_SCALAR_TYPE
228
229 for (auto scalar_type : all_scalar_types) {
230 auto names = getDtypeNames(scalar_type);
231 result[std::get<0>(names)] = scalar_type;
232 if (!std::get<1>(names).empty()) {
233 result[std::get<1>(names)] = scalar_type;
234 }
235 }
236 return result;
237 }
238
239 } // namespace c10
240