1 #pragma once 2 3 #include <c10/core/ScalarType.h> 4 #include <c10/util/BFloat16.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/Float8_e4m3fn.h> 7 #include <c10/util/Float8_e4m3fnuz.h> 8 #include <c10/util/Float8_e5m2.h> 9 #include <c10/util/Float8_e5m2fnuz.h> 10 #include <c10/util/Half.h> 11 12 namespace at { 13 14 // For FP16 or BFloat16 inputs, ops should perform internal math in FP32. 15 template <typename scalar_t> 16 struct OpMathType { 17 using type = scalar_t; 18 }; 19 template <> 20 struct OpMathType<at::Half> { 21 using type = float; 22 }; 23 template <> 24 struct OpMathType<at::BFloat16> { 25 using type = float; 26 }; 27 template <> 28 struct OpMathType<at::Float8_e5m2> { 29 using type = float; 30 }; 31 template <> 32 struct OpMathType<at::Float8_e4m3fn> { 33 using type = float; 34 }; 35 template <> 36 struct OpMathType<at::Float8_e5m2fnuz> { 37 using type = float; 38 }; 39 template <> 40 struct OpMathType<at::Float8_e4m3fnuz> { 41 using type = float; 42 }; 43 template <> 44 struct OpMathType<c10::complex<Half>> { 45 using type = c10::complex<float>; 46 }; 47 48 template <typename T> 49 using opmath_type = typename OpMathType<T>::type; 50 51 namespace { 52 53 inline c10::ScalarType toOpMathType(const c10::ScalarType type) { 54 switch (type) { 55 #define DEFINE_CASE(scalar_t, TypeNum) \ 56 case ScalarType::TypeNum: \ 57 return CppTypeToScalarType<at::opmath_type<scalar_t>>::value; 58 59 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) 60 #undef DEFINE_CASE 61 62 default: 63 TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); 64 } 65 } 66 67 } // namespace 68 69 } // namespace at 70