1 #ifndef THP_UTILS_H 2 #define THP_UTILS_H 3 4 #include <ATen/ATen.h> 5 #include <c10/util/Exception.h> 6 #include <torch/csrc/Storage.h> 7 #include <torch/csrc/THConcat.h> 8 #include <torch/csrc/utils/object_ptr.h> 9 #include <torch/csrc/utils/python_compat.h> 10 #include <torch/csrc/utils/python_numbers.h> 11 #include <string> 12 #include <type_traits> 13 #include <vector> 14 15 #ifdef USE_CUDA 16 #include <c10/cuda/CUDAStream.h> 17 #endif 18 19 #define THPUtils_(NAME) TH_CONCAT_4(THP, Real, Utils_, NAME) 20 21 #define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name) 22 23 #if defined(__GNUC__) || defined(__ICL) || defined(__clang__) 24 #define THP_EXPECT(x, y) (__builtin_expect((x), (y))) 25 #else 26 #define THP_EXPECT(x, y) (x) 27 #endif 28 29 #define THPUtils_checkReal_FLOAT(object) \ 30 (PyFloat_Check(object) || PyLong_Check(object)) 31 32 #define THPUtils_unpackReal_FLOAT(object) \ 33 (PyFloat_Check(object) ? PyFloat_AsDouble(object) \ 34 : PyLong_Check(object) \ 35 ? PyLong_AsLongLong(object) \ 36 : (throw std::runtime_error("Could not parse real"), 0)) 37 38 #define THPUtils_checkReal_INT(object) PyLong_Check(object) 39 40 #define THPUtils_unpackReal_INT(object) \ 41 (PyLong_Check(object) \ 42 ? PyLong_AsLongLong(object) \ 43 : (throw std::runtime_error("Could not parse real"), 0)) 44 45 #define THPUtils_unpackReal_BOOL(object) \ 46 (PyBool_Check(object) \ 47 ? object \ 48 : (throw std::runtime_error("Could not parse real"), Py_False)) 49 50 #define THPUtils_unpackReal_COMPLEX(object) \ 51 (PyComplex_Check(object) \ 52 ? (c10::complex<double>( \ 53 PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) \ 54 : PyFloat_Check(object) \ 55 ? (c10::complex<double>(PyFloat_AsDouble(object), 0)) \ 56 : PyLong_Check(object) \ 57 ? (c10::complex<double>(PyLong_AsLongLong(object), 0)) \ 58 : (throw std::runtime_error("Could not parse real"), \ 59 c10::complex<double>(0, 0))) 60 61 #define THPUtils_checkReal_BOOL(object) PyBool_Check(object) 62 63 #define THPUtils_checkReal_COMPLEX(object) \ 64 PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || \ 65 PyInt_Check(object) 66 67 #define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value) 68 #define THPUtils_newReal_INT(value) PyInt_FromLong(value) 69 70 #define THPUtils_newReal_BOOL(value) PyBool_FromLong(value) 71 72 #define THPUtils_newReal_COMPLEX(value) \ 73 PyComplex_FromDoubles(value.real(), value.imag()) 74 75 #define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) 76 #define THPDoubleUtils_unpackReal(object) \ 77 (double)THPUtils_unpackReal_FLOAT(object) 78 #define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value) 79 #define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) 80 #define THPFloatUtils_unpackReal(object) \ 81 (float)THPUtils_unpackReal_FLOAT(object) 82 #define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value) 83 #define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) 84 #define THPHalfUtils_unpackReal(object) \ 85 (at::Half) THPUtils_unpackReal_FLOAT(object) 86 #define THPHalfUtils_newReal(value) PyFloat_FromDouble(value) 87 #define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value) 88 #define THPComplexDoubleUtils_checkReal(object) \ 89 THPUtils_checkReal_COMPLEX(object) 90 #define THPComplexDoubleUtils_unpackReal(object) \ 91 THPUtils_unpackReal_COMPLEX(object) 92 #define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value) 93 #define THPComplexFloatUtils_checkReal(object) \ 94 THPUtils_checkReal_COMPLEX(object) 95 #define THPComplexFloatUtils_unpackReal(object) \ 96 (c10::complex<float>)THPUtils_unpackReal_COMPLEX(object) 97 #define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value) 98 #define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object) 99 #define THPBFloat16Utils_unpackReal(object) \ 100 (at::BFloat16) THPUtils_unpackReal_FLOAT(object) 101 #define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value) 102 #define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value) 103 104 #define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object) 105 #define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object) 106 #define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value) 107 #define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object) 108 #define THPBoolUtils_unpackAccreal(object) \ 109 (int64_t) THPUtils_unpackReal_BOOL(object) 110 #define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value) 111 #define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object) 112 #define THPLongUtils_unpackReal(object) \ 113 (int64_t) THPUtils_unpackReal_INT(object) 114 #define THPLongUtils_newReal(value) THPUtils_newReal_INT(value) 115 #define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object) 116 #define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 117 #define THPIntUtils_newReal(value) THPUtils_newReal_INT(value) 118 #define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object) 119 #define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object) 120 #define THPShortUtils_newReal(value) THPUtils_newReal_INT(value) 121 #define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object) 122 #define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object) 123 #define THPCharUtils_newReal(value) THPUtils_newReal_INT(value) 124 #define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object) 125 #define THPByteUtils_unpackReal(object) \ 126 (unsigned char)THPUtils_unpackReal_INT(object) 127 #define THPByteUtils_newReal(value) THPUtils_newReal_INT(value) 128 // quantized types 129 #define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) 130 #define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 131 #define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value) 132 #define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) 133 #define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 134 #define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value) 135 #define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object) 136 #define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 137 #define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value) 138 #define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object) 139 #define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 140 #define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value) 141 #define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object) 142 #define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) 143 #define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value) 144 145 /* 146 From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c 147 If compiled as a shared library, some compilers don't allow addresses of 148 Python objects defined in other libraries to be used in static PyTypeObject 149 initializers. The DEFERRED_ADDRESS macro is used to tag the slots where such 150 addresses appear; the module init function that adds the PyTypeObject to the 151 module must fill in the tagged slots at runtime. The argument is for 152 documentation -- the macro ignores it. 153 */ 154 #define DEFERRED_ADDRESS(ADDR) nullptr 155 156 TORCH_PYTHON_API void THPUtils_setError(const char* format, ...); 157 TORCH_PYTHON_API void THPUtils_invalidArguments( 158 PyObject* given_args, 159 PyObject* given_kwargs, 160 const char* function_name, 161 size_t num_options, 162 ...); 163 164 bool THPUtils_checkIntTuple(PyObject* arg); 165 std::vector<int> THPUtils_unpackIntTuple(PyObject* arg); 166 167 TORCH_PYTHON_API void THPUtils_addPyMethodDefs( 168 std::vector<PyMethodDef>& vector, 169 PyMethodDef* methods); 170 171 int THPUtils_getCallable(PyObject* arg, PyObject** result); 172 173 typedef THPPointer<THPGenerator> THPGeneratorPtr; 174 typedef class THPPointer<THPStorage> THPStoragePtr; 175 176 TORCH_PYTHON_API std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg); 177 PyObject* THPUtils_dispatchStateless( 178 PyObject* tensor, 179 const char* name, 180 PyObject* args, 181 PyObject* kwargs); 182 183 template <typename _real, typename = void> 184 struct mod_traits {}; 185 186 template <typename _real> 187 struct mod_traits<_real, std::enable_if_t<std::is_floating_point_v<_real>>> { 188 static _real mod(_real a, _real b) { 189 return fmod(a, b); 190 } 191 }; 192 193 template <typename _real> 194 struct mod_traits<_real, std::enable_if_t<std::is_integral_v<_real>>> { 195 static _real mod(_real a, _real b) { 196 return a % b; 197 } 198 }; 199 200 void setBackCompatBroadcastWarn(bool warn); 201 bool getBackCompatBroadcastWarn(); 202 203 void setBackCompatKeepdimWarn(bool warn); 204 bool getBackCompatKeepdimWarn(); 205 bool maybeThrowBackCompatKeepdimWarn(char* func); 206 207 // NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason 208 #ifdef USE_CUDA 209 std::vector<std::optional<at::cuda::CUDAStream>> 210 THPUtils_PySequence_to_CUDAStreamList(PyObject* obj); 211 #endif 212 213 void storage_fill(const at::Storage& self, uint8_t value); 214 void storage_set(const at::Storage& self, ptrdiff_t idx, uint8_t value); 215 uint8_t storage_get(const at::Storage& self, ptrdiff_t idx); 216 217 #endif 218