xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDADataType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 
5 #include <cuda.h>
6 #include <library_types.h>
7 
8 namespace at::cuda {
9 
10 template <typename scalar_t>
getCudaDataType()11 cudaDataType getCudaDataType() {
12   static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType.");
13   return {};
14 }
15 
16 template<> inline cudaDataType getCudaDataType<at::Half>() {
17   return CUDA_R_16F;
18 }
19 template<> inline cudaDataType getCudaDataType<float>() {
20   return CUDA_R_32F;
21 }
22 template<> inline cudaDataType getCudaDataType<double>() {
23   return CUDA_R_64F;
24 }
25 template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
26   return CUDA_C_16F;
27 }
28 template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
29   return CUDA_C_32F;
30 }
31 template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
32   return CUDA_C_64F;
33 }
34 
35 template<> inline cudaDataType getCudaDataType<uint8_t>() {
36   return CUDA_R_8U;
37 }
38 template<> inline cudaDataType getCudaDataType<int8_t>() {
39   return CUDA_R_8I;
40 }
41 template<> inline cudaDataType getCudaDataType<int>() {
42   return CUDA_R_32I;
43 }
44 
45 template<> inline cudaDataType getCudaDataType<int16_t>() {
46   return CUDA_R_16I;
47 }
48 template<> inline cudaDataType getCudaDataType<int64_t>() {
49   return CUDA_R_64I;
50 }
51 template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
52   return CUDA_R_16BF;
53 }
54 
ScalarTypeToCudaDataType(const c10::ScalarType & scalar_type)55 inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
56   switch (scalar_type) {
57     case c10::ScalarType::Byte:
58       return CUDA_R_8U;
59     case c10::ScalarType::Char:
60       return CUDA_R_8I;
61     case c10::ScalarType::Int:
62       return CUDA_R_32I;
63     case c10::ScalarType::Half:
64       return CUDA_R_16F;
65     case c10::ScalarType::Float:
66       return CUDA_R_32F;
67     case c10::ScalarType::Double:
68       return CUDA_R_64F;
69     case c10::ScalarType::ComplexHalf:
70       return CUDA_C_16F;
71     case c10::ScalarType::ComplexFloat:
72       return CUDA_C_32F;
73     case c10::ScalarType::ComplexDouble:
74       return CUDA_C_64F;
75     case c10::ScalarType::Short:
76       return CUDA_R_16I;
77     case c10::ScalarType::Long:
78       return CUDA_R_64I;
79     case c10::ScalarType::BFloat16:
80       return CUDA_R_16BF;
81 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
82     case c10::ScalarType::Float8_e4m3fn:
83       return CUDA_R_8F_E4M3;
84     case c10::ScalarType::Float8_e5m2:
85       return CUDA_R_8F_E5M2;
86 #endif
87 #if defined(USE_ROCM)
88 #if defined(HIP_NEW_TYPE_ENUMS)
89     case c10::ScalarType::Float8_e4m3fnuz:
90       return HIP_R_8F_E4M3_FNUZ;
91     case c10::ScalarType::Float8_e5m2fnuz:
92       return HIP_R_8F_E5M2_FNUZ;
93 #else
94     case c10::ScalarType::Float8_e4m3fnuz:
95       return static_cast<hipDataType>(1000);
96     case c10::ScalarType::Float8_e5m2fnuz:
97       return static_cast<hipDataType>(1001);
98 #endif
99 #endif
100     default:
101       TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
102   }
103 }
104 
105 } // namespace at::cuda
106