xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mkl/Exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <stdexcept>
5 #include <sstream>
6 #include <mkl_dfti.h>
7 #include <mkl_spblas.h>
8 
9 namespace at::native {
10 
MKL_DFTI_CHECK(MKL_INT status)11 static inline void MKL_DFTI_CHECK(MKL_INT status)
12 {
13   if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
14     std::ostringstream ss;
15     ss << "MKL FFT error: " << DftiErrorMessage(status);
16     throw std::runtime_error(ss.str());
17   }
18 }
19 
20 }  // namespace at::native
21 
22 namespace at::mkl::sparse {
_mklGetErrorString(sparse_status_t status)23 static inline const char* _mklGetErrorString(sparse_status_t status) {
24   if (status == SPARSE_STATUS_SUCCESS) {
25     return "SPARSE_STATUS_SUCCESS";
26   }
27   if (status == SPARSE_STATUS_NOT_INITIALIZED) {
28     return "SPARSE_STATUS_NOT_INITIALIZED";
29   }
30   if (status == SPARSE_STATUS_ALLOC_FAILED) {
31     return "SPARSE_STATUS_ALLOC_FAILED";
32   }
33   if (status == SPARSE_STATUS_INVALID_VALUE) {
34     return "SPARSE_STATUS_INVALID_VALUE";
35   }
36   if (status == SPARSE_STATUS_EXECUTION_FAILED) {
37     return "SPARSE_STATUS_EXECUTION_FAILED";
38   }
39   if (status == SPARSE_STATUS_INTERNAL_ERROR) {
40     return "SPARSE_STATUS_INTERNAL_ERROR";
41   }
42   if (status == SPARSE_STATUS_NOT_SUPPORTED) {
43     return "SPARSE_STATUS_NOT_SUPPORTED";
44   }
45   return "<unknown>";
46 }
47 } // namespace at::mkl::sparse
48 
49 #define TORCH_MKLSPARSE_CHECK(EXPR)                 \
50   do {                                              \
51     sparse_status_t __err = EXPR;                   \
52     TORCH_CHECK(                                    \
53         __err == SPARSE_STATUS_SUCCESS,             \
54         "MKL error: ",                              \
55         at::mkl::sparse::_mklGetErrorString(__err), \
56         " when calling `" #EXPR "`");               \
57   } while (0)
58 
59 #define TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, function_name) \
60   do {                                                   \
61     sparse_status_t __status = (status);                 \
62     TORCH_CHECK(                                         \
63         __status == SPARSE_STATUS_SUCCESS ||             \
64             __status == SPARSE_STATUS_INVALID_VALUE,     \
65         "MKL error: ",                                   \
66         at::mkl::sparse::_mklGetErrorString(__status),   \
67         " when calling `" function_name "`");            \
68   } while (0)
69