xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_client.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/base/casts.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/numbers.h"
25 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
26 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
28 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
29 #include "tensorflow/compiler/xla/python/callback.h"
30 #include "tensorflow/compiler/xla/python/exceptions.h"
31 #include "tensorflow/compiler/xla/python/pprof_profile_builder.h"
32 #include "tensorflow/compiler/xla/python/py_buffer.h"
33 #include "tensorflow/compiler/xla/python/py_executable.h"
34 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
35 #include "tensorflow/compiler/xla/python/traceback.h"
36 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
37 #include "tensorflow/compiler/xla/python/types.h"
38 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
39 #include "tensorflow/core/platform/statusor.h"
40 
41 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 #include "tensorflow/compiler/xla/python/py_client_gpu.h"
43 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
44 
45 namespace xla {
46 
47 namespace py = pybind11;
48 
PyClient(std::unique_ptr<PjRtClient> pjrt_client)49 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
50     : PyClient(std::shared_ptr<PjRtClient>(std::move(pjrt_client))) {}
51 
PyClient(std::shared_ptr<PjRtClient> pjrt_client)52 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
53     : pjrt_client_(std::move(pjrt_client)) {
54   CHECK(pjrt_client_ != nullptr);
55   buffers_.resize(pjrt_client_->device_count());
56   for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
57     if (device->id() >= buffers_.size()) {
58       buffers_.resize(device->id() + 1);
59     }
60   }
61 }
62 
~PyClient()63 PyClient::~PyClient() {
64   py::gil_scoped_release gil;
65   pjrt_client_ = nullptr;
66 }
67 
Devices()68 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
69   std::vector<ClientAndPtr<PjRtDevice>> devices;
70   auto span = pjrt_client_->devices();
71   devices.reserve(span.size());
72   for (PjRtDevice* device : span) {
73     devices.push_back(WrapWithClient(shared_from_this(), device));
74   }
75   return devices;
76 }
77 
LocalDevices()78 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
79   std::vector<ClientAndPtr<PjRtDevice>> devices;
80   devices.reserve(pjrt_client_->addressable_devices().size());
81   for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
82     devices.push_back(WrapWithClient(shared_from_this(), device));
83   }
84   return devices;
85 }
86 
LiveBuffers()87 std::vector<py::object> PyClient::LiveBuffers() {
88   CHECK(PyGILState_Check());
89   std::vector<py::object> buffers;
90   for (PyBuffer* device_buffers : buffers_) {
91     for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
92       if (!buffer->is_deleted()) {
93         buffers.push_back(
94             py::reinterpret_borrow<py::object>(buffer->AsHandle()));
95       }
96     }
97   }
98   return buffers;
99 }
100 
LiveBuffersOnDevice(PjRtDevice * device)101 std::vector<py::object> PyClient::LiveBuffersOnDevice(PjRtDevice* device) {
102   CHECK_EQ(device->client(), pjrt_client());
103   CHECK(PyGILState_Check());
104   std::vector<py::object> buffers;
105   for (PyBuffer* buffer = buffers_[device->id()]; buffer;
106        buffer = buffer->next_) {
107     if (!buffer->is_deleted()) {
108       buffers.push_back(py::reinterpret_borrow<py::object>(buffer->AsHandle()));
109     }
110   }
111   return buffers;
112 }
113 
LiveExecutables()114 std::vector<std::shared_ptr<PyExecutable>> PyClient::LiveExecutables() {
115   CHECK(PyGILState_Check());
116   std::vector<std::shared_ptr<PyExecutable>> executables;
117   for (PyExecutable* exec = executables_; exec; exec = exec->next_) {
118     if (!exec->is_deleted()) {
119       executables.push_back(exec->shared_from_this());
120     }
121   }
122   return executables;
123 }
124 
Defragment()125 Status PyClient::Defragment() {
126   CHECK(PyGILState_Check());
127   switch (pjrt_client_->runtime_type()) {
128     case PjRtRuntimeType::kTfrt:
129       return pjrt_client_->Defragment();
130     case PjRtRuntimeType::kStreamExecutor:
131       struct TmpBuffer {
132         PyBuffer* py_buffer;
133         // TODO(skyewm): maybe use py_buffer's HostValue
134         std::shared_ptr<Literal> host_copy;
135       };
136 
137       // Synchronously copy all buffers to host
138       std::vector<TmpBuffer> tmp_buffers;
139       for (PyBuffer* device_buffers : buffers_) {
140         for (PyBuffer* buffer = device_buffers; buffer;
141              buffer = buffer->next_) {
142           if (!buffer->is_deleted()) {
143             TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal,
144                                 buffer->buffer_->ToLiteralSync());
145             tmp_buffers.push_back({buffer, literal});
146           }
147         }
148       }
149 
150       // All buffers successfully copied to host, delete on-device copies.
151       //
152       // Use blocking delete operation to ensure all memory is actually cleared
153       // before we start rewriting buffers.
154       //
155       // Die instead of returning a bad status because program presumably can't
156       // continue if we fail to reconstitute device buffers.
157       for (TmpBuffer& tmp_buffer : tmp_buffers) {
158         TF_CHECK_OK(tensorflow::down_cast<PjRtStreamExecutorBuffer*>(
159                         tmp_buffer.py_buffer->buffer_.get())
160                         ->Release(/*wait_for_operations_to_complete=*/true)
161                         .status());
162       }
163 
164       // Copy host copies back to device and update PyBuffers in-place.
165       for (TmpBuffer& tmp_buffer : tmp_buffers) {
166         std::unique_ptr<PjRtBuffer> new_copy =
167             pjrt_client_
168                 ->BufferFromHostLiteral(*tmp_buffer.host_copy,
169                                         tmp_buffer.py_buffer->buffer_->device())
170                 .ValueOrDie();
171         TF_CHECK_OK(new_copy->BlockHostUntilReady());
172         tmp_buffer.py_buffer->buffer_.reset(new_copy.release());
173       }
174 
175       // TODO(skyewm): delete executables?
176   }
177   return OkStatus();
178 }
179 
180 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)181 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
182   TF_ASSIGN_OR_RETURN(
183       DeviceAssignment device_assignment,
184       pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
185   std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
186   result.resize(num_replicas);
187   for (int r = 0; r < num_replicas; ++r) {
188     result[r].resize(num_partitions);
189     for (int p = 0; p < num_partitions; ++p) {
190       int device_id = device_assignment(r, p);
191       TF_ASSIGN_OR_RETURN(PjRtDevice * device,
192                           pjrt_client_->LookupDevice(device_id));
193       result[r][p] = WrapWithClient(shared_from_this(), device);
194     }
195   }
196   return result;
197 }
198 
199 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)200 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
201   TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
202                       pjrt_client_->GetDefaultDeviceAssignment(
203                           num_replicas, /*num_partitions=*/1));
204   std::vector<ClientAndPtr<PjRtDevice>> result;
205   for (int i = 0; i < num_replicas; ++i) {
206     int device_id = device_assignment(i, 0);
207     TF_ASSIGN_OR_RETURN(PjRtDevice * device,
208                         pjrt_client_->LookupDevice(device_id));
209     result.push_back(WrapWithClient(shared_from_this(), device));
210   }
211   return result;
212 }
213 
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)214 StatusOr<py::object> PyClient::BufferFromPyval(
215     pybind11::handle argument, PjRtDevice* device, bool force_copy,
216     PjRtClient::HostBufferSemantics host_buffer_semantics) {
217   if (device == nullptr) {
218     TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
219     device = pjrt_client_->addressable_devices().front();
220   }
221   CHECK(device != nullptr);
222 
223   auto transfer_guard_formatter = [&argument, dst_device = device] {
224     auto type = py::cast<std::string>(py::str(argument.get_type()));
225     // Catch exceptions because shape and dtype properties convertible to str
226     // are not guaranteed to present in an arbitrary argument.
227     std::string shape;
228     std::string dtype;
229     try {
230       shape = py::cast<std::string>(py::str(argument.attr("shape")));
231     } catch (const std::exception& e) {
232       shape = "<unknown>";
233     }
234     try {
235       dtype = py::cast<std::string>(py::str(argument.attr("dtype")));
236     } catch (const std::exception& e) {
237       dtype = "<unknown>";
238     }
239     return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype,
240                         ", dst_device=", dst_device->DebugString());
241   };
242   TF_RETURN_IF_ERROR(
243       jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter));
244 
245   TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
246                       pjrt_client_->LookupDevice(device->id()));
247   if (found_device != device) {
248     return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
249                            device->DebugString(),
250                            pjrt_client_->platform_name());
251   }
252   GlobalPyRefManager()->CollectGarbage();
253 
254   DevicePutOptions options;
255   options.squash_64bit_types = false;
256   options.allow_zero_copy =
257       (!force_copy &&
258        (host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy));
259   TF_ASSIGN_OR_RETURN(DevicePutResult put,
260                       DevicePut(argument, device, options));
261 
262   if (put.owned_buffer) {
263     auto traceback = Traceback::Get();
264     return PyBuffer::Make(shared_from_this(), std::move(put.owned_buffer),
265                           std::move(traceback));
266   } else {
267     return py::reinterpret_borrow<py::object>(put.owning_pybuffer);
268   }
269 }
270 
271 StatusOr<std::vector<std::pair<pybind11::bytes, pybind11::object>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device)272 PyClient::MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
273                                       PjRtDevice* device) {
274   CHECK(device != nullptr);
275   absl::Mutex mu;
276   StatusOr<std::vector<PjRtCrossHostRecvDescriptors>> recv_descriptors_or;
277   bool done = false;
278 
279   TF_ASSIGN_OR_RETURN(
280       auto buffers, pjrt_client_->MakeCrossHostReceiveBuffers(
281                         shapes, device,
282                         [&done, &recv_descriptors_or,
283                          &mu](StatusOr<PjRtCrossHostRecvState> recv_state_or) {
284                           absl::MutexLock l(&mu);
285                           if (recv_state_or.ok()) {
286                             py::gil_scoped_acquire gil;
287                             recv_descriptors_or =
288                                 std::move(recv_state_or->descriptors);
289                           } else {
290                             recv_descriptors_or = recv_state_or.status();
291                           }
292                           done = true;
293                         }));
294 
295   {
296     py::gil_scoped_release gil_release;
297     absl::MutexLock l(&mu);
298     mu.Await(absl::Condition(&done));
299   }
300 
301   TF_RETURN_IF_ERROR(recv_descriptors_or.status());
302   CHECK_EQ(buffers.size(), recv_descriptors_or->size());
303   std::vector<std::pair<pybind11::bytes, pybind11::object>> result;
304   result.reserve(buffers.size());
305   for (int i = 0; i < buffers.size(); ++i) {
306     auto& descriptors = recv_descriptors_or->at(i);
307     CHECK_EQ(descriptors.serialized_descriptors.size(), 1);
308     const std::string& desc = descriptors.serialized_descriptors[0];
309     pybind11::bytes py_desc = pybind11::bytes(desc);
310     auto traceback = Traceback::Get();
311     auto py_buf =
312         PyBuffer::Make(shared_from_this(), std::move(buffers[i]), traceback);
313     result.push_back(std::make_pair(std::move(py_desc), std::move(py_buf)));
314   }
315   return result;
316 }
317 
Compile(const XlaComputation & computation,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)318 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
319     const XlaComputation& computation, CompileOptions options,
320     std::vector<pybind11::capsule> host_callbacks) {
321   std::unique_ptr<PjRtLoadedExecutable> executable;
322   std::optional<std::string> fingerprint;
323   {
324     py::gil_scoped_release gil_release;
325     TF_ASSIGN_OR_RETURN(executable,
326                         pjrt_client_->Compile(computation, std::move(options)));
327     TF_ASSIGN_OR_RETURN(fingerprint,
328                         pjrt_client_->ExecutableFingerprint(*executable));
329   }
330   auto traceback = Traceback::Get();
331   return std::make_shared<PyExecutable>(
332       shared_from_this(), std::move(executable), std::move(traceback),
333       std::move(fingerprint), std::move(host_callbacks));
334 }
335 
CompileMlir(std::string mlir_module,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)336 StatusOr<std::shared_ptr<PyExecutable>> PyClient::CompileMlir(
337     std::string mlir_module, CompileOptions options,
338     std::vector<pybind11::capsule> host_callbacks) {
339   std::unique_ptr<PjRtLoadedExecutable> executable;
340   std::optional<std::string> fingerprint;
341   {
342     py::gil_scoped_release gil_release;
343     mlir::MLIRContext context;
344     TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
345                         ParseMlirModuleString(mlir_module, context));
346     TF_ASSIGN_OR_RETURN(
347         executable, pjrt_client_->Compile(module.get(), std::move(options)));
348     TF_ASSIGN_OR_RETURN(fingerprint,
349                         pjrt_client_->ExecutableFingerprint(*executable));
350   }
351   auto traceback = Traceback::Get();
352   return std::make_shared<PyExecutable>(
353       shared_from_this(), std::move(executable), std::move(traceback),
354       std::move(fingerprint), std::move(host_callbacks));
355 }
356 
SerializeExecutable(const PyExecutable & executable) const357 StatusOr<py::bytes> PyClient::SerializeExecutable(
358     const PyExecutable& executable) const {
359   return pjrt_client_->SerializeExecutable(executable.pjrt_executable());
360 }
361 
DeserializeExecutable(const std::string & serialized,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)362 StatusOr<std::shared_ptr<PyExecutable>> PyClient::DeserializeExecutable(
363     const std::string& serialized, CompileOptions options,
364     std::vector<pybind11::capsule> host_callbacks) {
365   std::unique_ptr<PjRtLoadedExecutable> executable;
366   std::optional<std::string> fingerprint;
367   {
368     py::gil_scoped_release gil_release;
369     TF_ASSIGN_OR_RETURN(executable, pjrt_client_->DeserializeExecutable(
370                                         serialized, std::move(options)));
371     TF_ASSIGN_OR_RETURN(fingerprint,
372                         pjrt_client_->ExecutableFingerprint(*executable));
373   }
374   auto traceback = Traceback::Get();
375   return std::make_shared<PyExecutable>(
376       shared_from_this(), std::move(executable), std::move(traceback),
377       std::move(fingerprint), std::move(host_callbacks));
378 }
379 
380 namespace {
381 
382 struct HeapProfileKey {
383   Traceback* traceback;
384   int64_t size;
385   PjRtDevice* device;
386   bool operator==(const HeapProfileKey& other) const;
387 };
388 
operator ==(const HeapProfileKey & other) const389 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
390   if (size != other.size || device != other.device) {
391     return false;
392   }
393   if ((traceback == nullptr) != (other.traceback == nullptr)) {
394     return false;
395   }
396   if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
397     return false;
398   }
399   return true;
400 }
401 
402 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)403 H AbslHashValue(H h, const HeapProfileKey& key) {
404   if (key.traceback) {
405     h = H::combine(std::move(h), key.traceback->raw_frames());
406   }
407   h = H::combine(std::move(h), key.size, key.device);
408   return h;
409 }
410 
411 }  // namespace
412 
HeapProfile()413 StatusOr<py::bytes> PyClient::HeapProfile() {
414   CHECK(PyGILState_Check());
415   absl::flat_hash_set<PjRtBuffer*> buffer_set;
416   absl::flat_hash_map<HeapProfileKey, int64_t> entries;
417   for (PyBuffer* device_buffers : buffers_) {
418     for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
419       // We only wish to count each PjRtBuffer once, even though they may be
420       // shared by multiple PyBuffers.
421       if (!buffer->is_deleted() && buffer_set.insert(buffer->buffer()).second) {
422         TF_ASSIGN_OR_RETURN(size_t size,
423                             buffer->buffer()->GetOnDeviceSizeInBytes());
424         HeapProfileKey key{buffer->traceback().get(),
425                            static_cast<int64_t>(size),
426                            buffer->buffer()->device()};
427         ++entries[key];
428       }
429     }
430   }
431 
432   for (PyExecutable* executable = executables_; executable;
433        executable = executable->next_) {
434     if (!executable->is_deleted()) {
435       HeapProfileKey key{executable->traceback(),
436                          executable->SizeOfGeneratedCodeInBytes(), nullptr};
437       ++entries[key];
438     }
439   }
440 
441   PprofProfileBuilder builder;
442   auto* allocations = builder.profile().add_sample_type();
443   allocations->set_type(builder.StringId("allocations"));
444   allocations->set_unit(builder.StringId("count"));
445   auto* space = builder.profile().add_sample_type();
446   space->set_type(builder.StringId("space"));
447   space->set_unit(builder.StringId("bytes"));
448 
449   const int kind_string_id = builder.StringId("kind");
450   const int buffer_string_id = builder.StringId("buffer");
451   const int executable_string_id = builder.StringId("executable");
452   const int device_string_id = builder.StringId("device");
453   for (const auto& entry : entries) {
454     auto* sample = builder.profile().add_sample();
455     if (entry.first.traceback) {
456       for (const auto& frame : entry.first.traceback->raw_frames()) {
457         sample->add_location_id(builder.LocationId(frame.first, frame.second));
458       }
459     }
460     sample->add_value(entry.second);
461     sample->add_value(entry.first.size * entry.second);
462 
463     auto* kind_label = sample->add_label();
464     kind_label->set_key(kind_string_id);
465     if (entry.first.device) {
466       kind_label->set_str(buffer_string_id);
467       auto* device_label = sample->add_label();
468       device_label->set_key(device_string_id);
469       device_label->set_str(
470           builder.StringId(std::string(entry.first.device->DebugString())));
471     } else {
472       kind_label->set_str(executable_string_id);
473     }
474   }
475   return py::bytes(builder.profile().SerializeAsString());
476 }
477 
478 namespace {
479 
CreateCallbackArgs(absl::Span<Shape const> operand_shapes)480 StatusOr<std::vector<CpuCallback::Arg>> CreateCallbackArgs(
481     absl::Span<Shape const> operand_shapes) {
482   std::vector<CpuCallback::Arg> callback_args(operand_shapes.size());
483   for (int i = 0; i < operand_shapes.size(); ++i) {
484     Shape shape = operand_shapes[i];
485 
486     if (shape.IsArray()) {
487       Shape layout =
488           (shape.has_layout() ? shape
489                               : LayoutUtil::GetWithDefaultLayout(shape));
490       callback_args[i].dims.resize(shape.dimensions_size());
491       absl::c_copy(shape.dimensions(), callback_args[i].dims.begin());
492       callback_args[i].strides = ByteStridesForShape(layout);
493       callback_args[i].type = shape.element_type();
494       callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout);
495       TF_ASSIGN_OR_RETURN(callback_args[i].dtype,
496                           PrimitiveTypeToDtype(shape.element_type()));
497     } else if (shape.IsToken()) {
498       callback_args[i].type = TOKEN;
499     } else {
500       return InvalidArgument(
501           "Only array and token arguments to Python callbacks are supported, "
502           "got %s",
503           shape.ToString());
504     }
505   }
506   return callback_args;
507 }
508 
CreateCallbackResults(absl::Span<Shape const> result_shapes)509 StatusOr<std::vector<CpuCallback::Result>> CreateCallbackResults(
510     absl::Span<Shape const> result_shapes) {
511   std::vector<CpuCallback::Result> callback_results(result_shapes.size());
512   for (int i = 0; i < result_shapes.size(); ++i) {
513     if (result_shapes[i].IsArray()) {
514       const Shape& shape =
515           result_shapes[i].has_layout()
516               ? result_shapes[i]
517               : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
518       callback_results[i].expected_dims.resize(shape.dimensions_size());
519       absl::c_copy(shape.dimensions(),
520                    callback_results[i].expected_dims.begin());
521       callback_results[i].expected_strides = ByteStridesForShapeInt64(shape);
522       callback_results[i].type = shape.element_type();
523       callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape);
524       callback_results[i].reversed_layout.resize(shape.dimensions_size());
525       absl::c_reverse_copy(shape.layout().minor_to_major(),
526                            callback_results[i].reversed_layout.begin());
527     } else if (result_shapes[i].IsToken()) {
528       callback_results[i].type = TOKEN;
529     } else {
530       return InvalidArgument(
531           "Only array and token return values from Python callbacks are "
532           "supported, got %s",
533           result_shapes[i].ToString());
534     }
535   }
536   return callback_results;
537 }
538 
539 }  // namespace
540 
MakePythonCallbackUsingHostSendAndRecv(pybind11::function callable,absl::Span<Shape const> operand_shapes,absl::Span<Shape const> result_shapes,absl::Span<uint16_t const> send_channel_ids,absl::Span<uint16_t const> recv_channel_ids)541 StatusOr<pybind11::object> PyClient::MakePythonCallbackUsingHostSendAndRecv(
542     pybind11::function callable, absl::Span<Shape const> operand_shapes,
543     absl::Span<Shape const> result_shapes,
544     absl::Span<uint16_t const> send_channel_ids,
545     absl::Span<uint16_t const> recv_channel_ids) {
546   static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
547                 "Expected 64-bit pointers");
548 
549   TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes));
550   TF_ASSIGN_OR_RETURN(auto callback_results,
551                       CreateCallbackResults(result_shapes));
552 
553   auto callback = std::make_shared<CpuCallback>(
554       std::move(callable), callback_args, callback_results);
555 
556   auto* host_callback = new HostCallback;
557 
558   auto assign_arg_info = [](absl::Span<Shape const> shapes,
559                             absl::Span<uint16_t const> channel_ids,
560                             std::vector<HostCallbackArgInfo>& arg_infos) {
561     DCHECK_EQ(shapes.size(), channel_ids.size());
562     arg_infos.reserve(shapes.size());
563     for (int i = 0; i < shapes.size(); ++i) {
564       HostCallbackArgInfo host_callback_arg_info;
565       host_callback_arg_info.channel_id = channel_ids[i];
566       const auto& shape = shapes[i];
567       Shape layout =
568           (shape.has_layout() ? shape
569                               : LayoutUtil::GetWithDefaultLayout(shape));
570       host_callback_arg_info.shape = layout;
571       arg_infos.push_back(std::move(host_callback_arg_info));
572     }
573   };
574 
575   assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands);
576   assign_arg_info(result_shapes, recv_channel_ids, host_callback->results);
577 
578   host_callback->callback = [callback = std::move(callback)](void** outputs,
579                                                              void** inputs) {
580     return callback->PrepareAndCall(outputs, inputs);
581   };
582 
583   py::capsule callback_capsule(
584       host_callback, [](void* ptr) { delete static_cast<HostCallback*>(ptr); });
585 
586   return callback_capsule;
587 }
588 
589 StatusOr<std::pair<uint64_t, pybind11::object>>
GetEmitPythonCallbackDescriptor(pybind11::function callable,absl::Span<Shape const> operand_shapes,absl::Span<Shape const> result_shapes)590 PyClient::GetEmitPythonCallbackDescriptor(
591     pybind11::function callable, absl::Span<Shape const> operand_shapes,
592     absl::Span<Shape const> result_shapes) {
593   PjRtPlatformId platform_id = pjrt_client_->platform_id();
594   if (platform_id != GpuId() && platform_id != CpuId()) {
595     return Unimplemented(
596         "EmitPythonCallback is only implemented on CPU and GPU");
597   }
598 
599   static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
600                 "Expected 64-bit pointers");
601 
602   TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes));
603   TF_ASSIGN_OR_RETURN(auto callback_results,
604                       CreateCallbackResults(result_shapes));
605 
606   auto callback = std::make_unique<CpuCallback>(
607       std::move(callable), callback_args, callback_results);
608   uint64_t descriptor = absl::bit_cast<std::uint64_t>(callback.get());
609 
610   py::capsule callback_capsule(callback.release(), [](void* ptr) {
611     delete reinterpret_cast<CpuCallback*>(ptr);
612   });
613   return std::make_pair(descriptor, py::object(std::move(callback_capsule)));
614 }
615 
EmitPythonCallbackFromDescriptor(XlaBuilder & builder,uint64_t descriptor,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,std::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)616 StatusOr<XlaOp> PyClient::EmitPythonCallbackFromDescriptor(
617     XlaBuilder& builder, uint64_t descriptor, absl::Span<XlaOp const> operands,
618     absl::Span<Shape const> result_shapes,
619     std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
620   std::vector<Shape> custom_call_arg_layouts(operands.size() + 1);
621   custom_call_arg_layouts[0] =
622       ShapeUtil::MakeShapeWithDescendingLayout(U64, {});
623   std::vector<XlaOp> custom_call_args(operands.size() + 1);
624   custom_call_args[0] = ConstantR0<std::uint64_t>(&builder, descriptor);
625   absl::c_copy(operands, custom_call_args.begin() + 1);
626 
627   if (operand_layouts && operand_layouts->size() != operands.size()) {
628     return InvalidArgument(
629         "Mismatched number of operands (%d) and operand_layouts (%d)",
630         operands.size(), operand_layouts->size());
631   }
632 
633   for (int i = 0; i < operands.size(); ++i) {
634     TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
635     Shape layout = LayoutUtil::GetWithDefaultLayout(shape);
636     if (shape.IsArray() && operand_layouts) {
637       if (!(*operand_layouts)[i].has_layout()) {
638         return InvalidArgument(
639             "operand_layout shapes for callback must have "
640             "layouts, got %s",
641             (*operand_layouts)[i].ToString(/*print_layout=*/true));
642       }
643       if (!ShapeUtil::Compatible(shape, (*operand_layouts)[i])) {
644         return InvalidArgument(
645             "Incompatible shapes for Python callback argument %d: %s vs %s", i,
646             shape.ToString(),
647             (*operand_layouts)[i].ToString(/*print_layout=*/true));
648       }
649       layout = (*operand_layouts)[i];
650     }
651     custom_call_arg_layouts[i + 1] = layout;
652   }
653 
654   std::vector<Shape> result_shapes_with_layout(result_shapes.size());
655   for (int i = 0; i < result_shapes.size(); ++i) {
656     if (result_shapes[i].IsArray()) {
657       result_shapes_with_layout[i] =
658           result_shapes[i].has_layout()
659               ? result_shapes[i]
660               : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
661     } else if (result_shapes[i].IsToken()) {
662       result_shapes_with_layout[i] = result_shapes[i];
663     } else {
664       return InvalidArgument(
665           "Only array and token return values from Python callbacks are "
666           "supported, got %s",
667           result_shapes[i].ToString());
668     }
669   }
670   custom_call_args[0] = ConstantR0<std::uint64_t>(&builder, descriptor);
671   Shape result_shape = ShapeUtil::MakeTupleShape(result_shapes_with_layout);
672   std::string callback_str = std::to_string(descriptor);
673   std::string callback_name = "xla_python_cpu_callback";
674   if (pjrt_client_->platform_id() == GpuId()) {
675     callback_name = "xla_python_gpu_callback";
676   }
677   XlaOp result =
678       CustomCallWithLayout(&builder, callback_name, custom_call_args,
679                            result_shape, custom_call_arg_layouts,
680                            /*opaque=*/callback_str.data(), has_side_effect,
681                            /*output_operand_aliasing=*/{},
682                            /*literal=*/nullptr,
683                            /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
684                            /*api_version=*/API_VERSION_STATUS_RETURNING);
685   return result;
686 }
687 
EmitPythonCallback(pybind11::function callable,XlaBuilder & builder,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,std::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)688 StatusOr<std::pair<XlaOp, pybind11::object>> PyClient::EmitPythonCallback(
689     pybind11::function callable, XlaBuilder& builder,
690     absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
691     std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
692   std::vector<Shape> operand_shapes(operands.size());
693   for (int i = 0; i < operands.size(); ++i) {
694     TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
695     operand_shapes[i] =
696         (operand_layouts ? (*operand_layouts)[i]
697                          : LayoutUtil::GetWithDefaultLayout(shape));
698   }
699   StatusOr<std::pair<uint64_t, pybind11::object>> result_sor =
700       GetEmitPythonCallbackDescriptor(callable, operand_shapes, result_shapes);
701   TF_ASSIGN_OR_RETURN(auto result, result_sor);
702   uint64_t descriptor = result.first;
703   pybind11::object keepalive = result.second;
704   TF_ASSIGN_OR_RETURN(XlaOp callback_op,
705                       EmitPythonCallbackFromDescriptor(
706                           builder, descriptor, operands, result_shapes,
707                           operand_shapes, has_side_effect));
708   return std::make_pair(callback_op, keepalive);
709 }
710 
711 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
712                                              &XlaPythonCpuCallback);
713 
714 #if TENSORFLOW_USE_ROCM
715 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_gpu_callback",
716                                          &XlaPythonGpuCallback, "ROCM");
717 #elif defined(GOOGLE_CUDA)
718 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_gpu_callback",
719                                          &XlaPythonGpuCallback, "CUDA");
720 #endif  // TENSORFLOW_USE_ROCM
721 
722 }  // namespace xla
723