xref: /aosp_15_r20/external/pytorch/torch/csrc/PyInterpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/PythonFallbackKernel.h>
2 #include <ATen/core/PythonOpRegistrationTrampoline.h>
3 #include <torch/csrc/PyInterpreter.h>
4 #include <torch/csrc/THP.h>
5 #include <torch/csrc/autograd/generated/VariableType.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_dispatch.h>
8 
9 #include <string>
10 
11 using namespace torch;
12 using namespace at;
13 using namespace c10;
14 
15 namespace torch::detail {
16 
17 namespace {
18 
19 // NB: This is a macro and not a template function (like it was before)
20 // because passing in constexpr char* as template argument breaks some
21 // versions of MSVC that are being used internally at Meta.
22 // MSVC 14.16.27023 (vs2017_15.9)
23 #define CONCRETE_GPU_TRACE(device_type, func_name, ...)                       \
24   at::impl::MaybeSetTLSOnEntryGuard guard;                                    \
25   if (Py_IsInitialized()) {                                                   \
26     pybind11::gil_scoped_acquire gil;                                         \
27     try {                                                                     \
28       /* Masquerade hip as cuda because hip uses `torch.cuda` module. */      \
29       if (device_type == at::kHIP) {                                          \
30         device_type = at::kCUDA;                                              \
31       }                                                                       \
32       std::string module_name = "torch." + DeviceTypeName(device_type, true); \
33       py::module mod = py::module::import(module_name.c_str());               \
34       py::object hook =                                                       \
35           mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks");      \
36       hook(__VA_ARGS__);                                                      \
37     } catch (const std::exception& e) {                                       \
38       LOG(ERROR) << device_type                                               \
39                  << " trace hook execution failed: " << e.what();             \
40     }                                                                         \
41   }
42 
43 struct ConcretePyInterpreterVTable final
44     : public c10::impl::PyInterpreterVTable {
45   std::string name() const override;
46 
47   void incref(PyObject* pyobj) const override;
48   void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
49 
50   // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
51   // operate upon a PyObjectSlot rather than a TensorImpl
52   c10::intrusive_ptr<c10::TensorImpl> detach(
53       const c10::TensorImpl* self) const override;
54 
55   void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
56       const override;
57   void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
58   void python_dispatcher(
59       const c10::OperatorHandle& op,
60       c10::DispatchKeySet,
61       torch::jit::Stack* stack) const override;
62   // NB: this is defined in python_dispatch.cpp
python_op_registration_trampolinetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable63   void python_op_registration_trampoline(
64       const c10::OperatorHandle& op,
65       c10::DispatchKey key,
66       c10::DispatchKeySet keyset,
67       torch::jit::Stack* stack,
68       bool with_keyset,
69       bool with_op) const override {
70     torch::impl::dispatch::python_op_registration_trampoline_impl(
71         op, key, keyset, stack, with_keyset, with_op);
72   }
throw_abstract_impl_not_imported_errortorch::detail::__anon56d922760111::ConcretePyInterpreterVTable73   void throw_abstract_impl_not_imported_error(
74       std::string opname,
75       const char* pymodule,
76       const char* context) const override {
77     py::gil_scoped_acquire gil;
78     pybind11::module::import("torch._utils_internal")
79         .attr("throw_abstract_impl_not_imported_error")(
80             opname, pymodule, context);
81   }
82 
83   bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
84       const override;
85   bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
86       const override;
87   bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
88   c10::Device device(const c10::TensorImpl* self) const override;
89   int64_t dim(const c10::TensorImpl* self) const override;
90   c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
91   c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
92   c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
93   c10::Layout layout(const c10::TensorImpl* self) const override;
94   int64_t numel(const c10::TensorImpl* self) const override;
95   c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
96   c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
97   c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
98 
trace_gpu_event_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable99   void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
100       const override {
101     CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
102   }
trace_gpu_event_deletiontorch::detail::__anon56d922760111::ConcretePyInterpreterVTable103   void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
104       const override {
105     CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
106   }
trace_gpu_event_recordtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable107   void trace_gpu_event_record(
108       at::DeviceType device_type,
109       uintptr_t event,
110       uintptr_t stream) const override {
111     CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
112   }
trace_gpu_event_waittorch::detail::__anon56d922760111::ConcretePyInterpreterVTable113   void trace_gpu_event_wait(
114       at::DeviceType device_type,
115       uintptr_t event,
116       uintptr_t stream) const override {
117     CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
118   }
trace_gpu_memory_allocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable119   void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
120       const override {
121     CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
122   }
trace_gpu_memory_deallocationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable123   void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
124       const override {
125     CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
126   }
trace_gpu_stream_creationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable127   void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
128       const override {
129     CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
130   }
trace_gpu_device_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable131   void trace_gpu_device_synchronization(
132       at::DeviceType device_type) const override {
133     CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
134   }
trace_gpu_stream_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable135   void trace_gpu_stream_synchronization(
136       at::DeviceType device_type,
137       uintptr_t stream) const override {
138     CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
139   }
trace_gpu_event_synchronizationtorch::detail::__anon56d922760111::ConcretePyInterpreterVTable140   void trace_gpu_event_synchronization(
141       at::DeviceType device_type,
142       uintptr_t event) const override {
143     CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
144   }
145 
146   void reset_backward_hooks(const c10::TensorImpl* self) const override;
147 
instancetorch::detail::__anon56d922760111::ConcretePyInterpreterVTable148   static ConcretePyInterpreterVTable* instance() {
149     static ConcretePyInterpreterVTable s;
150     return &s;
151   }
152 };
153 
154 class PyInterpreterHolder {
155  public:
PyInterpreterHolder()156   PyInterpreterHolder()
157       : impl_(new c10::impl::PyInterpreter(
158             ConcretePyInterpreterVTable::instance())),
159         is_main_interpreter_(
160             at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
161                 impl_)) {}
162   // NB: intentionally leaks the PyInterpreter, as there may still be
163   // references to it that are live, living in objects that aren't being
164   // destructed while Python is being cleaned up.
~PyInterpreterHolder()165   ~PyInterpreterHolder() {
166     impl_->disarm();
167   }
get() const168   c10::impl::PyInterpreter* get() const noexcept {
169     return impl_;
170   }
is_main_interpreter() const171   bool is_main_interpreter() const noexcept {
172     return is_main_interpreter_;
173   }
174 
175  private:
176   c10::impl::PyInterpreter* impl_;
177   bool is_main_interpreter_;
178 };
179 
torchDispatchFromTensorImpl(const c10::TensorImpl * self,const char * func_name,PyObject * torch_api_function,const char * module_name,c10::SmallVector<py::object,1> extra_args={})180 py::object torchDispatchFromTensorImpl(
181     const c10::TensorImpl* self,
182     const char* func_name,
183     PyObject* torch_api_function,
184     const char* module_name,
185     // WARNING: MUST NOT BE TENSOR ARGS
186     c10::SmallVector<py::object, 1> extra_args = {}) {
187   if (torch_api_function == nullptr) {
188     throw python_error();
189   }
190   TORCH_CHECK(
191       PyGILState_Check(),
192       "GIL must be held before you call parseIValuesToPyArgsKwargs");
193 
194   std::vector<PyObject*> overloaded_args;
195   // TODO: there should be a shorter way to spell this
196   // TODO: fix the constness of target
197   at::Tensor self_t = at::Tensor(
198       c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
199           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
200       unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
201   auto self_p =
202       py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
203   // NB: this may not be a python tensor if you got here from a mode!
204   // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
205   append_overloaded_tensor(&overloaded_args, self_p.ptr());
206   auto args = py::reinterpret_steal<py::object>(
207       PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size())));
208   PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
209   int64_t i = 1;
210   for (auto& a : extra_args) {
211     if (a.ptr() == nullptr)
212       throw python_error();
213     PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
214     i++;
215   }
216 
217   py::dict kwargs;
218 
219   return py::reinterpret_steal<py::object>(
220       handle_torch_function_no_python_arg_parser(
221           overloaded_args,
222           args.ptr(),
223           kwargs.ptr(),
224           func_name,
225           torch_api_function,
226           module_name,
227           TorchFunctionName::TorchDispatch));
228 }
229 
230 // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
231 // Before calling PyInterpreter::decref, we must statically know if the
232 // pyobj has a PyObjectSlot or not.
233 // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
234 // - If it does not have a PyObjectSlot, we can freely decref
235 // One alternative to this is using PyObject_IsInstance
236 // to get at this information. However, we don't want to risk an incorrect
237 // `__instancecheck__` changing the semantics here.
decref(PyObject * pyobj,bool has_pyobj_slot) const238 void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
239     const {
240   // Leak the pyobj if not initialized.  This can happen if we are running
241   // exit handlers that are destructing tensors with residual (owned)
242   // PyObjects stored in them.
243   if (!Py_IsInitialized())
244     return;
245 
246   pybind11::gil_scoped_acquire gil;
247   // Two possibilities:
248   // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
249   // Storage. Then we must be careful about PyObject resurrection (see
250   // THPVariable_clear).
251   // 2. We are decref-ing some other Python object. We don't do
252   // PyObject resurrection on non-Tensors, so we just carry on as usual
253   if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
254     if (THPVariable_Check(pyobj)) {
255       // It's still alive!  This can happen if a weak ref resurrected
256       // the PyObject without flipping ownership.  At this point it is
257       // too late to rescue the object, so just stub out the PyObject
258       // so that it fails on subsequent uses.  Don't raise an error here;
259       // you're probably in a destructor.
260       TORCH_WARN(
261           "Deallocating Tensor that still has live PyObject references.  "
262           "This probably happened because you took out a weak reference to "
263           "Tensor and didn't call _fix_weakref() after dereferencing it.  "
264           "Subsequent accesses to this tensor via the PyObject will now fail.");
265       ((THPVariable*)pyobj)->cdata =
266           c10::MaybeOwned<torch::autograd::Variable>();
267     } else if (THPStorage_Check(pyobj)) {
268       TORCH_WARN(
269           "Deallocating UntypedStorage that still has live PyObject references.  "
270           "This probably happened because you took out a weak reference to "
271           "UntypedStorage and didn't call _fix_weakref() after dereferencing it.  "
272           "Subsequent accesses to this storage via the PyObject will now fail.");
273       ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
274     }
275   }
276   Py_DECREF(pyobj);
277 };
278 
incref(PyObject * pyobj) const279 void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
280   if (!Py_IsInitialized())
281     return;
282   pybind11::gil_scoped_acquire gil;
283   Py_INCREF(pyobj);
284 };
285 
isPythonTensor(const at::Tensor & tensor)286 bool isPythonTensor(const at::Tensor& tensor) {
287   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
288 }
289 
reportErrorCallback(PyObject * callback,DispatchKey key) const290 void ConcretePyInterpreterVTable::reportErrorCallback(
291     PyObject* callback,
292     DispatchKey key) const {
293   py::gil_scoped_acquire g;
294   auto func = py::reinterpret_borrow<py::object>(callback);
295   // Not all DispatchKeys are pybind'ed into Python and we do not have infra
296   // to ensure this, so just pass a string back to Python.
297   func(c10::toString(key));
298 }
299 
dispatch(const c10::OperatorHandle & op,torch::jit::Stack * stack) const300 void ConcretePyInterpreterVTable::dispatch(
301     const c10::OperatorHandle& op,
302     torch::jit::Stack* stack) const {
303   const auto& schema = op.schema();
304   const auto num_arguments = schema.arguments().size();
305   auto arguments = torch::jit::pop(*stack, num_arguments);
306 
307   // The plan: convert all the arguments back into PyObjects,
308   // extracting out the tensor handles, then call
309   // handle_torch_function_no_python_arg_parser
310   // NB: at the point arguments are pushed to the stack, ALL defaults
311   // are already present
312 
313   py::gil_scoped_acquire g;
314 
315   std::vector<PyObject*> overloaded_args;
316   py::handle torch_api_function_overload = getTorchApiFunction(op);
317 
318   // Find overloaded tensors
319   for (const auto idx : c10::irange(arguments.size())) {
320     const auto& ivalue = arguments[idx];
321     if (ivalue.isTensor()) {
322       const auto& tensor = ivalue.toTensor();
323       if (isPythonTensor(tensor)) {
324         append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
325       }
326     } else if (ivalue.isList()) {
327       const auto& list = ivalue.toListRef();
328       for (const auto jdx : c10::irange(list.size())) {
329         const auto& nv = list[jdx];
330         if (nv.isTensor()) {
331           const auto& tensor = nv.toTensor();
332           if (isPythonTensor(tensor)) {
333             append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
334           }
335         }
336       }
337     }
338   }
339 
340   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
341   auto args = std::move(args_kwargs.first);
342   auto kwargs = std::move(args_kwargs.second);
343 
344   PyObject* obj = handle_torch_function_no_python_arg_parser(
345       overloaded_args,
346       args.ptr(),
347       kwargs.ptr(),
348       nullptr,
349       torch_api_function_overload.ptr(),
350       nullptr,
351       TorchFunctionName::TorchDispatch);
352   pushPyOutToStack(
353       op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
354 }
355 
python_dispatcher(const c10::OperatorHandle & op,c10::DispatchKeySet ks,torch::jit::Stack * stack) const356 void ConcretePyInterpreterVTable::python_dispatcher(
357     const c10::OperatorHandle& op,
358     c10::DispatchKeySet ks,
359     torch::jit::Stack* stack) const {
360   py::gil_scoped_acquire g;
361   py::handle torch_api_function_overload = getTorchApiFunction(op);
362   // TODO: if necessary, can optimize to cache the cache lookup
363   // TODO: if necessary, can optimize OpOverload to have slots
364   auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
365   if (cache.ptr() == nullptr) {
366     throw python_error();
367   }
368 
369   c10::DispatchKey k = ks.highestPriorityTypeId();
370   // TODO: allow this to be non-owning
371   auto handler = py::reinterpret_borrow<py::object>(
372       PyDict_GetItem(cache.ptr(), py::cast(k).ptr()));
373   if (handler.ptr() == nullptr) {
374     // Slow path
375     handler = torch_api_function_overload.attr("_get_dispatch")(k);
376   }
377   if (py::isinstance<c10::DispatchKey>(handler)) {
378     // NB: not redispatch, as that will permanently remove the python
379     // dispatcher for subsequent redispatches
380     op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
381     return;
382   }
383 
384   const auto& schema = op.schema();
385   const auto num_arguments = schema.arguments().size();
386   auto arguments = torch::jit::pop(*stack, num_arguments);
387 
388   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
389   auto args = std::move(args_kwargs.first);
390   auto kwargs = std::move(args_kwargs.second);
391 
392   py::object obj = py::reinterpret_steal<py::object>(
393       PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
394 
395   if (obj.ptr() == nullptr) {
396     throw python_error();
397   }
398 
399   pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
400 }
401 
detach(const c10::TensorImpl * self) const402 c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
403     const c10::TensorImpl* self) const {
404   pybind11::gil_scoped_acquire gil;
405   at::impl::MaybeSetTLSOnEntryGuard guard;
406 
407   auto out = torchDispatchFromTensorImpl(
408       self,
409       "detach",
410       py::module::import("torch")
411           .attr("ops")
412           .attr("aten")
413           .attr("detach")
414           .attr("default")
415           .ptr(),
416       "torch.ops.aten");
417 
418   TORCH_CHECK(
419       THPVariable_Check(out.ptr()),
420       "detach returned invalid type ",
421       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
422       ", expected Tensor");
423   const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
424   return res_t.getIntrusivePtr();
425 }
426 
is_contiguous(const c10::TensorImpl * self,at::MemoryFormat memory_format) const427 bool ConcretePyInterpreterVTable::is_contiguous(
428     const c10::TensorImpl* self,
429     at::MemoryFormat memory_format) const {
430   pybind11::gil_scoped_acquire gil;
431   at::impl::MaybeSetTLSOnEntryGuard guard;
432 
433   py::object out;
434   if (memory_format == at::MemoryFormat::Contiguous) {
435     // For backwards compatibility
436     out = torchDispatchFromTensorImpl(
437         self,
438         "is_contiguous",
439         py::module::import("torch")
440             .attr("ops")
441             .attr("aten")
442             .attr("is_contiguous")
443             .attr("default")
444             .ptr(),
445         "torch.ops.aten");
446   } else {
447     out = torchDispatchFromTensorImpl(
448         self,
449         "is_contiguous",
450         py::module::import("torch")
451             .attr("ops")
452             .attr("aten")
453             .attr("is_contiguous")
454             .attr("memory_format")
455             .ptr(),
456         "torch.ops.aten",
457         {py::cast(memory_format)});
458   }
459 
460   if (out.is_none()) {
461     return self->is_contiguous_default(memory_format);
462   }
463 
464   TORCH_CHECK(
465       PyBool_Check(out.ptr()),
466       "is_contiguous returned invalid type ",
467       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
468       ", expected bool");
469 
470   return PyObject_IsTrue(out.ptr());
471 }
472 
is_strides_like(const c10::TensorImpl * self,at::MemoryFormat memory_format) const473 bool ConcretePyInterpreterVTable::is_strides_like(
474     const c10::TensorImpl* self,
475     at::MemoryFormat memory_format) const {
476   pybind11::gil_scoped_acquire gil;
477   at::impl::MaybeSetTLSOnEntryGuard guard;
478 
479   auto out = torchDispatchFromTensorImpl(
480       self,
481       "is_strides_like",
482       py::module::import("torch")
483           .attr("ops")
484           .attr("aten")
485           // NB: intentionally suffixed with _format to avoid
486           // triggering matches against "_like" suffix
487           .attr("is_strides_like_format")
488           .attr("default")
489           .ptr(),
490       "torch.ops.aten",
491       {py::cast(memory_format)});
492 
493   if (out.is_none()) {
494     return self->is_strides_like_default(memory_format);
495   }
496 
497   TORCH_CHECK(
498       PyBool_Check(out.ptr()),
499       "is_strides_like_format returned invalid type ",
500       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
501       ", expected bool");
502 
503   return PyObject_IsTrue(out.ptr());
504 }
505 
is_non_overlapping_and_dense(const c10::TensorImpl * self) const506 bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
507     const c10::TensorImpl* self) const {
508   pybind11::gil_scoped_acquire gil;
509   at::impl::MaybeSetTLSOnEntryGuard guard;
510 
511   auto out = torchDispatchFromTensorImpl(
512       self,
513       "is_non_overlapping_and_dense",
514       py::module::import("torch")
515           .attr("ops")
516           .attr("aten")
517           .attr("is_non_overlapping_and_dense")
518           .attr("default")
519           .ptr(),
520       "torch.ops.aten");
521 
522   if (out.is_none()) {
523     return self->is_non_overlapping_and_dense_default();
524   }
525 
526   TORCH_CHECK(
527       PyBool_Check(out.ptr()),
528       "is_non_overlapping_and_dense returned invalid type ",
529       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
530       ", expected bool");
531 
532   return PyObject_IsTrue(out.ptr());
533 }
534 
dim(const c10::TensorImpl * self) const535 int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
536   pybind11::gil_scoped_acquire gil;
537   at::impl::MaybeSetTLSOnEntryGuard guard;
538 
539   auto out = torchDispatchFromTensorImpl(
540       self,
541       "dim",
542       py::module::import("torch")
543           .attr("ops")
544           .attr("aten")
545           .attr("dim")
546           .attr("default")
547           .ptr(),
548       "torch.ops.aten");
549 
550   TORCH_CHECK(
551       PyLong_Check(out.ptr()),
552       "dim returned invalid type ",
553       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
554       ", expected int");
555 
556   return THPUtils_unpackLong(out.ptr());
557 }
558 
device(const c10::TensorImpl * self) const559 c10::Device ConcretePyInterpreterVTable::device(
560     const c10::TensorImpl* self) const {
561   pybind11::gil_scoped_acquire gil;
562   at::impl::MaybeSetTLSOnEntryGuard guard;
563 
564   auto out = torchDispatchFromTensorImpl(
565       self,
566       "device",
567       py::module::import("torch")
568           .attr("ops")
569           .attr("prim")
570           .attr("device")
571           .attr("default")
572           .ptr(),
573       "torch.ops.prim");
574 
575   return toDevice(out.ptr());
576 }
577 
set_tensor_attr_with_capsule(const c10::TensorImpl * tensor,py::capsule & capsule,const char * attr_name)578 static void set_tensor_attr_with_capsule(
579     const c10::TensorImpl* tensor,
580     py::capsule& capsule,
581     const char* attr_name) {
582   std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
583       getPyInterpreter(), /*ignore_hermetic_tls=*/false);
584   TORCH_CHECK(
585       mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
586   auto obj = mb_obj.value();
587   py::handle(obj).attr(attr_name) = capsule;
588 }
589 
590 // Note [Tensor Subclass custom size/stride caching strategy]
591 // Tensor subclasses can use __torch_dispatch__ to override size/stride calls.
592 // However, this presents a problem:
593 // (1) When you return a custom (maybe symbolic) size/stride
594 //     from python, we need to stash this fresh vector of ints/symints
595 //     somewhere so that it has the same lifetime as the tensor.
596 // (2) If the subclass experiences a metadata mutation,
597 //     this stashed vector is no longer valid, so we need to allocate a fresh
598 //     buffer to store the new sizes the next time someone asks for them.
599 //
600 // We handle this in the same way that `TensorImpl::sizes_default()`
601 // handles its buffer: we simply reallocate the buffer whenever
602 // the number of dimensions changes due to a resize.
603 // Notable, we do *not* reallocate the buffer if the values changed,
604 // but the number of dimensions stayed the same (e.g. `.transpose_()`).
605 template <typename T>
get_set_cached_attr(const c10::TensorImpl * tensor,const char * base_attr_name,const py::object & obj)606 static c10::ArrayRef<T> get_set_cached_attr(
607     const c10::TensorImpl* tensor,
608     const char* base_attr_name,
609     const py::object& obj) {
610   std::optional<PyObject*> mb_obj =
611       tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
612   TORCH_CHECK(
613       mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
614   auto tensor_obj = mb_obj.value();
615   auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
616 
617   bool is_buffer_allocated = false;
618   size_t curr_size = 0;
619   if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) {
620     auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str());
621     curr_size = py::cast<size_t>(len_pyobj);
622     is_buffer_allocated = true;
623   }
624 
625   size_t new_size = py::len(obj);
626 
627   // We do the smallvector optimization here: any time the new_size is <=5,
628   // we always allocate our buffer to size 5, so that if the next resize
629   // is also to <=5 elements, we don't need to reallocate.
630   // Note: I tried removing this optimization and tripped ASAN
631   // in a batchnorm kernel here:
632   // https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D
633   // We should fix this instead.
634   bool needs_resize = false;
635   // We need to resize if:
636   // (1) we haven't allocated our buffer at all yet
637   // (2) Our buffer size is different from the new size
638   //     (note: we use the small vector optimization, where our buffer
639   //     is always allocated to at least size 5, and any resizes
640   //     within the <= 5 regime to not require a reallocation).
641   auto is_smallvector = curr_size <= 5;
642   needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) ||
643       (!is_smallvector && curr_size != new_size);
644   if (needs_resize) {
645     // If our current buffer is not the right size (either because we haven't
646     // allocated it yet, or there was a metadata mutation that changed the
647     // number of dims of the tensor), allocate a fresh buffer. Note that this
648     // will trash the previous buffer if there already was one, invalidating any
649     // existing SymIntArrayRef's from an old .sym_size() call.
650     auto new_buffer_size = new_size;
651     if (new_size <= 5) {
652       // This is the smallvector optimization
653       new_buffer_size = 5;
654     }
655     T* ptr = new T[new_buffer_size];
656     auto capsule =
657         py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); });
658     int64_t idx = 0;
659     for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
660       ptr[idx] = py::cast<T>(*it);
661     }
662     // Set the buffer
663     set_tensor_attr_with_capsule(tensor, capsule, base_attr_name);
664     // Set the len buffer
665     py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size;
666   } else {
667     TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name));
668     auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
669     void* buffer_pycapsule =
670         PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
671     auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
672 
673     // Overwrite the buffer with our new values, but only if any of them changed
674     // (due to a metadata mutation).
675     // This is technically not thread safe, because the update happens lazily.
676     // The original metadata mutation call on the tensor might have been thread
677     // safe (e.g. a .resize_() call), but we won't actually mutate the size
678     // buffer until the first call to .sizes() which the user might not access
679     // in a thread-safe way. For now we are not explicitly locking, but maybe we
680     // should.
681     int64_t idx = 0;
682     // Quick sanity assert that our buffer size is large enough
683     // to compare against all the elements in the new buffer.
684     size_t curr_buffer_size = 5;
685     if (curr_buffer_size < curr_size) {
686       curr_buffer_size = curr_size;
687     }
688     TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size);
689     for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
690       auto actual_val = py::cast<T>(*it);
691       if constexpr (std::is_same_v<T, c10::SymInt>) {
692         // if our SymInts are symbolic, we are *not* doing an equality check on
693         // the symints. we just want to see if the nodes are the same. this is
694         // because we don't want to introduce any guards here.
695         if (!curr_buffer[idx].is_same(actual_val)) {
696           curr_buffer[idx] = actual_val;
697         }
698       } else {
699         if (curr_buffer[idx] != actual_val) {
700           curr_buffer[idx] = actual_val;
701         }
702       }
703     }
704   }
705 
706   // The correct data is now stored at the buffer - read and return it.
707   auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
708   void* buffer_pycapsule =
709       PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
710   auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
711   return c10::ArrayRef<T>(curr_buffer, new_size);
712 }
713 
strides(const c10::TensorImpl * self) const714 c10::IntArrayRef ConcretePyInterpreterVTable::strides(
715     const c10::TensorImpl* self) const {
716   pybind11::gil_scoped_acquire gil;
717   at::impl::MaybeSetTLSOnEntryGuard guard;
718 
719   auto out = torchDispatchFromTensorImpl(
720       self,
721       "stride",
722       py::module::import("torch")
723           .attr("ops")
724           .attr("aten")
725           .attr("stride")
726           .attr("default")
727           .ptr(),
728       "torch.ops.aten");
729 
730   if (out.is_none()) {
731     TORCH_CHECK(
732         !self->has_symbolic_sizes_strides(),
733         "Cannot call strides on a tensor with symbolic shapes/strides");
734     return self->strides_default();
735   }
736   TORCH_CHECK(
737       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
738       "strides must be a list or a tuple");
739   auto updated_strides =
740       get_set_cached_attr<int64_t>(self, "_strides_capsule", out);
741   return updated_strides;
742 }
743 
sizes(const c10::TensorImpl * self) const744 c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
745     const c10::TensorImpl* self) const {
746   pybind11::gil_scoped_acquire gil;
747   at::impl::MaybeSetTLSOnEntryGuard guard;
748   HANDLE_TH_ERRORS
749   auto out = torchDispatchFromTensorImpl(
750       self,
751       "size",
752       py::module::import("torch")
753           .attr("ops")
754           .attr("aten")
755           .attr("size")
756           .attr("default")
757           .ptr(),
758       "torch.ops.aten");
759   if (out.is_none()) {
760     TORCH_CHECK(
761         !self->has_symbolic_sizes_strides(),
762         "Cannot call sizes on a tensor with symbolic shapes/strides");
763     return self->sizes_default();
764   }
765   TORCH_CHECK(
766       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
767       "sizes must be a list or a tuple");
768 
769   auto updated_sizes =
770       get_set_cached_attr<int64_t>(self, "_sizes_capsule", out);
771   return updated_sizes;
772   END_HANDLE_TH_ERRORS_PYBIND
773 }
774 
sym_sizes(const c10::TensorImpl * self) const775 c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
776     const c10::TensorImpl* self) const {
777   pybind11::gil_scoped_acquire gil;
778   at::impl::MaybeSetTLSOnEntryGuard guard;
779   HANDLE_TH_ERRORS
780   auto out = torchDispatchFromTensorImpl(
781       self,
782       "sym_size",
783       py::module::import("torch")
784           .attr("ops")
785           .attr("aten")
786           .attr("sym_size")
787           .attr("default")
788           .ptr(),
789       "torch.ops.aten");
790 
791   if (out.is_none()) {
792     return self->sym_sizes_default();
793   }
794   TORCH_CHECK(
795       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
796       "sym_size must be a list or a tuple");
797 
798   // See Note [Tensor Subclass custom size/stride caching strategy]
799   auto updated_sym_sizes =
800       get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out);
801   return updated_sym_sizes;
802   END_HANDLE_TH_ERRORS_PYBIND
803 }
804 
layout(const c10::TensorImpl * self) const805 c10::Layout ConcretePyInterpreterVTable::layout(
806     const c10::TensorImpl* self) const {
807   pybind11::gil_scoped_acquire gil;
808   at::impl::MaybeSetTLSOnEntryGuard guard;
809   auto out = torchDispatchFromTensorImpl(
810       self,
811       "layout",
812       py::module::import("torch")
813           .attr("ops")
814           .attr("prim")
815           .attr("layout")
816           .attr("default")
817           .ptr(),
818       "torch.ops.prim");
819 
820   TORCH_CHECK(
821       THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()),
822       "layout returned invalid type ",
823       py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
824       ", expected Layout");
825 
826   if (THPLayout_Check(out.ptr())) {
827     return toLayout(out.ptr());
828   } else {
829     return c10::Layout(py::cast<int64_t>(out));
830   }
831 }
832 
numel(const c10::TensorImpl * self) const833 int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
834   pybind11::gil_scoped_acquire gil;
835   at::impl::MaybeSetTLSOnEntryGuard guard;
836   auto out = torchDispatchFromTensorImpl(
837       self,
838       "numel",
839       py::module::import("torch")
840           .attr("ops")
841           .attr("aten")
842           .attr("numel")
843           .attr("default")
844           .ptr(),
845       "torch.ops.aten");
846 
847   if (out.is_none()) {
848     TORCH_CHECK(
849         !self->has_symbolic_sizes_strides(),
850         "Cannot call sizes on a tensor with symbolic shapes/strides");
851     return self->numel_default();
852   }
853   return py::cast<int64_t>(out);
854 }
855 
sym_numel(const c10::TensorImpl * self) const856 c10::SymInt ConcretePyInterpreterVTable::sym_numel(
857     const c10::TensorImpl* self) const {
858   pybind11::gil_scoped_acquire gil;
859   at::impl::MaybeSetTLSOnEntryGuard guard;
860   auto out = torchDispatchFromTensorImpl(
861       self,
862       "sym_numel",
863       py::module::import("torch")
864           .attr("ops")
865           .attr("aten")
866           .attr("sym_numel")
867           .attr("default")
868           .ptr(),
869       "torch.ops.aten");
870 
871   if (out.is_none()) {
872     return self->sym_numel_default();
873   }
874   return torch::is_symint(out) ? out.cast<c10::SymInt>()
875                                : c10::SymInt{py::cast<int64_t>(out)};
876 }
877 
sym_storage_offset(const c10::TensorImpl * self) const878 c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
879     const c10::TensorImpl* self) const {
880   pybind11::gil_scoped_acquire gil;
881   at::impl::MaybeSetTLSOnEntryGuard guard;
882   auto out = torchDispatchFromTensorImpl(
883       self,
884       "sym_storage_offset",
885       py::module::import("torch")
886           .attr("ops")
887           .attr("aten")
888           .attr("sym_storage_offset")
889           .attr("default")
890           .ptr(),
891       "torch.ops.aten");
892 
893   if (out.is_none()) {
894     return self->sym_storage_offset_default();
895   }
896   return torch::is_symint(out) ? out.cast<c10::SymInt>()
897                                : c10::SymInt{py::cast<int64_t>(out)};
898 }
899 
sym_strides(const c10::TensorImpl * self) const900 c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
901     const c10::TensorImpl* self) const {
902   pybind11::gil_scoped_acquire gil;
903   at::impl::MaybeSetTLSOnEntryGuard guard;
904   HANDLE_TH_ERRORS
905   auto out = torchDispatchFromTensorImpl(
906       self,
907       "sym_stride",
908       py::module::import("torch")
909           .attr("ops")
910           .attr("aten")
911           .attr("sym_stride")
912           .attr("default")
913           .ptr(),
914       "torch.ops.aten");
915 
916   if (out.is_none()) {
917     return self->sym_strides_default();
918   }
919   // We need to squeeze SymIntNodes and ints into `SymInts`
920   // since it's a format `sym_strides()` are stored in
921   TORCH_CHECK(
922       py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
923       "sym_strides must be a list or a tuple");
924 
925   auto updated_sym_strides =
926       get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out);
927   return updated_sym_strides;
928   END_HANDLE_TH_ERRORS_PYBIND
929 }
930 
reset_backward_hooks(const c10::TensorImpl * self) const931 void ConcretePyInterpreterVTable::reset_backward_hooks(
932     const c10::TensorImpl* self) const {
933   pybind11::gil_scoped_acquire gil;
934   at::impl::MaybeSetTLSOnEntryGuard guard;
935   HANDLE_TH_ERRORS
936   Tensor self_t =
937       Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
938                  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
939              unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
940   auto self_p =
941       py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
942   PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
943   END_HANDLE_TH_ERRORS_PYBIND
944 }
945 
name() const946 std::string ConcretePyInterpreterVTable::name() const {
947   std::stringstream ss;
948   ss << getPyInterpreter();
949   return ss.str();
950 }
951 
952 PyInterpreterHolder self_interpreter;
953 
954 } // anonymous namespace
955 
getTorchApiFunction(const c10::OperatorHandle & op)956 py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
957   return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
958     // Parse the name into namespace and name (no overload_name)
959     // TODO: put this into the library
960     const auto& schema = op.schema();
961     const auto& qualified_name = op.operator_name().name;
962     const auto& overload_name = schema.overload_name();
963     auto pos = qualified_name.find("::");
964     TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
965     // Make me some null terminated strings
966     std::string ns_str = qualified_name.substr(0, pos);
967     const char* ns = ns_str.c_str();
968     const char* func_name = qualified_name.c_str() + pos + strlen("::");
969 
970     py::handle torch_api_function =
971         py::module::import("torch").attr("ops").attr(ns).attr(func_name);
972     if (overload_name.empty()) {
973       return torch_api_function.attr("default").ptr();
974     } else {
975       return torch_api_function.attr(overload_name.c_str()).ptr();
976     }
977   });
978 }
979 
980 } // namespace torch::detail
981 
getPyInterpreter()982 c10::impl::PyInterpreter* getPyInterpreter() {
983   return torch::detail::self_interpreter.get();
984 }
985 
isMainPyInterpreter()986 bool isMainPyInterpreter() {
987   return torch::detail::self_interpreter.is_main_interpreter();
988 }
989