xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/dlpack.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/span.h"
27 #include "include/dlpack/dlpack.h"  // from @dlpack
28 #include "pybind11/pytypes.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
31 #include "tensorflow/compiler/xla/python/traceback.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 
35 namespace py = pybind11;
36 
37 namespace xla {
38 namespace {
39 
40 const char* const kDlTensorCapsuleName = "dltensor";
41 
42 struct DLPackTensor {
43   ~DLPackTensor();
44 
45   // `buffer_reference` is populated if we have shared (read-only) access.
46   py::object buffer_reference;
47 
48   // `external_reference` is always populated.
49   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
50 
51   std::vector<int64_t> shape;
52   std::vector<int64_t> strides;
53   DLManagedTensor tensor;
54 };
55 
~DLPackTensor()56 DLPackTensor::~DLPackTensor() {
57   if (buffer_reference) {
58     GlobalPyRefManager()->AddGarbage(
59         absl::MakeSpan(&buffer_reference, /*size=*/1));
60   }
61 }
62 
DLPackTensorDeleter(DLManagedTensor * t)63 void DLPackTensorDeleter(DLManagedTensor* t) {
64   if (t) {
65     delete static_cast<DLPackTensor*>(t->manager_ctx);
66   }
67 }
68 
PrimitiveTypeToDLDataType(PrimitiveType type)69 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
70   switch (type) {
71     case S8:
72       return DLDataType{kDLInt, 8, 1};
73     case S16:
74       return DLDataType{kDLInt, 16, 1};
75     case S32:
76       return DLDataType{kDLInt, 32, 1};
77     case S64:
78       return DLDataType{kDLInt, 64, 1};
79     case U8:
80       return DLDataType{kDLUInt, 8, 1};
81     case U16:
82       return DLDataType{kDLUInt, 16, 1};
83     case U32:
84       return DLDataType{kDLUInt, 32, 1};
85     case U64:
86       return DLDataType{kDLUInt, 64, 1};
87     case F16:
88       return DLDataType{kDLFloat, 16, 1};
89     case F32:
90       return DLDataType{kDLFloat, 32, 1};
91     case F64:
92       return DLDataType{kDLFloat, 64, 1};
93     case BF16:
94       return DLDataType{kDLBfloat, 16, 1};
95     case PRED:
96       return DLDataType{kDLUInt, 8, 1};
97     case C64:
98       return DLDataType{kDLComplex, 64, 1};
99     case C128:
100       return DLDataType{kDLComplex, 128, 1};
101     default:
102       return Unimplemented("XLA type %s has no DLPack equivalent",
103                            PrimitiveType_Name(type));
104   }
105 }
106 
DLDataTypeToPrimitiveType(DLDataType type)107 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
108   if (type.lanes != 1) {
109     return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
110                          type.lanes);
111   }
112   switch (type.code) {
113     case kDLInt:
114       switch (type.bits) {
115         case 8:
116           return S8;
117         case 16:
118           return S16;
119         case 32:
120           return S32;
121         case 64:
122           return S64;
123         default:
124           return Unimplemented(
125               "Invalid or unsupported DLPack integer width: %d bits",
126               type.bits);
127       }
128     case kDLUInt:
129       switch (type.bits) {
130         case 8:
131           return U8;
132         case 16:
133           return U16;
134         case 32:
135           return U32;
136         case 64:
137           return U64;
138         default:
139           return Unimplemented(
140               "Invalid or unsupported DLPack unsigned integer width: %d bits",
141               type.bits);
142       }
143     case kDLFloat:
144       switch (type.bits) {
145         case 16:
146           return F16;
147         case 32:
148           return F32;
149         case 64:
150           return F64;
151         default:
152           return Unimplemented(
153               "Invalid or unsupported DLPack float width: %d bits", type.bits);
154       }
155     case kDLBfloat:
156       switch (type.bits) {
157         case 16:
158           return BF16;
159         default:
160           return Unimplemented(
161               "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
162       }
163     case kDLComplex:
164       switch (type.bits) {
165         case 64:
166           return C64;
167         case 128:
168           return C128;
169         default:
170           return Unimplemented(
171               "Invalid or unsupported DLPack complex width: %d bits",
172               type.bits);
173       }
174     default:
175       return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
176   }
177 }
178 
179 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)180 std::vector<int64_t> StridesForShape(const Shape& shape) {
181   std::vector<int64_t> strides;
182   CHECK(shape.IsArray());
183   CHECK(shape.has_layout());
184 
185   strides.resize(shape.dimensions_size());
186   int64_t stride = 1;
187   for (int i : shape.layout().minor_to_major()) {
188     strides.at(i) = stride;
189     stride *= shape.dimensions(i);
190   }
191   return strides;
192 }
193 
StridesToLayout(absl::Span<int64_t const> dims,absl::Span<int64_t const> strides)194 StatusOr<std::vector<int64_t>> StridesToLayout(
195     absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
196   CHECK_EQ(dims.size(), strides.size());
197   std::vector<int64_t> minor_to_major(dims.size());
198   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
199   absl::c_sort(minor_to_major, [&](int a, int b) {
200     if (strides[a] < strides[b]) {
201       return true;
202     }
203     if (strides[a] > strides[b]) {
204       return false;
205     }
206     return dims[a] == 1 && dims[b] != 1;
207   });
208   int64_t stride = 1;
209   for (int64_t d : minor_to_major) {
210     if (strides[d] != stride) {
211       return Unimplemented(
212           "Only DLPack tensors with trivial (compact) striding are supported; "
213           "i.e., tensors whose striding represents a transposition of the "
214           "underlying buffer but not broadcasting. Dimensions were: [%s], "
215           "strides were [%s].",
216           absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
217     }
218     stride *= dims[d];
219   }
220   return minor_to_major;
221 }
222 
DLDeviceTypeForDevice(const PjRtDevice & device)223 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
224   if (device.client()->platform_id() == CpuId()) {
225     return kDLCPU;
226   } else if (device.client()->platform_id() == GpuId()) {
227     return kDLCUDA;
228   }
229   return InvalidArgument("Device %s cannot be used as a DLPack device.",
230                          device.DebugString());
231 }
232 
DLDeviceForDevice(const PjRtDevice & device)233 StatusOr<DLDevice> DLDeviceForDevice(const PjRtDevice& device) {
234   DLDevice context;
235   TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
236   context.device_id = device.local_hardware_id();
237   return context;
238 }
239 
DeviceForDLDevice(const PjRtClient * cpu_client,const PjRtClient * gpu_client,const DLDevice & context)240 StatusOr<PjRtDevice*> DeviceForDLDevice(const PjRtClient* cpu_client,
241                                         const PjRtClient* gpu_client,
242                                         const DLDevice& context) {
243   switch (context.device_type) {
244     case kDLCPU:
245       if (cpu_client == nullptr) {
246         return InvalidArgument(
247             "DLPack tensor is on CPU, but no CPU backend was provided.");
248       }
249       TF_RET_CHECK(cpu_client->platform_id() == CpuId());
250       return cpu_client->LookupAddressableDevice(context.device_id);
251     case kDLCUDA:
252       if (gpu_client == nullptr) {
253         return InvalidArgument(
254             "DLPack tensor is on GPU, but no GPU backend was provided.");
255       }
256       TF_RET_CHECK(gpu_client->platform_id() == GpuId());
257       return gpu_client->LookupAddressableDevice(context.device_id);
258     default:
259       return InvalidArgument("Unknown/unsupported DLPack device type %d",
260                              context.device_type);
261   }
262 }
263 
264 }  // namespace
265 
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)266 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
267                                                   bool take_ownership) {
268   TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(py_buffer));
269   auto pack = std::make_unique<DLPackTensor>();
270   if (buffer->buffer()->on_device_shape().IsTuple()) {
271     return Unimplemented(
272         "unsafe_buffer_pointer is not implemented for tuple "
273         "buffers.");
274   }
275   if (buffer->buffer()->on_device_shape().is_dynamic()) {
276     return Unimplemented("DynamicShape is not implemented in DLPack.");
277   }
278 
279   DLTensor& dt = pack->tensor.dl_tensor;
280   if (take_ownership) {
281     // Block on outstanding operations, so that it is safe to read or mutate the
282     // returned buffer.
283     StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
284         buffer->buffer()->ReleaseDeviceMemoryOwnership(
285             /*wait_for_operations_to_complete=*/true);
286     if (!buffer_or.ok()) {
287       return InvalidArgument(
288           "Buffer synchronization failed converting to DLPack tensor: %s",
289           buffer_or.status().ToString());
290     }
291     pack->external_reference = std::move(buffer_or).value();
292     if (!pack->external_reference) {
293       return InvalidArgument(
294           "Cannot convert deleted/invalid buffer to DLPack tensor.");
295     }
296   } else {
297     // Block on outstanding operations, so that it is safe to read or mutate the
298     // returned buffer.
299     TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
300     pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
301     TF_ASSIGN_OR_RETURN(pack->external_reference,
302                         buffer->buffer()->AcquireExternalReference());
303   }
304   dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
305   pack->tensor.manager_ctx = pack.get();
306   pack->tensor.deleter = DLPackTensorDeleter;
307   TF_ASSIGN_OR_RETURN(dt.device,
308                       DLDeviceForDevice(*buffer->buffer()->device()));
309   dt.device.device_id = buffer->buffer()->device()->local_hardware_id();
310   dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
311   TF_ASSIGN_OR_RETURN(dt.dtype,
312                       PrimitiveTypeToDLDataType(
313                           buffer->buffer()->on_device_shape().element_type()));
314 
315   pack->shape = std::vector<int64_t>(
316       buffer->buffer()->on_device_shape().dimensions().begin(),
317       buffer->buffer()->on_device_shape().dimensions().end());
318   pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
319   dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
320   dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
321   dt.byte_offset = 0;
322 
323   py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
324                       [](PyObject* obj) {
325                         DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
326                             PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
327                         if (dlmt) {
328                           DLPackTensorDeleter(dlmt);
329                         } else {
330                           // The tensor has been deleted. Clear any error from
331                           // PyCapsule_GetPointer.
332                           PyErr_Clear();
333                         }
334                       });
335   return capsule;
336 }
337 
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> cpu_client,std::shared_ptr<PyClient> gpu_client)338 StatusOr<PyBuffer::object> DLPackManagedTensorToBuffer(
339     const pybind11::capsule& tensor, std::shared_ptr<PyClient> cpu_client,
340     std::shared_ptr<PyClient> gpu_client) {
341   // Backward compatibility: if only one client is passed, it may be from any
342   // platform. Drop this support after dropping support for jax <= 0.2.14.
343   if (cpu_client && cpu_client->pjrt_client()->platform_id() == GpuId()) {
344     gpu_client = std::move(cpu_client);
345     cpu_client = nullptr;
346   }
347   if (cpu_client && cpu_client->pjrt_client()->platform_id() != CpuId()) {
348     return InvalidArgument("DLPack does not support platform %s",
349                            cpu_client->pjrt_client()->platform_name());
350   }
351 
352   if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
353     return InvalidArgument(
354         "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
355         "Note that a DLPack tensor may be consumed at most once.",
356         absl::string_view(tensor.name()));
357   }
358   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
359   if (dlmt->dl_tensor.ndim < 0) {
360     return InvalidArgument(
361         "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
362         dlmt->dl_tensor.ndim);
363   }
364   TF_ASSIGN_OR_RETURN(
365       PjRtDevice * device,
366       DeviceForDLDevice(cpu_client ? cpu_client->pjrt_client() : nullptr,
367                         gpu_client ? gpu_client->pjrt_client() : nullptr,
368                         dlmt->dl_tensor.device));
369   absl::Span<int64_t const> dimensions(
370       reinterpret_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
371   TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
372                       DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
373 
374   std::vector<int64_t> minor_to_major;
375   if (dlmt->dl_tensor.strides &&
376       absl::c_find(dimensions, 0) == dimensions.end()) {
377     absl::Span<int64_t const> strides(
378         reinterpret_cast<int64_t*>(dlmt->dl_tensor.strides),
379         dlmt->dl_tensor.ndim);
380     TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
381   } else {
382     minor_to_major.resize(dlmt->dl_tensor.ndim);
383     std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
384   }
385   Shape shape =
386       ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
387 
388   std::function<void()> on_delete_callback;
389   if (dlmt->deleter) {
390     on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
391   }
392   TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
393                       device->client()->CreateViewOfDeviceBuffer(
394                           static_cast<char*>(dlmt->dl_tensor.data) +
395                               dlmt->dl_tensor.byte_offset,
396                           shape, device, on_delete_callback));
397   // We have taken ownership of the array inside the capsule; make sure the
398   // capsule it cannot be used again.
399   PyCapsule_SetName(tensor.ptr(), "used_dltensor");
400   PyCapsule_SetDestructor(tensor.ptr(), nullptr);
401   // TODO(phawkins): simplify the expression below once we know cpu_client is
402   // always non-null.
403   return PyBuffer::Make(
404       (cpu_client && device->client() == cpu_client->pjrt_client())
405           ? std::move(cpu_client)
406           : std::move(gpu_client),
407       std::move(pjrt_buffer), Traceback::Get());
408 }
409 
410 }  // namespace xla
411