1 #pragma once
2
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Deprecated.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Float8_e4m3fn.h>
7 #include <c10/util/Float8_e4m3fnuz.h>
8 #include <c10/util/Float8_e5m2.h>
9 #include <c10/util/Float8_e5m2fnuz.h>
10 #include <c10/util/Half.h>
11 #include <c10/util/bits.h>
12 #include <c10/util/complex.h>
13 #include <c10/util/qint32.h>
14 #include <c10/util/qint8.h>
15 #include <c10/util/quint2x4.h>
16 #include <c10/util/quint4x2.h>
17 #include <c10/util/quint8.h>
18
19 #include <array>
20 #include <cstddef>
21 #include <cstdint>
22 #include <limits>
23 #include <ostream>
24 #include <type_traits>
25 #include <unordered_map>
26
27 namespace c10 {
28
29 // dummy struct for uint1 to uint7, actual functionality
30 // of these dtypes will be implemented in python with Tensor subclass
31 template <unsigned int N>
32 struct dummy_uint1_7_t {};
33
34 // For the macros below:
35 //
36 // For users: If you want to macro some code for all non-QInt scalar types
37 // (i.e. types with complete information, you probably want one of the
38 // AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
39 // designed to behave similarly to the Dispatch macros with the same name.
40 //
41 // For adding a new dtype: In the beginning, we had an idea that there was a
42 // list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
43 // iterate over them. But over the years we added weird types which couldn't
44 // be handled uniformly everywhere and so in the end we ended up with some
45 // mish-mosh of some helper macros, but mostly use sites making a call about
46 // what dtypes they can or can't support. So if you want to add a new dtype,
47 // the preferred resolution is to find a dtype similar to what you want,
48 // grep for it and edit all the sites you find this way. If you need to add
49 // a completely new kind of dtype, you're going to have to laboriously audit
50 // all of the sites everywhere to figure out how it should work. Consulting
51 // some old PRs where we added new dtypes (check history of this file) can
52 // help give you an idea where to start.
53
54 // NB: Order matters for this macro; it is relied upon in
55 // _promoteTypesLookup and the serialization format.
56 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
57 _(uint8_t, Byte) /* 0 */ \
58 _(int8_t, Char) /* 1 */ \
59 _(int16_t, Short) /* 2 */ \
60 _(int, Int) /* 3 */ \
61 _(int64_t, Long) /* 4 */ \
62 _(at::Half, Half) /* 5 */ \
63 _(float, Float) /* 6 */ \
64 _(double, Double) /* 7 */ \
65 _(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
66 _(c10::complex<float>, ComplexFloat) /* 9 */ \
67 _(c10::complex<double>, ComplexDouble) /* 10 */ \
68 _(bool, Bool) /* 11 */ \
69 _(c10::qint8, QInt8) /* 12 */ \
70 _(c10::quint8, QUInt8) /* 13 */ \
71 _(c10::qint32, QInt32) /* 14 */ \
72 _(at::BFloat16, BFloat16) /* 15 */ \
73 _(c10::quint4x2, QUInt4x2) /* 16 */ \
74 _(c10::quint2x4, QUInt2x4) /* 17 */ \
75 _(c10::bits1x8, Bits1x8) /* 18 */ \
76 _(c10::bits2x4, Bits2x4) /* 19 */ \
77 _(c10::bits4x2, Bits4x2) /* 20 */ \
78 _(c10::bits8, Bits8) /* 21 */ \
79 _(c10::bits16, Bits16) /* 22 */ \
80 _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
81 _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
82 _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
83 _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
84 _(uint16_t, UInt16) /* 27 */ \
85 _(uint32_t, UInt32) /* 28 */ \
86 _(uint64_t, UInt64) /* 29 */ \
87 _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
88 _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
89 _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
90 _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
91 _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
92 _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
93 _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */
94
95 // If you want to support ComplexHalf for real, add ComplexHalf
96 // into this macro (and change the name). But beware: convert()
97 // doesn't work for all the conversions you need...
98 //
99 // TODO: To add unsigned int types here, we must define accumulate type.
100 // But uint8 currently accumulates into int64, so we would have to make
101 // an inconsistent choice for the larger types. Difficult.
102 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
103 _(uint8_t, Byte) \
104 _(int8_t, Char) \
105 _(int16_t, Short) \
106 _(int, Int) \
107 _(int64_t, Long) \
108 _(at::Half, Half) \
109 _(float, Float) \
110 _(double, Double) \
111 _(c10::complex<float>, ComplexFloat) \
112 _(c10::complex<double>, ComplexDouble) \
113 _(bool, Bool) \
114 _(at::BFloat16, BFloat16) \
115 _(at::Float8_e5m2, Float8_e5m2) \
116 _(at::Float8_e4m3fn, Float8_e4m3fn)
117
118 // This macro controls many of our C++ APIs, including constructors
119 // for Scalar as well as the data() and item() accessors on Tensor
120 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
121 _(uint8_t, Byte) \
122 _(int8_t, Char) \
123 _(int16_t, Short) \
124 _(int, Int) \
125 _(int64_t, Long) \
126 _(at::Half, Half) \
127 _(float, Float) \
128 _(double, Double) \
129 _(c10::complex<c10::Half>, ComplexHalf) \
130 _(c10::complex<float>, ComplexFloat) \
131 _(c10::complex<double>, ComplexDouble) \
132 _(bool, Bool) \
133 _(at::BFloat16, BFloat16) \
134 _(at::Float8_e5m2, Float8_e5m2) \
135 _(at::Float8_e4m3fn, Float8_e4m3fn) \
136 _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
137 _(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
138
139 enum class ScalarType : int8_t {
140 #define DEFINE_ST_ENUM_VAL_(_1, n) n,
141 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
142 #undef DEFINE_ENUM_ST_ENUM_VAL_
143 Undefined,
144 NumOptions
145 };
146
147 constexpr uint16_t NumScalarTypes =
148 static_cast<uint16_t>(ScalarType::NumOptions);
149
150 namespace impl {
151
152 // These are used to map ScalarTypes to C++ types.
153
154 template <c10::ScalarType N>
155 struct ScalarTypeToCPPType;
156
157 #define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
158 template <> \
159 struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
160 using type = cpp_type; \
161 \
162 /* This is a workaround for the CUDA bug which prevents */ \
163 /* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
164 /* ambiguous reference which can't to be resolved. For some reason it */ \
165 /* can't pick between at::detail and at::cuda::detail. */ \
166 /* For repro example, please see: */ \
167 /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
168 /* TODO: remove once the bug is fixed. */ \
169 static type t; \
170 };
171
172 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
173
174 #undef SPECIALIZE_ScalarTypeToCPPType
175
176 template <c10::ScalarType N>
177 using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
178
179 } // namespace impl
180
181 template <typename T>
182 struct CppTypeToScalarType;
183
184 #define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
185 template <> \
186 struct CppTypeToScalarType<cpp_type> \
187 : std:: \
188 integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
189 };
190
191 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
192
193 #undef SPECIALIZE_CppTypeToScalarType
194
195 // NB: despite its generic sounding name, the macros that don't take _AND
196 // are mostly only used by tensorexpr
197 #define AT_FORALL_INT_TYPES(_) \
198 _(uint8_t, Byte) \
199 _(int8_t, Char) \
200 _(int16_t, Short) \
201 _(int, Int) \
202 _(int64_t, Long)
203
204 #define AT_FORALL_SCALAR_TYPES(_) \
205 _(uint8_t, Byte) \
206 _(int8_t, Char) \
207 _(int16_t, Short) \
208 _(int, Int) \
209 _(int64_t, Long) \
210 _(float, Float) \
211 _(double, Double)
212
213 // These macros are often controlling how many template instantiations we
214 // create for kernels. It is typically inappropriate to add new dtypes here,
215 // instead, new types should be added to use sites on a case-by-case basis.
216 // We generally are not accepting new dtypes due to binary size concerns.
217
218 #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
219 _(uint8_t, Byte) \
220 _(int8_t, Char) \
221 _(int16_t, Short) \
222 _(int, Int) \
223 _(int64_t, Long) \
224 _(float, Float) \
225 _(double, Double) \
226 _(decltype(::c10::impl::ScalarTypeToCPPType< \
227 ::c10::ScalarType::SCALARTYPE>::t), \
228 SCALARTYPE)
229
230 #define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
231 _(uint8_t, Byte) \
232 _(int8_t, Char) \
233 _(int16_t, Short) \
234 _(int, Int) \
235 _(int64_t, Long) \
236 _(float, Float) \
237 _(double, Double) \
238 _(decltype(::c10::impl::ScalarTypeToCPPType< \
239 ::c10::ScalarType::SCALARTYPE1>::t), \
240 SCALARTYPE1) \
241 _(decltype(::c10::impl::ScalarTypeToCPPType< \
242 ::c10::ScalarType::SCALARTYPE2>::t), \
243 SCALARTYPE2)
244
245 #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
246 _(uint8_t, Byte) \
247 _(int8_t, Char) \
248 _(int16_t, Short) \
249 _(int, Int) \
250 _(int64_t, Long) \
251 _(float, Float) \
252 _(double, Double) \
253 _(decltype(::c10::impl::ScalarTypeToCPPType< \
254 ::c10::ScalarType::SCALARTYPE1>::t), \
255 SCALARTYPE1) \
256 _(decltype(::c10::impl::ScalarTypeToCPPType< \
257 ::c10::ScalarType::SCALARTYPE2>::t), \
258 SCALARTYPE2) \
259 _(decltype(::c10::impl::ScalarTypeToCPPType< \
260 ::c10::ScalarType::SCALARTYPE3>::t), \
261 SCALARTYPE3)
262
263 #define AT_FORALL_SCALAR_TYPES_AND7( \
264 SCALARTYPE1, \
265 SCALARTYPE2, \
266 SCALARTYPE3, \
267 SCALARTYPE4, \
268 SCALARTYPE5, \
269 SCALARTYPE6, \
270 SCALARTYPE7, \
271 _) \
272 _(uint8_t, Byte) \
273 _(int8_t, Char) \
274 _(int16_t, Short) \
275 _(int, Int) \
276 _(int64_t, Long) \
277 _(float, Float) \
278 _(double, Double) \
279 _(decltype(::c10::impl::ScalarTypeToCPPType< \
280 ::c10::ScalarType::SCALARTYPE1>::t), \
281 SCALARTYPE1) \
282 _(decltype(::c10::impl::ScalarTypeToCPPType< \
283 ::c10::ScalarType::SCALARTYPE2>::t), \
284 SCALARTYPE2) \
285 _(decltype(::c10::impl::ScalarTypeToCPPType< \
286 ::c10::ScalarType::SCALARTYPE3>::t), \
287 SCALARTYPE3) \
288 _(decltype(::c10::impl::ScalarTypeToCPPType< \
289 ::c10::ScalarType::SCALARTYPE4>::t), \
290 SCALARTYPE4) \
291 _(decltype(::c10::impl::ScalarTypeToCPPType< \
292 ::c10::ScalarType::SCALARTYPE5>::t), \
293 SCALARTYPE5) \
294 _(decltype(::c10::impl::ScalarTypeToCPPType< \
295 ::c10::ScalarType::SCALARTYPE6>::t), \
296 SCALARTYPE6) \
297 _(decltype(::c10::impl::ScalarTypeToCPPType< \
298 ::c10::ScalarType::SCALARTYPE7>::t), \
299 SCALARTYPE7)
300
301 #define AT_FORALL_QINT_TYPES(_) \
302 _(c10::qint8, QInt8) \
303 _(c10::quint8, QUInt8) \
304 _(c10::qint32, QInt32) \
305 _(c10::quint4x2, QUInt4x2) \
306 _(c10::quint2x4, QUInt2x4)
307
308 #define AT_FORALL_COMPLEX_TYPES(_) \
309 _(c10::complex<float>, ComplexFloat) \
310 _(c10::complex<double>, ComplexDouble)
311
312 #define DEFINE_CONSTANT(_, name) \
313 constexpr ScalarType k##name = ScalarType::name;
314
315 // NOLINTNEXTLINE(clang-diagnostic-unused-const-variable)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)316 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
317 #undef DEFINE_CONSTANT
318
319 inline const char* toString(ScalarType t) {
320 #define DEFINE_CASE(_, name) \
321 case ScalarType::name: \
322 return #name;
323
324 switch (t) {
325 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
326 default:
327 return "UNKNOWN_SCALAR";
328 }
329 #undef DEFINE_CASE
330 }
331
elementSize(ScalarType t)332 inline size_t elementSize(ScalarType t) {
333 #define CASE_ELEMENTSIZE_CASE(ctype, name) \
334 case ScalarType::name: \
335 return sizeof(ctype);
336
337 switch (t) {
338 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
339 default:
340 TORCH_CHECK(false, "Unknown ScalarType");
341 }
342 #undef CASE_ELEMENTSIZE_CASE
343 }
344
isIntegralType(ScalarType t,bool includeBool)345 inline bool isIntegralType(ScalarType t, bool includeBool) {
346 bool isIntegral =
347 (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
348 t == ScalarType::Long || t == ScalarType::Short ||
349 t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
350 t == ScalarType::UInt64);
351
352 return isIntegral || (includeBool && t == ScalarType::Bool);
353 }
354
355 C10_DEPRECATED_MESSAGE(
356 "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
isIntegralType(ScalarType t)357 inline bool isIntegralType(ScalarType t) {
358 return isIntegralType(t, /*includeBool=*/false);
359 }
360
isFloat8Type(ScalarType t)361 inline bool isFloat8Type(ScalarType t) {
362 return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
363 t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
364 }
365
isReducedFloatingType(ScalarType t)366 inline bool isReducedFloatingType(ScalarType t) {
367 return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
368 }
369
isFloatingType(ScalarType t)370 inline bool isFloatingType(ScalarType t) {
371 return t == ScalarType::Double || t == ScalarType::Float ||
372 isReducedFloatingType(t);
373 }
374
isComplexType(ScalarType t)375 inline bool isComplexType(ScalarType t) {
376 return (
377 t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
378 t == ScalarType::ComplexDouble);
379 }
380
isQIntType(ScalarType t)381 inline bool isQIntType(ScalarType t) {
382 // Don't forget to extend this when adding new QInt types
383 return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
384 t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
385 t == ScalarType::QUInt2x4;
386 }
387
isBitsType(ScalarType t)388 inline bool isBitsType(ScalarType t) {
389 return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
390 t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
391 t == ScalarType::Bits16;
392 }
393
isBarebonesUnsignedType(ScalarType t)394 inline bool isBarebonesUnsignedType(ScalarType t) {
395 return t == ScalarType::UInt1 || t == ScalarType::UInt2 ||
396 t == ScalarType::UInt3 || t == ScalarType::UInt4 ||
397 t == ScalarType::UInt5 || t == ScalarType::UInt6 ||
398 t == ScalarType::UInt7 || t == ScalarType::UInt16 ||
399 t == ScalarType::UInt32 || t == ScalarType::UInt64;
400 }
401
toQIntType(ScalarType t)402 inline ScalarType toQIntType(ScalarType t) {
403 switch (t) {
404 case ScalarType::Byte:
405 return ScalarType::QUInt8;
406 case ScalarType::Char:
407 return ScalarType::QInt8;
408 case ScalarType::Int:
409 return ScalarType::QInt32;
410 default:
411 return t;
412 }
413 }
414
toUnderlying(ScalarType t)415 inline ScalarType toUnderlying(ScalarType t) {
416 switch (t) {
417 case ScalarType::QUInt8:
418 case ScalarType::QUInt4x2:
419 [[fallthrough]];
420 case ScalarType::QUInt2x4:
421 return ScalarType::Byte;
422 case ScalarType::QInt8:
423 return ScalarType::Char;
424 case ScalarType::QInt32:
425 return ScalarType::Int;
426 default:
427 return t;
428 }
429 }
430
isSignedType(ScalarType t)431 inline bool isSignedType(ScalarType t) {
432 #define CASE_ISSIGNED(name) \
433 case ScalarType::name: \
434 return std::numeric_limits< \
435 ::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
436
437 switch (t) {
438 case ScalarType::QInt8:
439 case ScalarType::QUInt8:
440 case ScalarType::QInt32:
441 case ScalarType::QUInt4x2:
442 case ScalarType::QUInt2x4:
443 TORCH_CHECK(false, "isSignedType not supported for quantized types");
444 case ScalarType::Bits1x8:
445 case ScalarType::Bits2x4:
446 case ScalarType::Bits4x2:
447 case ScalarType::Bits8:
448 case ScalarType::Bits16:
449 TORCH_CHECK(false, "Bits types are undefined");
450 CASE_ISSIGNED(UInt16);
451 CASE_ISSIGNED(UInt32);
452 CASE_ISSIGNED(UInt64);
453 CASE_ISSIGNED(BFloat16);
454 CASE_ISSIGNED(Float8_e5m2);
455 CASE_ISSIGNED(Float8_e5m2fnuz);
456 CASE_ISSIGNED(Float8_e4m3fn);
457 CASE_ISSIGNED(Float8_e4m3fnuz);
458 CASE_ISSIGNED(Byte);
459 CASE_ISSIGNED(Char);
460 CASE_ISSIGNED(Short);
461 CASE_ISSIGNED(Int);
462 CASE_ISSIGNED(Long);
463 CASE_ISSIGNED(Half);
464 CASE_ISSIGNED(Float);
465 CASE_ISSIGNED(Double);
466 CASE_ISSIGNED(ComplexHalf);
467 CASE_ISSIGNED(ComplexFloat);
468 CASE_ISSIGNED(ComplexDouble);
469 CASE_ISSIGNED(Bool);
470 case ScalarType::UInt1:
471 case ScalarType::UInt2:
472 case ScalarType::UInt3:
473 case ScalarType::UInt4:
474 case ScalarType::UInt5:
475 case ScalarType::UInt6:
476 case ScalarType::UInt7:
477 return true;
478 case ScalarType::Undefined:
479 case ScalarType::NumOptions:
480 break;
481 // Do not add default here, but rather define behavior of every new entry
482 // here. `-Wswitch-enum` would raise a warning in those cases.
483 }
484 TORCH_CHECK(false, "Unknown ScalarType ", t);
485 #undef CASE_ISSIGNED
486 }
487
isUnderlying(ScalarType type,ScalarType qtype)488 inline bool isUnderlying(ScalarType type, ScalarType qtype) {
489 return type == toUnderlying(qtype);
490 }
491
toRealValueType(ScalarType t)492 inline ScalarType toRealValueType(ScalarType t) {
493 switch (t) {
494 case ScalarType::ComplexHalf:
495 return ScalarType::Half;
496 case ScalarType::ComplexFloat:
497 return ScalarType::Float;
498 case ScalarType::ComplexDouble:
499 return ScalarType::Double;
500 default:
501 return t;
502 }
503 }
504
toComplexType(ScalarType t)505 inline ScalarType toComplexType(ScalarType t) {
506 switch (t) {
507 case ScalarType::BFloat16:
508 // BFloat16 has range equivalent to Float,
509 // so we map it to ComplexFloat.
510 return ScalarType::ComplexFloat;
511 case ScalarType::Half:
512 return ScalarType::ComplexHalf;
513 case ScalarType::Float:
514 return ScalarType::ComplexFloat;
515 case ScalarType::Double:
516 return ScalarType::ComplexDouble;
517 case ScalarType::ComplexHalf:
518 return ScalarType::ComplexHalf;
519 case ScalarType::ComplexFloat:
520 return ScalarType::ComplexFloat;
521 case ScalarType::ComplexDouble:
522 return ScalarType::ComplexDouble;
523 default:
524 TORCH_CHECK(false, "Unknown Complex ScalarType for ", t);
525 }
526 }
527
528 // see tensor_attributes.rst for detailed explanation and examples
529 // of casting rules.
canCast(const ScalarType from,const ScalarType to)530 inline bool canCast(const ScalarType from, const ScalarType to) {
531 // We disallow complex -> non complex, e.g., float_tensor *= complex is
532 // disallowed.
533 if (isComplexType(from) && !isComplexType(to)) {
534 return false;
535 }
536 // We disallow float -> integral, e.g., int_tensor *= float is disallowed.
537 if (isFloatingType(from) && isIntegralType(to, false)) {
538 return false;
539 }
540
541 // Treat bool as a distinct "category," to be consistent with type promotion
542 // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
543 // category as `bool_tensor`, we would not promote. Differing categories
544 // implies `bool_tensor += 5` is disallowed.
545 //
546 // NB: numpy distinguishes "unsigned" as a category to get the desired
547 // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
548 // * We don't want the performance hit of checking the runtime sign of
549 // Scalars.
550 // * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
551 if (from != ScalarType::Bool && to == ScalarType::Bool) {
552 return false;
553 }
554 return true;
555 }
556
557 C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
558
559 inline std::ostream& operator<<(
560 std::ostream& stream,
561 at::ScalarType scalar_type) {
562 return stream << toString(scalar_type);
563 }
564
565 // Returns a pair of strings representing the names for each dtype.
566 // The returned pair is (name, legacy_name_if_applicable)
567 C10_API std::pair<std::string, std::string> getDtypeNames(
568 c10::ScalarType scalarType);
569
570 // Returns a map of string name to dtype.
571 C10_API const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap();
572
573 } // namespace c10
574