xref: /aosp_15_r20/external/pytorch/aten/src/ATen/OpMathType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/BFloat16.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Float8_e4m3fn.h>
7 #include <c10/util/Float8_e4m3fnuz.h>
8 #include <c10/util/Float8_e5m2.h>
9 #include <c10/util/Float8_e5m2fnuz.h>
10 #include <c10/util/Half.h>
11 
12 namespace at {
13 
14 // For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
15 template <typename scalar_t>
16 struct OpMathType {
17   using type = scalar_t;
18 };
19 template <>
20 struct OpMathType<at::Half> {
21   using type = float;
22 };
23 template <>
24 struct OpMathType<at::BFloat16> {
25   using type = float;
26 };
27 template <>
28 struct OpMathType<at::Float8_e5m2> {
29   using type = float;
30 };
31 template <>
32 struct OpMathType<at::Float8_e4m3fn> {
33   using type = float;
34 };
35 template <>
36 struct OpMathType<at::Float8_e5m2fnuz> {
37   using type = float;
38 };
39 template <>
40 struct OpMathType<at::Float8_e4m3fnuz> {
41   using type = float;
42 };
43 template <>
44 struct OpMathType<c10::complex<Half>> {
45   using type = c10::complex<float>;
46 };
47 
48 template <typename T>
49 using opmath_type = typename OpMathType<T>::type;
50 
51 namespace {
52 
53 inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
54   switch (type) {
55 #define DEFINE_CASE(scalar_t, TypeNum) \
56   case ScalarType::TypeNum:            \
57     return CppTypeToScalarType<at::opmath_type<scalar_t>>::value;
58 
59     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
60 #undef DEFINE_CASE
61 
62     default:
63       TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
64   }
65 }
66 
67 } // namespace
68 
69 } // namespace at
70