xref: /aosp_15_r20/external/pytorch/c10/core/ScalarType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Deprecated.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Float8_e4m3fn.h>
7 #include <c10/util/Float8_e4m3fnuz.h>
8 #include <c10/util/Float8_e5m2.h>
9 #include <c10/util/Float8_e5m2fnuz.h>
10 #include <c10/util/Half.h>
11 #include <c10/util/bits.h>
12 #include <c10/util/complex.h>
13 #include <c10/util/qint32.h>
14 #include <c10/util/qint8.h>
15 #include <c10/util/quint2x4.h>
16 #include <c10/util/quint4x2.h>
17 #include <c10/util/quint8.h>
18 
19 #include <array>
20 #include <cstddef>
21 #include <cstdint>
22 #include <limits>
23 #include <ostream>
24 #include <type_traits>
25 #include <unordered_map>
26 
27 namespace c10 {
28 
29 // dummy struct for uint1 to uint7, actual functionality
30 // of these dtypes will be implemented in python with Tensor subclass
31 template <unsigned int N>
32 struct dummy_uint1_7_t {};
33 
34 // For the macros below:
35 //
36 // For users: If you want to macro some code for all non-QInt scalar types
37 // (i.e. types with complete information, you probably want one of the
38 // AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
39 // designed to behave similarly to the Dispatch macros with the same name.
40 //
41 // For adding a new dtype: In the beginning, we had an idea that there was a
42 // list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
43 // iterate over them.  But over the years we added weird types which couldn't
44 // be handled uniformly everywhere and so in the end we ended up with some
45 // mish-mosh of some helper macros, but mostly use sites making a call about
46 // what dtypes they can or can't support.  So if you want to add a new dtype,
47 // the preferred resolution is to find a dtype similar to what you want,
48 // grep for it and edit all the sites you find this way.  If you need to add
49 // a completely new kind of dtype, you're going to have to laboriously audit
50 // all of the sites everywhere to figure out how it should work.  Consulting
51 // some old PRs where we added new dtypes (check history of this file) can
52 // help give you an idea where to start.
53 
54 // NB: Order matters for this macro; it is relied upon in
55 // _promoteTypesLookup and the serialization format.
56 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
57   _(uint8_t, Byte) /* 0 */                               \
58   _(int8_t, Char) /* 1 */                                \
59   _(int16_t, Short) /* 2 */                              \
60   _(int, Int) /* 3 */                                    \
61   _(int64_t, Long) /* 4 */                               \
62   _(at::Half, Half) /* 5 */                              \
63   _(float, Float) /* 6 */                                \
64   _(double, Double) /* 7 */                              \
65   _(c10::complex<c10::Half>, ComplexHalf) /* 8 */        \
66   _(c10::complex<float>, ComplexFloat) /* 9 */           \
67   _(c10::complex<double>, ComplexDouble) /* 10 */        \
68   _(bool, Bool) /* 11 */                                 \
69   _(c10::qint8, QInt8) /* 12 */                          \
70   _(c10::quint8, QUInt8) /* 13 */                        \
71   _(c10::qint32, QInt32) /* 14 */                        \
72   _(at::BFloat16, BFloat16) /* 15 */                     \
73   _(c10::quint4x2, QUInt4x2) /* 16 */                    \
74   _(c10::quint2x4, QUInt2x4) /* 17 */                    \
75   _(c10::bits1x8, Bits1x8) /* 18 */                      \
76   _(c10::bits2x4, Bits2x4) /* 19 */                      \
77   _(c10::bits4x2, Bits4x2) /* 20 */                      \
78   _(c10::bits8, Bits8) /* 21 */                          \
79   _(c10::bits16, Bits16) /* 22 */                        \
80   _(c10::Float8_e5m2, Float8_e5m2) /* 23 */              \
81   _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */          \
82   _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */      \
83   _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */      \
84   _(uint16_t, UInt16) /* 27 */                           \
85   _(uint32_t, UInt32) /* 28 */                           \
86   _(uint64_t, UInt64) /* 29 */                           \
87   _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */             \
88   _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */             \
89   _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */             \
90   _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */             \
91   _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */             \
92   _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */             \
93   _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */
94 
95 // If you want to support ComplexHalf for real, add ComplexHalf
96 // into this macro (and change the name).  But beware: convert()
97 // doesn't work for all the conversions you need...
98 //
99 // TODO: To add unsigned int types here, we must define accumulate type.
100 // But uint8 currently accumulates into int64, so we would have to make
101 // an inconsistent choice for the larger types.  Difficult.
102 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
103   _(uint8_t, Byte)                                                      \
104   _(int8_t, Char)                                                       \
105   _(int16_t, Short)                                                     \
106   _(int, Int)                                                           \
107   _(int64_t, Long)                                                      \
108   _(at::Half, Half)                                                     \
109   _(float, Float)                                                       \
110   _(double, Double)                                                     \
111   _(c10::complex<float>, ComplexFloat)                                  \
112   _(c10::complex<double>, ComplexDouble)                                \
113   _(bool, Bool)                                                         \
114   _(at::BFloat16, BFloat16)                                             \
115   _(at::Float8_e5m2, Float8_e5m2)                                       \
116   _(at::Float8_e4m3fn, Float8_e4m3fn)
117 
118 // This macro controls many of our C++ APIs, including constructors
119 // for Scalar as well as the data() and item() accessors on Tensor
120 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
121   _(uint8_t, Byte)                             \
122   _(int8_t, Char)                              \
123   _(int16_t, Short)                            \
124   _(int, Int)                                  \
125   _(int64_t, Long)                             \
126   _(at::Half, Half)                            \
127   _(float, Float)                              \
128   _(double, Double)                            \
129   _(c10::complex<c10::Half>, ComplexHalf)      \
130   _(c10::complex<float>, ComplexFloat)         \
131   _(c10::complex<double>, ComplexDouble)       \
132   _(bool, Bool)                                \
133   _(at::BFloat16, BFloat16)                    \
134   _(at::Float8_e5m2, Float8_e5m2)              \
135   _(at::Float8_e4m3fn, Float8_e4m3fn)          \
136   _(at::Float8_e5m2fnuz, Float8_e5m2fnuz)      \
137   _(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
138 
139 enum class ScalarType : int8_t {
140 #define DEFINE_ST_ENUM_VAL_(_1, n) n,
141   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
142 #undef DEFINE_ENUM_ST_ENUM_VAL_
143       Undefined,
144   NumOptions
145 };
146 
147 constexpr uint16_t NumScalarTypes =
148     static_cast<uint16_t>(ScalarType::NumOptions);
149 
150 namespace impl {
151 
152 // These are used to map ScalarTypes to C++ types.
153 
154 template <c10::ScalarType N>
155 struct ScalarTypeToCPPType;
156 
157 #define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type)                \
158   template <>                                                                \
159   struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> {                 \
160     using type = cpp_type;                                                   \
161                                                                              \
162     /* This is a workaround for the CUDA bug which prevents */               \
163     /* ::detail::ScalarTypeToCType<T>::type being used directly due to */    \
164     /* ambiguous reference which can't to be resolved. For some reason it */ \
165     /* can't pick between at::detail and at::cuda::detail. */                \
166     /* For repro example, please see: */                                     \
167     /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */    \
168     /* TODO: remove once the bug is fixed. */                                \
169     static type t;                                                           \
170   };
171 
172 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
173 
174 #undef SPECIALIZE_ScalarTypeToCPPType
175 
176 template <c10::ScalarType N>
177 using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
178 
179 } // namespace impl
180 
181 template <typename T>
182 struct CppTypeToScalarType;
183 
184 #define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type)                  \
185   template <>                                                                  \
186   struct CppTypeToScalarType<cpp_type>                                         \
187       : std::                                                                  \
188             integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
189   };
190 
191 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
192 
193 #undef SPECIALIZE_CppTypeToScalarType
194 
195 // NB: despite its generic sounding name, the macros that don't take _AND
196 // are mostly only used by tensorexpr
197 #define AT_FORALL_INT_TYPES(_) \
198   _(uint8_t, Byte)             \
199   _(int8_t, Char)              \
200   _(int16_t, Short)            \
201   _(int, Int)                  \
202   _(int64_t, Long)
203 
204 #define AT_FORALL_SCALAR_TYPES(_) \
205   _(uint8_t, Byte)                \
206   _(int8_t, Char)                 \
207   _(int16_t, Short)               \
208   _(int, Int)                     \
209   _(int64_t, Long)                \
210   _(float, Float)                 \
211   _(double, Double)
212 
213 // These macros are often controlling how many template instantiations we
214 // create for kernels.  It is typically inappropriate to add new dtypes here,
215 // instead, new types should be added to use sites on a case-by-case basis.
216 // We generally are not accepting new dtypes due to binary size concerns.
217 
218 #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
219   _(uint8_t, Byte)                                \
220   _(int8_t, Char)                                 \
221   _(int16_t, Short)                               \
222   _(int, Int)                                     \
223   _(int64_t, Long)                                \
224   _(float, Float)                                 \
225   _(double, Double)                               \
226   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
227              ::c10::ScalarType::SCALARTYPE>::t),  \
228     SCALARTYPE)
229 
230 #define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
231   _(uint8_t, Byte)                                               \
232   _(int8_t, Char)                                                \
233   _(int16_t, Short)                                              \
234   _(int, Int)                                                    \
235   _(int64_t, Long)                                               \
236   _(float, Float)                                                \
237   _(double, Double)                                              \
238   _(decltype(::c10::impl::ScalarTypeToCPPType<                   \
239              ::c10::ScalarType::SCALARTYPE1>::t),                \
240     SCALARTYPE1)                                                 \
241   _(decltype(::c10::impl::ScalarTypeToCPPType<                   \
242              ::c10::ScalarType::SCALARTYPE2>::t),                \
243     SCALARTYPE2)
244 
245 #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
246   _(uint8_t, Byte)                                                            \
247   _(int8_t, Char)                                                             \
248   _(int16_t, Short)                                                           \
249   _(int, Int)                                                                 \
250   _(int64_t, Long)                                                            \
251   _(float, Float)                                                             \
252   _(double, Double)                                                           \
253   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
254              ::c10::ScalarType::SCALARTYPE1>::t),                             \
255     SCALARTYPE1)                                                              \
256   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
257              ::c10::ScalarType::SCALARTYPE2>::t),                             \
258     SCALARTYPE2)                                                              \
259   _(decltype(::c10::impl::ScalarTypeToCPPType<                                \
260              ::c10::ScalarType::SCALARTYPE3>::t),                             \
261     SCALARTYPE3)
262 
263 #define AT_FORALL_SCALAR_TYPES_AND7(              \
264     SCALARTYPE1,                                  \
265     SCALARTYPE2,                                  \
266     SCALARTYPE3,                                  \
267     SCALARTYPE4,                                  \
268     SCALARTYPE5,                                  \
269     SCALARTYPE6,                                  \
270     SCALARTYPE7,                                  \
271     _)                                            \
272   _(uint8_t, Byte)                                \
273   _(int8_t, Char)                                 \
274   _(int16_t, Short)                               \
275   _(int, Int)                                     \
276   _(int64_t, Long)                                \
277   _(float, Float)                                 \
278   _(double, Double)                               \
279   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
280              ::c10::ScalarType::SCALARTYPE1>::t), \
281     SCALARTYPE1)                                  \
282   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
283              ::c10::ScalarType::SCALARTYPE2>::t), \
284     SCALARTYPE2)                                  \
285   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
286              ::c10::ScalarType::SCALARTYPE3>::t), \
287     SCALARTYPE3)                                  \
288   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
289              ::c10::ScalarType::SCALARTYPE4>::t), \
290     SCALARTYPE4)                                  \
291   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
292              ::c10::ScalarType::SCALARTYPE5>::t), \
293     SCALARTYPE5)                                  \
294   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
295              ::c10::ScalarType::SCALARTYPE6>::t), \
296     SCALARTYPE6)                                  \
297   _(decltype(::c10::impl::ScalarTypeToCPPType<    \
298              ::c10::ScalarType::SCALARTYPE7>::t), \
299     SCALARTYPE7)
300 
301 #define AT_FORALL_QINT_TYPES(_) \
302   _(c10::qint8, QInt8)          \
303   _(c10::quint8, QUInt8)        \
304   _(c10::qint32, QInt32)        \
305   _(c10::quint4x2, QUInt4x2)    \
306   _(c10::quint2x4, QUInt2x4)
307 
308 #define AT_FORALL_COMPLEX_TYPES(_)     \
309   _(c10::complex<float>, ComplexFloat) \
310   _(c10::complex<double>, ComplexDouble)
311 
312 #define DEFINE_CONSTANT(_, name) \
313   constexpr ScalarType k##name = ScalarType::name;
314 
315 // NOLINTNEXTLINE(clang-diagnostic-unused-const-variable)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)316 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
317 #undef DEFINE_CONSTANT
318 
319 inline const char* toString(ScalarType t) {
320 #define DEFINE_CASE(_, name) \
321   case ScalarType::name:     \
322     return #name;
323 
324   switch (t) {
325     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
326     default:
327       return "UNKNOWN_SCALAR";
328   }
329 #undef DEFINE_CASE
330 }
331 
elementSize(ScalarType t)332 inline size_t elementSize(ScalarType t) {
333 #define CASE_ELEMENTSIZE_CASE(ctype, name) \
334   case ScalarType::name:                   \
335     return sizeof(ctype);
336 
337   switch (t) {
338     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
339     default:
340       TORCH_CHECK(false, "Unknown ScalarType");
341   }
342 #undef CASE_ELEMENTSIZE_CASE
343 }
344 
isIntegralType(ScalarType t,bool includeBool)345 inline bool isIntegralType(ScalarType t, bool includeBool) {
346   bool isIntegral =
347       (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
348        t == ScalarType::Long || t == ScalarType::Short ||
349        t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
350        t == ScalarType::UInt64);
351 
352   return isIntegral || (includeBool && t == ScalarType::Bool);
353 }
354 
355 C10_DEPRECATED_MESSAGE(
356     "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
isIntegralType(ScalarType t)357 inline bool isIntegralType(ScalarType t) {
358   return isIntegralType(t, /*includeBool=*/false);
359 }
360 
isFloat8Type(ScalarType t)361 inline bool isFloat8Type(ScalarType t) {
362   return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
363       t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
364 }
365 
isReducedFloatingType(ScalarType t)366 inline bool isReducedFloatingType(ScalarType t) {
367   return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
368 }
369 
isFloatingType(ScalarType t)370 inline bool isFloatingType(ScalarType t) {
371   return t == ScalarType::Double || t == ScalarType::Float ||
372       isReducedFloatingType(t);
373 }
374 
isComplexType(ScalarType t)375 inline bool isComplexType(ScalarType t) {
376   return (
377       t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
378       t == ScalarType::ComplexDouble);
379 }
380 
isQIntType(ScalarType t)381 inline bool isQIntType(ScalarType t) {
382   // Don't forget to extend this when adding new QInt types
383   return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
384       t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
385       t == ScalarType::QUInt2x4;
386 }
387 
isBitsType(ScalarType t)388 inline bool isBitsType(ScalarType t) {
389   return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
390       t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
391       t == ScalarType::Bits16;
392 }
393 
isBarebonesUnsignedType(ScalarType t)394 inline bool isBarebonesUnsignedType(ScalarType t) {
395   return t == ScalarType::UInt1 || t == ScalarType::UInt2 ||
396       t == ScalarType::UInt3 || t == ScalarType::UInt4 ||
397       t == ScalarType::UInt5 || t == ScalarType::UInt6 ||
398       t == ScalarType::UInt7 || t == ScalarType::UInt16 ||
399       t == ScalarType::UInt32 || t == ScalarType::UInt64;
400 }
401 
toQIntType(ScalarType t)402 inline ScalarType toQIntType(ScalarType t) {
403   switch (t) {
404     case ScalarType::Byte:
405       return ScalarType::QUInt8;
406     case ScalarType::Char:
407       return ScalarType::QInt8;
408     case ScalarType::Int:
409       return ScalarType::QInt32;
410     default:
411       return t;
412   }
413 }
414 
toUnderlying(ScalarType t)415 inline ScalarType toUnderlying(ScalarType t) {
416   switch (t) {
417     case ScalarType::QUInt8:
418     case ScalarType::QUInt4x2:
419       [[fallthrough]];
420     case ScalarType::QUInt2x4:
421       return ScalarType::Byte;
422     case ScalarType::QInt8:
423       return ScalarType::Char;
424     case ScalarType::QInt32:
425       return ScalarType::Int;
426     default:
427       return t;
428   }
429 }
430 
isSignedType(ScalarType t)431 inline bool isSignedType(ScalarType t) {
432 #define CASE_ISSIGNED(name)     \
433   case ScalarType::name:        \
434     return std::numeric_limits< \
435         ::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
436 
437   switch (t) {
438     case ScalarType::QInt8:
439     case ScalarType::QUInt8:
440     case ScalarType::QInt32:
441     case ScalarType::QUInt4x2:
442     case ScalarType::QUInt2x4:
443       TORCH_CHECK(false, "isSignedType not supported for quantized types");
444     case ScalarType::Bits1x8:
445     case ScalarType::Bits2x4:
446     case ScalarType::Bits4x2:
447     case ScalarType::Bits8:
448     case ScalarType::Bits16:
449       TORCH_CHECK(false, "Bits types are undefined");
450       CASE_ISSIGNED(UInt16);
451       CASE_ISSIGNED(UInt32);
452       CASE_ISSIGNED(UInt64);
453       CASE_ISSIGNED(BFloat16);
454       CASE_ISSIGNED(Float8_e5m2);
455       CASE_ISSIGNED(Float8_e5m2fnuz);
456       CASE_ISSIGNED(Float8_e4m3fn);
457       CASE_ISSIGNED(Float8_e4m3fnuz);
458       CASE_ISSIGNED(Byte);
459       CASE_ISSIGNED(Char);
460       CASE_ISSIGNED(Short);
461       CASE_ISSIGNED(Int);
462       CASE_ISSIGNED(Long);
463       CASE_ISSIGNED(Half);
464       CASE_ISSIGNED(Float);
465       CASE_ISSIGNED(Double);
466       CASE_ISSIGNED(ComplexHalf);
467       CASE_ISSIGNED(ComplexFloat);
468       CASE_ISSIGNED(ComplexDouble);
469       CASE_ISSIGNED(Bool);
470     case ScalarType::UInt1:
471     case ScalarType::UInt2:
472     case ScalarType::UInt3:
473     case ScalarType::UInt4:
474     case ScalarType::UInt5:
475     case ScalarType::UInt6:
476     case ScalarType::UInt7:
477       return true;
478     case ScalarType::Undefined:
479     case ScalarType::NumOptions:
480       break;
481       // Do not add default here, but rather define behavior of every new entry
482       // here.  `-Wswitch-enum` would raise a warning in those cases.
483   }
484   TORCH_CHECK(false, "Unknown ScalarType ", t);
485 #undef CASE_ISSIGNED
486 }
487 
isUnderlying(ScalarType type,ScalarType qtype)488 inline bool isUnderlying(ScalarType type, ScalarType qtype) {
489   return type == toUnderlying(qtype);
490 }
491 
toRealValueType(ScalarType t)492 inline ScalarType toRealValueType(ScalarType t) {
493   switch (t) {
494     case ScalarType::ComplexHalf:
495       return ScalarType::Half;
496     case ScalarType::ComplexFloat:
497       return ScalarType::Float;
498     case ScalarType::ComplexDouble:
499       return ScalarType::Double;
500     default:
501       return t;
502   }
503 }
504 
toComplexType(ScalarType t)505 inline ScalarType toComplexType(ScalarType t) {
506   switch (t) {
507     case ScalarType::BFloat16:
508       // BFloat16 has range equivalent to Float,
509       // so we map it to ComplexFloat.
510       return ScalarType::ComplexFloat;
511     case ScalarType::Half:
512       return ScalarType::ComplexHalf;
513     case ScalarType::Float:
514       return ScalarType::ComplexFloat;
515     case ScalarType::Double:
516       return ScalarType::ComplexDouble;
517     case ScalarType::ComplexHalf:
518       return ScalarType::ComplexHalf;
519     case ScalarType::ComplexFloat:
520       return ScalarType::ComplexFloat;
521     case ScalarType::ComplexDouble:
522       return ScalarType::ComplexDouble;
523     default:
524       TORCH_CHECK(false, "Unknown Complex ScalarType for ", t);
525   }
526 }
527 
528 // see tensor_attributes.rst for detailed explanation and examples
529 // of casting rules.
canCast(const ScalarType from,const ScalarType to)530 inline bool canCast(const ScalarType from, const ScalarType to) {
531   // We disallow complex -> non complex, e.g., float_tensor *= complex is
532   // disallowed.
533   if (isComplexType(from) && !isComplexType(to)) {
534     return false;
535   }
536   // We disallow float -> integral, e.g., int_tensor *= float is disallowed.
537   if (isFloatingType(from) && isIntegralType(to, false)) {
538     return false;
539   }
540 
541   // Treat bool as a distinct "category," to be consistent with type promotion
542   // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
543   // category as `bool_tensor`, we would not promote. Differing categories
544   // implies `bool_tensor += 5` is disallowed.
545   //
546   // NB: numpy distinguishes "unsigned" as a category to get the desired
547   // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
548   // * We don't want the performance hit of checking the runtime sign of
549   // Scalars.
550   // * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
551   if (from != ScalarType::Bool && to == ScalarType::Bool) {
552     return false;
553   }
554   return true;
555 }
556 
557 C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
558 
559 inline std::ostream& operator<<(
560     std::ostream& stream,
561     at::ScalarType scalar_type) {
562   return stream << toString(scalar_type);
563 }
564 
565 // Returns a pair of strings representing the names for each dtype.
566 // The returned pair is (name, legacy_name_if_applicable)
567 C10_API std::pair<std::string, std::string> getDtypeNames(
568     c10::ScalarType scalarType);
569 
570 // Returns a map of string name to dtype.
571 C10_API const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap();
572 
573 } // namespace c10
574