1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/span.h"
27 #include "include/dlpack/dlpack.h" // from @dlpack
28 #include "pybind11/pytypes.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
31 #include "tensorflow/compiler/xla/python/traceback.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34
35 namespace py = pybind11;
36
37 namespace xla {
38 namespace {
39
40 const char* const kDlTensorCapsuleName = "dltensor";
41
42 struct DLPackTensor {
43 ~DLPackTensor();
44
45 // `buffer_reference` is populated if we have shared (read-only) access.
46 py::object buffer_reference;
47
48 // `external_reference` is always populated.
49 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
50
51 std::vector<int64_t> shape;
52 std::vector<int64_t> strides;
53 DLManagedTensor tensor;
54 };
55
~DLPackTensor()56 DLPackTensor::~DLPackTensor() {
57 if (buffer_reference) {
58 GlobalPyRefManager()->AddGarbage(
59 absl::MakeSpan(&buffer_reference, /*size=*/1));
60 }
61 }
62
DLPackTensorDeleter(DLManagedTensor * t)63 void DLPackTensorDeleter(DLManagedTensor* t) {
64 if (t) {
65 delete static_cast<DLPackTensor*>(t->manager_ctx);
66 }
67 }
68
PrimitiveTypeToDLDataType(PrimitiveType type)69 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
70 switch (type) {
71 case S8:
72 return DLDataType{kDLInt, 8, 1};
73 case S16:
74 return DLDataType{kDLInt, 16, 1};
75 case S32:
76 return DLDataType{kDLInt, 32, 1};
77 case S64:
78 return DLDataType{kDLInt, 64, 1};
79 case U8:
80 return DLDataType{kDLUInt, 8, 1};
81 case U16:
82 return DLDataType{kDLUInt, 16, 1};
83 case U32:
84 return DLDataType{kDLUInt, 32, 1};
85 case U64:
86 return DLDataType{kDLUInt, 64, 1};
87 case F16:
88 return DLDataType{kDLFloat, 16, 1};
89 case F32:
90 return DLDataType{kDLFloat, 32, 1};
91 case F64:
92 return DLDataType{kDLFloat, 64, 1};
93 case BF16:
94 return DLDataType{kDLBfloat, 16, 1};
95 case PRED:
96 return DLDataType{kDLUInt, 8, 1};
97 case C64:
98 return DLDataType{kDLComplex, 64, 1};
99 case C128:
100 return DLDataType{kDLComplex, 128, 1};
101 default:
102 return Unimplemented("XLA type %s has no DLPack equivalent",
103 PrimitiveType_Name(type));
104 }
105 }
106
DLDataTypeToPrimitiveType(DLDataType type)107 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
108 if (type.lanes != 1) {
109 return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
110 type.lanes);
111 }
112 switch (type.code) {
113 case kDLInt:
114 switch (type.bits) {
115 case 8:
116 return S8;
117 case 16:
118 return S16;
119 case 32:
120 return S32;
121 case 64:
122 return S64;
123 default:
124 return Unimplemented(
125 "Invalid or unsupported DLPack integer width: %d bits",
126 type.bits);
127 }
128 case kDLUInt:
129 switch (type.bits) {
130 case 8:
131 return U8;
132 case 16:
133 return U16;
134 case 32:
135 return U32;
136 case 64:
137 return U64;
138 default:
139 return Unimplemented(
140 "Invalid or unsupported DLPack unsigned integer width: %d bits",
141 type.bits);
142 }
143 case kDLFloat:
144 switch (type.bits) {
145 case 16:
146 return F16;
147 case 32:
148 return F32;
149 case 64:
150 return F64;
151 default:
152 return Unimplemented(
153 "Invalid or unsupported DLPack float width: %d bits", type.bits);
154 }
155 case kDLBfloat:
156 switch (type.bits) {
157 case 16:
158 return BF16;
159 default:
160 return Unimplemented(
161 "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
162 }
163 case kDLComplex:
164 switch (type.bits) {
165 case 64:
166 return C64;
167 case 128:
168 return C128;
169 default:
170 return Unimplemented(
171 "Invalid or unsupported DLPack complex width: %d bits",
172 type.bits);
173 }
174 default:
175 return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
176 }
177 }
178
179 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)180 std::vector<int64_t> StridesForShape(const Shape& shape) {
181 std::vector<int64_t> strides;
182 CHECK(shape.IsArray());
183 CHECK(shape.has_layout());
184
185 strides.resize(shape.dimensions_size());
186 int64_t stride = 1;
187 for (int i : shape.layout().minor_to_major()) {
188 strides.at(i) = stride;
189 stride *= shape.dimensions(i);
190 }
191 return strides;
192 }
193
StridesToLayout(absl::Span<int64_t const> dims,absl::Span<int64_t const> strides)194 StatusOr<std::vector<int64_t>> StridesToLayout(
195 absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
196 CHECK_EQ(dims.size(), strides.size());
197 std::vector<int64_t> minor_to_major(dims.size());
198 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
199 absl::c_sort(minor_to_major, [&](int a, int b) {
200 if (strides[a] < strides[b]) {
201 return true;
202 }
203 if (strides[a] > strides[b]) {
204 return false;
205 }
206 return dims[a] == 1 && dims[b] != 1;
207 });
208 int64_t stride = 1;
209 for (int64_t d : minor_to_major) {
210 if (strides[d] != stride) {
211 return Unimplemented(
212 "Only DLPack tensors with trivial (compact) striding are supported; "
213 "i.e., tensors whose striding represents a transposition of the "
214 "underlying buffer but not broadcasting. Dimensions were: [%s], "
215 "strides were [%s].",
216 absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
217 }
218 stride *= dims[d];
219 }
220 return minor_to_major;
221 }
222
DLDeviceTypeForDevice(const PjRtDevice & device)223 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
224 if (device.client()->platform_id() == CpuId()) {
225 return kDLCPU;
226 } else if (device.client()->platform_id() == GpuId()) {
227 return kDLCUDA;
228 }
229 return InvalidArgument("Device %s cannot be used as a DLPack device.",
230 device.DebugString());
231 }
232
DLDeviceForDevice(const PjRtDevice & device)233 StatusOr<DLDevice> DLDeviceForDevice(const PjRtDevice& device) {
234 DLDevice context;
235 TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
236 context.device_id = device.local_hardware_id();
237 return context;
238 }
239
DeviceForDLDevice(const PjRtClient * cpu_client,const PjRtClient * gpu_client,const DLDevice & context)240 StatusOr<PjRtDevice*> DeviceForDLDevice(const PjRtClient* cpu_client,
241 const PjRtClient* gpu_client,
242 const DLDevice& context) {
243 switch (context.device_type) {
244 case kDLCPU:
245 if (cpu_client == nullptr) {
246 return InvalidArgument(
247 "DLPack tensor is on CPU, but no CPU backend was provided.");
248 }
249 TF_RET_CHECK(cpu_client->platform_id() == CpuId());
250 return cpu_client->LookupAddressableDevice(context.device_id);
251 case kDLCUDA:
252 if (gpu_client == nullptr) {
253 return InvalidArgument(
254 "DLPack tensor is on GPU, but no GPU backend was provided.");
255 }
256 TF_RET_CHECK(gpu_client->platform_id() == GpuId());
257 return gpu_client->LookupAddressableDevice(context.device_id);
258 default:
259 return InvalidArgument("Unknown/unsupported DLPack device type %d",
260 context.device_type);
261 }
262 }
263
264 } // namespace
265
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)266 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
267 bool take_ownership) {
268 TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(py_buffer));
269 auto pack = std::make_unique<DLPackTensor>();
270 if (buffer->buffer()->on_device_shape().IsTuple()) {
271 return Unimplemented(
272 "unsafe_buffer_pointer is not implemented for tuple "
273 "buffers.");
274 }
275 if (buffer->buffer()->on_device_shape().is_dynamic()) {
276 return Unimplemented("DynamicShape is not implemented in DLPack.");
277 }
278
279 DLTensor& dt = pack->tensor.dl_tensor;
280 if (take_ownership) {
281 // Block on outstanding operations, so that it is safe to read or mutate the
282 // returned buffer.
283 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
284 buffer->buffer()->ReleaseDeviceMemoryOwnership(
285 /*wait_for_operations_to_complete=*/true);
286 if (!buffer_or.ok()) {
287 return InvalidArgument(
288 "Buffer synchronization failed converting to DLPack tensor: %s",
289 buffer_or.status().ToString());
290 }
291 pack->external_reference = std::move(buffer_or).value();
292 if (!pack->external_reference) {
293 return InvalidArgument(
294 "Cannot convert deleted/invalid buffer to DLPack tensor.");
295 }
296 } else {
297 // Block on outstanding operations, so that it is safe to read or mutate the
298 // returned buffer.
299 TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
300 pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
301 TF_ASSIGN_OR_RETURN(pack->external_reference,
302 buffer->buffer()->AcquireExternalReference());
303 }
304 dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
305 pack->tensor.manager_ctx = pack.get();
306 pack->tensor.deleter = DLPackTensorDeleter;
307 TF_ASSIGN_OR_RETURN(dt.device,
308 DLDeviceForDevice(*buffer->buffer()->device()));
309 dt.device.device_id = buffer->buffer()->device()->local_hardware_id();
310 dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
311 TF_ASSIGN_OR_RETURN(dt.dtype,
312 PrimitiveTypeToDLDataType(
313 buffer->buffer()->on_device_shape().element_type()));
314
315 pack->shape = std::vector<int64_t>(
316 buffer->buffer()->on_device_shape().dimensions().begin(),
317 buffer->buffer()->on_device_shape().dimensions().end());
318 pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
319 dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
320 dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
321 dt.byte_offset = 0;
322
323 py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
324 [](PyObject* obj) {
325 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
326 PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
327 if (dlmt) {
328 DLPackTensorDeleter(dlmt);
329 } else {
330 // The tensor has been deleted. Clear any error from
331 // PyCapsule_GetPointer.
332 PyErr_Clear();
333 }
334 });
335 return capsule;
336 }
337
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> cpu_client,std::shared_ptr<PyClient> gpu_client)338 StatusOr<PyBuffer::object> DLPackManagedTensorToBuffer(
339 const pybind11::capsule& tensor, std::shared_ptr<PyClient> cpu_client,
340 std::shared_ptr<PyClient> gpu_client) {
341 // Backward compatibility: if only one client is passed, it may be from any
342 // platform. Drop this support after dropping support for jax <= 0.2.14.
343 if (cpu_client && cpu_client->pjrt_client()->platform_id() == GpuId()) {
344 gpu_client = std::move(cpu_client);
345 cpu_client = nullptr;
346 }
347 if (cpu_client && cpu_client->pjrt_client()->platform_id() != CpuId()) {
348 return InvalidArgument("DLPack does not support platform %s",
349 cpu_client->pjrt_client()->platform_name());
350 }
351
352 if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
353 return InvalidArgument(
354 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
355 "Note that a DLPack tensor may be consumed at most once.",
356 absl::string_view(tensor.name()));
357 }
358 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
359 if (dlmt->dl_tensor.ndim < 0) {
360 return InvalidArgument(
361 "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
362 dlmt->dl_tensor.ndim);
363 }
364 TF_ASSIGN_OR_RETURN(
365 PjRtDevice * device,
366 DeviceForDLDevice(cpu_client ? cpu_client->pjrt_client() : nullptr,
367 gpu_client ? gpu_client->pjrt_client() : nullptr,
368 dlmt->dl_tensor.device));
369 absl::Span<int64_t const> dimensions(
370 reinterpret_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
371 TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
372 DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
373
374 std::vector<int64_t> minor_to_major;
375 if (dlmt->dl_tensor.strides &&
376 absl::c_find(dimensions, 0) == dimensions.end()) {
377 absl::Span<int64_t const> strides(
378 reinterpret_cast<int64_t*>(dlmt->dl_tensor.strides),
379 dlmt->dl_tensor.ndim);
380 TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
381 } else {
382 minor_to_major.resize(dlmt->dl_tensor.ndim);
383 std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
384 }
385 Shape shape =
386 ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
387
388 std::function<void()> on_delete_callback;
389 if (dlmt->deleter) {
390 on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
391 }
392 TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
393 device->client()->CreateViewOfDeviceBuffer(
394 static_cast<char*>(dlmt->dl_tensor.data) +
395 dlmt->dl_tensor.byte_offset,
396 shape, device, on_delete_callback));
397 // We have taken ownership of the array inside the capsule; make sure the
398 // capsule it cannot be used again.
399 PyCapsule_SetName(tensor.ptr(), "used_dltensor");
400 PyCapsule_SetDestructor(tensor.ptr(), nullptr);
401 // TODO(phawkins): simplify the expression below once we know cpu_client is
402 // always non-null.
403 return PyBuffer::Make(
404 (cpu_client && device->client() == cpu_client->pjrt_client())
405 ? std::move(cpu_client)
406 : std::move(gpu_client),
407 std::move(pjrt_buffer), Traceback::Get());
408 }
409
410 } // namespace xla
411