1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
17
18 #include <optional>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
24 #include "tensorflow/compiler/xla/python/py_buffer.h"
25 #include "tensorflow/compiler/xla/python/python_utils.h"
26 #include "tensorflow/core/platform/statusor.h"
27
28 namespace jax {
29
30 namespace py = pybind11;
31
32 namespace {
33
34 struct ShardedDeviceArrayBaseObject {
35 PyObject_HEAD;
36 };
37 static_assert(std::is_standard_layout<ShardedDeviceArrayBaseObject>::value,
38 "ShardedDeviceArrayBaseObject must be standard layout");
39
40 struct ShardedDeviceArrayObject {
41 ShardedDeviceArrayBaseObject base;
42 ShardedDeviceArray sda;
43 // Used by the Python interpreter to maintain a list of weak references to
44 // this object.
45 PyObject* weakrefs;
46 };
47 static_assert(std::is_standard_layout<ShardedDeviceArrayObject>::value,
48 "ShardedDeviceArrayObject must be standard layout");
49
sharded_device_array_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)50 PyObject* sharded_device_array_tp_new(PyTypeObject* subtype, PyObject* args,
51 PyObject* kwds) {
52 ShardedDeviceArrayObject* self = reinterpret_cast<ShardedDeviceArrayObject*>(
53 subtype->tp_alloc(subtype, 0));
54 if (!self) return nullptr;
55 self->weakrefs = nullptr;
56 return reinterpret_cast<PyObject*>(self);
57 }
58
sharded_device_array_tp_dealloc(PyObject * self)59 void sharded_device_array_tp_dealloc(PyObject* self) {
60 PyTypeObject* tp = Py_TYPE(self);
61 ShardedDeviceArrayObject* o =
62 reinterpret_cast<ShardedDeviceArrayObject*>(self);
63 if (o->weakrefs) {
64 PyObject_ClearWeakRefs(self);
65 }
66 o->sda.~ShardedDeviceArray();
67 tp->tp_free(self);
68 Py_DECREF(tp);
69 }
70
71 } // namespace
72
Delete()73 void ShardedDeviceArray::Delete() {
74 // If already deleted, do nothing.
75 if (is_deleted_) {
76 return;
77 }
78 // We can't inline this expression into the for loop! Here, .value()
79 // returns an rvalue reference to the Span embedded in the StatusOr.
80 // Binding the reference would extend the lifetime of the Span itself,
81 // but not of the StatusOr, causing stack-use-after-scope errors. Also see
82 // https://en.cppreference.com/w/cpp/language/range-for#Temporary_range_expression
83 auto buffers = GetPjRtBuffers().value();
84 for (xla::PjRtBuffer* pjrt_buffer : buffers) {
85 pjrt_buffer->Delete();
86 }
87 device_buffers_ = std::nullopt;
88 cpp_device_buffers_ = std::nullopt;
89 npy_value_ = std::nullopt;
90 is_deleted_ = true;
91 }
92
93 xla::StatusOr<absl::Span<xla::PjRtBuffer* const>>
GetPjRtBuffers()94 ShardedDeviceArray::GetPjRtBuffers() {
95 if (cpp_device_buffers_.has_value()) {
96 return absl::MakeConstSpan(cpp_device_buffers_.value());
97 }
98
99 if (!device_buffers_.has_value()) {
100 return xla::InvalidArgument("ShardedDeviceArray has been deleted.");
101 }
102 const int num_devices = device_buffers_->size();
103 std::vector<xla::PjRtBuffer*> cpp_device_buffers;
104 cpp_device_buffers.reserve(num_devices);
105 int i = 0;
106 for (auto& handle : device_buffers_.value()) {
107 // Note that invariants guarantee the cast should never fail.
108 TF_ASSIGN_OR_RETURN(xla::PyBuffer * pybuffer,
109 xla::PyBuffer::AsPyBuffer(handle));
110 cpp_device_buffers.push_back(pybuffer->buffer());
111 i += 1;
112 }
113 cpp_device_buffers_ = std::move(cpp_device_buffers);
114 return absl::MakeConstSpan(cpp_device_buffers_.value());
115 }
116
117 PyObject* ShardedDeviceArray::base_type_ = nullptr;
118 PyObject* ShardedDeviceArray::type_ = nullptr;
119
Make(py::object aval,ShardingSpec sharding_spec,py::list device_buffers,py::object indices,bool weak_type)120 /*static*/ ShardedDeviceArray::object ShardedDeviceArray::Make(
121 py::object aval, ShardingSpec sharding_spec, py::list device_buffers,
122 py::object indices, bool weak_type) {
123 py::object obj =
124 py::reinterpret_steal<py::object>(sharded_device_array_tp_new(
125 reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
126 ShardedDeviceArrayObject* sda =
127 reinterpret_cast<ShardedDeviceArrayObject*>(obj.ptr());
128 new (&sda->sda)
129 ShardedDeviceArray(aval, std::move(sharding_spec),
130 std::move(device_buffers), indices, weak_type);
131 return py::reinterpret_borrow<ShardedDeviceArray::object>(obj);
132 }
133
IsShardedDeviceArray(py::handle handle)134 bool ShardedDeviceArray::IsShardedDeviceArray(py::handle handle) {
135 return handle.get_type() == ShardedDeviceArray::type();
136 }
137
138 /*static*/ ShardedDeviceArray*
AsShardedDeviceArrayUnchecked(py::handle handle)139 ShardedDeviceArray::AsShardedDeviceArrayUnchecked(py::handle handle) {
140 return &(reinterpret_cast<ShardedDeviceArrayObject*>(handle.ptr())->sda);
141 }
142
143 /*static*/ xla::StatusOr<ShardedDeviceArray*>
AsShardedDeviceArray(py::handle handle)144 ShardedDeviceArray::AsShardedDeviceArray(py::handle handle) {
145 if (!IsShardedDeviceArray(handle)) {
146 return xla::InvalidArgument("Expected a ShardedDeviceArray");
147 }
148 return AsShardedDeviceArrayUnchecked(handle);
149 }
150
AsHandle()151 py::handle ShardedDeviceArray::AsHandle() {
152 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
153 offsetof(ShardedDeviceArrayObject, sda));
154 }
155
RegisterTypes(py::module & m)156 /*static*/ xla::Status ShardedDeviceArray::RegisterTypes(py::module& m) {
157 // We need to use heap-allocated type objects because we want to add
158 // additional methods dynamically.
159 // Similar to py_buffer.cc
160 {
161 py::str name = py::str("ShardedDeviceArrayBase");
162 py::str qualname = py::str("ShardedDeviceArrayBase");
163 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
164 PyType_Type.tp_alloc(&PyType_Type, 0));
165 // Caution: we must not call any functions that might invoke the GC until
166 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
167 // type object.
168 if (!heap_type) {
169 return xla::Internal("Unable to create heap type object");
170 }
171 heap_type->ht_name = name.release().ptr();
172 heap_type->ht_qualname = qualname.release().ptr();
173 PyTypeObject* type = &heap_type->ht_type;
174 type->tp_name = "ShardedDeviceArrayBase";
175 type->tp_basicsize = sizeof(ShardedDeviceArrayBaseObject);
176 type->tp_flags =
177 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
178 TF_RET_CHECK(PyType_Ready(type) == 0);
179 base_type_ = reinterpret_cast<PyObject*>(type);
180 }
181 py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
182 base_type.attr("__module__") = m.attr("__name__");
183 m.attr("ShardedDeviceArrayBase") = base_type;
184
185 {
186 py::tuple bases = py::make_tuple(base_type);
187 py::str name = py::str("ShardedDeviceArray");
188 py::str qualname = py::str("ShardedDeviceArray");
189 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
190 PyType_Type.tp_alloc(&PyType_Type, 0));
191 // Caution: we must not call any functions that might invoke the GC until
192 // PyType_Ready() is called below. Otherwise the GC might see a
193 // half-constructed type object.
194 if (!heap_type) {
195 return xla::Internal("Unable to create heap type object");
196 }
197 heap_type->ht_name = name.release().ptr();
198 heap_type->ht_qualname = qualname.release().ptr();
199 PyTypeObject* type = &heap_type->ht_type;
200 type->tp_name = "ShardedDeviceArray";
201 type->tp_basicsize = sizeof(ShardedDeviceArrayObject);
202 type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
203 type->tp_bases = bases.release().ptr();
204 type->tp_dealloc = sharded_device_array_tp_dealloc;
205 type->tp_new = sharded_device_array_tp_new;
206 // Supported protocols
207 type->tp_as_number = &heap_type->as_number;
208 type->tp_as_sequence = &heap_type->as_sequence;
209 type->tp_as_mapping = &heap_type->as_mapping;
210 type->tp_as_buffer = nullptr;
211
212 // Allow weak references to DeviceArray objects.
213 type->tp_weaklistoffset = offsetof(ShardedDeviceArrayObject, weakrefs);
214
215 TF_RET_CHECK(PyType_Ready(type) == 0);
216 type_ = reinterpret_cast<PyObject*>(type);
217 }
218 py::object type = py::reinterpret_borrow<py::object>(type_);
219 type.attr("__module__") = m.attr("__name__");
220 m.attr("ShardedDeviceArray") = type;
221
222 type.attr("make") = def_static([](py::object aval, ShardingSpec sharding_spec,
223 py::list device_buffers, py::object indices,
224 bool weak_type) {
225 return ShardedDeviceArray::Make(aval, sharding_spec, device_buffers,
226 indices, weak_type);
227 });
228 type.attr("aval") =
229 property_readonly([](ShardedDeviceArray::object self) -> py::object {
230 return self.sda()->aval();
231 });
232 type.attr("indices") =
233 property_readonly([](ShardedDeviceArray::object self) -> py::object {
234 return self.sda()->indices();
235 });
236 type.attr("sharding_spec") =
237 property_readonly([](ShardedDeviceArray::object self) {
238 return self.sda()->GetShardingSpec();
239 });
240 type.attr("device_buffers") =
241 property_readonly([](ShardedDeviceArray::object self) {
242 return self.sda()->device_buffers();
243 });
244 type.attr("_npy_value") = property(
245 [](ShardedDeviceArray::object self) { return self.sda()->npy_value(); },
246 [](ShardedDeviceArray::object self, py::object npy_value) {
247 return self.sda()->set_npy_value(npy_value);
248 });
249 type.attr("_one_replica_buffer_indices") = property(
250 [](ShardedDeviceArray::object self) {
251 return self.sda()->one_replica_buffer_indices();
252 },
253 [](ShardedDeviceArray::object self, py::object obj) {
254 return self.sda()->set_one_replica_buffer_indices(obj);
255 });
256 type.attr("shape") = property_readonly([](ShardedDeviceArray::object self) {
257 return self.sda()->aval().attr("shape");
258 });
259 type.attr("dtype") = property_readonly([](ShardedDeviceArray::object self) {
260 return self.sda()->aval().attr("dtype");
261 });
262 type.attr("size") = property_readonly([](ShardedDeviceArray::object self) {
263 py::tuple shape = py::cast<py::tuple>(self.sda()->aval().attr("shape"));
264 int64_t size = 1;
265 for (auto dim : shape) {
266 size *= py::cast<int64_t>(dim);
267 }
268 return size;
269 });
270 type.attr("ndim") = property_readonly([](ShardedDeviceArray::object self) {
271 return py::len(self.sda()->aval().attr("shape"));
272 });
273
274 type.attr("delete") = py::cpp_function(
275 [](ShardedDeviceArray::object self) { self.sda()->Delete(); },
276 py::is_method(type));
277 type.attr("is_deleted") = py::cpp_function(
278 [](ShardedDeviceArray::object self) { return self.sda()->is_deleted(); },
279 py::is_method(type));
280
281 return ::tensorflow::OkStatus();
282 }
283
284 } // namespace jax
285