xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/sharded_device_array.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
17 
18 #include <optional>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
24 #include "tensorflow/compiler/xla/python/py_buffer.h"
25 #include "tensorflow/compiler/xla/python/python_utils.h"
26 #include "tensorflow/core/platform/statusor.h"
27 
28 namespace jax {
29 
30 namespace py = pybind11;
31 
32 namespace {
33 
34 struct ShardedDeviceArrayBaseObject {
35   PyObject_HEAD;
36 };
37 static_assert(std::is_standard_layout<ShardedDeviceArrayBaseObject>::value,
38               "ShardedDeviceArrayBaseObject must be standard layout");
39 
40 struct ShardedDeviceArrayObject {
41   ShardedDeviceArrayBaseObject base;
42   ShardedDeviceArray sda;
43   // Used by the Python interpreter to maintain a list of weak references to
44   // this object.
45   PyObject* weakrefs;
46 };
47 static_assert(std::is_standard_layout<ShardedDeviceArrayObject>::value,
48               "ShardedDeviceArrayObject must be standard layout");
49 
sharded_device_array_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)50 PyObject* sharded_device_array_tp_new(PyTypeObject* subtype, PyObject* args,
51                                       PyObject* kwds) {
52   ShardedDeviceArrayObject* self = reinterpret_cast<ShardedDeviceArrayObject*>(
53       subtype->tp_alloc(subtype, 0));
54   if (!self) return nullptr;
55   self->weakrefs = nullptr;
56   return reinterpret_cast<PyObject*>(self);
57 }
58 
sharded_device_array_tp_dealloc(PyObject * self)59 void sharded_device_array_tp_dealloc(PyObject* self) {
60   PyTypeObject* tp = Py_TYPE(self);
61   ShardedDeviceArrayObject* o =
62       reinterpret_cast<ShardedDeviceArrayObject*>(self);
63   if (o->weakrefs) {
64     PyObject_ClearWeakRefs(self);
65   }
66   o->sda.~ShardedDeviceArray();
67   tp->tp_free(self);
68   Py_DECREF(tp);
69 }
70 
71 }  // namespace
72 
Delete()73 void ShardedDeviceArray::Delete() {
74   // If already deleted, do nothing.
75   if (is_deleted_) {
76     return;
77   }
78   // We can't inline this expression into the for loop! Here, .value()
79   // returns an rvalue reference to the Span embedded in the StatusOr.
80   // Binding the reference would extend the lifetime of the Span itself,
81   // but not of the StatusOr, causing stack-use-after-scope errors. Also see
82   // https://en.cppreference.com/w/cpp/language/range-for#Temporary_range_expression
83   auto buffers = GetPjRtBuffers().value();
84   for (xla::PjRtBuffer* pjrt_buffer : buffers) {
85     pjrt_buffer->Delete();
86   }
87   device_buffers_ = std::nullopt;
88   cpp_device_buffers_ = std::nullopt;
89   npy_value_ = std::nullopt;
90   is_deleted_ = true;
91 }
92 
93 xla::StatusOr<absl::Span<xla::PjRtBuffer* const>>
GetPjRtBuffers()94 ShardedDeviceArray::GetPjRtBuffers() {
95   if (cpp_device_buffers_.has_value()) {
96     return absl::MakeConstSpan(cpp_device_buffers_.value());
97   }
98 
99   if (!device_buffers_.has_value()) {
100     return xla::InvalidArgument("ShardedDeviceArray has been deleted.");
101   }
102   const int num_devices = device_buffers_->size();
103   std::vector<xla::PjRtBuffer*> cpp_device_buffers;
104   cpp_device_buffers.reserve(num_devices);
105   int i = 0;
106   for (auto& handle : device_buffers_.value()) {
107     // Note that invariants guarantee the cast should never fail.
108     TF_ASSIGN_OR_RETURN(xla::PyBuffer * pybuffer,
109                         xla::PyBuffer::AsPyBuffer(handle));
110     cpp_device_buffers.push_back(pybuffer->buffer());
111     i += 1;
112   }
113   cpp_device_buffers_ = std::move(cpp_device_buffers);
114   return absl::MakeConstSpan(cpp_device_buffers_.value());
115 }
116 
117 PyObject* ShardedDeviceArray::base_type_ = nullptr;
118 PyObject* ShardedDeviceArray::type_ = nullptr;
119 
Make(py::object aval,ShardingSpec sharding_spec,py::list device_buffers,py::object indices,bool weak_type)120 /*static*/ ShardedDeviceArray::object ShardedDeviceArray::Make(
121     py::object aval, ShardingSpec sharding_spec, py::list device_buffers,
122     py::object indices, bool weak_type) {
123   py::object obj =
124       py::reinterpret_steal<py::object>(sharded_device_array_tp_new(
125           reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
126   ShardedDeviceArrayObject* sda =
127       reinterpret_cast<ShardedDeviceArrayObject*>(obj.ptr());
128   new (&sda->sda)
129       ShardedDeviceArray(aval, std::move(sharding_spec),
130                          std::move(device_buffers), indices, weak_type);
131   return py::reinterpret_borrow<ShardedDeviceArray::object>(obj);
132 }
133 
IsShardedDeviceArray(py::handle handle)134 bool ShardedDeviceArray::IsShardedDeviceArray(py::handle handle) {
135   return handle.get_type() == ShardedDeviceArray::type();
136 }
137 
138 /*static*/ ShardedDeviceArray*
AsShardedDeviceArrayUnchecked(py::handle handle)139 ShardedDeviceArray::AsShardedDeviceArrayUnchecked(py::handle handle) {
140   return &(reinterpret_cast<ShardedDeviceArrayObject*>(handle.ptr())->sda);
141 }
142 
143 /*static*/ xla::StatusOr<ShardedDeviceArray*>
AsShardedDeviceArray(py::handle handle)144 ShardedDeviceArray::AsShardedDeviceArray(py::handle handle) {
145   if (!IsShardedDeviceArray(handle)) {
146     return xla::InvalidArgument("Expected a ShardedDeviceArray");
147   }
148   return AsShardedDeviceArrayUnchecked(handle);
149 }
150 
AsHandle()151 py::handle ShardedDeviceArray::AsHandle() {
152   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
153                                      offsetof(ShardedDeviceArrayObject, sda));
154 }
155 
RegisterTypes(py::module & m)156 /*static*/ xla::Status ShardedDeviceArray::RegisterTypes(py::module& m) {
157   // We need to use heap-allocated type objects because we want to add
158   // additional methods dynamically.
159   // Similar to py_buffer.cc
160   {
161     py::str name = py::str("ShardedDeviceArrayBase");
162     py::str qualname = py::str("ShardedDeviceArrayBase");
163     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
164         PyType_Type.tp_alloc(&PyType_Type, 0));
165     // Caution: we must not call any functions that might invoke the GC until
166     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
167     // type object.
168     if (!heap_type) {
169       return xla::Internal("Unable to create heap type object");
170     }
171     heap_type->ht_name = name.release().ptr();
172     heap_type->ht_qualname = qualname.release().ptr();
173     PyTypeObject* type = &heap_type->ht_type;
174     type->tp_name = "ShardedDeviceArrayBase";
175     type->tp_basicsize = sizeof(ShardedDeviceArrayBaseObject);
176     type->tp_flags =
177         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
178     TF_RET_CHECK(PyType_Ready(type) == 0);
179     base_type_ = reinterpret_cast<PyObject*>(type);
180   }
181   py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
182   base_type.attr("__module__") = m.attr("__name__");
183   m.attr("ShardedDeviceArrayBase") = base_type;
184 
185   {
186     py::tuple bases = py::make_tuple(base_type);
187     py::str name = py::str("ShardedDeviceArray");
188     py::str qualname = py::str("ShardedDeviceArray");
189     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
190         PyType_Type.tp_alloc(&PyType_Type, 0));
191     // Caution: we must not call any functions that might invoke the GC until
192     // PyType_Ready() is called below. Otherwise the GC might see a
193     // half-constructed type object.
194     if (!heap_type) {
195       return xla::Internal("Unable to create heap type object");
196     }
197     heap_type->ht_name = name.release().ptr();
198     heap_type->ht_qualname = qualname.release().ptr();
199     PyTypeObject* type = &heap_type->ht_type;
200     type->tp_name = "ShardedDeviceArray";
201     type->tp_basicsize = sizeof(ShardedDeviceArrayObject);
202     type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
203     type->tp_bases = bases.release().ptr();
204     type->tp_dealloc = sharded_device_array_tp_dealloc;
205     type->tp_new = sharded_device_array_tp_new;
206     // Supported protocols
207     type->tp_as_number = &heap_type->as_number;
208     type->tp_as_sequence = &heap_type->as_sequence;
209     type->tp_as_mapping = &heap_type->as_mapping;
210     type->tp_as_buffer = nullptr;
211 
212     // Allow weak references to DeviceArray objects.
213     type->tp_weaklistoffset = offsetof(ShardedDeviceArrayObject, weakrefs);
214 
215     TF_RET_CHECK(PyType_Ready(type) == 0);
216     type_ = reinterpret_cast<PyObject*>(type);
217   }
218   py::object type = py::reinterpret_borrow<py::object>(type_);
219   type.attr("__module__") = m.attr("__name__");
220   m.attr("ShardedDeviceArray") = type;
221 
222   type.attr("make") = def_static([](py::object aval, ShardingSpec sharding_spec,
223                                     py::list device_buffers, py::object indices,
224                                     bool weak_type) {
225     return ShardedDeviceArray::Make(aval, sharding_spec, device_buffers,
226                                     indices, weak_type);
227   });
228   type.attr("aval") =
229       property_readonly([](ShardedDeviceArray::object self) -> py::object {
230         return self.sda()->aval();
231       });
232   type.attr("indices") =
233       property_readonly([](ShardedDeviceArray::object self) -> py::object {
234         return self.sda()->indices();
235       });
236   type.attr("sharding_spec") =
237       property_readonly([](ShardedDeviceArray::object self) {
238         return self.sda()->GetShardingSpec();
239       });
240   type.attr("device_buffers") =
241       property_readonly([](ShardedDeviceArray::object self) {
242         return self.sda()->device_buffers();
243       });
244   type.attr("_npy_value") = property(
245       [](ShardedDeviceArray::object self) { return self.sda()->npy_value(); },
246       [](ShardedDeviceArray::object self, py::object npy_value) {
247         return self.sda()->set_npy_value(npy_value);
248       });
249   type.attr("_one_replica_buffer_indices") = property(
250       [](ShardedDeviceArray::object self) {
251         return self.sda()->one_replica_buffer_indices();
252       },
253       [](ShardedDeviceArray::object self, py::object obj) {
254         return self.sda()->set_one_replica_buffer_indices(obj);
255       });
256   type.attr("shape") = property_readonly([](ShardedDeviceArray::object self) {
257     return self.sda()->aval().attr("shape");
258   });
259   type.attr("dtype") = property_readonly([](ShardedDeviceArray::object self) {
260     return self.sda()->aval().attr("dtype");
261   });
262   type.attr("size") = property_readonly([](ShardedDeviceArray::object self) {
263     py::tuple shape = py::cast<py::tuple>(self.sda()->aval().attr("shape"));
264     int64_t size = 1;
265     for (auto dim : shape) {
266       size *= py::cast<int64_t>(dim);
267     }
268     return size;
269   });
270   type.attr("ndim") = property_readonly([](ShardedDeviceArray::object self) {
271     return py::len(self.sda()->aval().attr("shape"));
272   });
273 
274   type.attr("delete") = py::cpp_function(
275       [](ShardedDeviceArray::object self) { self.sda()->Delete(); },
276       py::is_method(type));
277   type.attr("is_deleted") = py::cpp_function(
278       [](ShardedDeviceArray::object self) { return self.sda()->is_deleted(); },
279       py::is_method(type));
280 
281   return ::tensorflow::OkStatus();
282 }
283 
284 }  // namespace jax
285