1 #pragma once 2 3 #include <c10/core/ScalarType.h> 4 5 #include <cuda.h> 6 #include <library_types.h> 7 8 namespace at::cuda { 9 10 template <typename scalar_t> getCudaDataType()11cudaDataType getCudaDataType() { 12 static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); 13 return {}; 14 } 15 16 template<> inline cudaDataType getCudaDataType<at::Half>() { 17 return CUDA_R_16F; 18 } 19 template<> inline cudaDataType getCudaDataType<float>() { 20 return CUDA_R_32F; 21 } 22 template<> inline cudaDataType getCudaDataType<double>() { 23 return CUDA_R_64F; 24 } 25 template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() { 26 return CUDA_C_16F; 27 } 28 template<> inline cudaDataType getCudaDataType<c10::complex<float>>() { 29 return CUDA_C_32F; 30 } 31 template<> inline cudaDataType getCudaDataType<c10::complex<double>>() { 32 return CUDA_C_64F; 33 } 34 35 template<> inline cudaDataType getCudaDataType<uint8_t>() { 36 return CUDA_R_8U; 37 } 38 template<> inline cudaDataType getCudaDataType<int8_t>() { 39 return CUDA_R_8I; 40 } 41 template<> inline cudaDataType getCudaDataType<int>() { 42 return CUDA_R_32I; 43 } 44 45 template<> inline cudaDataType getCudaDataType<int16_t>() { 46 return CUDA_R_16I; 47 } 48 template<> inline cudaDataType getCudaDataType<int64_t>() { 49 return CUDA_R_64I; 50 } 51 template<> inline cudaDataType getCudaDataType<at::BFloat16>() { 52 return CUDA_R_16BF; 53 } 54 ScalarTypeToCudaDataType(const c10::ScalarType & scalar_type)55inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { 56 switch (scalar_type) { 57 case c10::ScalarType::Byte: 58 return CUDA_R_8U; 59 case c10::ScalarType::Char: 60 return CUDA_R_8I; 61 case c10::ScalarType::Int: 62 return CUDA_R_32I; 63 case c10::ScalarType::Half: 64 return CUDA_R_16F; 65 case c10::ScalarType::Float: 66 return CUDA_R_32F; 67 case c10::ScalarType::Double: 68 return CUDA_R_64F; 69 case c10::ScalarType::ComplexHalf: 70 return CUDA_C_16F; 71 case c10::ScalarType::ComplexFloat: 72 return CUDA_C_32F; 73 case c10::ScalarType::ComplexDouble: 74 return CUDA_C_64F; 75 case c10::ScalarType::Short: 76 return CUDA_R_16I; 77 case c10::ScalarType::Long: 78 return CUDA_R_64I; 79 case c10::ScalarType::BFloat16: 80 return CUDA_R_16BF; 81 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 82 case c10::ScalarType::Float8_e4m3fn: 83 return CUDA_R_8F_E4M3; 84 case c10::ScalarType::Float8_e5m2: 85 return CUDA_R_8F_E5M2; 86 #endif 87 #if defined(USE_ROCM) 88 #if defined(HIP_NEW_TYPE_ENUMS) 89 case c10::ScalarType::Float8_e4m3fnuz: 90 return HIP_R_8F_E4M3_FNUZ; 91 case c10::ScalarType::Float8_e5m2fnuz: 92 return HIP_R_8F_E5M2_FNUZ; 93 #else 94 case c10::ScalarType::Float8_e4m3fnuz: 95 return static_cast<hipDataType>(1000); 96 case c10::ScalarType::Float8_e5m2fnuz: 97 return static_cast<hipDataType>(1001); 98 #endif 99 #endif 100 default: 101 TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") 102 } 103 } 104 105 } // namespace at::cuda 106