1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <stdexcept> 22 #include <vector> 23 24 #include "absl/strings/string_view.h" 25 #include "absl/synchronization/notification.h" 26 #include "pybind11/numpy.h" 27 #include "pybind11/pybind11.h" 28 #include "tensorflow/compiler/xla/python/py_client.h" 29 #include "tensorflow/compiler/xla/python/traceback.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 33 namespace xla { 34 35 // Python wrapper around PjRtBuffer. We use a wrapper class: 36 // a) to keep the PjRtClient alive via a std::shared_ptr<> 37 // b) to add Python-specific functionality. 38 // 39 // A `PyBuffer` can be used from Python without being wrapped in a Python 40 // `DeviceArray` object. 41 class PyBuffer { 42 public: 43 // pybind11::object typed subclass for PyBuffer objects. 44 class pyobject : public pybind11::object { 45 public: 46 PYBIND11_OBJECT(pyobject, // NOLINT 47 pybind11::object, PyBuffer::IsPyBuffer); 48 pyobject() = default; buf()49 PyBuffer* buf() const { return PyBuffer::AsPyBufferUnchecked(*this); } 50 }; 51 using object = pyobject; 52 53 static object Make(std::shared_ptr<PyClient> client, 54 std::shared_ptr<PjRtBuffer> buffer, 55 std::shared_ptr<Traceback> traceback); 56 57 // Returns true if `h` is a PyBuffer. 58 static bool IsPyBuffer(pybind11::handle handle); 59 // Converts `handle` to a PyBuffer*. Does not do any checking. 60 static PyBuffer* AsPyBufferUnchecked(pybind11::handle handle); 61 // Converts `handle` to a PyBuffer*. Returns an error status if 62 // !IsPyBuffer(handle) 63 static StatusOr<PyBuffer*> AsPyBuffer(pybind11::handle handle); 64 65 // Gets a Python handle to an existing PyBuffer. Assumes the PyObject was 66 // allocated on the Python heap, which is the case if Make() was used. 67 pybind11::handle AsHandle(); 68 69 ~PyBuffer(); 70 client()71 std::shared_ptr<PyClient> client() const { return client_; } buffer()72 PjRtBuffer* buffer() const { return buffer_.get(); } shared_ptr_buffer()73 std::shared_ptr<PjRtBuffer> shared_ptr_buffer() const { return buffer_; } 74 75 ClientAndPtr<PjRtDevice> device() const; platform_name()76 absl::string_view platform_name() const { 77 return buffer_->client()->platform_name(); 78 } is_deleted()79 bool is_deleted() const { return buffer_->IsDeleted(); } 80 81 StatusOr<pybind11::object> CopyToDevice( 82 const ClientAndPtr<PjRtDevice>& dst_device) const; 83 std::pair<Status, bool> CopyToRemoteDevice( 84 absl::string_view serialized_descriptor) const; 85 OnDeviceSizeInBytes()86 StatusOr<size_t> OnDeviceSizeInBytes() { 87 return buffer_->GetOnDeviceSizeInBytes(); 88 } 89 Delete()90 void Delete() { 91 buffer_->Delete(); 92 host_value_ = nullptr; 93 } 94 95 // Makes a copy of this PyBuffer object that shares the underlying PjRtBuffer. 96 // This is useful because we may wish to change JAX metadata (e.g., the sticky 97 // device) without copying the buffer. 98 object Clone() const; 99 100 // Returns xla::InvalidArgument if the buffer has been deleted. 101 // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. IsReady()102 StatusOr<bool> IsReady() { 103 if (buffer_->IsDeleted()) { 104 return InvalidArgument("DeviceArray has been deleted."); 105 } 106 return buffer_->GetReadyFuture().IsReady(); 107 } IsKnownReady()108 StatusOr<bool> IsKnownReady() { 109 if (buffer_->IsDeleted()) { 110 return InvalidArgument("DeviceArray has been deleted."); 111 } 112 return buffer_->GetReadyFuture().IsKnownReady(); 113 } 114 115 // Returns xla::InvalidArgument if the buffer has been deleted. 116 Status BlockHostUntilReady(); 117 Status CopyToHostAsync(); 118 shape()119 const Shape& shape() { return buffer_->on_device_shape(); } 120 121 StatusOr<std::uintptr_t> UnsafeBufferPointer() const; 122 123 // Implementation of the CUDA array interface for sharing GPU buffers with 124 // other Python libraries. 125 StatusOr<pybind11::dict> CudaArrayInterface(); 126 traceback()127 const std::shared_ptr<Traceback>& traceback() const { return traceback_; } 128 129 // Returns the size (i.e. number of elements) of the (host) numpy array. 130 StatusOr<int64_t> size(); 131 132 // Returns the number of dimensions of the (host) numpy array. ndim()133 int ndim() const { return buffer()->on_device_shape().dimensions_size(); } 134 135 pybind11::tuple python_shape() const; 136 pybind11::dtype python_dtype() const; 137 138 // Representing the logical view of the underlying dynamic shapes. 139 StatusOr<const Shape*> xla_dynamic_shape(); 140 set_sticky_device(PjRtDevice * sticky_device)141 Status set_sticky_device(PjRtDevice* sticky_device) { 142 TF_RET_CHECK(sticky_device == nullptr || 143 sticky_device == buffer_->device()); 144 sticky_device_ = sticky_device; 145 return OkStatus(); 146 } sticky_device()147 PjRtDevice* sticky_device() const { return sticky_device_; } 148 set_weak_type(std::optional<bool> weak_type)149 void set_weak_type(std::optional<bool> weak_type) { weak_type_ = weak_type; } weak_type()150 std::optional<bool> weak_type() const { return weak_type_; } 151 152 StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj); 153 SetAval(pybind11::object aval)154 void SetAval(pybind11::object aval) { aval_ = aval; } GetAval()155 pybind11::object GetAval() const { return aval_; } 156 157 static Status RegisterTypes(pybind11::module& m); base_type()158 static PyObject* base_type() { return base_type_; } type()159 static PyObject* type() { return type_; } 160 161 private: 162 // PyBuffer objects must not be allocated directly since they must always live 163 // on the Python heap. Use Make() instead. 164 PyBuffer(std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer, 165 std::shared_ptr<Traceback> traceback); 166 167 static PyObject* base_type_; 168 static PyObject* type_; 169 170 friend class PyClient; 171 172 struct HostValue { 173 absl::Notification ready; 174 Status status; 175 std::shared_ptr<xla::Literal> value; 176 }; 177 std::shared_ptr<PyClient> client_; 178 std::shared_ptr<PjRtBuffer> buffer_; 179 std::shared_ptr<Traceback> traceback_; 180 std::shared_ptr<HostValue> host_value_; // Protected by the GIL. 181 182 // JAX uses this field to record whether a buffer is committed to a particular 183 // device by the user (https://github.com/google/jax/pull/1916). 184 PjRtDevice* sticky_device_ = nullptr; 185 186 // TODO(phawkins): consider not keeping an explicit aval on C++ buffer 187 // objects. 188 pybind11::object aval_ = pybind11::none(); 189 190 // An optional weak type. If absent, the JAX jit code computes the weak_type 191 // from the aval_.weak_type attribute. This is a backwards compatibility 192 // measure for older Python code that does not set weak_type explicitly. 193 // TODO(phawkins): drop support for older jax Python versions and make 194 // weak_type mandatory. 195 std::optional<bool> weak_type_ = std::nullopt; 196 197 std::optional<Shape> dynamic_shape_ = std::nullopt; 198 // Doubly-linked list of all PyBuffers known to the client. Protected by the 199 // GIL. Since multiple PyBuffers may share the same PjRtBuffer, there may be 200 // duplicate PjRtBuffers in this list. 201 PyBuffer* next_; 202 PyBuffer* prev_; 203 }; 204 205 } // namespace xla 206 207 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 208