xref: /aosp_15_r20/external/pytorch/c10/core/SafePyObject.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/PyInterpreter.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/python_stub.h>
6 #include <utility>
7 
8 namespace c10 {
9 
10 // This is an safe owning holder for a PyObject, akin to pybind11's
11 // py::object, with two major differences:
12 //
13 //  - It is in c10/core; i.e., you can use this type in contexts where
14 //    you do not have a libpython dependency
15 //
16 //  - It is multi-interpreter safe (ala torchdeploy); when you fetch
17 //    the underlying PyObject* you are required to specify what the current
18 //    interpreter context is and we will check that you match it.
19 //
20 // It is INVALID to store a reference to a Tensor object in this way;
21 // you should just use TensorImpl directly in that case!
22 struct C10_API SafePyObject {
23   // Steals a reference to data
SafePyObjectSafePyObject24   SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
25       : data_(data), pyinterpreter_(pyinterpreter) {}
SafePyObjectSafePyObject26   SafePyObject(SafePyObject&& other) noexcept
27       : data_(std::exchange(other.data_, nullptr)),
28         pyinterpreter_(other.pyinterpreter_) {}
29   // For now it's not used, so we just disallow it.
30   SafePyObject& operator=(SafePyObject&&) = delete;
31 
SafePyObjectSafePyObject32   SafePyObject(SafePyObject const& other)
33       : data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
34     if (data_ != nullptr) {
35       (*pyinterpreter_)->incref(data_);
36     }
37   }
38 
39   SafePyObject& operator=(SafePyObject const& other) {
40     if (this == &other) {
41       return *this; // Handle self-assignment
42     }
43     if (other.data_ != nullptr) {
44       (*other.pyinterpreter_)->incref(other.data_);
45     }
46     if (data_ != nullptr) {
47       (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
48     }
49     data_ = other.data_;
50     pyinterpreter_ = other.pyinterpreter_;
51     return *this;
52   }
53 
~SafePyObjectSafePyObject54   ~SafePyObject() {
55     if (data_ != nullptr) {
56       (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
57     }
58   }
59 
pyinterpreterSafePyObject60   c10::impl::PyInterpreter& pyinterpreter() const {
61     return *pyinterpreter_;
62   }
63   PyObject* ptr(const c10::impl::PyInterpreter*) const;
64 
65   // stop tracking the current object, and return it
releaseSafePyObject66   PyObject* release() {
67     auto rv = data_;
68     data_ = nullptr;
69     return rv;
70   }
71 
72  private:
73   PyObject* data_;
74   c10::impl::PyInterpreter* pyinterpreter_;
75 };
76 
77 // A newtype wrapper around SafePyObject for type safety when a python object
78 // represents a specific type. Note that `T` is only used as a tag and isn't
79 // actually used for any true purpose.
80 template <typename T>
81 struct SafePyObjectT : private SafePyObject {
SafePyObjectTSafePyObjectT82   SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
83       : SafePyObject(data, pyinterpreter) {}
SafePyObjectTSafePyObjectT84   SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {}
85   SafePyObjectT(SafePyObjectT const&) = delete;
86   SafePyObjectT& operator=(SafePyObjectT const&) = delete;
87 
88   using SafePyObject::ptr;
89   using SafePyObject::pyinterpreter;
90   using SafePyObject::release;
91 };
92 
93 // Like SafePyObject, but non-owning.  Good for references to global PyObjects
94 // that will be leaked on interpreter exit.  You get a copy constructor/assign
95 // this way.
96 struct C10_API SafePyHandle {
SafePyHandleSafePyHandle97   SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
SafePyHandleSafePyHandle98   SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
99       : data_(data), pyinterpreter_(pyinterpreter) {}
100 
pyinterpreterSafePyHandle101   c10::impl::PyInterpreter& pyinterpreter() const {
102     return *pyinterpreter_;
103   }
104   PyObject* ptr(const c10::impl::PyInterpreter*) const;
resetSafePyHandle105   void reset() {
106     data_ = nullptr;
107     pyinterpreter_ = nullptr;
108   }
109   operator bool() {
110     return data_;
111   }
112 
113  private:
114   PyObject* data_;
115   c10::impl::PyInterpreter* pyinterpreter_;
116 };
117 
118 } // namespace c10
119