xref: /aosp_15_r20/external/pytorch/torch/csrc/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef THP_UTILS_H
2 #define THP_UTILS_H
3 
4 #include <ATen/ATen.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/Storage.h>
7 #include <torch/csrc/THConcat.h>
8 #include <torch/csrc/utils/object_ptr.h>
9 #include <torch/csrc/utils/python_compat.h>
10 #include <torch/csrc/utils/python_numbers.h>
11 #include <string>
12 #include <type_traits>
13 #include <vector>
14 
15 #ifdef USE_CUDA
16 #include <c10/cuda/CUDAStream.h>
17 #endif
18 
19 #define THPUtils_(NAME) TH_CONCAT_4(THP, Real, Utils_, NAME)
20 
21 #define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name)
22 
23 #if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
24 #define THP_EXPECT(x, y) (__builtin_expect((x), (y)))
25 #else
26 #define THP_EXPECT(x, y) (x)
27 #endif
28 
29 #define THPUtils_checkReal_FLOAT(object) \
30   (PyFloat_Check(object) || PyLong_Check(object))
31 
32 #define THPUtils_unpackReal_FLOAT(object)           \
33   (PyFloat_Check(object) ? PyFloat_AsDouble(object) \
34        : PyLong_Check(object)                       \
35        ? PyLong_AsLongLong(object)                  \
36        : (throw std::runtime_error("Could not parse real"), 0))
37 
38 #define THPUtils_checkReal_INT(object) PyLong_Check(object)
39 
40 #define THPUtils_unpackReal_INT(object) \
41   (PyLong_Check(object)                 \
42        ? PyLong_AsLongLong(object)      \
43        : (throw std::runtime_error("Could not parse real"), 0))
44 
45 #define THPUtils_unpackReal_BOOL(object) \
46   (PyBool_Check(object)                  \
47        ? object                          \
48        : (throw std::runtime_error("Could not parse real"), Py_False))
49 
50 #define THPUtils_unpackReal_COMPLEX(object)                                   \
51   (PyComplex_Check(object)                                                    \
52        ? (c10::complex<double>(                                               \
53              PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) \
54        : PyFloat_Check(object)                                                \
55        ? (c10::complex<double>(PyFloat_AsDouble(object), 0))                  \
56        : PyLong_Check(object)                                                 \
57        ? (c10::complex<double>(PyLong_AsLongLong(object), 0))                 \
58        : (throw std::runtime_error("Could not parse real"),                   \
59           c10::complex<double>(0, 0)))
60 
61 #define THPUtils_checkReal_BOOL(object) PyBool_Check(object)
62 
63 #define THPUtils_checkReal_COMPLEX(object)                                    \
64   PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || \
65       PyInt_Check(object)
66 
67 #define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value)
68 #define THPUtils_newReal_INT(value) PyInt_FromLong(value)
69 
70 #define THPUtils_newReal_BOOL(value) PyBool_FromLong(value)
71 
72 #define THPUtils_newReal_COMPLEX(value) \
73   PyComplex_FromDoubles(value.real(), value.imag())
74 
75 #define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
76 #define THPDoubleUtils_unpackReal(object) \
77   (double)THPUtils_unpackReal_FLOAT(object)
78 #define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value)
79 #define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
80 #define THPFloatUtils_unpackReal(object) \
81   (float)THPUtils_unpackReal_FLOAT(object)
82 #define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value)
83 #define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
84 #define THPHalfUtils_unpackReal(object) \
85   (at::Half) THPUtils_unpackReal_FLOAT(object)
86 #define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
87 #define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
88 #define THPComplexDoubleUtils_checkReal(object) \
89   THPUtils_checkReal_COMPLEX(object)
90 #define THPComplexDoubleUtils_unpackReal(object) \
91   THPUtils_unpackReal_COMPLEX(object)
92 #define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
93 #define THPComplexFloatUtils_checkReal(object) \
94   THPUtils_checkReal_COMPLEX(object)
95 #define THPComplexFloatUtils_unpackReal(object) \
96   (c10::complex<float>)THPUtils_unpackReal_COMPLEX(object)
97 #define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
98 #define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object)
99 #define THPBFloat16Utils_unpackReal(object) \
100   (at::BFloat16) THPUtils_unpackReal_FLOAT(object)
101 #define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value)
102 #define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value)
103 
104 #define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
105 #define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
106 #define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value)
107 #define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object)
108 #define THPBoolUtils_unpackAccreal(object) \
109   (int64_t) THPUtils_unpackReal_BOOL(object)
110 #define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value)
111 #define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object)
112 #define THPLongUtils_unpackReal(object) \
113   (int64_t) THPUtils_unpackReal_INT(object)
114 #define THPLongUtils_newReal(value) THPUtils_newReal_INT(value)
115 #define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object)
116 #define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
117 #define THPIntUtils_newReal(value) THPUtils_newReal_INT(value)
118 #define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object)
119 #define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object)
120 #define THPShortUtils_newReal(value) THPUtils_newReal_INT(value)
121 #define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object)
122 #define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object)
123 #define THPCharUtils_newReal(value) THPUtils_newReal_INT(value)
124 #define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object)
125 #define THPByteUtils_unpackReal(object) \
126   (unsigned char)THPUtils_unpackReal_INT(object)
127 #define THPByteUtils_newReal(value) THPUtils_newReal_INT(value)
128 // quantized types
129 #define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
130 #define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
131 #define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value)
132 #define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
133 #define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
134 #define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value)
135 #define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object)
136 #define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
137 #define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value)
138 #define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object)
139 #define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
140 #define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value)
141 #define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object)
142 #define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
143 #define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value)
144 
145 /*
146    From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c
147    If compiled as a shared library, some compilers don't allow addresses of
148    Python objects defined in other libraries to be used in static PyTypeObject
149    initializers. The DEFERRED_ADDRESS macro is used to tag the slots where such
150    addresses appear; the module init function that adds the PyTypeObject to the
151    module must fill in the tagged slots at runtime. The argument is for
152    documentation -- the macro ignores it.
153 */
154 #define DEFERRED_ADDRESS(ADDR) nullptr
155 
156 TORCH_PYTHON_API void THPUtils_setError(const char* format, ...);
157 TORCH_PYTHON_API void THPUtils_invalidArguments(
158     PyObject* given_args,
159     PyObject* given_kwargs,
160     const char* function_name,
161     size_t num_options,
162     ...);
163 
164 bool THPUtils_checkIntTuple(PyObject* arg);
165 std::vector<int> THPUtils_unpackIntTuple(PyObject* arg);
166 
167 TORCH_PYTHON_API void THPUtils_addPyMethodDefs(
168     std::vector<PyMethodDef>& vector,
169     PyMethodDef* methods);
170 
171 int THPUtils_getCallable(PyObject* arg, PyObject** result);
172 
173 typedef THPPointer<THPGenerator> THPGeneratorPtr;
174 typedef class THPPointer<THPStorage> THPStoragePtr;
175 
176 TORCH_PYTHON_API std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg);
177 PyObject* THPUtils_dispatchStateless(
178     PyObject* tensor,
179     const char* name,
180     PyObject* args,
181     PyObject* kwargs);
182 
183 template <typename _real, typename = void>
184 struct mod_traits {};
185 
186 template <typename _real>
187 struct mod_traits<_real, std::enable_if_t<std::is_floating_point_v<_real>>> {
188   static _real mod(_real a, _real b) {
189     return fmod(a, b);
190   }
191 };
192 
193 template <typename _real>
194 struct mod_traits<_real, std::enable_if_t<std::is_integral_v<_real>>> {
195   static _real mod(_real a, _real b) {
196     return a % b;
197   }
198 };
199 
200 void setBackCompatBroadcastWarn(bool warn);
201 bool getBackCompatBroadcastWarn();
202 
203 void setBackCompatKeepdimWarn(bool warn);
204 bool getBackCompatKeepdimWarn();
205 bool maybeThrowBackCompatKeepdimWarn(char* func);
206 
207 // NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason
208 #ifdef USE_CUDA
209 std::vector<std::optional<at::cuda::CUDAStream>>
210 THPUtils_PySequence_to_CUDAStreamList(PyObject* obj);
211 #endif
212 
213 void storage_fill(const at::Storage& self, uint8_t value);
214 void storage_set(const at::Storage& self, ptrdiff_t idx, uint8_t value);
215 uint8_t storage_get(const at::Storage& self, ptrdiff_t idx);
216 
217 #endif
218