1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/core/DispatchKeySet.h> 5 #include <c10/core/Layout.h> 6 #include <c10/core/MemoryFormat.h> 7 #include <c10/core/SymIntArrayRef.h> 8 #include <c10/macros/Export.h> 9 #include <c10/util/ArrayRef.h> 10 #include <c10/util/intrusive_ptr.h> 11 #include <c10/util/python_stub.h> 12 #include <string> 13 #include <vector> 14 15 // Forward declarations 16 17 namespace c10 { 18 struct IValue; 19 class OperatorHandle; 20 struct TensorImpl; 21 } // namespace c10 22 23 namespace torch::jit { 24 using Stack = std::vector<c10::IValue>; 25 } 26 27 // Actual implementation 28 29 namespace c10::impl { 30 31 struct C10_API PyInterpreter; 32 33 // Note [Python interpreter tag] 34 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 // Traditionally, PyTorch is layered such that our Python library 36 // (libtorch_python) references our pure C++ library (libtorch) as the 37 // natural order of things. However, sometimes this natural order is 38 // subverted: C++ objects refer to Python objects (for example, we 39 // store a PyObject* pointer on TensorImpl so that converting from a 40 // C++ Tensor to a Python Tensor is just a memory dereference). 41 // 42 // These unusual orderings must be treated with care. To start, you need to 43 // virtualize the destructor so that the PyObject can be decref'ed on 44 // destruction (because the C++ object itself doesn't know anything about 45 // Python--remember, layering!). This process itself is fraught, since 46 // acquiring the GIL could lead to deadlocks if someone is blocking on you 47 // while holding the GIL. Furthermore, if the C++ objects outlive the 48 // interpreter (which can happen if you stash them in a static global 49 // variable defined in libtorch), you may attempt to decref the object when 50 // the Python interpreter has already been shutdown. 51 // 52 // BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python 53 // interpreters in a single process. If a C++ object is accessible from 54 // multiple interpreters, we must take care not to accidentally pass a 55 // PyObject from one interpreter with another interpreter. 56 // 57 // To prevent these mixups, we introduce a PyInterpreter "tag" (object with 58 // a vtable), which specifies a specific Python interpreter. 59 // 60 // - Any given object can be associated with AT MOST one Python interpreter. 61 // We represent the interpreter tag as a memory address to an instance of 62 // a virtual class that is allocated once per interpreter (this is so that 63 // we can request the interpreter to perform operations for us, if 64 // necessary). 65 // 66 // - It can be recorded with a PyObject (PyInterpreterObject) so that 67 // we know what interpreter the object is associated with, and we can 68 // raise an error if you try to use the PyObject from the wrong 69 // interpreter context. 70 // 71 // - It contains a vtable that can be used to perform various Python 72 // operations from ordinary C++ code that ordinarily wouldn't be accessible 73 // from libtorch. 74 // 75 // A simple use case is when a C++ object must be associated with a PyObject. 76 // However, for TensorImpl, we lazily allocate a PyObject the first time the 77 // object passes into Python. The invariants for this situation are more 78 // subtle: 79 // 80 // - A given TensorImpl's interpreter tag can only go from uninitialized to 81 // tagged; once tagged, this is a quiescent state (once tagged to an 82 // interpreter, ALWAYS tagged to that interpreter) 83 // 84 // - A thread may mutate the PyObject field of a TensorImpl if and only if it 85 // holds the GIL for the interpreter tagged on the TensorImpl. (If the 86 // TensorImpl is not tagged, it must first atomically claim its tag before it 87 // can validly write) 88 // 89 // WARNING: This class has to be written very carefully, because it may be 90 // possible for a Tensor to have a reference an interpreter corresponding to 91 // a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling 92 // virtual methods very dangerous, because the vtable may be garbage at that 93 // point (on a good day, you might get "pure virtual method called"). 94 // 95 // The idea to solve this problem is we always leak PyInterpreters (so they 96 // always stay live even after dlclose), and make sure we can disarm their 97 // virtual methods by indirecting through a separate PyInterpreterVTable 98 // object. This can be replaced with a no-op vtable from libc10.so, which 99 // is guaranteed to stick around until the bitter end. 100 // 101 // NB: The downside with representing PyInterpreter tags as full objects is that 102 // it takes an extra word on TensorImpl. If tags were instead just integer 103 // indices, on 64-bit architectures we could pack the tag and PyObject together 104 // into a single atomic word. On 32-bit architectures we could simply say that 105 // only one Python interpreter is supported (erroring if a nontrivial 106 // interpreter tag is attempted to be set). 107 // 108 // The difficulty with this scheme is we need to maintain an out-of-line table 109 // to get at the PyInterpreters so that we can do virtual method calls on them, 110 // and registration/deregistration to this table must be done in a thread safe 111 // manner. This can be easily done if the number of possible PyInterpreters is 112 // small enough (e.g., 8-bit integer) by simply preallocating an array of 113 // sufficient size to hold all possible interpreters. Surely 128 threads is 114 // more than enough for anyone! 115 // 116 // I didn't decide to do this technique at the moment, because the extra word 117 // added by the PyInterpreter tag takes us to 24 words, which means that we 118 // still fit inside three eight word cache lines. If you need to penny pinch 119 // another word consider doing this! 120 121 struct C10_API PyInterpreterVTable { 122 virtual ~PyInterpreterVTable() = default; 123 124 // Report the name of this interpreter 125 virtual std::string name() const = 0; 126 127 // Run Py_INCREF on a PyObject. 128 virtual void incref(PyObject* pyobj) const = 0; 129 // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call 130 // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] 131 virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0; 132 133 // Perform a detach by deferring to the __torch_dispatch__ implementation of 134 // detach, which will also arrange for the PyObject to get copied in this 135 // situation 136 virtual c10::intrusive_ptr<TensorImpl> detach( 137 const TensorImpl* self) const = 0; 138 139 // Invoke the Python boxed fallback dispatch to go back into Python 140 virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) 141 const = 0; 142 143 virtual void reportErrorCallback(PyObject* callback, DispatchKey key) 144 const = 0; 145 146 // This is only invoked in the multipy/torchdeploy situation from 147 // pythonOpRegistrationTrampoline; this lets us get to the Python 148 // interpreter to actually find the appropriate Python op registration 149 // entry to call. 150 virtual void python_op_registration_trampoline( 151 const c10::OperatorHandle& op, 152 c10::DispatchKey, 153 c10::DispatchKeySet keyset, 154 torch::jit::Stack* stack, 155 bool with_keyset, 156 bool with_op) const = 0; 157 158 virtual void throw_abstract_impl_not_imported_error( 159 std::string opname, 160 const char* pymodule, 161 const char* context) const = 0; 162 163 // Invoke the Python dispatcher to handle this call 164 virtual void python_dispatcher( 165 const c10::OperatorHandle& op, 166 c10::DispatchKeySet, 167 torch::jit::Stack* stack) const = 0; 168 169 virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat) 170 const = 0; 171 virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat) 172 const = 0; 173 virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0; 174 virtual c10::Device device(const TensorImpl* self) const = 0; 175 virtual int64_t dim(const TensorImpl* self) const = 0; 176 virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0; 177 virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0; 178 virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0; 179 virtual c10::Layout layout(const TensorImpl* self) const = 0; 180 virtual int64_t numel(const TensorImpl* self) const = 0; 181 virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0; 182 virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0; 183 virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0; 184 185 virtual void trace_gpu_event_creation( 186 c10::DeviceType device_type, 187 uintptr_t event) const = 0; 188 virtual void trace_gpu_event_deletion( 189 c10::DeviceType device_type, 190 uintptr_t event) const = 0; 191 virtual void trace_gpu_event_record( 192 c10::DeviceType device_type, 193 uintptr_t event, 194 uintptr_t stream) const = 0; 195 virtual void trace_gpu_event_wait( 196 c10::DeviceType device_type, 197 uintptr_t event, 198 uintptr_t stream) const = 0; 199 virtual void trace_gpu_memory_allocation( 200 c10::DeviceType device_type, 201 uintptr_t ptr) const = 0; 202 virtual void trace_gpu_memory_deallocation( 203 c10::DeviceType device_type, 204 uintptr_t ptr) const = 0; 205 virtual void trace_gpu_stream_creation( 206 c10::DeviceType device_type, 207 uintptr_t stream) const = 0; 208 virtual void trace_gpu_device_synchronization( 209 c10::DeviceType device_type) const = 0; 210 virtual void trace_gpu_stream_synchronization( 211 c10::DeviceType device_type, 212 uintptr_t stream) const = 0; 213 virtual void trace_gpu_event_synchronization( 214 c10::DeviceType device_type, 215 uintptr_t event) const = 0; 216 217 virtual void reset_backward_hooks(const TensorImpl* self) const = 0; 218 }; 219 220 struct C10_API PyInterpreter { 221 const PyInterpreterVTable* vtable_; 222 PyInterpreterPyInterpreter223 PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){}; 224 225 const PyInterpreterVTable& operator*() const noexcept { 226 return *vtable_; 227 } 228 const PyInterpreterVTable* operator->() const noexcept { 229 return vtable_; 230 } 231 232 // Disarm this PyInterpreter, making all of its methods noops. 233 // The vtable pointer is not an atomic at the moment, which means 234 // a disarm() invocation that is concurrent with active destructors 235 // is not thread safe and will trigger TSAN. My hope is that this 236 // situations doesn't ever actually happen; tensor destruction should 237 // quiesce when a dlclose happens, and any long lived tensors whose 238 // destructors would be disarmed here only begin the destruction process 239 // on process shutdown (long after the dlclose has occurred). 240 void disarm() noexcept; 241 }; 242 243 // PyInterpreterStatus describes what the state of its interpreter tag 244 // is, relative to the thread currently holding the GIL. 245 enum class PyInterpreterStatus { 246 // We just allocated the Tensor, it hasn't escaped to other threads, 247 // we know that it definitely hasn't been tagged to be associated 248 // with an interpreter. 249 DEFINITELY_UNINITIALIZED, 250 // We queried the interpreter field and it looked uninitialized. But 251 // another thread may have raced with us to tag it with some other 252 // interpreter id. So we will have to do a CEX to make sure we can 253 // actually nab it. 254 MAYBE_UNINITIALIZED, 255 // We queried the interpreter field and it was tagged to belong to us. 256 // This means we have sole write access (as we hold the GIL for this 257 // interpreter) 258 TAGGED_BY_US, 259 // Someone else tagged this. We can't use this TensorImpl from Python. 260 TAGGED_BY_OTHER, 261 }; 262 263 } // namespace c10::impl 264