1 #pragma once 2 3 #include <torch/csrc/python_headers.h> 4 #include <torch/csrc/utils/pythoncapi_compat.h> 5 6 #include <ATen/core/Tensor.h> 7 #include <ATen/core/jit_type_base.h> 8 #include <c10/util/irange.h> 9 #include <pybind11/pybind11.h> 10 #include <pybind11/stl.h> 11 12 #include <torch/csrc/Device.h> 13 #include <torch/csrc/Dtype.h> 14 #include <torch/csrc/DynamicTypes.h> 15 #include <torch/csrc/Generator.h> 16 #include <torch/csrc/MemoryFormat.h> 17 #include <torch/csrc/Stream.h> 18 #include <torch/csrc/utils/tensor_memoryformats.h> 19 20 namespace py = pybind11; 21 22 // This makes intrusive_ptr to be available as a custom pybind11 holder type, 23 // see 24 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers 25 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr<T>, true); 26 27 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr<T>); 28 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr<T>, true); 29 30 namespace pybind11::detail { 31 32 // torch.Tensor <-> at::Tensor conversions (without unwrapping) 33 template <> 34 struct TORCH_PYTHON_API type_caster<at::Tensor> { 35 public: 36 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 37 PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor")); 38 39 bool load(handle src, bool); 40 41 static handle cast( 42 const at::Tensor& src, 43 return_value_policy /* policy */, 44 handle /* parent */); 45 }; 46 47 // torch._StorageBase <-> at::Storage 48 template <> 49 struct type_caster<at::Storage> { 50 public: 51 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 52 PYBIND11_TYPE_CASTER(at::Storage, _("torch.StorageBase")); 53 54 bool load(handle src, bool) { 55 PyObject* obj = src.ptr(); 56 if (torch::isStorage(obj)) { 57 value = torch::createStorage(obj); 58 return true; 59 } 60 return false; 61 } 62 63 static handle cast( 64 const at::Storage& src, 65 return_value_policy /* policy */, 66 handle /* parent */) { 67 return handle(torch::createPyObject(src)); 68 } 69 }; 70 71 template <> 72 struct type_caster<at::Generator> { 73 public: 74 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 75 PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator")); 76 77 bool load(handle src, bool) { 78 PyObject* obj = src.ptr(); 79 if (THPGenerator_Check(obj)) { 80 value = reinterpret_cast<THPGenerator*>(obj)->cdata; 81 return true; 82 } 83 return false; 84 } 85 86 static handle cast( 87 const at::Generator& src, 88 return_value_policy /* policy */, 89 handle /* parent */) { 90 return handle(THPGenerator_Wrap(src)); 91 } 92 }; 93 94 template <> 95 struct TORCH_PYTHON_API type_caster<at::IntArrayRef> { 96 public: 97 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 98 PYBIND11_TYPE_CASTER(at::IntArrayRef, _("Tuple[int, ...]")); 99 100 bool load(handle src, bool); 101 static handle cast( 102 at::IntArrayRef src, 103 return_value_policy /* policy */, 104 handle /* parent */); 105 106 private: 107 std::vector<int64_t> v_value; 108 }; 109 110 template <> 111 struct TORCH_PYTHON_API type_caster<at::SymIntArrayRef> { 112 public: 113 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 114 PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("List[int]")); 115 116 bool load(handle src, bool); 117 static handle cast( 118 at::SymIntArrayRef src, 119 return_value_policy /* policy */, 120 handle /* parent */); 121 122 private: 123 std::vector<c10::SymInt> v_value; 124 }; 125 126 template <> 127 struct TORCH_PYTHON_API type_caster<at::ArrayRef<c10::SymNode>> { 128 public: 129 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 130 PYBIND11_TYPE_CASTER(at::ArrayRef<c10::SymNode>, _("List[SymNode]")); 131 132 bool load(handle src, bool); 133 static handle cast( 134 at::ArrayRef<c10::SymNode> src, 135 return_value_policy /* policy */, 136 handle /* parent */); 137 138 private: 139 std::vector<c10::SymNode> v_value; 140 }; 141 142 template <> 143 struct type_caster<at::MemoryFormat> { 144 public: 145 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 146 PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format")); 147 148 bool load(handle src, bool) { 149 PyObject* obj = src.ptr(); 150 if (THPMemoryFormat_Check(obj)) { 151 value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format; 152 return true; 153 } 154 return false; 155 } 156 static handle cast( 157 at::MemoryFormat src, 158 return_value_policy /* policy */, 159 handle /* parent */) { 160 return handle(Py_NewRef(torch::utils::getTHPMemoryFormat(src))); 161 } 162 }; 163 164 template <> 165 struct type_caster<at::Device> { 166 public: 167 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 168 PYBIND11_TYPE_CASTER(at::Device, _("torch.device")); 169 170 // PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device 171 // cannot be default-initialized, we provide this constructor to explicitly 172 // initialize that field. The value doesn't matter as it will be overwritten 173 // after a successful call to load. 174 type_caster() : value(c10::kCPU) {} 175 176 bool load(handle src, bool) { 177 PyObject* obj = src.ptr(); 178 if (THPDevice_Check(obj)) { 179 value = reinterpret_cast<THPDevice*>(obj)->device; 180 return true; 181 } 182 return false; 183 } 184 185 static handle cast( 186 const at::Device& src, 187 return_value_policy /* policy */, 188 handle /* parent */) { 189 return handle(THPDevice_New(src)); 190 } 191 }; 192 193 template <> 194 struct type_caster<at::ScalarType> { 195 public: 196 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 197 PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); 198 199 // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType 200 // cannot be default-initialized, we provide this constructor to explicitly 201 // initialize that field. The value doesn't matter as it will be overwritten 202 // after a successful call to load. 203 type_caster() : value(at::kFloat) {} 204 205 bool load(handle src, bool) { 206 PyObject* obj = src.ptr(); 207 if (THPDtype_Check(obj)) { 208 value = reinterpret_cast<THPDtype*>(obj)->scalar_type; 209 return true; 210 } 211 return false; 212 } 213 214 static handle cast( 215 const at::ScalarType& src, 216 return_value_policy /* policy */, 217 handle /* parent */) { 218 return Py_NewRef(torch::getTHPDtype(src)); 219 } 220 }; 221 222 template <> 223 struct type_caster<c10::Stream> { 224 public: 225 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 226 PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream")); 227 228 // PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream 229 // cannot be default-initialized, we provide this constructor to explicitly 230 // initialize that field. The value doesn't matter as it will be overwritten 231 // after a successful call to load. 232 type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {} 233 234 bool load(handle src, bool) { 235 PyObject* obj = src.ptr(); 236 if (THPStream_Check(obj)) { 237 value = c10::Stream::unpack3( 238 ((THPStream*)obj)->stream_id, 239 static_cast<c10::DeviceIndex>(((THPStream*)obj)->device_index), 240 static_cast<c10::DeviceType>(((THPStream*)obj)->device_type)); 241 return true; 242 } 243 return false; 244 } 245 246 static handle cast( 247 const c10::Stream& src, 248 return_value_policy /* policy */, 249 handle /* parent */) { 250 return handle(THPStream_Wrap(src)); 251 } 252 }; 253 254 template <> 255 struct type_caster<c10::DispatchKey> 256 : public type_caster_base<c10::DispatchKey> { 257 using base = type_caster_base<c10::DispatchKey>; 258 c10::DispatchKey tmp{}; 259 260 public: 261 bool load(handle src, bool convert) { 262 if (base::load(src, convert)) { 263 return true; 264 } else if (py::isinstance( 265 src, py::module_::import("builtins").attr("str"))) { 266 tmp = c10::parseDispatchKey(py::cast<std::string>(src)); 267 value = &tmp; 268 return true; 269 } 270 return false; 271 } 272 273 static handle cast( 274 c10::DispatchKey src, 275 return_value_policy policy, 276 handle parent) { 277 return base::cast(src, policy, parent); 278 } 279 }; 280 281 template <> 282 struct TORCH_PYTHON_API type_caster<c10::Scalar> { 283 public: 284 PYBIND11_TYPE_CASTER( 285 c10::Scalar, 286 _("Union[Number, torch.SymInt, torch.SymFloat, torch.SymBool]")); 287 bool load(py::handle src, bool); 288 289 static py::handle cast( 290 const c10::Scalar& si, 291 return_value_policy /* policy */, 292 handle /* parent */); 293 }; 294 295 template <> 296 struct TORCH_PYTHON_API type_caster<c10::SymInt> { 297 public: 298 PYBIND11_TYPE_CASTER(c10::SymInt, _("Union[int, torch.SymInt]")); 299 bool load(py::handle src, bool); 300 301 static py::handle cast( 302 const c10::SymInt& si, 303 return_value_policy /* policy */, 304 handle /* parent */); 305 }; 306 307 template <> 308 struct TORCH_PYTHON_API type_caster<c10::SymFloat> { 309 public: 310 PYBIND11_TYPE_CASTER(c10::SymFloat, _("float")); 311 bool load(py::handle src, bool); 312 313 static py::handle cast( 314 const c10::SymFloat& si, 315 return_value_policy /* policy */, 316 handle /* parent */); 317 }; 318 319 template <> 320 struct TORCH_PYTHON_API type_caster<c10::SymBool> { 321 public: 322 PYBIND11_TYPE_CASTER(c10::SymBool, _("Union[bool, torch.SymBool]")); 323 bool load(py::handle src, bool); 324 325 static py::handle cast( 326 const c10::SymBool& si, 327 return_value_policy /* policy */, 328 handle /* parent */); 329 }; 330 331 template <typename T> 332 struct type_caster<c10::complex<T>> { 333 public: 334 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 335 PYBIND11_TYPE_CASTER(c10::complex<T>, _("complex")); 336 337 bool load(handle src, bool) { 338 PyObject* obj = src.ptr(); 339 340 // Refered from `THPUtils_unpackComplexDouble` 341 Py_complex py_complex = PyComplex_AsCComplex(obj); 342 if (py_complex.real == -1.0 && PyErr_Occurred()) { 343 return false; 344 } 345 346 // Python's Complex is always double precision. 347 value = c10::complex<double>(py_complex.real, py_complex.imag); 348 return true; 349 } 350 351 static handle cast( 352 const c10::complex<T>& complex, 353 return_value_policy /* policy */, 354 handle /* parent */) { 355 // Python only knows double precision complex. 356 return handle(PyComplex_FromDoubles(complex.real(), complex.imag())); 357 } 358 }; 359 360 } // namespace pybind11::detail 361 362 namespace torch::impl { 363 364 // Use this function if you have a C++ object that is used from both C++ 365 // and Python contexts, and you need its GIL to be released when you 366 // destruct it in the Python context. 367 // 368 // This function is a valid shared_ptr destructor and can be used to 369 // conveniently allocate a shared_ptr to an object whose destructor will be run 370 // without the GIL. Pass it as the second argument to shared_ptr, e.g., 371 // 372 // shared_ptr<T>(new T(), destroy_without_gil<T>) 373 // 374 // Attaching the GIL release logic to the holder pointer rather than the 375 // actual destructor of T is helpful when T is Python-agnostic and 376 // shouldn't refer to the PYthon API. 377 // 378 // Note there are limitations to the correctness of code that makes use of this. 379 // In particular, if a shared_ptr is constructed from C++ code without this 380 // destructor and then passed to pybind11, pybind11 will happily take ownership 381 // of the shared_ptr (and be willing to destruct it from a context where it is 382 // holding the GIL). unique_ptr with a type branded deleter is less prone to 383 // this problem, because a stock deleter unique_ptr is not convertible with it. 384 // I plan to mitigate this problem by adding DEBUG-only asserts to the true C++ 385 // destructors that the GIL is not held (using a virtual call to get to the 386 // Python interpreter); alternately, we could use a virtual call to simply 387 // ensure we release the GIL in the C++ destructor, however, this is a layering 388 // violation (why does code that is ostensibly Python agnostic calling into the 389 // GIL). 390 // 391 // Adapted from 392 // https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510 393 template <typename T> 394 inline void destroy_without_gil(T* ptr) { 395 // Because the ownership of a shared_ptr is diffuse, it's not possible to 396 // necessarily predict whether or not the last reference to an object will 397 // be destructed from Python or C++. This means that in the destructor here, 398 // we don't necessarily know if we actually have the GIL or not; in fact, 399 // we don't even know if the Python interpreter still exists! Thus, we have 400 // to test for it before releasing the GIL. 401 // 402 // PyGILState_Check is hopefully self explanatory. But Py_IsInitialized or 403 // _PyIsFinalizing? Both get set at the same time during the Python 404 // destruction process: 405 // https://github.com/python/cpython/blob/d92513390a1a0da781bb08c284136f4d7abea36d/Python/pylifecycle.c#L1716-L1717 406 // so the operant question is whether or not you want to release the GIL after 407 // finalization has completed (and there is just no Python interpreter). 408 // Clearly there is no need to release GIL in that state, so we want 409 // Py_IsInitialized. 410 if (Py_IsInitialized() && PyGILState_Check()) { 411 pybind11::gil_scoped_release nogil; 412 delete ptr; 413 } else { 414 delete ptr; 415 } 416 } 417 418 } // namespace torch::impl 419