xref: /aosp_15_r20/external/pytorch/c10/core/impl/PyInterpreter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/DispatchKeySet.h>
5 #include <c10/core/Layout.h>
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/core/SymIntArrayRef.h>
8 #include <c10/macros/Export.h>
9 #include <c10/util/ArrayRef.h>
10 #include <c10/util/intrusive_ptr.h>
11 #include <c10/util/python_stub.h>
12 #include <string>
13 #include <vector>
14 
15 // Forward declarations
16 
17 namespace c10 {
18 struct IValue;
19 class OperatorHandle;
20 struct TensorImpl;
21 } // namespace c10
22 
23 namespace torch::jit {
24 using Stack = std::vector<c10::IValue>;
25 }
26 
27 // Actual implementation
28 
29 namespace c10::impl {
30 
31 struct C10_API PyInterpreter;
32 
33 // Note [Python interpreter tag]
34 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
35 // Traditionally, PyTorch is layered such that our Python library
36 // (libtorch_python) references our pure C++ library (libtorch) as the
37 // natural order of things.  However, sometimes this natural order is
38 // subverted: C++ objects refer to Python objects (for example, we
39 // store a PyObject* pointer on TensorImpl so that converting from a
40 // C++ Tensor to a Python Tensor is just a memory dereference).
41 //
42 // These unusual orderings must be treated with care.  To start, you need to
43 // virtualize the destructor so that the PyObject can be decref'ed on
44 // destruction (because the C++ object itself doesn't know anything about
45 // Python--remember, layering!).  This process itself is fraught, since
46 // acquiring the GIL could lead to deadlocks if someone is blocking on you
47 // while holding the GIL.  Furthermore, if the C++ objects outlive the
48 // interpreter (which can happen if you stash them in a static global
49 // variable defined in libtorch), you may attempt to decref the object when
50 // the Python interpreter has already been shutdown.
51 //
52 // BUT WAIT, IT GETS WORSE.  With torchdeploy, there may be multiple Python
53 // interpreters in a single process. If a C++ object is accessible from
54 // multiple interpreters, we must take care not to accidentally pass a
55 // PyObject from one interpreter with another interpreter.
56 //
57 // To prevent these mixups, we introduce a PyInterpreter "tag" (object with
58 // a vtable), which specifies a specific Python interpreter.
59 //
60 //  - Any given object can be associated with AT MOST one Python interpreter.
61 //    We represent the interpreter tag as a memory address to an instance of
62 //    a virtual class that is allocated once per interpreter (this is so that
63 //    we can request the interpreter to perform operations for us, if
64 //    necessary).
65 //
66 //  - It can be recorded with a PyObject (PyInterpreterObject) so that
67 //    we know what interpreter the object is associated with, and we can
68 //    raise an error if you try to use the PyObject from the wrong
69 //    interpreter context.
70 //
71 //  - It contains a vtable that can be used to perform various Python
72 //    operations from ordinary C++ code that ordinarily wouldn't be accessible
73 //    from libtorch.
74 //
75 // A simple use case is when a C++ object must be associated with a PyObject.
76 // However, for TensorImpl, we lazily allocate a PyObject the first time the
77 // object passes into Python.  The invariants for this situation are more
78 // subtle:
79 //
80 //  - A given TensorImpl's interpreter tag can only go from uninitialized to
81 //    tagged; once tagged, this is a quiescent state (once tagged to an
82 //    interpreter, ALWAYS tagged to that interpreter)
83 //
84 //  - A thread may mutate the PyObject field of a TensorImpl if and only if it
85 //    holds the GIL for the interpreter tagged on the TensorImpl.  (If the
86 //    TensorImpl is not tagged, it must first atomically claim its tag before it
87 //    can validly write)
88 //
89 // WARNING: This class has to be written very carefully, because it may be
90 // possible for a Tensor to have a reference an interpreter corresponding to
91 // a shared library that has ALREADY BEEN UNLOADED.  This makes blindly calling
92 // virtual methods very dangerous, because the vtable may be garbage at that
93 // point (on a good day, you might get "pure virtual method called").
94 //
95 // The idea to solve this problem is we always leak PyInterpreters (so they
96 // always stay live even after dlclose), and make sure we can disarm their
97 // virtual methods by indirecting through a separate PyInterpreterVTable
98 // object.  This can be replaced with a no-op vtable from libc10.so, which
99 // is guaranteed to stick around until the bitter end.
100 //
101 // NB: The downside with representing PyInterpreter tags as full objects is that
102 // it takes an extra word on TensorImpl.  If tags were instead just integer
103 // indices, on 64-bit architectures we could pack the tag and PyObject together
104 // into a single atomic word.  On 32-bit architectures we could simply say that
105 // only one Python interpreter is supported (erroring if a nontrivial
106 // interpreter tag is attempted to be set).
107 //
108 // The difficulty with this scheme is we need to maintain an out-of-line table
109 // to get at the PyInterpreters so that we can do virtual method calls on them,
110 // and registration/deregistration to this table must be done in a thread safe
111 // manner.  This can be easily done if the number of possible PyInterpreters is
112 // small enough (e.g., 8-bit integer) by simply preallocating an array of
113 // sufficient size to hold all possible interpreters.  Surely 128 threads is
114 // more than enough for anyone!
115 //
116 // I didn't decide to do this technique at the moment, because the extra word
117 // added by the PyInterpreter tag takes us to 24 words, which means that we
118 // still fit inside three eight word cache lines.  If you need to penny pinch
119 // another word consider doing this!
120 
121 struct C10_API PyInterpreterVTable {
122   virtual ~PyInterpreterVTable() = default;
123 
124   // Report the name of this interpreter
125   virtual std::string name() const = 0;
126 
127   // Run Py_INCREF on a PyObject.
128   virtual void incref(PyObject* pyobj) const = 0;
129   // Run Py_DECREF on a PyObject.  We DO NOT assume the GIL is held on call
130   // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
131   virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
132 
133   // Perform a detach by deferring to the __torch_dispatch__ implementation of
134   // detach, which will also arrange for the PyObject to get copied in this
135   // situation
136   virtual c10::intrusive_ptr<TensorImpl> detach(
137       const TensorImpl* self) const = 0;
138 
139   // Invoke the Python boxed fallback dispatch to go back into Python
140   virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
141       const = 0;
142 
143   virtual void reportErrorCallback(PyObject* callback, DispatchKey key)
144       const = 0;
145 
146   // This is only invoked in the multipy/torchdeploy situation from
147   // pythonOpRegistrationTrampoline; this lets us get to the Python
148   // interpreter to actually find the appropriate Python op registration
149   // entry to call.
150   virtual void python_op_registration_trampoline(
151       const c10::OperatorHandle& op,
152       c10::DispatchKey,
153       c10::DispatchKeySet keyset,
154       torch::jit::Stack* stack,
155       bool with_keyset,
156       bool with_op) const = 0;
157 
158   virtual void throw_abstract_impl_not_imported_error(
159       std::string opname,
160       const char* pymodule,
161       const char* context) const = 0;
162 
163   // Invoke the Python dispatcher to handle this call
164   virtual void python_dispatcher(
165       const c10::OperatorHandle& op,
166       c10::DispatchKeySet,
167       torch::jit::Stack* stack) const = 0;
168 
169   virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
170       const = 0;
171   virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
172       const = 0;
173   virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
174   virtual c10::Device device(const TensorImpl* self) const = 0;
175   virtual int64_t dim(const TensorImpl* self) const = 0;
176   virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0;
177   virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0;
178   virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0;
179   virtual c10::Layout layout(const TensorImpl* self) const = 0;
180   virtual int64_t numel(const TensorImpl* self) const = 0;
181   virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
182   virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
183   virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
184 
185   virtual void trace_gpu_event_creation(
186       c10::DeviceType device_type,
187       uintptr_t event) const = 0;
188   virtual void trace_gpu_event_deletion(
189       c10::DeviceType device_type,
190       uintptr_t event) const = 0;
191   virtual void trace_gpu_event_record(
192       c10::DeviceType device_type,
193       uintptr_t event,
194       uintptr_t stream) const = 0;
195   virtual void trace_gpu_event_wait(
196       c10::DeviceType device_type,
197       uintptr_t event,
198       uintptr_t stream) const = 0;
199   virtual void trace_gpu_memory_allocation(
200       c10::DeviceType device_type,
201       uintptr_t ptr) const = 0;
202   virtual void trace_gpu_memory_deallocation(
203       c10::DeviceType device_type,
204       uintptr_t ptr) const = 0;
205   virtual void trace_gpu_stream_creation(
206       c10::DeviceType device_type,
207       uintptr_t stream) const = 0;
208   virtual void trace_gpu_device_synchronization(
209       c10::DeviceType device_type) const = 0;
210   virtual void trace_gpu_stream_synchronization(
211       c10::DeviceType device_type,
212       uintptr_t stream) const = 0;
213   virtual void trace_gpu_event_synchronization(
214       c10::DeviceType device_type,
215       uintptr_t event) const = 0;
216 
217   virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
218 };
219 
220 struct C10_API PyInterpreter {
221   const PyInterpreterVTable* vtable_;
222 
PyInterpreterPyInterpreter223   PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){};
224 
225   const PyInterpreterVTable& operator*() const noexcept {
226     return *vtable_;
227   }
228   const PyInterpreterVTable* operator->() const noexcept {
229     return vtable_;
230   }
231 
232   // Disarm this PyInterpreter, making all of its methods noops.
233   // The vtable pointer is not an atomic at the moment, which means
234   // a disarm() invocation that is concurrent with active destructors
235   // is not thread safe and will trigger TSAN.  My hope is that this
236   // situations doesn't ever actually happen; tensor destruction should
237   // quiesce when a dlclose happens, and any long lived tensors whose
238   // destructors would be disarmed here only begin the destruction process
239   // on process shutdown (long after the dlclose has occurred).
240   void disarm() noexcept;
241 };
242 
243 // PyInterpreterStatus describes what the state of its interpreter tag
244 // is, relative to the thread currently holding the GIL.
245 enum class PyInterpreterStatus {
246   // We just allocated the Tensor, it hasn't escaped to other threads,
247   // we know that it definitely hasn't been tagged to be associated
248   // with an interpreter.
249   DEFINITELY_UNINITIALIZED,
250   // We queried the interpreter field and it looked uninitialized.  But
251   // another thread may have raced with us to tag it with some other
252   // interpreter id.  So we will have to do a CEX to make sure we can
253   // actually nab it.
254   MAYBE_UNINITIALIZED,
255   // We queried the interpreter field and it was tagged to belong to us.
256   // This means we have sole write access (as we hold the GIL for this
257   // interpreter)
258   TAGGED_BY_US,
259   // Someone else tagged this.  We can't use this TensorImpl from Python.
260   TAGGED_BY_OTHER,
261 };
262 
263 } // namespace c10::impl
264