xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <stdexcept>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/synchronization/notification.h"
26 #include "pybind11/numpy.h"
27 #include "pybind11/pybind11.h"
28 #include "tensorflow/compiler/xla/python/py_client.h"
29 #include "tensorflow/compiler/xla/python/traceback.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // Python wrapper around PjRtBuffer. We use a wrapper class:
36 // a) to keep the PjRtClient alive via a std::shared_ptr<>
37 // b) to add Python-specific functionality.
38 //
39 // A `PyBuffer` can be used from Python without being wrapped in a Python
40 // `DeviceArray` object.
41 class PyBuffer {
42  public:
43   // pybind11::object typed subclass for PyBuffer objects.
44   class pyobject : public pybind11::object {
45    public:
46     PYBIND11_OBJECT(pyobject,  // NOLINT
47                     pybind11::object, PyBuffer::IsPyBuffer);
48     pyobject() = default;
buf()49     PyBuffer* buf() const { return PyBuffer::AsPyBufferUnchecked(*this); }
50   };
51   using object = pyobject;
52 
53   static object Make(std::shared_ptr<PyClient> client,
54                      std::shared_ptr<PjRtBuffer> buffer,
55                      std::shared_ptr<Traceback> traceback);
56 
57   // Returns true if `h` is a PyBuffer.
58   static bool IsPyBuffer(pybind11::handle handle);
59   // Converts `handle` to a PyBuffer*. Does not do any checking.
60   static PyBuffer* AsPyBufferUnchecked(pybind11::handle handle);
61   // Converts `handle` to a PyBuffer*. Returns an error status if
62   // !IsPyBuffer(handle)
63   static StatusOr<PyBuffer*> AsPyBuffer(pybind11::handle handle);
64 
65   // Gets a Python handle to an existing PyBuffer. Assumes the PyObject was
66   // allocated on the Python heap, which is the case if Make() was used.
67   pybind11::handle AsHandle();
68 
69   ~PyBuffer();
70 
client()71   std::shared_ptr<PyClient> client() const { return client_; }
buffer()72   PjRtBuffer* buffer() const { return buffer_.get(); }
shared_ptr_buffer()73   std::shared_ptr<PjRtBuffer> shared_ptr_buffer() const { return buffer_; }
74 
75   ClientAndPtr<PjRtDevice> device() const;
platform_name()76   absl::string_view platform_name() const {
77     return buffer_->client()->platform_name();
78   }
is_deleted()79   bool is_deleted() const { return buffer_->IsDeleted(); }
80 
81   StatusOr<pybind11::object> CopyToDevice(
82       const ClientAndPtr<PjRtDevice>& dst_device) const;
83   std::pair<Status, bool> CopyToRemoteDevice(
84       absl::string_view serialized_descriptor) const;
85 
OnDeviceSizeInBytes()86   StatusOr<size_t> OnDeviceSizeInBytes() {
87     return buffer_->GetOnDeviceSizeInBytes();
88   }
89 
Delete()90   void Delete() {
91     buffer_->Delete();
92     host_value_ = nullptr;
93   }
94 
95   // Makes a copy of this PyBuffer object that shares the underlying PjRtBuffer.
96   // This is useful because we may wish to change JAX metadata (e.g., the sticky
97   // device) without copying the buffer.
98   object Clone() const;
99 
100   // Returns xla::InvalidArgument if the buffer has been deleted.
101   // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`.
IsReady()102   StatusOr<bool> IsReady() {
103     if (buffer_->IsDeleted()) {
104       return InvalidArgument("DeviceArray has been deleted.");
105     }
106     return buffer_->GetReadyFuture().IsReady();
107   }
IsKnownReady()108   StatusOr<bool> IsKnownReady() {
109     if (buffer_->IsDeleted()) {
110       return InvalidArgument("DeviceArray has been deleted.");
111     }
112     return buffer_->GetReadyFuture().IsKnownReady();
113   }
114 
115   // Returns xla::InvalidArgument if the buffer has been deleted.
116   Status BlockHostUntilReady();
117   Status CopyToHostAsync();
118 
shape()119   const Shape& shape() { return buffer_->on_device_shape(); }
120 
121   StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
122 
123   // Implementation of the CUDA array interface for sharing GPU buffers with
124   // other Python libraries.
125   StatusOr<pybind11::dict> CudaArrayInterface();
126 
traceback()127   const std::shared_ptr<Traceback>& traceback() const { return traceback_; }
128 
129   // Returns the size (i.e. number of elements) of the (host) numpy array.
130   StatusOr<int64_t> size();
131 
132   // Returns the number of dimensions of the (host) numpy array.
ndim()133   int ndim() const { return buffer()->on_device_shape().dimensions_size(); }
134 
135   pybind11::tuple python_shape() const;
136   pybind11::dtype python_dtype() const;
137 
138   // Representing the logical view of the underlying dynamic shapes.
139   StatusOr<const Shape*> xla_dynamic_shape();
140 
set_sticky_device(PjRtDevice * sticky_device)141   Status set_sticky_device(PjRtDevice* sticky_device) {
142     TF_RET_CHECK(sticky_device == nullptr ||
143                  sticky_device == buffer_->device());
144     sticky_device_ = sticky_device;
145     return OkStatus();
146   }
sticky_device()147   PjRtDevice* sticky_device() const { return sticky_device_; }
148 
set_weak_type(std::optional<bool> weak_type)149   void set_weak_type(std::optional<bool> weak_type) { weak_type_ = weak_type; }
weak_type()150   std::optional<bool> weak_type() const { return weak_type_; }
151 
152   StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj);
153 
SetAval(pybind11::object aval)154   void SetAval(pybind11::object aval) { aval_ = aval; }
GetAval()155   pybind11::object GetAval() const { return aval_; }
156 
157   static Status RegisterTypes(pybind11::module& m);
base_type()158   static PyObject* base_type() { return base_type_; }
type()159   static PyObject* type() { return type_; }
160 
161  private:
162   // PyBuffer objects must not be allocated directly since they must always live
163   // on the Python heap. Use Make() instead.
164   PyBuffer(std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
165            std::shared_ptr<Traceback> traceback);
166 
167   static PyObject* base_type_;
168   static PyObject* type_;
169 
170   friend class PyClient;
171 
172   struct HostValue {
173     absl::Notification ready;
174     Status status;
175     std::shared_ptr<xla::Literal> value;
176   };
177   std::shared_ptr<PyClient> client_;
178   std::shared_ptr<PjRtBuffer> buffer_;
179   std::shared_ptr<Traceback> traceback_;
180   std::shared_ptr<HostValue> host_value_;  // Protected by the GIL.
181 
182   // JAX uses this field to record whether a buffer is committed to a particular
183   // device by the user (https://github.com/google/jax/pull/1916).
184   PjRtDevice* sticky_device_ = nullptr;
185 
186   // TODO(phawkins): consider not keeping an explicit aval on C++ buffer
187   // objects.
188   pybind11::object aval_ = pybind11::none();
189 
190   // An optional weak type. If absent, the JAX jit code computes the weak_type
191   // from the aval_.weak_type attribute. This is a backwards compatibility
192   // measure for older Python code that does not set weak_type explicitly.
193   // TODO(phawkins): drop support for older jax Python versions and make
194   // weak_type mandatory.
195   std::optional<bool> weak_type_ = std::nullopt;
196 
197   std::optional<Shape> dynamic_shape_ = std::nullopt;
198   // Doubly-linked list of all PyBuffers known to the client. Protected by the
199   // GIL. Since multiple PyBuffers may share the same PjRtBuffer, there may be
200   // duplicate PjRtBuffers in this list.
201   PyBuffer* next_;
202   PyBuffer* prev_;
203 };
204 
205 }  // namespace xla
206 
207 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
208