1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17
18 #include <functional>
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22
23 #include "absl/base/casts.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/pytypes.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/python/py_client.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/python_utils.h"
30 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
31 #include "tensorflow/compiler/xla/python/types.h"
32 #include "tensorflow/compiler/xla/python/util.h"
33 #include "tensorflow/compiler/xla/util.h"
34
35 namespace xla {
36
37 namespace py = pybind11;
38
39 namespace {
40
41 // Representation of a DeviceArrayBase as a Python object. Since
42 // a DeviceArrayBase has no fields, this is just a PyObject.
43 struct PyBufferBasePyObject {
44 PyObject_HEAD;
45 };
46 static_assert(std::is_standard_layout<PyBufferBasePyObject>::value,
47 "PyBufferBasePyObject must be standard layout");
48
49 // Representation of a DeviceArray as a Python object.
50 struct PyBufferPyObject {
51 PyBufferBasePyObject base;
52 PyBuffer buffer;
53 // Used by the Python interpreter to maintain a list of weak references to
54 // this object.
55 PyObject* weakrefs;
56 };
57 static_assert(std::is_standard_layout<PyBufferPyObject>::value,
58 "PyBufferPyObject must be standard layout");
59
PyBuffer_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)60 PyObject* PyBuffer_tp_new(PyTypeObject* subtype, PyObject* args,
61 PyObject* kwds) {
62 PyBufferPyObject* self =
63 reinterpret_cast<PyBufferPyObject*>(subtype->tp_alloc(subtype, 0));
64 if (!self) return nullptr;
65 self->weakrefs = nullptr;
66 return reinterpret_cast<PyObject*>(self);
67 }
68
PyBuffer_tp_dealloc(PyObject * self)69 void PyBuffer_tp_dealloc(PyObject* self) {
70 PyTypeObject* tp = Py_TYPE(self);
71 PyBufferPyObject* o = reinterpret_cast<PyBufferPyObject*>(self);
72 if (o->weakrefs) {
73 PyObject_ClearWeakRefs(self);
74 }
75 o->buffer.~PyBuffer();
76 tp->tp_free(self);
77 Py_DECREF(tp);
78 }
79
80 } // namespace
81
Make(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)82 /*static*/ PyBuffer::object PyBuffer::Make(
83 std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
84 std::shared_ptr<Traceback> traceback) {
85 py::object obj = py::reinterpret_steal<py::object>(PyBuffer_tp_new(
86 reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
87 PyBufferPyObject* buf = reinterpret_cast<PyBufferPyObject*>(obj.ptr());
88 new (&buf->buffer)
89 PyBuffer(std::move(client), std::move(buffer), std::move(traceback));
90 return py::reinterpret_borrow<PyBuffer::object>(obj);
91 }
92
IsPyBuffer(py::handle handle)93 bool PyBuffer::IsPyBuffer(py::handle handle) {
94 return handle.get_type() == PyBuffer::type();
95 }
96
AsPyBufferUnchecked(pybind11::handle handle)97 /*static*/ PyBuffer* PyBuffer::AsPyBufferUnchecked(pybind11::handle handle) {
98 return &(reinterpret_cast<PyBufferPyObject*>(handle.ptr())->buffer);
99 }
100
AsPyBuffer(pybind11::handle handle)101 /*static*/ StatusOr<PyBuffer*> PyBuffer::AsPyBuffer(pybind11::handle handle) {
102 if (!IsPyBuffer(handle)) {
103 return InvalidArgument("Expected a DeviceArray");
104 }
105 return AsPyBufferUnchecked(handle);
106 }
107
AsHandle()108 py::handle PyBuffer::AsHandle() {
109 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
110 offsetof(PyBufferPyObject, buffer));
111 }
112
PyBuffer(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)113 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
114 std::shared_ptr<PjRtBuffer> buffer,
115 std::shared_ptr<Traceback> traceback)
116 : client_(std::move(client)),
117 buffer_(std::move(buffer)),
118 traceback_(std::move(traceback)) {
119 CHECK(PyGILState_Check());
120 next_ = client_->buffers_[buffer_->device()->id()];
121 client_->buffers_[buffer_->device()->id()] = this;
122 prev_ = nullptr;
123 if (next_) {
124 next_->prev_ = this;
125 }
126 }
127
~PyBuffer()128 PyBuffer::~PyBuffer() {
129 CHECK(PyGILState_Check());
130 if (client_->buffers_[device()->id()] == this) {
131 client_->buffers_[device()->id()] = next_;
132 }
133 if (prev_) {
134 prev_->next_ = next_;
135 }
136 if (next_) {
137 next_->prev_ = prev_;
138 }
139 }
140
size()141 StatusOr<int64_t> PyBuffer::size() {
142 Shape max_buffer_shape = buffer()->on_device_shape();
143 if (max_buffer_shape.is_dynamic()) {
144 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
145 return ShapeUtil::ElementsIn(*dynamic_shape);
146 }
147 return ShapeUtil::ElementsIn(max_buffer_shape);
148 }
149
xla_dynamic_shape()150 StatusOr<const Shape*> PyBuffer::xla_dynamic_shape() {
151 CHECK(PyGILState_Check());
152 if (buffer_->on_device_shape().is_static()) {
153 return &buffer_->on_device_shape();
154 }
155 // Python buffer protocol references shape data by pointer, therefore we must
156 // store a valid copy of the shape.
157 if (!dynamic_shape_) {
158 Shape dynamic_shape;
159 {
160 py::gil_scoped_release gil_release;
161 TF_ASSIGN_OR_RETURN(dynamic_shape, buffer_->logical_on_device_shape());
162 }
163 dynamic_shape_ = dynamic_shape;
164 }
165 return &dynamic_shape_.value();
166 }
167
python_shape() const168 pybind11::tuple PyBuffer::python_shape() const {
169 return SpanToTuple(buffer()->on_device_shape().dimensions());
170 }
171
python_dtype() const172 pybind11::dtype PyBuffer::python_dtype() const {
173 PrimitiveType primitive = buffer()->on_device_shape().element_type();
174 return PrimitiveTypeToDtype(primitive).ValueOrDie();
175 }
176
device() const177 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
178 return WrapWithClient(client_, buffer_->device());
179 }
180
Clone() const181 PyBuffer::object PyBuffer::Clone() const {
182 auto buffer = Make(client_, buffer_, traceback_);
183 buffer.buf()->sticky_device_ = sticky_device_;
184 buffer.buf()->aval_ = aval_;
185 return buffer;
186 }
187
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const188 StatusOr<py::object> PyBuffer::CopyToDevice(
189 const ClientAndPtr<PjRtDevice>& dst_device) const {
190 CHECK(dst_device.get() != nullptr);
191 auto transfer_guard_formatter = [this, &dst_device] {
192 auto shape = py::cast<std::string>(py::str(python_shape()));
193 auto dtype = py::cast<std::string>(py::str(python_dtype()));
194 return absl::StrCat("shape=", shape, ", dtype=", dtype,
195 ", device=", device()->DebugString(),
196 ", dst_device=", dst_device->DebugString());
197 };
198 TF_RETURN_IF_ERROR(
199 jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter));
200
201 GlobalPyRefManager()->CollectGarbage();
202 std::unique_ptr<PjRtBuffer> out;
203 {
204 py::gil_scoped_release gil_release;
205 TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
206 }
207 auto traceback = Traceback::Get();
208 return Make(dst_device.client, std::move(out), std::move(traceback));
209 }
210
CopyToRemoteDevice(absl::string_view serialized_descriptor) const211 std::pair<Status, bool> PyBuffer::CopyToRemoteDevice(
212 absl::string_view serialized_descriptor) const {
213 absl::Mutex mu;
214 bool done = false;
215 Status status;
216 bool sends_were_enqueued;
217 buffer_->CopyToRemoteDevice(
218 serialized_descriptor,
219 [&done, &status, &sends_were_enqueued, &mu](Status s, bool dispatched) {
220 absl::MutexLock l(&mu);
221 done = true;
222 status = s;
223 sends_were_enqueued = dispatched;
224 });
225 {
226 py::gil_scoped_release gil_release;
227 absl::MutexLock l(&mu);
228 mu.Await(absl::Condition(
229 +[](bool* done) { return *done; }, &done));
230 }
231 return std::make_pair(status, sends_were_enqueued);
232 }
233
BlockHostUntilReady()234 Status PyBuffer::BlockHostUntilReady() {
235 GlobalPyRefManager()->CollectGarbage();
236 py::gil_scoped_release gil_release;
237 return buffer_->BlockHostUntilReady();
238 }
239
CopyToHostAsync()240 Status PyBuffer::CopyToHostAsync() {
241 if (!buffer_->IsOnCpu() && !host_value_) {
242 auto transfer_guard_formatter = [this] {
243 auto shape = py::cast<std::string>(py::str(python_shape()));
244 auto dtype = py::cast<std::string>(py::str(python_dtype()));
245 return absl::StrCat("shape=", shape, ", dtype=", dtype,
246 ", device=", device()->DebugString());
247 };
248 TF_RETURN_IF_ERROR(
249 jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));
250
251 std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
252 host_value_ = host_value;
253 // TODO(b/182461453): This is a blocking call. If we further implemented
254 // populating dynamic shape metadata while fetching the literal, we wouldn't
255 // need this static approach.
256 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
257
258 py::gil_scoped_release gil;
259 host_value->value = std::make_shared<Literal>(
260 ShapeUtil::DeviceShapeToHostShape(*dynamic_shape));
261 Literal* literal = host_value->value.get();
262 buffer_->ToLiteral(literal,
263 [host_value{std::move(host_value)}](Status status) {
264 host_value->status = std::move(status);
265 host_value->ready.Notify();
266 });
267 }
268 return OkStatus();
269 }
270
AsNumPyArray(py::handle this_obj)271 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
272 if (buffer_->IsDeleted()) {
273 return InvalidArgument("DeviceArray has been deleted.");
274 }
275 TF_RET_CHECK(buffer_->on_device_shape().IsArray());
276 // On CPU, we can return the value in a zero-copy way.
277 if (buffer_->IsOnCpu()) {
278 TF_ASSIGN_OR_RETURN(const auto* shape, xla_dynamic_shape());
279 TF_ASSIGN_OR_RETURN(py::dtype dtype,
280 PrimitiveTypeToDtype(shape->element_type()));
281 // Objects that must be kept alive while the array is alive.
282 struct Hold {
283 py::object buffer;
284 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
285 };
286 auto hold = std::make_unique<Hold>();
287 TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
288 buffer_->AcquireExternalReference());
289 hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
290 void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
291 py::capsule hold_capsule(hold.release(),
292 [](void* h) { delete static_cast<Hold*>(h); });
293 py::array array(dtype, shape->dimensions(), ByteStridesForShape(*shape),
294 data, hold_capsule);
295 array.attr("flags").attr("writeable") = Py_False;
296 {
297 py::gil_scoped_release gil;
298 TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
299 }
300 return array;
301 }
302
303 TF_RETURN_IF_ERROR(CopyToHostAsync());
304 if (!host_value_->ready.HasBeenNotified()) {
305 py::gil_scoped_release gil;
306 host_value_->ready.WaitForNotification();
307 }
308 TF_RETURN_IF_ERROR(host_value_->status);
309 TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
310 array.attr("flags").attr("writeable") = Py_False;
311 return array;
312 }
313
UnsafeBufferPointer() const314 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
315 return client_->pjrt_client()->UnsafeBufferPointer(buffer_.get());
316 }
317
CudaArrayInterface()318 StatusOr<py::dict> PyBuffer::CudaArrayInterface() {
319 // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
320 if (buffer_->client()->platform_id() != GpuId()) {
321 return InvalidArgument(
322 "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
323 }
324 if (!buffer_->on_device_shape().IsArray()) {
325 return InvalidArgument(
326 "__cuda_array_interface__ is only defined for array buffers.");
327 }
328 if (buffer_->on_device_shape().element_type() == BF16) {
329 return InvalidArgument(
330 "__cuda_array_interface__ is not supported for bfloat16 buffers.");
331 }
332 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
333 buffer_->on_device_shape().layout()));
334
335 py::dict result;
336 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
337 result["shape"] = SpanToTuple(dynamic_shape->dimensions());
338 TF_ASSIGN_OR_RETURN(py::str typestr,
339 TypeDescriptorForPrimitiveType(
340 buffer_->on_device_shape().element_type()));
341 result["typestr"] = std::move(typestr);
342 TF_ASSIGN_OR_RETURN(
343 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
344 buffer_->AcquireExternalReference());
345 const void* root_ptr =
346 external_reference_hold->OpaqueDeviceMemoryDataPointer();
347 py::tuple data(2);
348 data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
349 data[1] = py::bool_(true); // read-only
350 result["data"] = std::move(data);
351 result["version"] = py::int_(2);
352 return result;
353 }
354
355 // PEP 3118 buffer protocol implementation.
356
357 namespace {
358
359 // Extra data to be kept alive by the consumer of the buffer protocol.
360 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anond20e1a500811::ExtraBufferInfo361 explicit ExtraBufferInfo(
362 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
363 : external_reference_hold(std::move(external_reference_hold)) {}
364
365 std::string format;
366 std::vector<Py_ssize_t> strides;
367 // We keep an external reference hold to the PjRtBuffer. This prevents a
368 // use-after-free in the event that Delete() is called on a buffer with an
369 // live buffer protocol view. It does however mean that Delete() sometimes
370 // won't actually delete immediately.
371 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
372 };
373
PyBuffer_bf_getbuffer(PyObject * exporter,Py_buffer * view,int flags)374 int PyBuffer_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) {
375 Status status = [&]() {
376 TF_ASSIGN_OR_RETURN(PyBuffer * py_buffer, PyBuffer::AsPyBuffer(exporter));
377 PjRtBuffer& buffer = *py_buffer->buffer();
378 TF_ASSIGN_OR_RETURN(const auto* shape, py_buffer->xla_dynamic_shape());
379 // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
380 // Additionally we call BlockHostUntilReady() below, which may block.
381 py::gil_scoped_release gil_release;
382
383 if (!buffer.IsOnCpu()) {
384 return InvalidArgument(
385 "Python buffer protocol is only defined for CPU buffers.");
386 }
387 if (!buffer.on_device_shape().IsArray()) {
388 return InvalidArgument(
389 "Python buffer protocol is only defined for array buffers.");
390 }
391 // If we allowed exports of formatted BF16 buffers, consumers would get
392 // confused about the type because there is no way to describe BF16 to
393 // Python.
394 if (buffer.on_device_shape().element_type() == BF16 &&
395 ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
396 return InvalidArgument(
397 "bfloat16 buffer format not supported by Python buffer protocol.");
398 }
399 if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
400 return InvalidArgument("XLA buffers are read-only.");
401 }
402 TF_ASSIGN_OR_RETURN(
403 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
404 buffer.AcquireExternalReference());
405 if (buffer.IsDeleted()) {
406 return InvalidArgument("Deleted buffer used in buffer protocol.");
407 }
408
409 if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
410 (flags & PyBUF_STRIDES) == PyBUF_ND) &&
411 !LayoutUtil::IsMonotonicWithDim0Major(shape->layout())) {
412 return InvalidArgument("Buffer is not in C-contiguous layout.");
413 } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
414 !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
415 return InvalidArgument("Buffer is not in F-contiguous layout.");
416 } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
417 !LayoutUtil::IsMonotonicWithDim0Major(shape->layout()) &&
418 !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
419 return InvalidArgument("Buffer is not in contiguous layout.");
420 }
421 std::memset(view, 0, sizeof(Py_buffer));
422 const void* root_ptr =
423 external_reference_hold->OpaqueDeviceMemoryDataPointer();
424 view->buf = const_cast<void*>(root_ptr);
425 auto extra =
426 std::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
427 view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape->element_type());
428 view->len = ShapeUtil::ByteSizeOf(*shape);
429 view->readonly = 1;
430 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
431 TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
432 shape->element_type()));
433 view->format = const_cast<char*>(extra->format.c_str());
434 }
435 if ((flags & PyBUF_ND) == PyBUF_ND) {
436 view->ndim = shape->dimensions_size();
437 static_assert(sizeof(int64_t) == sizeof(Py_ssize_t),
438 "Py_ssize_t must be 64 bits");
439 if (view->ndim != 0) {
440 view->shape = reinterpret_cast<Py_ssize_t*>(
441 const_cast<int64_t*>(shape->dimensions().data()));
442 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
443 extra->strides = ByteStridesForShape(*shape);
444 view->strides = extra->strides.data();
445 }
446 }
447 }
448 TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
449 view->internal = extra.release();
450 return OkStatus();
451 }();
452 if (!status.ok()) {
453 // numpy.asarray(...) silents the PyExc_BufferError. Adding a log here helps
454 // debugging when the error really occurs.
455 VLOG(1) << "Buffer Protocol Error: " << status;
456 PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
457 return -1;
458 }
459 view->obj = exporter;
460 Py_INCREF(view->obj);
461 return 0;
462 }
463
PyBuffer_bf_releasebuffer(PyObject *,Py_buffer * buffer)464 void PyBuffer_bf_releasebuffer(PyObject*, Py_buffer* buffer) {
465 auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
466 delete extra;
467 }
468
__anond20e1a500a02() 469 PyBufferProcs PyBuffer_tp_as_buffer = []() {
470 PyBufferProcs procs;
471 procs.bf_getbuffer = &PyBuffer_bf_getbuffer;
472 procs.bf_releasebuffer = &PyBuffer_bf_releasebuffer;
473 return procs;
474 }();
475
476 } // namespace
477
478 PyObject* PyBuffer::base_type_ = nullptr;
479 PyObject* PyBuffer::type_ = nullptr;
480
RegisterTypes(py::module & m)481 Status PyBuffer::RegisterTypes(py::module& m) {
482 // We do not use pybind11::class_ to build Python wrapper objects because
483 // creation, destruction, and casting of buffer objects is performance
484 // critical. By using hand-written Python classes, we can avoid extra C heap
485 // allocations, and we can avoid pybind11's slow cast<>() implementation
486 // during jit dispatch.
487
488 // We need to use heap-allocated type objects because we want to add
489 // additional methods dynamically.
490 {
491 py::str name = py::str("DeviceArrayBase");
492 py::str qualname = py::str("DeviceArrayBase");
493 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
494 PyType_Type.tp_alloc(&PyType_Type, 0));
495 // Caution: we must not call any functions that might invoke the GC until
496 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
497 // type object.
498 if (!heap_type) {
499 return Internal("Unable to create heap type object");
500 }
501 heap_type->ht_name = name.release().ptr();
502 heap_type->ht_qualname = qualname.release().ptr();
503 PyTypeObject* type = &heap_type->ht_type;
504 type->tp_name = "DeviceArrayBase";
505 type->tp_basicsize = sizeof(PyBufferBasePyObject);
506 type->tp_flags =
507 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
508 TF_RET_CHECK(PyType_Ready(type) == 0);
509 base_type_ = reinterpret_cast<PyObject*>(type);
510 }
511 py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
512 base_type.attr("__module__") = m.attr("__name__");
513 m.attr("DeviceArrayBase") = base_type;
514
515 {
516 py::tuple bases = py::make_tuple(base_type);
517 py::str name = py::str("DeviceArray");
518 py::str qualname = py::str("DeviceArray");
519 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
520 PyType_Type.tp_alloc(&PyType_Type, 0));
521 // Caution: we must not call any functions that might invoke the GC until
522 // PyType_Ready() is called below. Otherwise the GC might see a
523 // half-constructed type object.
524 if (!heap_type) {
525 return Internal("Unable to create heap type object");
526 }
527 heap_type->ht_name = name.release().ptr();
528 heap_type->ht_qualname = qualname.release().ptr();
529 PyTypeObject* type = &heap_type->ht_type;
530 type->tp_name = "DeviceArray";
531 type->tp_basicsize = sizeof(PyBufferPyObject);
532 type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
533 type->tp_bases = bases.release().ptr();
534 type->tp_dealloc = PyBuffer_tp_dealloc;
535 type->tp_new = PyBuffer_tp_new;
536 // Supported protocols
537 type->tp_as_number = &heap_type->as_number;
538 type->tp_as_sequence = &heap_type->as_sequence;
539 type->tp_as_mapping = &heap_type->as_mapping;
540 type->tp_as_buffer = &PyBuffer_tp_as_buffer;
541
542 // Allow weak references to DeviceArray objects.
543 type->tp_weaklistoffset = offsetof(PyBufferPyObject, weakrefs);
544
545 TF_RET_CHECK(PyType_Ready(type) == 0);
546 type_ = reinterpret_cast<PyObject*>(type);
547 }
548 py::object type = py::reinterpret_borrow<py::object>(type_);
549 m.attr("DeviceArray") = type;
550 m.attr("PyLocalBuffer") = type;
551 m.attr("Buffer") = type;
552
553 // Add methods and properties to the class. We use pybind11 and add methods
554 // dynamically mostly because this is easy to write and allows us to use
555 // pybind11's casting logic. This is most likely slightly slower than
556 // hand-writing bindings, but most of these methods are not performance
557 // critical.
558 using jax::property;
559 using jax::property_readonly;
560 type.attr("__array_priority__") =
561 property_readonly([](py::object self) -> int { return 100; });
562 type.attr("_device") = property(
563 [](PyBuffer::object self) -> ClientAndPtr<PjRtDevice> {
564 return WrapWithClient(self.buf()->client(),
565 self.buf()->sticky_device());
566 },
567 [](PyBuffer::object self, PjRtDevice* sticky_device) {
568 return self.buf()->set_sticky_device(sticky_device);
569 });
570 type.attr("aval") = property(
571 [](PyBuffer::object self) -> py::object { return self.buf()->GetAval(); },
572 [](PyBuffer::object self, py::object aval) {
573 return self.buf()->SetAval(std::move(aval));
574 });
575 type.attr("weak_type") = property(
576 [](PyBuffer::object self) -> std::optional<bool> {
577 return self.buf()->weak_type();
578 },
579 [](PyBuffer::object self, std::optional<bool> weak_type) {
580 return self.buf()->set_weak_type(weak_type);
581 });
582 type.attr("device_buffer") =
583 property_readonly([](py::object self) { return self; });
584 type.attr(
585 "shape") = property_readonly([](PyBuffer::object self) -> py::tuple {
586 return SpanToTuple(self.buf()->buffer()->on_device_shape().dimensions());
587 });
588 type.attr("dtype") = property_readonly([](PyBuffer::object self) {
589 PrimitiveType primitive =
590 self.buf()->buffer()->on_device_shape().element_type();
591 return PrimitiveTypeToDtype(primitive).ValueOrDie();
592 });
593 type.attr("size") =
594 property_readonly([](PyBuffer::object self) -> StatusOr<int64_t> {
595 return self.buf()->size();
596 });
597 type.attr("ndim") = property_readonly(
598 [](PyBuffer::object self) -> int { return self.buf()->ndim(); });
599 type.attr("_value") = property_readonly(
600 [](PyBuffer::object self) -> StatusOr<pybind11::object> {
601 GlobalPyRefManager()->CollectGarbage();
602 return self.buf()->AsNumPyArray(self);
603 });
604 type.attr("copy_to_device") = py::cpp_function(
605 [](PyBuffer::object self, const ClientAndPtr<PjRtDevice>& dst_device) {
606 return self.buf()->CopyToDevice(dst_device);
607 },
608 py::is_method(type));
609 type.attr("copy_to_remote_device") = py::cpp_function(
610 [](PyBuffer::object self, const py::bytes serialized_descriptor) {
611 // TODO(phawkins): remove the std::string cast after C++17 is required.
612 // py::bytes has a std::string_view cast, but not an absl::string_view
613 // cast.
614 return self.buf()->CopyToRemoteDevice(
615 static_cast<std::string>(serialized_descriptor));
616 },
617 py::is_method(type));
618
619 type.attr("on_device_size_in_bytes") = py::cpp_function(
620 [](PyBuffer::object self) -> StatusOr<size_t> {
621 return self.buf()->OnDeviceSizeInBytes();
622 },
623 py::is_method(type));
624 type.attr("delete") = py::cpp_function(
625 [](PyBuffer::object self) { self.buf()->Delete(); }, py::is_method(type));
626 type.attr("block_host_until_ready") = py::cpp_function(
627 [](PyBuffer::object self) {
628 // TODO(phawkins): remove 3 months after the release of jaxlib >= 0.3.2.
629 PythonDeprecationWarning(
630 "block_host_until_ready() on a JAX array object is deprecated, use "
631 "block_until_ready() instead.");
632 return self.buf()->BlockHostUntilReady();
633 },
634 py::is_method(type));
635 type.attr("is_ready") = py::cpp_function(
636 [](PyBuffer::object self) { return self.buf()->IsReady(); },
637 py::is_method(type));
638 type.attr("is_known_ready") = py::cpp_function(
639 [](PyBuffer::object self) { return self.buf()->IsKnownReady(); },
640 py::is_method(type));
641 type.attr("block_until_ready") = py::cpp_function(
642 [](PyBuffer::object self) -> StatusOr<PyBuffer::object> {
643 TF_RETURN_IF_ERROR(self.buf()->BlockHostUntilReady());
644 return std::move(self);
645 },
646 py::is_method(type));
647 type.attr("copy_to_host_async") = py::cpp_function(
648 [](PyBuffer::object self) { return self.buf()->CopyToHostAsync(); },
649 py::is_method(type));
650 type.attr("to_py") = py::cpp_function(
651 [](PyBuffer::object self) { return self.buf()->AsNumPyArray(self); },
652 py::is_method(type));
653 type.attr("xla_shape") = py::cpp_function(
654 [](PyBuffer::object self) { return self.buf()->shape(); },
655 py::is_method(type));
656 type.attr("xla_dynamic_shape") = py::cpp_function(
657 [](PyBuffer::object self) { return self.buf()->xla_dynamic_shape(); },
658 py::is_method(type));
659 type.attr("client") = property_readonly(
660 [](PyBuffer::object self) { return self.buf()->client(); });
661 type.attr("device") = py::cpp_function(
662 [](PyBuffer::object self) { return self.buf()->device(); },
663 py::is_method(type));
664 type.attr("platform") = py::cpp_function(
665 [](PyBuffer::object self) { return self.buf()->platform_name(); },
666 py::is_method(type));
667 type.attr("is_deleted") = py::cpp_function(
668 [](PyBuffer::object self) { return self.buf()->is_deleted(); },
669 py::is_method(type));
670 type.attr("unsafe_buffer_pointer") = py::cpp_function(
671 [](PyBuffer::object self) { return self.buf()->UnsafeBufferPointer(); },
672 py::is_method(type));
673 type.attr("__cuda_array_interface__") = property_readonly(
674 [](PyBuffer::object self) { return self.buf()->CudaArrayInterface(); });
675 type.attr("traceback") = property_readonly(
676 [](PyBuffer::object self) { return self.buf()->traceback(); });
677 type.attr("clone") = py::cpp_function(
678 [](PyBuffer::object self) { return self.buf()->Clone(); },
679 py::is_method(type));
680 type.attr("__module__") = m.attr("__name__");
681 return OkStatus();
682 }
683
684 } // namespace xla
685