1 #pragma once 2 #include <ATen/Config.h> 3 #include <c10/core/DeviceType.h> 4 #include <c10/core/ScalarType.h> 5 #include <c10/util/BFloat16.h> 6 #include <c10/util/Float8_e4m3fn.h> 7 #include <c10/util/Float8_e4m3fnuz.h> 8 #include <c10/util/Float8_e5m2.h> 9 #include <c10/util/Float8_e5m2fnuz.h> 10 #include <c10/util/Half.h> 11 12 // Defines the accumulation type for a scalar type. 13 // Example: 14 // using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>; 15 // 16 // Accumulation types are an important concept in numeric computing 17 // because you frequently want to perform intermediate computations 18 // at a higher precision than the input and output precision, to avoid 19 // compounding internal rounding errors. Accumulation is the most 20 // well-known intermediate computation (it is of great importance for 21 // sum reduction and matrix multiply, for example), but in PyTorch 22 // acc_type ends up getting used for all sorts of other intermediate 23 // computations, so it perhaps would be more accurately (ahem) called an 24 // "accurate" type. acc_type is especially important for reduced 25 // precision operations like float16 and bfloat16, where relatively 26 // benign looking inputs can easily end up overflowing/underflowing. 27 // 28 // acc_type is parametrized by whether or not you are running on CUDA 29 // or not, because on CUDA double precision operations are expensive 30 // and so by default, we don't actually want to use double as an 31 // acc_type on CUDA. A lot of things are typed out below, but 32 // basically, the table is generated by a few rules: 33 // 34 // If bool: 35 // Use 'bool' as acc_type. 36 // If floating point: 37 // If CUDA, use 'float' as acc_type (unless scalar_t is double), 38 // otherwise (CPU) use 'double' 39 // If integral: 40 // Use 'int64_t' as acc_type 41 // 42 // You're not forced to use this template; if you happen to know 43 // something specific about your use case, you can specify your own 44 // desired behavior. This template, however, will give you a reasonable 45 // default that will work for all dtypes supported in PyTorch. 46 47 #if defined(__CUDACC__) 48 #include <cuda.h> 49 #include <cuda_fp16.h> 50 #elif defined(__HIPCC__) 51 #include <hip/hip_fp16.h> 52 #include <hip/hip_runtime.h> 53 #endif 54 55 namespace at { 56 57 template <typename T, c10::DeviceType D> 58 struct AccumulateTypeDevice {}; 59 60 template <typename T, bool> 61 struct AccumulateType {}; 62 63 template <typename T> 64 struct AccumulateType<T, false> { 65 using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type; 66 }; 67 68 template <typename T> 69 struct AccumulateType<T, true> { 70 using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type; 71 }; 72 73 template <typename T, c10::DeviceType device> 74 using acc_type_device = typename AccumulateTypeDevice<T, device>::type; 75 76 template <typename T, bool is_cuda> 77 using acc_type = typename AccumulateType<T, is_cuda>::type; 78 79 #define ACC_TYPE(t, acc_t, device_type) \ 80 template <> \ 81 struct AccumulateTypeDevice<t, device_type> { \ 82 using type = acc_t; \ 83 }; 84 #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) 85 #define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU) 86 #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) 87 #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) 88 89 MPS_ACC_TYPE(BFloat16, float); 90 MPS_ACC_TYPE(Half, float); 91 MPS_ACC_TYPE(Float8_e5m2, float); 92 MPS_ACC_TYPE(Float8_e4m3fn, float); 93 MPS_ACC_TYPE(Float8_e5m2fnuz, float); 94 MPS_ACC_TYPE(Float8_e4m3fnuz, float); 95 MPS_ACC_TYPE(float, float); 96 MPS_ACC_TYPE(double, float); 97 MPS_ACC_TYPE(int8_t, int64_t); 98 MPS_ACC_TYPE(uint8_t, int64_t); 99 MPS_ACC_TYPE(char, int64_t); 100 MPS_ACC_TYPE(int16_t, int64_t); 101 MPS_ACC_TYPE(int32_t, int64_t); 102 MPS_ACC_TYPE(int64_t, int64_t); 103 MPS_ACC_TYPE(bool, bool); 104 MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>); 105 MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>); 106 MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>); 107 108 XPU_ACC_TYPE(BFloat16, float); 109 XPU_ACC_TYPE(Half, float); 110 XPU_ACC_TYPE(Float8_e5m2, float); 111 XPU_ACC_TYPE(Float8_e4m3fn, float); 112 XPU_ACC_TYPE(Float8_e5m2fnuz, float); 113 XPU_ACC_TYPE(Float8_e4m3fnuz, float); 114 XPU_ACC_TYPE(float, float); 115 XPU_ACC_TYPE(double, double); 116 XPU_ACC_TYPE(int8_t, int64_t); 117 XPU_ACC_TYPE(uint8_t, int64_t); 118 XPU_ACC_TYPE(char, int64_t); 119 XPU_ACC_TYPE(int16_t, int64_t); 120 XPU_ACC_TYPE(int32_t, int64_t); 121 XPU_ACC_TYPE(int64_t, int64_t); 122 XPU_ACC_TYPE(bool, bool); 123 XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>); 124 XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>); 125 XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>); 126 127 #if defined(__CUDACC__) || defined(__HIPCC__) 128 CUDA_ACC_TYPE(half, float); 129 #endif 130 CUDA_ACC_TYPE(BFloat16, float); 131 CUDA_ACC_TYPE(Half, float); 132 CUDA_ACC_TYPE(Float8_e5m2, float); 133 CUDA_ACC_TYPE(Float8_e4m3fn, float); 134 CUDA_ACC_TYPE(Float8_e5m2fnuz, float); 135 CUDA_ACC_TYPE(Float8_e4m3fnuz, float); 136 CUDA_ACC_TYPE(float, float); 137 CUDA_ACC_TYPE(double, double); 138 CUDA_ACC_TYPE(int8_t, int64_t); 139 CUDA_ACC_TYPE(uint8_t, int64_t); 140 CUDA_ACC_TYPE(char, int64_t); 141 CUDA_ACC_TYPE(int16_t, int64_t); 142 CUDA_ACC_TYPE(int32_t, int64_t); 143 CUDA_ACC_TYPE(int64_t, int64_t); 144 CUDA_ACC_TYPE(bool, bool); 145 CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>); 146 CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>); 147 CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>); 148 149 CPU_ACC_TYPE(BFloat16, float); 150 CPU_ACC_TYPE(Half, float); 151 CPU_ACC_TYPE(Float8_e5m2, float); 152 CPU_ACC_TYPE(Float8_e4m3fn, float); 153 CPU_ACC_TYPE(Float8_e5m2fnuz, float); 154 CPU_ACC_TYPE(Float8_e4m3fnuz, float); 155 CPU_ACC_TYPE(float, double); 156 CPU_ACC_TYPE(double, double); 157 CPU_ACC_TYPE(int8_t, int64_t); 158 CPU_ACC_TYPE(uint8_t, int64_t); 159 CPU_ACC_TYPE(char, int64_t); 160 CPU_ACC_TYPE(int16_t, int64_t); 161 CPU_ACC_TYPE(int32_t, int64_t); 162 CPU_ACC_TYPE(int64_t, int64_t); 163 CPU_ACC_TYPE(bool, bool); 164 CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>); 165 CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>); 166 CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>); 167 168 TORCH_API c10::ScalarType toAccumulateType( 169 c10::ScalarType type, 170 c10::DeviceType device); 171 TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda); 172 173 } // namespace at 174