xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/Types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/miopen/Types.h>
2 
3 #include <ATen/ATen.h>
4 #include <miopen/version.h>
5 
6 namespace at { namespace native {
7 
getMiopenDataType(const at::Tensor & tensor)8 miopenDataType_t getMiopenDataType(const at::Tensor& tensor) {
9   if (tensor.scalar_type() == at::kFloat) {
10     return miopenFloat;
11   } else if (tensor.scalar_type() == at::kHalf) {
12     return miopenHalf;
13   }  else if (tensor.scalar_type() == at::kBFloat16) {
14     return miopenBFloat16;
15   }
16   std::string msg("getMiopenDataType() not supported for ");
17   msg += toString(tensor.scalar_type());
18   throw std::runtime_error(msg);
19 }
20 
miopen_version()21 int64_t miopen_version() {
22   return (MIOPEN_VERSION_MAJOR<<8) + (MIOPEN_VERSION_MINOR<<4) + MIOPEN_VERSION_PATCH;
23 }
24 
25 }}  // namespace at::miopen
26