xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/Types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cudnn/Types.h>
2 
3 #include <ATen/ATen.h>
4 
5 namespace at::native {
6 
getCudnnDataTypeFromScalarType(const at::ScalarType dtype)7 cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) {
8   if (dtype == c10::kQInt8) {
9     return CUDNN_DATA_INT8;
10   } else if (dtype == at::kFloat) {
11     return CUDNN_DATA_FLOAT;
12   } else if (dtype == at::kDouble) {
13     return CUDNN_DATA_DOUBLE;
14   } else if (dtype == at::kHalf) {
15     return CUDNN_DATA_HALF;
16   } else if (dtype == at::kBFloat16) {
17     return CUDNN_DATA_BFLOAT16;
18   } else if (dtype == at::kInt) {
19     return CUDNN_DATA_INT32;
20   } else if (dtype == at::kByte) {
21     return CUDNN_DATA_UINT8;
22   } else if (dtype == at::kChar) {
23     return CUDNN_DATA_INT8;
24   }
25   std::string msg("getCudnnDataTypeFromScalarType() not supported for ");
26   msg += toString(dtype);
27   throw std::runtime_error(msg);
28 }
29 
getCudnnDataType(const at::Tensor & tensor)30 cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) {
31   return getCudnnDataTypeFromScalarType(tensor.scalar_type());
32 }
33 
cudnn_version()34 int64_t cudnn_version() {
35   return CUDNN_VERSION;
36 }
37 
38 } // namespace at::native
39