1 #include <ATen/miopen/Types.h> 2 3 #include <ATen/ATen.h> 4 #include <miopen/version.h> 5 6 namespace at { namespace native { 7 getMiopenDataType(const at::Tensor & tensor)8miopenDataType_t getMiopenDataType(const at::Tensor& tensor) { 9 if (tensor.scalar_type() == at::kFloat) { 10 return miopenFloat; 11 } else if (tensor.scalar_type() == at::kHalf) { 12 return miopenHalf; 13 } else if (tensor.scalar_type() == at::kBFloat16) { 14 return miopenBFloat16; 15 } 16 std::string msg("getMiopenDataType() not supported for "); 17 msg += toString(tensor.scalar_type()); 18 throw std::runtime_error(msg); 19 } 20 miopen_version()21int64_t miopen_version() { 22 return (MIOPEN_VERSION_MAJOR<<8) + (MIOPEN_VERSION_MINOR<<4) + MIOPEN_VERSION_PATCH; 23 } 24 25 }} // namespace at::miopen 26