xref: /aosp_15_r20/external/pytorch/c10/core/impl/PyObjectSlot.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/HermeticPyObjectTLS.h>
4 #include <c10/core/impl/PyInterpreter.h>
5 #include <c10/util/python_stub.h>
6 #include <optional>
7 
8 #include <atomic>
9 
10 namespace c10::impl {
11 
12 struct C10_API PyObjectSlot {
13  public:
14   PyObjectSlot();
15 
16   ~PyObjectSlot();
17 
18   void maybe_destroy_pyobj();
19 
20   // Associate the TensorImpl with the specified PyObject, and, if necessary,
21   // also tag the interpreter.
22   //
23   // NB: This lives in a header so that we can inline away the switch on status
24   //
25   // NB: THIS FUNCTION CAN RAISE AN EXCEPTION.  Make sure to clean up after
26   // PyObject if necessary!
init_pyobjPyObjectSlot27   void init_pyobj(
28       PyInterpreter* self_interpreter,
29       PyObject* pyobj,
30       PyInterpreterStatus status) {
31     impl::PyInterpreter* expected = nullptr;
32     switch (status) {
33       case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
34         // caller guarantees there is no multithreaded access; if there is
35         // no data race OK to do a relaxed store
36         pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed);
37         break;
38       case impl::PyInterpreterStatus::TAGGED_BY_US:
39         // no tagging is necessary, the tag is already correct
40         break;
41       case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
42         // attempt to claim this TensorImpl with the specified interpreter
43         // tag
44         if (pyobj_interpreter_.compare_exchange_strong(
45                 expected, self_interpreter, std::memory_order_acq_rel)) {
46           break;
47         }
48         // test if, actually, it was already tagged by us!  this situation can't
49         // be caused by a race, but it could be caused by a situation
50         // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
51         // (because they didn't pre-check the tag) when actually it was
52         // owned by the interpreter
53         if (expected == self_interpreter) {
54           break;
55         }
56         // fallthrough, we lost the race.  We are guaranteed not to lose the
57         // race with ourself, as calls to init_pyobj with the same interpreter
58         // ID must be sequentialized by the GIL
59         [[fallthrough]];
60       case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
61         TORCH_CHECK(
62             false,
63             "cannot allocate PyObject for Tensor on interpreter ",
64             self_interpreter,
65             " that has already been used by another torch deploy interpreter ",
66             pyobj_interpreter_.load());
67     }
68 
69     // we are the ONLY thread that can have gotten to this point.  It is not
70     // possible to conflict with another zero interpreter as access is protected
71     // by GIL
72     // NB: owns_pyobj tag is initially false
73     pyobj_ = pyobj;
74   }
75 
76   // Query the PyObject interpreter.  This may return null if there is no
77   // interpreter.  This is racy!
78   PyInterpreter* pyobj_interpreter();
79 
80   PyObject* _unchecked_untagged_pyobj() const;
81 
82   // Test the interpreter tag.  If tagged for the current interpreter, return
83   // a non-nullopt (but possibly null) PyObject.  If (possibly) untagged,
84   // returns a nullopt.  If it is definitely invalid, raises an error.
85   //
86   // If `ignore_hermetic_tls` is false and this function is called from a
87   // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
88   // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
89   // context is ignored, allowing you to check the interpreter tag of a
90   // nonhermetic PyObject from within a hermetic context. This is necessary
91   // because there are some cases where the deallocator function of a
92   // nonhermetic PyObject is called from within a hermetic context, so it must
93   // be properly treated as a nonhermetic PyObject.
94   //
95   // NB: this lives in header so that we can avoid actually creating the
96   // std::optional
97   std::optional<PyObject*> check_pyobj(
98       PyInterpreter* self_interpreter,
99       bool ignore_hermetic_tls = false) const {
100     // Note [Memory ordering on Python interpreter tag]
101     impl::PyInterpreter* interpreter =
102         pyobj_interpreter_.load(std::memory_order_acquire);
103     if (interpreter == nullptr) {
104       // NB: This never returns DEFINITELY_UNINITIALIZED because there is
105       // always the possibility that another thread races to initialize
106       // after we query here.  The only time when we can conclude a tensor
107       // is definitely uninitialized is when we have just allocated it and
108       // it cannot have escaped to other threads yet
109       return std::nullopt;
110     } else if (interpreter == self_interpreter) {
111       // NB: pyobj_ could still be null!
112       if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
113         return std::nullopt;
114       } else {
115         return std::make_optional(_unchecked_untagged_pyobj());
116       }
117     } else {
118       TORCH_CHECK(
119           false,
120           "cannot access PyObject for Tensor on interpreter ",
121           (*self_interpreter)->name(),
122           " that has already been used by another torch deploy interpreter ",
123           (*pyobj_interpreter_.load())->name());
124     }
125   }
126 
127   // Clear the PyObject field for an interpreter, in situations where we
128   // statically know the tensor is tagged with our interpreter.
129   void unchecked_clear_pyobj(PyInterpreter* interpreter);
130 
131   PyInterpreter& load_pyobj_interpreter() const;
132 
133   // Check if the PyObjectSlot's interpreter is the same as the specified
134   // interpreter
135   bool check_interpreter(PyInterpreter* interpreter);
136 
137   // Check if the PyObjectSlot is holding a PyObject, owned or non-owned
138   bool has_pyobj_nonhermetic();
139 
140   bool owns_pyobj();
141 
142   void set_owns_pyobj(bool b);
143 
144  private:
145   // This field contains the interpreter tag for this object.  See
146   // Note [Python interpreter tag] for general context
147   //
148   // Note [Memory ordering on Python interpreter tag]
149   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
150   // What memory_order do we need when accessing this atomic?  We don't
151   // need a single total modification order (as provided by
152   // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
153   // transition from -1 to some positive integer and never changes afterwards.
154   // Because there is only one modification, it trivially already has a total
155   // modification order (e.g., we don't need fences or locked instructions on
156   // x86)
157   //
158   // In fact, one could make a reasonable argument that relaxed reads are OK,
159   // due to the presence of external locking (GIL) to ensure that interactions
160   // with other data structures are still correctly synchronized, so that
161   // we fall in the "Single-Location Data Structures" case as described in
162   // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
163   // However, on x86, it doesn't matter if I use acquire or relaxed on the load
164   // as I get the same assembly in both cases.  So I just use the more
165   // conservative acquire (which will impede compiler optimizations but I don't
166   // care)
167   std::atomic<PyInterpreter*> pyobj_interpreter_;
168 
169   // This field contains a reference to a PyObject representing this Tensor.
170   // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
171   // PyObject for it and set this field.  This field does not have to be
172   // protected by an atomic as it is only allowed to be accessed when you hold
173   // the GIL, or during destruction of the tensor.
174   //
175   // When a PyObject dies, you are obligated to clear this field
176   // (otherwise, you will try to use-after-free the pyobj); this currently
177   // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
178   //
179   // NB: Ordinarily, this should not be a strong reference, as if the
180   // PyObject owns the Tensor, this would create a reference cycle.
181   // However, sometimes this ownership flips.  To track who owns
182   // who, this has a single pointer tag indicating whether or not the
183   // C++ object owns the PyObject (the common case, zero, means PyObject
184   // owns the C++ object); see _unchecked_untagged_pyobj for raw access
185   // or check_pyobj for checked access.  See references to PyObject
186   // resurrection in torch/csrc/autograd/python_variable.cpp
187   PyObject* pyobj_;
188 };
189 
190 } // namespace c10::impl
191