1 #include <ATen/core/PythonFallbackKernel.h>
2 #include <ATen/core/PythonOpRegistrationTrampoline.h>
3 #include <torch/csrc/PyInterpreter.h>
4 #include <torch/csrc/THP.h>
5 #include <torch/csrc/autograd/generated/VariableType.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_dispatch.h>
8
9 #include <string>
10
11 using namespace torch;
12 using namespace at;
13 using namespace c10;
14
15 namespace torch::detail {
16
17 namespace {
18
19 // NB: This is a macro and not a template function (like it was before)
20 // because passing in constexpr char* as template argument breaks some
21 // versions of MSVC that are being used internally at Meta.
22 // MSVC 14.16.27023 (vs2017_15.9)
23 #define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
24 at::impl::MaybeSetTLSOnEntryGuard guard; \
25 if (Py_IsInitialized()) { \
26 pybind11::gil_scoped_acquire gil; \
27 try { \
28 /* Masquerade hip as cuda because hip uses `torch.cuda` module. */ \
29 if (device_type == at::kHIP) { \
30 device_type = at::kCUDA; \
31 } \
32 std::string module_name = "torch." + DeviceTypeName(device_type, true); \
33 py::module mod = py::module::import(module_name.c_str()); \
34 py::object hook = \
35 mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks"); \
36 hook(__VA_ARGS__); \
37 } catch (const std::exception& e) { \
38 LOG(ERROR) << device_type \
39 << " trace hook execution failed: " << e.what(); \
40 } \
41 }
42
43 struct ConcretePyInterpreterVTable final
44 : public c10::impl::PyInterpreterVTable {
45 std::string name() const override;
46
47 void incref(PyObject* pyobj) const override;
48 void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
49
50 // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
51 // operate upon a PyObjectSlot rather than a TensorImpl
52 c10::intrusive_ptr<c10::TensorImpl> detach(
53 const c10::TensorImpl* self) const override;
54
55 void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
56 const override;
57 void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
58 void python_dispatcher(
59 const c10::OperatorHandle& op,
60 c10::DispatchKeySet,
61 torch::jit::Stack* stack) const override;
62 // NB: this is defined in python_dispatch.cpp
python_op_registration_trampolinetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable63 void python_op_registration_trampoline(
64 const c10::OperatorHandle& op,
65 c10::DispatchKey key,
66 c10::DispatchKeySet keyset,
67 torch::jit::Stack* stack,
68 bool with_keyset,
69 bool with_op) const override {
70 torch::impl::dispatch::python_op_registration_trampoline_impl(
71 op, key, keyset, stack, with_keyset, with_op);
72 }
throw_abstract_impl_not_imported_errortorch::detail::__anon56d922760111::ConcretePyInterpreterVTable73 void throw_abstract_impl_not_imported_error(
74 std::string opname,
75 const char* pymodule,
76 const char* context) const override {
77 py::gil_scoped_acquire gil;
78 pybind11::module::import("torch._utils_internal")
79 .attr("throw_abstract_impl_not_imported_error")(
80 opname, pymodule, context);
81 }
82
83 bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
84 const override;
85 bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
86 const override;
87 bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
88 c10::Device device(const c10::TensorImpl* self) const override;
89 int64_t dim(const c10::TensorImpl* self) const override;
90 c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
91 c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
92 c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
93 c10::Layout layout(const c10::TensorImpl* self) const override;
94 int64_t numel(const c10::TensorImpl* self) const override;
95 c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
96 c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
97 c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
98
trace_gpu_event_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable99 void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
100 const override {
101 CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
102 }
trace_gpu_event_deletiontorch::detail::__anon56d922760111::ConcretePyInterpreterVTable103 void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
104 const override {
105 CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
106 }
trace_gpu_event_recordtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable107 void trace_gpu_event_record(
108 at::DeviceType device_type,
109 uintptr_t event,
110 uintptr_t stream) const override {
111 CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
112 }
trace_gpu_event_waittorch::detail::__anon56d922760111::ConcretePyInterpreterVTable113 void trace_gpu_event_wait(
114 at::DeviceType device_type,
115 uintptr_t event,
116 uintptr_t stream) const override {
117 CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
118 }
trace_gpu_memory_allocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable119 void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
120 const override {
121 CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
122 }
trace_gpu_memory_deallocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable123 void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
124 const override {
125 CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
126 }
trace_gpu_stream_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable127 void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
128 const override {
129 CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
130 }
trace_gpu_device_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable131 void trace_gpu_device_synchronization(
132 at::DeviceType device_type) const override {
133 CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
134 }
trace_gpu_stream_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable135 void trace_gpu_stream_synchronization(
136 at::DeviceType device_type,
137 uintptr_t stream) const override {
138 CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
139 }
trace_gpu_event_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable140 void trace_gpu_event_synchronization(
141 at::DeviceType device_type,
142 uintptr_t event) const override {
143 CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
144 }
145
146 void reset_backward_hooks(const c10::TensorImpl* self) const override;
147
instancetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable148 static ConcretePyInterpreterVTable* instance() {
149 static ConcretePyInterpreterVTable s;
150 return &s;
151 }
152 };
153
154 class PyInterpreterHolder {
155 public:
PyInterpreterHolder()156 PyInterpreterHolder()
157 : impl_(new c10::impl::PyInterpreter(
158 ConcretePyInterpreterVTable::instance())),
159 is_main_interpreter_(
160 at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
161 impl_)) {}
162 // NB: intentionally leaks the PyInterpreter, as there may still be
163 // references to it that are live, living in objects that aren't being
164 // destructed while Python is being cleaned up.
~PyInterpreterHolder()165 ~PyInterpreterHolder() {
166 impl_->disarm();
167 }
get() const168 c10::impl::PyInterpreter* get() const noexcept {
169 return impl_;
170 }
is_main_interpreter() const171 bool is_main_interpreter() const noexcept {
172 return is_main_interpreter_;
173 }
174
175 private:
176 c10::impl::PyInterpreter* impl_;
177 bool is_main_interpreter_;
178 };
179
torchDispatchFromTensorImpl(const c10::TensorImpl * self,const char * func_name,PyObject * torch_api_function,const char * module_name,c10::SmallVector<py::object,1> extra_args={})180 py::object torchDispatchFromTensorImpl(
181 const c10::TensorImpl* self,
182 const char* func_name,
183 PyObject* torch_api_function,
184 const char* module_name,
185 // WARNING: MUST NOT BE TENSOR ARGS
186 c10::SmallVector<py::object, 1> extra_args = {}) {
187 if (torch_api_function == nullptr) {
188 throw python_error();
189 }
190 TORCH_CHECK(
191 PyGILState_Check(),
192 "GIL must be held before you call parseIValuesToPyArgsKwargs");
193
194 std::vector<PyObject*> overloaded_args;
195 // TODO: there should be a shorter way to spell this
196 // TODO: fix the constness of target
197 at::Tensor self_t = at::Tensor(
198 c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
199 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
200 unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
201 auto self_p =
202 py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
203 // NB: this may not be a python tensor if you got here from a mode!
204 // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
205 append_overloaded_tensor(&overloaded_args, self_p.ptr());
206 auto args = py::reinterpret_steal<py::object>(
207 PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size())));
208 PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
209 int64_t i = 1;
210 for (auto& a : extra_args) {
211 if (a.ptr() == nullptr)
212 throw python_error();
213 PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
214 i++;
215 }
216
217 py::dict kwargs;
218
219 return py::reinterpret_steal<py::object>(
220 handle_torch_function_no_python_arg_parser(
221 overloaded_args,
222 args.ptr(),
223 kwargs.ptr(),
224 func_name,
225 torch_api_function,
226 module_name,
227 TorchFunctionName::TorchDispatch));
228 }
229
230 // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
231 // Before calling PyInterpreter::decref, we must statically know if the
232 // pyobj has a PyObjectSlot or not.
233 // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
234 // - If it does not have a PyObjectSlot, we can freely decref
235 // One alternative to this is using PyObject_IsInstance
236 // to get at this information. However, we don't want to risk an incorrect
237 // `__instancecheck__` changing the semantics here.
decref(PyObject * pyobj,bool has_pyobj_slot) const238 void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
239 const {
240 // Leak the pyobj if not initialized. This can happen if we are running
241 // exit handlers that are destructing tensors with residual (owned)
242 // PyObjects stored in them.
243 if (!Py_IsInitialized())
244 return;
245
246 pybind11::gil_scoped_acquire gil;
247 // Two possibilities:
248 // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
249 // Storage. Then we must be careful about PyObject resurrection (see
250 // THPVariable_clear).
251 // 2. We are decref-ing some other Python object. We don't do
252 // PyObject resurrection on non-Tensors, so we just carry on as usual
253 if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
254 if (THPVariable_Check(pyobj)) {
255 // It's still alive! This can happen if a weak ref resurrected
256 // the PyObject without flipping ownership. At this point it is
257 // too late to rescue the object, so just stub out the PyObject
258 // so that it fails on subsequent uses. Don't raise an error here;
259 // you're probably in a destructor.
260 TORCH_WARN(
261 "Deallocating Tensor that still has live PyObject references. "
262 "This probably happened because you took out a weak reference to "
263 "Tensor and didn't call _fix_weakref() after dereferencing it. "
264 "Subsequent accesses to this tensor via the PyObject will now fail.");
265 ((THPVariable*)pyobj)->cdata =
266 c10::MaybeOwned<torch::autograd::Variable>();
267 } else if (THPStorage_Check(pyobj)) {
268 TORCH_WARN(
269 "Deallocating UntypedStorage that still has live PyObject references. "
270 "This probably happened because you took out a weak reference to "
271 "UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
272 "Subsequent accesses to this storage via the PyObject will now fail.");
273 ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
274 }
275 }
276 Py_DECREF(pyobj);
277 };
278
incref(PyObject * pyobj) const279 void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
280 if (!Py_IsInitialized())
281 return;
282 pybind11::gil_scoped_acquire gil;
283 Py_INCREF(pyobj);
284 };
285
isPythonTensor(const at::Tensor & tensor)286 bool isPythonTensor(const at::Tensor& tensor) {
287 return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
288 }
289
reportErrorCallback(PyObject * callback,DispatchKey key) const290 void ConcretePyInterpreterVTable::reportErrorCallback(
291 PyObject* callback,
292 DispatchKey key) const {
293 py::gil_scoped_acquire g;
294 auto func = py::reinterpret_borrow<py::object>(callback);
295 // Not all DispatchKeys are pybind'ed into Python and we do not have infra
296 // to ensure this, so just pass a string back to Python.
297 func(c10::toString(key));
298 }
299
dispatch(const c10::OperatorHandle & op,torch::jit::Stack * stack) const300 void ConcretePyInterpreterVTable::dispatch(
301 const c10::OperatorHandle& op,
302 torch::jit::Stack* stack) const {
303 const auto& schema = op.schema();
304 const auto num_arguments = schema.arguments().size();
305 auto arguments = torch::jit::pop(*stack, num_arguments);
306
307 // The plan: convert all the arguments back into PyObjects,
308 // extracting out the tensor handles, then call
309 // handle_torch_function_no_python_arg_parser
310 // NB: at the point arguments are pushed to the stack, ALL defaults
311 // are already present
312
313 py::gil_scoped_acquire g;
314
315 std::vector<PyObject*> overloaded_args;
316 py::handle torch_api_function_overload = getTorchApiFunction(op);
317
318 // Find overloaded tensors
319 for (const auto idx : c10::irange(arguments.size())) {
320 const auto& ivalue = arguments[idx];
321 if (ivalue.isTensor()) {
322 const auto& tensor = ivalue.toTensor();
323 if (isPythonTensor(tensor)) {
324 append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
325 }
326 } else if (ivalue.isList()) {
327 const auto& list = ivalue.toListRef();
328 for (const auto jdx : c10::irange(list.size())) {
329 const auto& nv = list[jdx];
330 if (nv.isTensor()) {
331 const auto& tensor = nv.toTensor();
332 if (isPythonTensor(tensor)) {
333 append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
334 }
335 }
336 }
337 }
338 }
339
340 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
341 auto args = std::move(args_kwargs.first);
342 auto kwargs = std::move(args_kwargs.second);
343
344 PyObject* obj = handle_torch_function_no_python_arg_parser(
345 overloaded_args,
346 args.ptr(),
347 kwargs.ptr(),
348 nullptr,
349 torch_api_function_overload.ptr(),
350 nullptr,
351 TorchFunctionName::TorchDispatch);
352 pushPyOutToStack(
353 op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
354 }
355
python_dispatcher(const c10::OperatorHandle & op,c10::DispatchKeySet ks,torch::jit::Stack * stack) const356 void ConcretePyInterpreterVTable::python_dispatcher(
357 const c10::OperatorHandle& op,
358 c10::DispatchKeySet ks,
359 torch::jit::Stack* stack) const {
360 py::gil_scoped_acquire g;
361 py::handle torch_api_function_overload = getTorchApiFunction(op);
362 // TODO: if necessary, can optimize to cache the cache lookup
363 // TODO: if necessary, can optimize OpOverload to have slots
364 auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
365 if (cache.ptr() == nullptr) {
366 throw python_error();
367 }
368
369 c10::DispatchKey k = ks.highestPriorityTypeId();
370 // TODO: allow this to be non-owning
371 auto handler = py::reinterpret_borrow<py::object>(
372 PyDict_GetItem(cache.ptr(), py::cast(k).ptr()));
373 if (handler.ptr() == nullptr) {
374 // Slow path
375 handler = torch_api_function_overload.attr("_get_dispatch")(k);
376 }
377 if (py::isinstance<c10::DispatchKey>(handler)) {
378 // NB: not redispatch, as that will permanently remove the python
379 // dispatcher for subsequent redispatches
380 op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
381 return;
382 }
383
384 const auto& schema = op.schema();
385 const auto num_arguments = schema.arguments().size();
386 auto arguments = torch::jit::pop(*stack, num_arguments);
387
388 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
389 auto args = std::move(args_kwargs.first);
390 auto kwargs = std::move(args_kwargs.second);
391
392 py::object obj = py::reinterpret_steal<py::object>(
393 PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
394
395 if (obj.ptr() == nullptr) {
396 throw python_error();
397 }
398
399 pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
400 }
401
detach(const c10::TensorImpl * self) const402 c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
403 const c10::TensorImpl* self) const {
404 pybind11::gil_scoped_acquire gil;
405 at::impl::MaybeSetTLSOnEntryGuard guard;
406
407 auto out = torchDispatchFromTensorImpl(
408 self,
409 "detach",
410 py::module::import("torch")
411 .attr("ops")
412 .attr("aten")
413 .attr("detach")
414 .attr("default")
415 .ptr(),
416 "torch.ops.aten");
417
418 TORCH_CHECK(
419 THPVariable_Check(out.ptr()),
420 "detach returned invalid type ",
421 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
422 ", expected Tensor");
423 const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
424 return res_t.getIntrusivePtr();
425 }
426
is_contiguous(const c10::TensorImpl * self,at::MemoryFormat memory_format) const427 bool ConcretePyInterpreterVTable::is_contiguous(
428 const c10::TensorImpl* self,
429 at::MemoryFormat memory_format) const {
430 pybind11::gil_scoped_acquire gil;
431 at::impl::MaybeSetTLSOnEntryGuard guard;
432
433 py::object out;
434 if (memory_format == at::MemoryFormat::Contiguous) {
435 // For backwards compatibility
436 out = torchDispatchFromTensorImpl(
437 self,
438 "is_contiguous",
439 py::module::import("torch")
440 .attr("ops")
441 .attr("aten")
442 .attr("is_contiguous")
443 .attr("default")
444 .ptr(),
445 "torch.ops.aten");
446 } else {
447 out = torchDispatchFromTensorImpl(
448 self,
449 "is_contiguous",
450 py::module::import("torch")
451 .attr("ops")
452 .attr("aten")
453 .attr("is_contiguous")
454 .attr("memory_format")
455 .ptr(),
456 "torch.ops.aten",
457 {py::cast(memory_format)});
458 }
459
460 if (out.is_none()) {
461 return self->is_contiguous_default(memory_format);
462 }
463
464 TORCH_CHECK(
465 PyBool_Check(out.ptr()),
466 "is_contiguous returned invalid type ",
467 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
468 ", expected bool");
469
470 return PyObject_IsTrue(out.ptr());
471 }
472
is_strides_like(const c10::TensorImpl * self,at::MemoryFormat memory_format) const473 bool ConcretePyInterpreterVTable::is_strides_like(
474 const c10::TensorImpl* self,
475 at::MemoryFormat memory_format) const {
476 pybind11::gil_scoped_acquire gil;
477 at::impl::MaybeSetTLSOnEntryGuard guard;
478
479 auto out = torchDispatchFromTensorImpl(
480 self,
481 "is_strides_like",
482 py::module::import("torch")
483 .attr("ops")
484 .attr("aten")
485 // NB: intentionally suffixed with _format to avoid
486 // triggering matches against "_like" suffix
487 .attr("is_strides_like_format")
488 .attr("default")
489 .ptr(),
490 "torch.ops.aten",
491 {py::cast(memory_format)});
492
493 if (out.is_none()) {
494 return self->is_strides_like_default(memory_format);
495 }
496
497 TORCH_CHECK(
498 PyBool_Check(out.ptr()),
499 "is_strides_like_format returned invalid type ",
500 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
501 ", expected bool");
502
503 return PyObject_IsTrue(out.ptr());
504 }
505
is_non_overlapping_and_dense(const c10::TensorImpl * self) const506 bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
507 const c10::TensorImpl* self) const {
508 pybind11::gil_scoped_acquire gil;
509 at::impl::MaybeSetTLSOnEntryGuard guard;
510
511 auto out = torchDispatchFromTensorImpl(
512 self,
513 "is_non_overlapping_and_dense",
514 py::module::import("torch")
515 .attr("ops")
516 .attr("aten")
517 .attr("is_non_overlapping_and_dense")
518 .attr("default")
519 .ptr(),
520 "torch.ops.aten");
521
522 if (out.is_none()) {
523 return self->is_non_overlapping_and_dense_default();
524 }
525
526 TORCH_CHECK(
527 PyBool_Check(out.ptr()),
528 "is_non_overlapping_and_dense returned invalid type ",
529 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
530 ", expected bool");
531
532 return PyObject_IsTrue(out.ptr());
533 }
534
dim(const c10::TensorImpl * self) const535 int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
536 pybind11::gil_scoped_acquire gil;
537 at::impl::MaybeSetTLSOnEntryGuard guard;
538
539 auto out = torchDispatchFromTensorImpl(
540 self,
541 "dim",
542 py::module::import("torch")
543 .attr("ops")
544 .attr("aten")
545 .attr("dim")
546 .attr("default")
547 .ptr(),
548 "torch.ops.aten");
549
550 TORCH_CHECK(
551 PyLong_Check(out.ptr()),
552 "dim returned invalid type ",
553 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
554 ", expected int");
555
556 return THPUtils_unpackLong(out.ptr());
557 }
558
device(const c10::TensorImpl * self) const559 c10::Device ConcretePyInterpreterVTable::device(
560 const c10::TensorImpl* self) const {
561 pybind11::gil_scoped_acquire gil;
562 at::impl::MaybeSetTLSOnEntryGuard guard;
563
564 auto out = torchDispatchFromTensorImpl(
565 self,
566 "device",
567 py::module::import("torch")
568 .attr("ops")
569 .attr("prim")
570 .attr("device")
571 .attr("default")
572 .ptr(),
573 "torch.ops.prim");
574
575 return toDevice(out.ptr());
576 }
577
set_tensor_attr_with_capsule(const c10::TensorImpl * tensor,py::capsule & capsule,const char * attr_name)578 static void set_tensor_attr_with_capsule(
579 const c10::TensorImpl* tensor,
580 py::capsule& capsule,
581 const char* attr_name) {
582 std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
583 getPyInterpreter(), /*ignore_hermetic_tls=*/false);
584 TORCH_CHECK(
585 mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
586 auto obj = mb_obj.value();
587 py::handle(obj).attr(attr_name) = capsule;
588 }
589
590 // Note [Tensor Subclass custom size/stride caching strategy]
591 // Tensor subclasses can use __torch_dispatch__ to override size/stride calls.
592 // However, this presents a problem:
593 // (1) When you return a custom (maybe symbolic) size/stride
594 // from python, we need to stash this fresh vector of ints/symints
595 // somewhere so that it has the same lifetime as the tensor.
596 // (2) If the subclass experiences a metadata mutation,
597 // this stashed vector is no longer valid, so we need to allocate a fresh
598 // buffer to store the new sizes the next time someone asks for them.
599 //
600 // We handle this in the same way that `TensorImpl::sizes_default()`
601 // handles its buffer: we simply reallocate the buffer whenever
602 // the number of dimensions changes due to a resize.
603 // Notable, we do *not* reallocate the buffer if the values changed,
604 // but the number of dimensions stayed the same (e.g. `.transpose_()`).
605 template <typename T>
get_set_cached_attr(const c10::TensorImpl * tensor,const char * base_attr_name,const py::object & obj)606 static c10::ArrayRef<T> get_set_cached_attr(
607 const c10::TensorImpl* tensor,
608 const char* base_attr_name,
609 const py::object& obj) {
610 std::optional<PyObject*> mb_obj =
611 tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
612 TORCH_CHECK(
613 mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
614 auto tensor_obj = mb_obj.value();
615 auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
616
617 bool is_buffer_allocated = false;
618 size_t curr_size = 0;
619 if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) {
620 auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str());
621 curr_size = py::cast<size_t>(len_pyobj);
622 is_buffer_allocated = true;
623 }
624
625 size_t new_size = py::len(obj);
626
627 // We do the smallvector optimization here: any time the new_size is <=5,
628 // we always allocate our buffer to size 5, so that if the next resize
629 // is also to <=5 elements, we don't need to reallocate.
630 // Note: I tried removing this optimization and tripped ASAN
631 // in a batchnorm kernel here:
632 // https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D
633 // We should fix this instead.
634 bool needs_resize = false;
635 // We need to resize if:
636 // (1) we haven't allocated our buffer at all yet
637 // (2) Our buffer size is different from the new size
638 // (note: we use the small vector optimization, where our buffer
639 // is always allocated to at least size 5, and any resizes
640 // within the <= 5 regime to not require a reallocation).
641 auto is_smallvector = curr_size <= 5;
642 needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) ||
643 (!is_smallvector && curr_size != new_size);
644 if (needs_resize) {
645 // If our current buffer is not the right size (either because we haven't
646 // allocated it yet, or there was a metadata mutation that changed the
647 // number of dims of the tensor), allocate a fresh buffer. Note that this
648 // will trash the previous buffer if there already was one, invalidating any
649 // existing SymIntArrayRef's from an old .sym_size() call.
650 auto new_buffer_size = new_size;
651 if (new_size <= 5) {
652 // This is the smallvector optimization
653 new_buffer_size = 5;
654 }
655 T* ptr = new T[new_buffer_size];
656 auto capsule =
657 py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); });
658 int64_t idx = 0;
659 for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
660 ptr[idx] = py::cast<T>(*it);
661 }
662 // Set the buffer
663 set_tensor_attr_with_capsule(tensor, capsule, base_attr_name);
664 // Set the len buffer
665 py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size;
666 } else {
667 TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name));
668 auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
669 void* buffer_pycapsule =
670 PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
671 auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
672
673 // Overwrite the buffer with our new values, but only if any of them changed
674 // (due to a metadata mutation).
675 // This is technically not thread safe, because the update happens lazily.
676 // The original metadata mutation call on the tensor might have been thread
677 // safe (e.g. a .resize_() call), but we won't actually mutate the size
678 // buffer until the first call to .sizes() which the user might not access
679 // in a thread-safe way. For now we are not explicitly locking, but maybe we
680 // should.
681 int64_t idx = 0;
682 // Quick sanity assert that our buffer size is large enough
683 // to compare against all the elements in the new buffer.
684 size_t curr_buffer_size = 5;
685 if (curr_buffer_size < curr_size) {
686 curr_buffer_size = curr_size;
687 }
688 TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size);
689 for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
690 auto actual_val = py::cast<T>(*it);
691 if constexpr (std::is_same_v<T, c10::SymInt>) {
692 // if our SymInts are symbolic, we are *not* doing an equality check on
693 // the symints. we just want to see if the nodes are the same. this is
694 // because we don't want to introduce any guards here.
695 if (!curr_buffer[idx].is_same(actual_val)) {
696 curr_buffer[idx] = actual_val;
697 }
698 } else {
699 if (curr_buffer[idx] != actual_val) {
700 curr_buffer[idx] = actual_val;
701 }
702 }
703 }
704 }
705
706 // The correct data is now stored at the buffer - read and return it.
707 auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
708 void* buffer_pycapsule =
709 PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
710 auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
711 return c10::ArrayRef<T>(curr_buffer, new_size);
712 }
713
strides(const c10::TensorImpl * self) const714 c10::IntArrayRef ConcretePyInterpreterVTable::strides(
715 const c10::TensorImpl* self) const {
716 pybind11::gil_scoped_acquire gil;
717 at::impl::MaybeSetTLSOnEntryGuard guard;
718
719 auto out = torchDispatchFromTensorImpl(
720 self,
721 "stride",
722 py::module::import("torch")
723 .attr("ops")
724 .attr("aten")
725 .attr("stride")
726 .attr("default")
727 .ptr(),
728 "torch.ops.aten");
729
730 if (out.is_none()) {
731 TORCH_CHECK(
732 !self->has_symbolic_sizes_strides(),
733 "Cannot call strides on a tensor with symbolic shapes/strides");
734 return self->strides_default();
735 }
736 TORCH_CHECK(
737 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
738 "strides must be a list or a tuple");
739 auto updated_strides =
740 get_set_cached_attr<int64_t>(self, "_strides_capsule", out);
741 return updated_strides;
742 }
743
sizes(const c10::TensorImpl * self) const744 c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
745 const c10::TensorImpl* self) const {
746 pybind11::gil_scoped_acquire gil;
747 at::impl::MaybeSetTLSOnEntryGuard guard;
748 HANDLE_TH_ERRORS
749 auto out = torchDispatchFromTensorImpl(
750 self,
751 "size",
752 py::module::import("torch")
753 .attr("ops")
754 .attr("aten")
755 .attr("size")
756 .attr("default")
757 .ptr(),
758 "torch.ops.aten");
759 if (out.is_none()) {
760 TORCH_CHECK(
761 !self->has_symbolic_sizes_strides(),
762 "Cannot call sizes on a tensor with symbolic shapes/strides");
763 return self->sizes_default();
764 }
765 TORCH_CHECK(
766 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
767 "sizes must be a list or a tuple");
768
769 auto updated_sizes =
770 get_set_cached_attr<int64_t>(self, "_sizes_capsule", out);
771 return updated_sizes;
772 END_HANDLE_TH_ERRORS_PYBIND
773 }
774
sym_sizes(const c10::TensorImpl * self) const775 c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
776 const c10::TensorImpl* self) const {
777 pybind11::gil_scoped_acquire gil;
778 at::impl::MaybeSetTLSOnEntryGuard guard;
779 HANDLE_TH_ERRORS
780 auto out = torchDispatchFromTensorImpl(
781 self,
782 "sym_size",
783 py::module::import("torch")
784 .attr("ops")
785 .attr("aten")
786 .attr("sym_size")
787 .attr("default")
788 .ptr(),
789 "torch.ops.aten");
790
791 if (out.is_none()) {
792 return self->sym_sizes_default();
793 }
794 TORCH_CHECK(
795 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
796 "sym_size must be a list or a tuple");
797
798 // See Note [Tensor Subclass custom size/stride caching strategy]
799 auto updated_sym_sizes =
800 get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out);
801 return updated_sym_sizes;
802 END_HANDLE_TH_ERRORS_PYBIND
803 }
804
layout(const c10::TensorImpl * self) const805 c10::Layout ConcretePyInterpreterVTable::layout(
806 const c10::TensorImpl* self) const {
807 pybind11::gil_scoped_acquire gil;
808 at::impl::MaybeSetTLSOnEntryGuard guard;
809 auto out = torchDispatchFromTensorImpl(
810 self,
811 "layout",
812 py::module::import("torch")
813 .attr("ops")
814 .attr("prim")
815 .attr("layout")
816 .attr("default")
817 .ptr(),
818 "torch.ops.prim");
819
820 TORCH_CHECK(
821 THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()),
822 "layout returned invalid type ",
823 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
824 ", expected Layout");
825
826 if (THPLayout_Check(out.ptr())) {
827 return toLayout(out.ptr());
828 } else {
829 return c10::Layout(py::cast<int64_t>(out));
830 }
831 }
832
numel(const c10::TensorImpl * self) const833 int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
834 pybind11::gil_scoped_acquire gil;
835 at::impl::MaybeSetTLSOnEntryGuard guard;
836 auto out = torchDispatchFromTensorImpl(
837 self,
838 "numel",
839 py::module::import("torch")
840 .attr("ops")
841 .attr("aten")
842 .attr("numel")
843 .attr("default")
844 .ptr(),
845 "torch.ops.aten");
846
847 if (out.is_none()) {
848 TORCH_CHECK(
849 !self->has_symbolic_sizes_strides(),
850 "Cannot call sizes on a tensor with symbolic shapes/strides");
851 return self->numel_default();
852 }
853 return py::cast<int64_t>(out);
854 }
855
sym_numel(const c10::TensorImpl * self) const856 c10::SymInt ConcretePyInterpreterVTable::sym_numel(
857 const c10::TensorImpl* self) const {
858 pybind11::gil_scoped_acquire gil;
859 at::impl::MaybeSetTLSOnEntryGuard guard;
860 auto out = torchDispatchFromTensorImpl(
861 self,
862 "sym_numel",
863 py::module::import("torch")
864 .attr("ops")
865 .attr("aten")
866 .attr("sym_numel")
867 .attr("default")
868 .ptr(),
869 "torch.ops.aten");
870
871 if (out.is_none()) {
872 return self->sym_numel_default();
873 }
874 return torch::is_symint(out) ? out.cast<c10::SymInt>()
875 : c10::SymInt{py::cast<int64_t>(out)};
876 }
877
sym_storage_offset(const c10::TensorImpl * self) const878 c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
879 const c10::TensorImpl* self) const {
880 pybind11::gil_scoped_acquire gil;
881 at::impl::MaybeSetTLSOnEntryGuard guard;
882 auto out = torchDispatchFromTensorImpl(
883 self,
884 "sym_storage_offset",
885 py::module::import("torch")
886 .attr("ops")
887 .attr("aten")
888 .attr("sym_storage_offset")
889 .attr("default")
890 .ptr(),
891 "torch.ops.aten");
892
893 if (out.is_none()) {
894 return self->sym_storage_offset_default();
895 }
896 return torch::is_symint(out) ? out.cast<c10::SymInt>()
897 : c10::SymInt{py::cast<int64_t>(out)};
898 }
899
sym_strides(const c10::TensorImpl * self) const900 c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
901 const c10::TensorImpl* self) const {
902 pybind11::gil_scoped_acquire gil;
903 at::impl::MaybeSetTLSOnEntryGuard guard;
904 HANDLE_TH_ERRORS
905 auto out = torchDispatchFromTensorImpl(
906 self,
907 "sym_stride",
908 py::module::import("torch")
909 .attr("ops")
910 .attr("aten")
911 .attr("sym_stride")
912 .attr("default")
913 .ptr(),
914 "torch.ops.aten");
915
916 if (out.is_none()) {
917 return self->sym_strides_default();
918 }
919 // We need to squeeze SymIntNodes and ints into `SymInts`
920 // since it's a format `sym_strides()` are stored in
921 TORCH_CHECK(
922 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
923 "sym_strides must be a list or a tuple");
924
925 auto updated_sym_strides =
926 get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out);
927 return updated_sym_strides;
928 END_HANDLE_TH_ERRORS_PYBIND
929 }
930
reset_backward_hooks(const c10::TensorImpl * self) const931 void ConcretePyInterpreterVTable::reset_backward_hooks(
932 const c10::TensorImpl* self) const {
933 pybind11::gil_scoped_acquire gil;
934 at::impl::MaybeSetTLSOnEntryGuard guard;
935 HANDLE_TH_ERRORS
936 Tensor self_t =
937 Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
938 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
939 unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
940 auto self_p =
941 py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
942 PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
943 END_HANDLE_TH_ERRORS_PYBIND
944 }
945
name() const946 std::string ConcretePyInterpreterVTable::name() const {
947 std::stringstream ss;
948 ss << getPyInterpreter();
949 return ss.str();
950 }
951
952 PyInterpreterHolder self_interpreter;
953
954 } // anonymous namespace
955
getTorchApiFunction(const c10::OperatorHandle & op)956 py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
957 return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
958 // Parse the name into namespace and name (no overload_name)
959 // TODO: put this into the library
960 const auto& schema = op.schema();
961 const auto& qualified_name = op.operator_name().name;
962 const auto& overload_name = schema.overload_name();
963 auto pos = qualified_name.find("::");
964 TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
965 // Make me some null terminated strings
966 std::string ns_str = qualified_name.substr(0, pos);
967 const char* ns = ns_str.c_str();
968 const char* func_name = qualified_name.c_str() + pos + strlen("::");
969
970 py::handle torch_api_function =
971 py::module::import("torch").attr("ops").attr(ns).attr(func_name);
972 if (overload_name.empty()) {
973 return torch_api_function.attr("default").ptr();
974 } else {
975 return torch_api_function.attr(overload_name.c_str()).ptr();
976 }
977 });
978 }
979
980 } // namespace torch::detail
981
getPyInterpreter()982 c10::impl::PyInterpreter* getPyInterpreter() {
983 return torch::detail::self_interpreter.get();
984 }
985
isMainPyInterpreter()986 bool isMainPyInterpreter() {
987 return torch::detail::self_interpreter.is_main_interpreter();
988 }
989