xref: /aosp_15_r20/external/pytorch/aten/src/ATen/AccumulateType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/Config.h>
3 #include <c10/core/DeviceType.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/BFloat16.h>
6 #include <c10/util/Float8_e4m3fn.h>
7 #include <c10/util/Float8_e4m3fnuz.h>
8 #include <c10/util/Float8_e5m2.h>
9 #include <c10/util/Float8_e5m2fnuz.h>
10 #include <c10/util/Half.h>
11 
12 // Defines the accumulation type for a scalar type.
13 // Example:
14 //   using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
15 //
16 // Accumulation types are an important concept in numeric computing
17 // because you frequently want to perform intermediate computations
18 // at a higher precision than the input and output precision, to avoid
19 // compounding internal rounding errors.  Accumulation is the most
20 // well-known intermediate computation (it is of great importance for
21 // sum reduction and matrix multiply, for example), but in PyTorch
22 // acc_type ends up getting used for all sorts of other intermediate
23 // computations, so it perhaps would be more accurately (ahem) called an
24 // "accurate" type.  acc_type is especially important for reduced
25 // precision operations like float16 and bfloat16, where relatively
26 // benign looking inputs can easily end up overflowing/underflowing.
27 //
28 // acc_type is parametrized by whether or not you are running on CUDA
29 // or not, because on CUDA double precision operations are expensive
30 // and so by default, we don't actually want to use double as an
31 // acc_type on CUDA.  A lot of things are typed out below, but
32 // basically, the table is generated by a few rules:
33 //
34 //  If bool:
35 //      Use 'bool' as acc_type.
36 //  If floating point:
37 //      If CUDA, use 'float' as acc_type (unless scalar_t is double),
38 //      otherwise (CPU) use 'double'
39 //  If integral:
40 //      Use 'int64_t' as acc_type
41 //
42 // You're not forced to use this template; if you happen to know
43 // something specific about your use case, you can specify your own
44 // desired behavior.  This template, however, will give you a reasonable
45 // default that will work for all dtypes supported in PyTorch.
46 
47 #if defined(__CUDACC__)
48 #include <cuda.h>
49 #include <cuda_fp16.h>
50 #elif defined(__HIPCC__)
51 #include <hip/hip_fp16.h>
52 #include <hip/hip_runtime.h>
53 #endif
54 
55 namespace at {
56 
57 template <typename T, c10::DeviceType D>
58 struct AccumulateTypeDevice {};
59 
60 template <typename T, bool>
61 struct AccumulateType {};
62 
63 template <typename T>
64 struct AccumulateType<T, false> {
65   using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
66 };
67 
68 template <typename T>
69 struct AccumulateType<T, true> {
70   using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
71 };
72 
73 template <typename T, c10::DeviceType device>
74 using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
75 
76 template <typename T, bool is_cuda>
77 using acc_type = typename AccumulateType<T, is_cuda>::type;
78 
79 #define ACC_TYPE(t, acc_t, device_type)         \
80   template <>                                   \
81   struct AccumulateTypeDevice<t, device_type> { \
82     using type = acc_t;                         \
83   };
84 #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
85 #define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
86 #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
87 #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
88 
89 MPS_ACC_TYPE(BFloat16, float);
90 MPS_ACC_TYPE(Half, float);
91 MPS_ACC_TYPE(Float8_e5m2, float);
92 MPS_ACC_TYPE(Float8_e4m3fn, float);
93 MPS_ACC_TYPE(Float8_e5m2fnuz, float);
94 MPS_ACC_TYPE(Float8_e4m3fnuz, float);
95 MPS_ACC_TYPE(float, float);
96 MPS_ACC_TYPE(double, float);
97 MPS_ACC_TYPE(int8_t, int64_t);
98 MPS_ACC_TYPE(uint8_t, int64_t);
99 MPS_ACC_TYPE(char, int64_t);
100 MPS_ACC_TYPE(int16_t, int64_t);
101 MPS_ACC_TYPE(int32_t, int64_t);
102 MPS_ACC_TYPE(int64_t, int64_t);
103 MPS_ACC_TYPE(bool, bool);
104 MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
105 MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
106 MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
107 
108 XPU_ACC_TYPE(BFloat16, float);
109 XPU_ACC_TYPE(Half, float);
110 XPU_ACC_TYPE(Float8_e5m2, float);
111 XPU_ACC_TYPE(Float8_e4m3fn, float);
112 XPU_ACC_TYPE(Float8_e5m2fnuz, float);
113 XPU_ACC_TYPE(Float8_e4m3fnuz, float);
114 XPU_ACC_TYPE(float, float);
115 XPU_ACC_TYPE(double, double);
116 XPU_ACC_TYPE(int8_t, int64_t);
117 XPU_ACC_TYPE(uint8_t, int64_t);
118 XPU_ACC_TYPE(char, int64_t);
119 XPU_ACC_TYPE(int16_t, int64_t);
120 XPU_ACC_TYPE(int32_t, int64_t);
121 XPU_ACC_TYPE(int64_t, int64_t);
122 XPU_ACC_TYPE(bool, bool);
123 XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
124 XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
125 XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
126 
127 #if defined(__CUDACC__) || defined(__HIPCC__)
128 CUDA_ACC_TYPE(half, float);
129 #endif
130 CUDA_ACC_TYPE(BFloat16, float);
131 CUDA_ACC_TYPE(Half, float);
132 CUDA_ACC_TYPE(Float8_e5m2, float);
133 CUDA_ACC_TYPE(Float8_e4m3fn, float);
134 CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
135 CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
136 CUDA_ACC_TYPE(float, float);
137 CUDA_ACC_TYPE(double, double);
138 CUDA_ACC_TYPE(int8_t, int64_t);
139 CUDA_ACC_TYPE(uint8_t, int64_t);
140 CUDA_ACC_TYPE(char, int64_t);
141 CUDA_ACC_TYPE(int16_t, int64_t);
142 CUDA_ACC_TYPE(int32_t, int64_t);
143 CUDA_ACC_TYPE(int64_t, int64_t);
144 CUDA_ACC_TYPE(bool, bool);
145 CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
146 CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
147 CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
148 
149 CPU_ACC_TYPE(BFloat16, float);
150 CPU_ACC_TYPE(Half, float);
151 CPU_ACC_TYPE(Float8_e5m2, float);
152 CPU_ACC_TYPE(Float8_e4m3fn, float);
153 CPU_ACC_TYPE(Float8_e5m2fnuz, float);
154 CPU_ACC_TYPE(Float8_e4m3fnuz, float);
155 CPU_ACC_TYPE(float, double);
156 CPU_ACC_TYPE(double, double);
157 CPU_ACC_TYPE(int8_t, int64_t);
158 CPU_ACC_TYPE(uint8_t, int64_t);
159 CPU_ACC_TYPE(char, int64_t);
160 CPU_ACC_TYPE(int16_t, int64_t);
161 CPU_ACC_TYPE(int32_t, int64_t);
162 CPU_ACC_TYPE(int64_t, int64_t);
163 CPU_ACC_TYPE(bool, bool);
164 CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
165 CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
166 CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
167 
168 TORCH_API c10::ScalarType toAccumulateType(
169     c10::ScalarType type,
170     c10::DeviceType device);
171 TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
172 
173 } // namespace at
174