1 #pragma once 2 3 #include <c10/core/impl/PyInterpreter.h> 4 #include <c10/macros/Macros.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/python_stub.h> 7 8 #include <atomic> 9 10 namespace c10 { 11 12 // A PyHandleCache represents a cached pointer from a C++ object to 13 // a Python object that represents that object analogously in Python. 14 // Upon a cache hit, the relevant object can be retrieved after a test 15 // and then a memory load. Two conditions must hold to be able to use this 16 // class: 17 // 18 // - This must truly be a cache; e.g., the caller must be able to produce 19 // the object some other way if the cache hit misses. 20 // 21 // - This must truly be a handle; e.g., the Python object referenced by 22 // this class must have static lifetime. This means we don't have to 23 // maintain strong ownership or deallocate the object when the C++ object 24 // dies. Static lifetime is a good idea in conjunction with the cache, 25 // since if you are producing a fresh object on miss you won't be 26 // maintaining object identity. If you need bidirectional ownership, 27 // you will want to factor out the pattern in TensorImpl with 28 // resurrection. 29 // 30 // This cache is expected to not improve perf under torchdeploy, as one 31 // interpreter will fill up the cache, and all the interpreters will be 32 // unable to use the slot. A potential improvement is to have multiple 33 // slots (one per interpreter), which will work in deployment scenarios 34 // where there a stable, fixed number of interpreters. You can also store 35 // the relevant state in the Python library, rather than in the non-Python 36 // library (although in many cases, this is not convenient, as there may 37 // not be a way to conveniently index based on the object.) 38 class PyHandleCache { 39 public: PyHandleCache()40 PyHandleCache() : pyinterpreter_(nullptr) {} 41 42 // Attempt to fetch the pointer from the cache, if the PyInterpreter 43 // matches. If it doesn't exist, or the cache entry is not valid, 44 // use slow_accessor to get the real pointer value and return that 45 // (possibly writing it to the cache, if the cache entry is 46 // available.) 47 template <typename F> ptr_or(impl::PyInterpreter * self_interpreter,F slow_accessor)48 PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor) 49 const { 50 // Note [Memory ordering on Python interpreter tag] 51 impl::PyInterpreter* interpreter = 52 pyinterpreter_.load(std::memory_order_acquire); 53 if (C10_LIKELY(interpreter == self_interpreter)) { 54 return data_; 55 } else if (interpreter == nullptr) { 56 auto* r = slow_accessor(); 57 impl::PyInterpreter* expected = nullptr; 58 // attempt to claim this cache entry with the specified interpreter tag 59 if (pyinterpreter_.compare_exchange_strong( 60 expected, self_interpreter, std::memory_order_acq_rel)) { 61 data_ = r; 62 } 63 // This shouldn't be possible, as you should be GIL protected 64 TORCH_INTERNAL_ASSERT(expected != self_interpreter); 65 return r; 66 } else { 67 return slow_accessor(); 68 } 69 } 70 71 private: 72 mutable std::atomic<impl::PyInterpreter*> pyinterpreter_; 73 mutable PyObject* data_{nullptr}; 74 }; 75 76 } // namespace c10 77