xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_buffer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17 
18 #include <functional>
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22 
23 #include "absl/base/casts.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/pytypes.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/python/py_client.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/python_utils.h"
30 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
31 #include "tensorflow/compiler/xla/python/types.h"
32 #include "tensorflow/compiler/xla/python/util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 
35 namespace xla {
36 
37 namespace py = pybind11;
38 
39 namespace {
40 
41 // Representation of a DeviceArrayBase as a Python object. Since
42 // a DeviceArrayBase has no fields, this is just a PyObject.
43 struct PyBufferBasePyObject {
44   PyObject_HEAD;
45 };
46 static_assert(std::is_standard_layout<PyBufferBasePyObject>::value,
47               "PyBufferBasePyObject must be standard layout");
48 
49 // Representation of a DeviceArray as a Python object.
50 struct PyBufferPyObject {
51   PyBufferBasePyObject base;
52   PyBuffer buffer;
53   // Used by the Python interpreter to maintain a list of weak references to
54   // this object.
55   PyObject* weakrefs;
56 };
57 static_assert(std::is_standard_layout<PyBufferPyObject>::value,
58               "PyBufferPyObject must be standard layout");
59 
PyBuffer_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)60 PyObject* PyBuffer_tp_new(PyTypeObject* subtype, PyObject* args,
61                           PyObject* kwds) {
62   PyBufferPyObject* self =
63       reinterpret_cast<PyBufferPyObject*>(subtype->tp_alloc(subtype, 0));
64   if (!self) return nullptr;
65   self->weakrefs = nullptr;
66   return reinterpret_cast<PyObject*>(self);
67 }
68 
PyBuffer_tp_dealloc(PyObject * self)69 void PyBuffer_tp_dealloc(PyObject* self) {
70   PyTypeObject* tp = Py_TYPE(self);
71   PyBufferPyObject* o = reinterpret_cast<PyBufferPyObject*>(self);
72   if (o->weakrefs) {
73     PyObject_ClearWeakRefs(self);
74   }
75   o->buffer.~PyBuffer();
76   tp->tp_free(self);
77   Py_DECREF(tp);
78 }
79 
80 }  // namespace
81 
Make(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)82 /*static*/ PyBuffer::object PyBuffer::Make(
83     std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
84     std::shared_ptr<Traceback> traceback) {
85   py::object obj = py::reinterpret_steal<py::object>(PyBuffer_tp_new(
86       reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
87   PyBufferPyObject* buf = reinterpret_cast<PyBufferPyObject*>(obj.ptr());
88   new (&buf->buffer)
89       PyBuffer(std::move(client), std::move(buffer), std::move(traceback));
90   return py::reinterpret_borrow<PyBuffer::object>(obj);
91 }
92 
IsPyBuffer(py::handle handle)93 bool PyBuffer::IsPyBuffer(py::handle handle) {
94   return handle.get_type() == PyBuffer::type();
95 }
96 
AsPyBufferUnchecked(pybind11::handle handle)97 /*static*/ PyBuffer* PyBuffer::AsPyBufferUnchecked(pybind11::handle handle) {
98   return &(reinterpret_cast<PyBufferPyObject*>(handle.ptr())->buffer);
99 }
100 
AsPyBuffer(pybind11::handle handle)101 /*static*/ StatusOr<PyBuffer*> PyBuffer::AsPyBuffer(pybind11::handle handle) {
102   if (!IsPyBuffer(handle)) {
103     return InvalidArgument("Expected a DeviceArray");
104   }
105   return AsPyBufferUnchecked(handle);
106 }
107 
AsHandle()108 py::handle PyBuffer::AsHandle() {
109   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
110                                      offsetof(PyBufferPyObject, buffer));
111 }
112 
PyBuffer(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)113 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
114                    std::shared_ptr<PjRtBuffer> buffer,
115                    std::shared_ptr<Traceback> traceback)
116     : client_(std::move(client)),
117       buffer_(std::move(buffer)),
118       traceback_(std::move(traceback)) {
119   CHECK(PyGILState_Check());
120   next_ = client_->buffers_[buffer_->device()->id()];
121   client_->buffers_[buffer_->device()->id()] = this;
122   prev_ = nullptr;
123   if (next_) {
124     next_->prev_ = this;
125   }
126 }
127 
~PyBuffer()128 PyBuffer::~PyBuffer() {
129   CHECK(PyGILState_Check());
130   if (client_->buffers_[device()->id()] == this) {
131     client_->buffers_[device()->id()] = next_;
132   }
133   if (prev_) {
134     prev_->next_ = next_;
135   }
136   if (next_) {
137     next_->prev_ = prev_;
138   }
139 }
140 
size()141 StatusOr<int64_t> PyBuffer::size() {
142   Shape max_buffer_shape = buffer()->on_device_shape();
143   if (max_buffer_shape.is_dynamic()) {
144     TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
145     return ShapeUtil::ElementsIn(*dynamic_shape);
146   }
147   return ShapeUtil::ElementsIn(max_buffer_shape);
148 }
149 
xla_dynamic_shape()150 StatusOr<const Shape*> PyBuffer::xla_dynamic_shape() {
151   CHECK(PyGILState_Check());
152   if (buffer_->on_device_shape().is_static()) {
153     return &buffer_->on_device_shape();
154   }
155   // Python buffer protocol references shape data by pointer, therefore we must
156   // store a valid copy of the shape.
157   if (!dynamic_shape_) {
158     Shape dynamic_shape;
159     {
160       py::gil_scoped_release gil_release;
161       TF_ASSIGN_OR_RETURN(dynamic_shape, buffer_->logical_on_device_shape());
162     }
163     dynamic_shape_ = dynamic_shape;
164   }
165   return &dynamic_shape_.value();
166 }
167 
python_shape() const168 pybind11::tuple PyBuffer::python_shape() const {
169   return SpanToTuple(buffer()->on_device_shape().dimensions());
170 }
171 
python_dtype() const172 pybind11::dtype PyBuffer::python_dtype() const {
173   PrimitiveType primitive = buffer()->on_device_shape().element_type();
174   return PrimitiveTypeToDtype(primitive).ValueOrDie();
175 }
176 
device() const177 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
178   return WrapWithClient(client_, buffer_->device());
179 }
180 
Clone() const181 PyBuffer::object PyBuffer::Clone() const {
182   auto buffer = Make(client_, buffer_, traceback_);
183   buffer.buf()->sticky_device_ = sticky_device_;
184   buffer.buf()->aval_ = aval_;
185   return buffer;
186 }
187 
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const188 StatusOr<py::object> PyBuffer::CopyToDevice(
189     const ClientAndPtr<PjRtDevice>& dst_device) const {
190   CHECK(dst_device.get() != nullptr);
191   auto transfer_guard_formatter = [this, &dst_device] {
192     auto shape = py::cast<std::string>(py::str(python_shape()));
193     auto dtype = py::cast<std::string>(py::str(python_dtype()));
194     return absl::StrCat("shape=", shape, ", dtype=", dtype,
195                         ", device=", device()->DebugString(),
196                         ", dst_device=", dst_device->DebugString());
197   };
198   TF_RETURN_IF_ERROR(
199       jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter));
200 
201   GlobalPyRefManager()->CollectGarbage();
202   std::unique_ptr<PjRtBuffer> out;
203   {
204     py::gil_scoped_release gil_release;
205     TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
206   }
207   auto traceback = Traceback::Get();
208   return Make(dst_device.client, std::move(out), std::move(traceback));
209 }
210 
CopyToRemoteDevice(absl::string_view serialized_descriptor) const211 std::pair<Status, bool> PyBuffer::CopyToRemoteDevice(
212     absl::string_view serialized_descriptor) const {
213   absl::Mutex mu;
214   bool done = false;
215   Status status;
216   bool sends_were_enqueued;
217   buffer_->CopyToRemoteDevice(
218       serialized_descriptor,
219       [&done, &status, &sends_were_enqueued, &mu](Status s, bool dispatched) {
220         absl::MutexLock l(&mu);
221         done = true;
222         status = s;
223         sends_were_enqueued = dispatched;
224       });
225   {
226     py::gil_scoped_release gil_release;
227     absl::MutexLock l(&mu);
228     mu.Await(absl::Condition(
229         +[](bool* done) { return *done; }, &done));
230   }
231   return std::make_pair(status, sends_were_enqueued);
232 }
233 
BlockHostUntilReady()234 Status PyBuffer::BlockHostUntilReady() {
235   GlobalPyRefManager()->CollectGarbage();
236   py::gil_scoped_release gil_release;
237   return buffer_->BlockHostUntilReady();
238 }
239 
CopyToHostAsync()240 Status PyBuffer::CopyToHostAsync() {
241   if (!buffer_->IsOnCpu() && !host_value_) {
242     auto transfer_guard_formatter = [this] {
243       auto shape = py::cast<std::string>(py::str(python_shape()));
244       auto dtype = py::cast<std::string>(py::str(python_dtype()));
245       return absl::StrCat("shape=", shape, ", dtype=", dtype,
246                           ", device=", device()->DebugString());
247     };
248     TF_RETURN_IF_ERROR(
249         jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));
250 
251     std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
252     host_value_ = host_value;
253     // TODO(b/182461453): This is a blocking call. If we further implemented
254     // populating dynamic shape metadata while fetching the literal, we wouldn't
255     // need this static approach.
256     TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
257 
258     py::gil_scoped_release gil;
259     host_value->value = std::make_shared<Literal>(
260         ShapeUtil::DeviceShapeToHostShape(*dynamic_shape));
261     Literal* literal = host_value->value.get();
262     buffer_->ToLiteral(literal,
263                        [host_value{std::move(host_value)}](Status status) {
264                          host_value->status = std::move(status);
265                          host_value->ready.Notify();
266                        });
267   }
268   return OkStatus();
269 }
270 
AsNumPyArray(py::handle this_obj)271 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
272   if (buffer_->IsDeleted()) {
273     return InvalidArgument("DeviceArray has been deleted.");
274   }
275   TF_RET_CHECK(buffer_->on_device_shape().IsArray());
276   // On CPU, we can return the value in a zero-copy way.
277   if (buffer_->IsOnCpu()) {
278     TF_ASSIGN_OR_RETURN(const auto* shape, xla_dynamic_shape());
279     TF_ASSIGN_OR_RETURN(py::dtype dtype,
280                         PrimitiveTypeToDtype(shape->element_type()));
281     // Objects that must be kept alive while the array is alive.
282     struct Hold {
283       py::object buffer;
284       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
285     };
286     auto hold = std::make_unique<Hold>();
287     TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
288                         buffer_->AcquireExternalReference());
289     hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
290     void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
291     py::capsule hold_capsule(hold.release(),
292                              [](void* h) { delete static_cast<Hold*>(h); });
293     py::array array(dtype, shape->dimensions(), ByteStridesForShape(*shape),
294                     data, hold_capsule);
295     array.attr("flags").attr("writeable") = Py_False;
296     {
297       py::gil_scoped_release gil;
298       TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
299     }
300     return array;
301   }
302 
303   TF_RETURN_IF_ERROR(CopyToHostAsync());
304   if (!host_value_->ready.HasBeenNotified()) {
305     py::gil_scoped_release gil;
306     host_value_->ready.WaitForNotification();
307   }
308   TF_RETURN_IF_ERROR(host_value_->status);
309   TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
310   array.attr("flags").attr("writeable") = Py_False;
311   return array;
312 }
313 
UnsafeBufferPointer() const314 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
315   return client_->pjrt_client()->UnsafeBufferPointer(buffer_.get());
316 }
317 
CudaArrayInterface()318 StatusOr<py::dict> PyBuffer::CudaArrayInterface() {
319   // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
320   if (buffer_->client()->platform_id() != GpuId()) {
321     return InvalidArgument(
322         "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
323   }
324   if (!buffer_->on_device_shape().IsArray()) {
325     return InvalidArgument(
326         "__cuda_array_interface__ is only defined for array buffers.");
327   }
328   if (buffer_->on_device_shape().element_type() == BF16) {
329     return InvalidArgument(
330         "__cuda_array_interface__ is not supported for bfloat16 buffers.");
331   }
332   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
333       buffer_->on_device_shape().layout()));
334 
335   py::dict result;
336   TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
337   result["shape"] = SpanToTuple(dynamic_shape->dimensions());
338   TF_ASSIGN_OR_RETURN(py::str typestr,
339                       TypeDescriptorForPrimitiveType(
340                           buffer_->on_device_shape().element_type()));
341   result["typestr"] = std::move(typestr);
342   TF_ASSIGN_OR_RETURN(
343       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
344       buffer_->AcquireExternalReference());
345   const void* root_ptr =
346       external_reference_hold->OpaqueDeviceMemoryDataPointer();
347   py::tuple data(2);
348   data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
349   data[1] = py::bool_(true);  // read-only
350   result["data"] = std::move(data);
351   result["version"] = py::int_(2);
352   return result;
353 }
354 
355 // PEP 3118 buffer protocol implementation.
356 
357 namespace {
358 
359 // Extra data to be kept alive by the consumer of the buffer protocol.
360 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anond20e1a500811::ExtraBufferInfo361   explicit ExtraBufferInfo(
362       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
363       : external_reference_hold(std::move(external_reference_hold)) {}
364 
365   std::string format;
366   std::vector<Py_ssize_t> strides;
367   // We keep an external reference hold to the PjRtBuffer. This prevents a
368   // use-after-free in the event that Delete() is called on a buffer with an
369   // live buffer protocol view. It does however mean that Delete() sometimes
370   // won't actually delete immediately.
371   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
372 };
373 
PyBuffer_bf_getbuffer(PyObject * exporter,Py_buffer * view,int flags)374 int PyBuffer_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) {
375   Status status = [&]() {
376     TF_ASSIGN_OR_RETURN(PyBuffer * py_buffer, PyBuffer::AsPyBuffer(exporter));
377     PjRtBuffer& buffer = *py_buffer->buffer();
378     TF_ASSIGN_OR_RETURN(const auto* shape, py_buffer->xla_dynamic_shape());
379     // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
380     // Additionally we call BlockHostUntilReady() below, which may block.
381     py::gil_scoped_release gil_release;
382 
383     if (!buffer.IsOnCpu()) {
384       return InvalidArgument(
385           "Python buffer protocol is only defined for CPU buffers.");
386     }
387     if (!buffer.on_device_shape().IsArray()) {
388       return InvalidArgument(
389           "Python buffer protocol is only defined for array buffers.");
390     }
391     // If we allowed exports of formatted BF16 buffers, consumers would get
392     // confused about the type because there is no way to describe BF16 to
393     // Python.
394     if (buffer.on_device_shape().element_type() == BF16 &&
395         ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
396       return InvalidArgument(
397           "bfloat16 buffer format not supported by Python buffer protocol.");
398     }
399     if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
400       return InvalidArgument("XLA buffers are read-only.");
401     }
402     TF_ASSIGN_OR_RETURN(
403         std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
404         buffer.AcquireExternalReference());
405     if (buffer.IsDeleted()) {
406       return InvalidArgument("Deleted buffer used in buffer protocol.");
407     }
408 
409     if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
410          (flags & PyBUF_STRIDES) == PyBUF_ND) &&
411         !LayoutUtil::IsMonotonicWithDim0Major(shape->layout())) {
412       return InvalidArgument("Buffer is not in C-contiguous layout.");
413     } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
414                !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
415       return InvalidArgument("Buffer is not in F-contiguous layout.");
416     } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
417                !LayoutUtil::IsMonotonicWithDim0Major(shape->layout()) &&
418                !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
419       return InvalidArgument("Buffer is not in contiguous layout.");
420     }
421     std::memset(view, 0, sizeof(Py_buffer));
422     const void* root_ptr =
423         external_reference_hold->OpaqueDeviceMemoryDataPointer();
424     view->buf = const_cast<void*>(root_ptr);
425     auto extra =
426         std::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
427     view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape->element_type());
428     view->len = ShapeUtil::ByteSizeOf(*shape);
429     view->readonly = 1;
430     if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
431       TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
432                                              shape->element_type()));
433       view->format = const_cast<char*>(extra->format.c_str());
434     }
435     if ((flags & PyBUF_ND) == PyBUF_ND) {
436       view->ndim = shape->dimensions_size();
437       static_assert(sizeof(int64_t) == sizeof(Py_ssize_t),
438                     "Py_ssize_t must be 64 bits");
439       if (view->ndim != 0) {
440         view->shape = reinterpret_cast<Py_ssize_t*>(
441             const_cast<int64_t*>(shape->dimensions().data()));
442         if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
443           extra->strides = ByteStridesForShape(*shape);
444           view->strides = extra->strides.data();
445         }
446       }
447     }
448     TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
449     view->internal = extra.release();
450     return OkStatus();
451   }();
452   if (!status.ok()) {
453     // numpy.asarray(...) silents the PyExc_BufferError. Adding a log here helps
454     // debugging when the error really occurs.
455     VLOG(1) << "Buffer Protocol Error: " << status;
456     PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
457     return -1;
458   }
459   view->obj = exporter;
460   Py_INCREF(view->obj);
461   return 0;
462 }
463 
PyBuffer_bf_releasebuffer(PyObject *,Py_buffer * buffer)464 void PyBuffer_bf_releasebuffer(PyObject*, Py_buffer* buffer) {
465   auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
466   delete extra;
467 }
468 
__anond20e1a500a02() 469 PyBufferProcs PyBuffer_tp_as_buffer = []() {
470   PyBufferProcs procs;
471   procs.bf_getbuffer = &PyBuffer_bf_getbuffer;
472   procs.bf_releasebuffer = &PyBuffer_bf_releasebuffer;
473   return procs;
474 }();
475 
476 }  // namespace
477 
478 PyObject* PyBuffer::base_type_ = nullptr;
479 PyObject* PyBuffer::type_ = nullptr;
480 
RegisterTypes(py::module & m)481 Status PyBuffer::RegisterTypes(py::module& m) {
482   // We do not use pybind11::class_ to build Python wrapper objects because
483   // creation, destruction, and casting of buffer objects is performance
484   // critical. By using hand-written Python classes, we can avoid extra C heap
485   // allocations, and we can avoid pybind11's slow cast<>() implementation
486   // during jit dispatch.
487 
488   // We need to use heap-allocated type objects because we want to add
489   // additional methods dynamically.
490   {
491     py::str name = py::str("DeviceArrayBase");
492     py::str qualname = py::str("DeviceArrayBase");
493     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
494         PyType_Type.tp_alloc(&PyType_Type, 0));
495     // Caution: we must not call any functions that might invoke the GC until
496     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
497     // type object.
498     if (!heap_type) {
499       return Internal("Unable to create heap type object");
500     }
501     heap_type->ht_name = name.release().ptr();
502     heap_type->ht_qualname = qualname.release().ptr();
503     PyTypeObject* type = &heap_type->ht_type;
504     type->tp_name = "DeviceArrayBase";
505     type->tp_basicsize = sizeof(PyBufferBasePyObject);
506     type->tp_flags =
507         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
508     TF_RET_CHECK(PyType_Ready(type) == 0);
509     base_type_ = reinterpret_cast<PyObject*>(type);
510   }
511   py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
512   base_type.attr("__module__") = m.attr("__name__");
513   m.attr("DeviceArrayBase") = base_type;
514 
515   {
516     py::tuple bases = py::make_tuple(base_type);
517     py::str name = py::str("DeviceArray");
518     py::str qualname = py::str("DeviceArray");
519     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
520         PyType_Type.tp_alloc(&PyType_Type, 0));
521     // Caution: we must not call any functions that might invoke the GC until
522     // PyType_Ready() is called below. Otherwise the GC might see a
523     // half-constructed type object.
524     if (!heap_type) {
525       return Internal("Unable to create heap type object");
526     }
527     heap_type->ht_name = name.release().ptr();
528     heap_type->ht_qualname = qualname.release().ptr();
529     PyTypeObject* type = &heap_type->ht_type;
530     type->tp_name = "DeviceArray";
531     type->tp_basicsize = sizeof(PyBufferPyObject);
532     type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
533     type->tp_bases = bases.release().ptr();
534     type->tp_dealloc = PyBuffer_tp_dealloc;
535     type->tp_new = PyBuffer_tp_new;
536     // Supported protocols
537     type->tp_as_number = &heap_type->as_number;
538     type->tp_as_sequence = &heap_type->as_sequence;
539     type->tp_as_mapping = &heap_type->as_mapping;
540     type->tp_as_buffer = &PyBuffer_tp_as_buffer;
541 
542     // Allow weak references to DeviceArray objects.
543     type->tp_weaklistoffset = offsetof(PyBufferPyObject, weakrefs);
544 
545     TF_RET_CHECK(PyType_Ready(type) == 0);
546     type_ = reinterpret_cast<PyObject*>(type);
547   }
548   py::object type = py::reinterpret_borrow<py::object>(type_);
549   m.attr("DeviceArray") = type;
550   m.attr("PyLocalBuffer") = type;
551   m.attr("Buffer") = type;
552 
553   // Add methods and properties to the class. We use pybind11 and add methods
554   // dynamically mostly because this is easy to write and allows us to use
555   // pybind11's casting logic. This is most likely slightly slower than
556   // hand-writing bindings, but most of these methods are not performance
557   // critical.
558   using jax::property;
559   using jax::property_readonly;
560   type.attr("__array_priority__") =
561       property_readonly([](py::object self) -> int { return 100; });
562   type.attr("_device") = property(
563       [](PyBuffer::object self) -> ClientAndPtr<PjRtDevice> {
564         return WrapWithClient(self.buf()->client(),
565                               self.buf()->sticky_device());
566       },
567       [](PyBuffer::object self, PjRtDevice* sticky_device) {
568         return self.buf()->set_sticky_device(sticky_device);
569       });
570   type.attr("aval") = property(
571       [](PyBuffer::object self) -> py::object { return self.buf()->GetAval(); },
572       [](PyBuffer::object self, py::object aval) {
573         return self.buf()->SetAval(std::move(aval));
574       });
575   type.attr("weak_type") = property(
576       [](PyBuffer::object self) -> std::optional<bool> {
577         return self.buf()->weak_type();
578       },
579       [](PyBuffer::object self, std::optional<bool> weak_type) {
580         return self.buf()->set_weak_type(weak_type);
581       });
582   type.attr("device_buffer") =
583       property_readonly([](py::object self) { return self; });
584   type.attr(
585       "shape") = property_readonly([](PyBuffer::object self) -> py::tuple {
586     return SpanToTuple(self.buf()->buffer()->on_device_shape().dimensions());
587   });
588   type.attr("dtype") = property_readonly([](PyBuffer::object self) {
589     PrimitiveType primitive =
590         self.buf()->buffer()->on_device_shape().element_type();
591     return PrimitiveTypeToDtype(primitive).ValueOrDie();
592   });
593   type.attr("size") =
594       property_readonly([](PyBuffer::object self) -> StatusOr<int64_t> {
595         return self.buf()->size();
596       });
597   type.attr("ndim") = property_readonly(
598       [](PyBuffer::object self) -> int { return self.buf()->ndim(); });
599   type.attr("_value") = property_readonly(
600       [](PyBuffer::object self) -> StatusOr<pybind11::object> {
601         GlobalPyRefManager()->CollectGarbage();
602         return self.buf()->AsNumPyArray(self);
603       });
604   type.attr("copy_to_device") = py::cpp_function(
605       [](PyBuffer::object self, const ClientAndPtr<PjRtDevice>& dst_device) {
606         return self.buf()->CopyToDevice(dst_device);
607       },
608       py::is_method(type));
609   type.attr("copy_to_remote_device") = py::cpp_function(
610       [](PyBuffer::object self, const py::bytes serialized_descriptor) {
611         // TODO(phawkins): remove the std::string cast after C++17 is required.
612         // py::bytes has a std::string_view cast, but not an absl::string_view
613         // cast.
614         return self.buf()->CopyToRemoteDevice(
615             static_cast<std::string>(serialized_descriptor));
616       },
617       py::is_method(type));
618 
619   type.attr("on_device_size_in_bytes") = py::cpp_function(
620       [](PyBuffer::object self) -> StatusOr<size_t> {
621         return self.buf()->OnDeviceSizeInBytes();
622       },
623       py::is_method(type));
624   type.attr("delete") = py::cpp_function(
625       [](PyBuffer::object self) { self.buf()->Delete(); }, py::is_method(type));
626   type.attr("block_host_until_ready") = py::cpp_function(
627       [](PyBuffer::object self) {
628         // TODO(phawkins): remove 3 months after the release of jaxlib >= 0.3.2.
629         PythonDeprecationWarning(
630             "block_host_until_ready() on a JAX array object is deprecated, use "
631             "block_until_ready() instead.");
632         return self.buf()->BlockHostUntilReady();
633       },
634       py::is_method(type));
635   type.attr("is_ready") = py::cpp_function(
636       [](PyBuffer::object self) { return self.buf()->IsReady(); },
637       py::is_method(type));
638   type.attr("is_known_ready") = py::cpp_function(
639       [](PyBuffer::object self) { return self.buf()->IsKnownReady(); },
640       py::is_method(type));
641   type.attr("block_until_ready") = py::cpp_function(
642       [](PyBuffer::object self) -> StatusOr<PyBuffer::object> {
643         TF_RETURN_IF_ERROR(self.buf()->BlockHostUntilReady());
644         return std::move(self);
645       },
646       py::is_method(type));
647   type.attr("copy_to_host_async") = py::cpp_function(
648       [](PyBuffer::object self) { return self.buf()->CopyToHostAsync(); },
649       py::is_method(type));
650   type.attr("to_py") = py::cpp_function(
651       [](PyBuffer::object self) { return self.buf()->AsNumPyArray(self); },
652       py::is_method(type));
653   type.attr("xla_shape") = py::cpp_function(
654       [](PyBuffer::object self) { return self.buf()->shape(); },
655       py::is_method(type));
656   type.attr("xla_dynamic_shape") = py::cpp_function(
657       [](PyBuffer::object self) { return self.buf()->xla_dynamic_shape(); },
658       py::is_method(type));
659   type.attr("client") = property_readonly(
660       [](PyBuffer::object self) { return self.buf()->client(); });
661   type.attr("device") = py::cpp_function(
662       [](PyBuffer::object self) { return self.buf()->device(); },
663       py::is_method(type));
664   type.attr("platform") = py::cpp_function(
665       [](PyBuffer::object self) { return self.buf()->platform_name(); },
666       py::is_method(type));
667   type.attr("is_deleted") = py::cpp_function(
668       [](PyBuffer::object self) { return self.buf()->is_deleted(); },
669       py::is_method(type));
670   type.attr("unsafe_buffer_pointer") = py::cpp_function(
671       [](PyBuffer::object self) { return self.buf()->UnsafeBufferPointer(); },
672       py::is_method(type));
673   type.attr("__cuda_array_interface__") = property_readonly(
674       [](PyBuffer::object self) { return self.buf()->CudaArrayInterface(); });
675   type.attr("traceback") = property_readonly(
676       [](PyBuffer::object self) { return self.buf()->traceback(); });
677   type.attr("clone") = py::cpp_function(
678       [](PyBuffer::object self) { return self.buf()->Clone(); },
679       py::is_method(type));
680   type.attr("__module__") = m.attr("__name__");
681   return OkStatus();
682 }
683 
684 }  // namespace xla
685