xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/Exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/miopen/miopen-wrapper.h>
4 #include <string>
5 #include <stdexcept>
6 #include <sstream>
7 
8 namespace at { namespace native {
9 
10 class miopen_exception : public std::runtime_error {
11 public:
12   miopenStatus_t status;
miopen_exception(miopenStatus_t status,const char * msg)13   miopen_exception(miopenStatus_t status, const char* msg)
14       : std::runtime_error(msg)
15       , status(status) {}
miopen_exception(miopenStatus_t status,const std::string & msg)16   miopen_exception(miopenStatus_t status, const std::string& msg)
17       : std::runtime_error(msg)
18       , status(status) {}
19 };
20 
MIOPEN_CHECK(miopenStatus_t status)21 inline void MIOPEN_CHECK(miopenStatus_t status)
22 {
23   if (status != miopenStatusSuccess) {
24     if (status == miopenStatusNotImplemented) {
25         throw miopen_exception(status, std::string(miopenGetErrorString(status)) +
26                 ". This error may appear if you passed in a non-contiguous input.");
27     }
28     throw miopen_exception(status, miopenGetErrorString(status));
29   }
30 }
31 
HIP_CHECK(hipError_t error)32 inline void HIP_CHECK(hipError_t error)
33 {
34   if (error != hipSuccess) {
35     std::string msg("HIP error: ");
36     msg += hipGetErrorString(error);
37     throw std::runtime_error(msg);
38   }
39 }
40 
41 }} // namespace at::native
42