1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_client.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/base/casts.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/numbers.h"
25 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
26 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
28 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
29 #include "tensorflow/compiler/xla/python/callback.h"
30 #include "tensorflow/compiler/xla/python/exceptions.h"
31 #include "tensorflow/compiler/xla/python/pprof_profile_builder.h"
32 #include "tensorflow/compiler/xla/python/py_buffer.h"
33 #include "tensorflow/compiler/xla/python/py_executable.h"
34 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
35 #include "tensorflow/compiler/xla/python/traceback.h"
36 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
37 #include "tensorflow/compiler/xla/python/types.h"
38 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
39 #include "tensorflow/core/platform/statusor.h"
40
41 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 #include "tensorflow/compiler/xla/python/py_client_gpu.h"
43 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
44
45 namespace xla {
46
47 namespace py = pybind11;
48
PyClient(std::unique_ptr<PjRtClient> pjrt_client)49 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
50 : PyClient(std::shared_ptr<PjRtClient>(std::move(pjrt_client))) {}
51
PyClient(std::shared_ptr<PjRtClient> pjrt_client)52 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
53 : pjrt_client_(std::move(pjrt_client)) {
54 CHECK(pjrt_client_ != nullptr);
55 buffers_.resize(pjrt_client_->device_count());
56 for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
57 if (device->id() >= buffers_.size()) {
58 buffers_.resize(device->id() + 1);
59 }
60 }
61 }
62
~PyClient()63 PyClient::~PyClient() {
64 py::gil_scoped_release gil;
65 pjrt_client_ = nullptr;
66 }
67
Devices()68 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
69 std::vector<ClientAndPtr<PjRtDevice>> devices;
70 auto span = pjrt_client_->devices();
71 devices.reserve(span.size());
72 for (PjRtDevice* device : span) {
73 devices.push_back(WrapWithClient(shared_from_this(), device));
74 }
75 return devices;
76 }
77
LocalDevices()78 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
79 std::vector<ClientAndPtr<PjRtDevice>> devices;
80 devices.reserve(pjrt_client_->addressable_devices().size());
81 for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
82 devices.push_back(WrapWithClient(shared_from_this(), device));
83 }
84 return devices;
85 }
86
LiveBuffers()87 std::vector<py::object> PyClient::LiveBuffers() {
88 CHECK(PyGILState_Check());
89 std::vector<py::object> buffers;
90 for (PyBuffer* device_buffers : buffers_) {
91 for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
92 if (!buffer->is_deleted()) {
93 buffers.push_back(
94 py::reinterpret_borrow<py::object>(buffer->AsHandle()));
95 }
96 }
97 }
98 return buffers;
99 }
100
LiveBuffersOnDevice(PjRtDevice * device)101 std::vector<py::object> PyClient::LiveBuffersOnDevice(PjRtDevice* device) {
102 CHECK_EQ(device->client(), pjrt_client());
103 CHECK(PyGILState_Check());
104 std::vector<py::object> buffers;
105 for (PyBuffer* buffer = buffers_[device->id()]; buffer;
106 buffer = buffer->next_) {
107 if (!buffer->is_deleted()) {
108 buffers.push_back(py::reinterpret_borrow<py::object>(buffer->AsHandle()));
109 }
110 }
111 return buffers;
112 }
113
LiveExecutables()114 std::vector<std::shared_ptr<PyExecutable>> PyClient::LiveExecutables() {
115 CHECK(PyGILState_Check());
116 std::vector<std::shared_ptr<PyExecutable>> executables;
117 for (PyExecutable* exec = executables_; exec; exec = exec->next_) {
118 if (!exec->is_deleted()) {
119 executables.push_back(exec->shared_from_this());
120 }
121 }
122 return executables;
123 }
124
Defragment()125 Status PyClient::Defragment() {
126 CHECK(PyGILState_Check());
127 switch (pjrt_client_->runtime_type()) {
128 case PjRtRuntimeType::kTfrt:
129 return pjrt_client_->Defragment();
130 case PjRtRuntimeType::kStreamExecutor:
131 struct TmpBuffer {
132 PyBuffer* py_buffer;
133 // TODO(skyewm): maybe use py_buffer's HostValue
134 std::shared_ptr<Literal> host_copy;
135 };
136
137 // Synchronously copy all buffers to host
138 std::vector<TmpBuffer> tmp_buffers;
139 for (PyBuffer* device_buffers : buffers_) {
140 for (PyBuffer* buffer = device_buffers; buffer;
141 buffer = buffer->next_) {
142 if (!buffer->is_deleted()) {
143 TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal,
144 buffer->buffer_->ToLiteralSync());
145 tmp_buffers.push_back({buffer, literal});
146 }
147 }
148 }
149
150 // All buffers successfully copied to host, delete on-device copies.
151 //
152 // Use blocking delete operation to ensure all memory is actually cleared
153 // before we start rewriting buffers.
154 //
155 // Die instead of returning a bad status because program presumably can't
156 // continue if we fail to reconstitute device buffers.
157 for (TmpBuffer& tmp_buffer : tmp_buffers) {
158 TF_CHECK_OK(tensorflow::down_cast<PjRtStreamExecutorBuffer*>(
159 tmp_buffer.py_buffer->buffer_.get())
160 ->Release(/*wait_for_operations_to_complete=*/true)
161 .status());
162 }
163
164 // Copy host copies back to device and update PyBuffers in-place.
165 for (TmpBuffer& tmp_buffer : tmp_buffers) {
166 std::unique_ptr<PjRtBuffer> new_copy =
167 pjrt_client_
168 ->BufferFromHostLiteral(*tmp_buffer.host_copy,
169 tmp_buffer.py_buffer->buffer_->device())
170 .ValueOrDie();
171 TF_CHECK_OK(new_copy->BlockHostUntilReady());
172 tmp_buffer.py_buffer->buffer_.reset(new_copy.release());
173 }
174
175 // TODO(skyewm): delete executables?
176 }
177 return OkStatus();
178 }
179
180 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)181 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
182 TF_ASSIGN_OR_RETURN(
183 DeviceAssignment device_assignment,
184 pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
185 std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
186 result.resize(num_replicas);
187 for (int r = 0; r < num_replicas; ++r) {
188 result[r].resize(num_partitions);
189 for (int p = 0; p < num_partitions; ++p) {
190 int device_id = device_assignment(r, p);
191 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
192 pjrt_client_->LookupDevice(device_id));
193 result[r][p] = WrapWithClient(shared_from_this(), device);
194 }
195 }
196 return result;
197 }
198
199 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)200 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
201 TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
202 pjrt_client_->GetDefaultDeviceAssignment(
203 num_replicas, /*num_partitions=*/1));
204 std::vector<ClientAndPtr<PjRtDevice>> result;
205 for (int i = 0; i < num_replicas; ++i) {
206 int device_id = device_assignment(i, 0);
207 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
208 pjrt_client_->LookupDevice(device_id));
209 result.push_back(WrapWithClient(shared_from_this(), device));
210 }
211 return result;
212 }
213
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)214 StatusOr<py::object> PyClient::BufferFromPyval(
215 pybind11::handle argument, PjRtDevice* device, bool force_copy,
216 PjRtClient::HostBufferSemantics host_buffer_semantics) {
217 if (device == nullptr) {
218 TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
219 device = pjrt_client_->addressable_devices().front();
220 }
221 CHECK(device != nullptr);
222
223 auto transfer_guard_formatter = [&argument, dst_device = device] {
224 auto type = py::cast<std::string>(py::str(argument.get_type()));
225 // Catch exceptions because shape and dtype properties convertible to str
226 // are not guaranteed to present in an arbitrary argument.
227 std::string shape;
228 std::string dtype;
229 try {
230 shape = py::cast<std::string>(py::str(argument.attr("shape")));
231 } catch (const std::exception& e) {
232 shape = "<unknown>";
233 }
234 try {
235 dtype = py::cast<std::string>(py::str(argument.attr("dtype")));
236 } catch (const std::exception& e) {
237 dtype = "<unknown>";
238 }
239 return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype,
240 ", dst_device=", dst_device->DebugString());
241 };
242 TF_RETURN_IF_ERROR(
243 jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter));
244
245 TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
246 pjrt_client_->LookupDevice(device->id()));
247 if (found_device != device) {
248 return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
249 device->DebugString(),
250 pjrt_client_->platform_name());
251 }
252 GlobalPyRefManager()->CollectGarbage();
253
254 DevicePutOptions options;
255 options.squash_64bit_types = false;
256 options.allow_zero_copy =
257 (!force_copy &&
258 (host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy));
259 TF_ASSIGN_OR_RETURN(DevicePutResult put,
260 DevicePut(argument, device, options));
261
262 if (put.owned_buffer) {
263 auto traceback = Traceback::Get();
264 return PyBuffer::Make(shared_from_this(), std::move(put.owned_buffer),
265 std::move(traceback));
266 } else {
267 return py::reinterpret_borrow<py::object>(put.owning_pybuffer);
268 }
269 }
270
271 StatusOr<std::vector<std::pair<pybind11::bytes, pybind11::object>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device)272 PyClient::MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
273 PjRtDevice* device) {
274 CHECK(device != nullptr);
275 absl::Mutex mu;
276 StatusOr<std::vector<PjRtCrossHostRecvDescriptors>> recv_descriptors_or;
277 bool done = false;
278
279 TF_ASSIGN_OR_RETURN(
280 auto buffers, pjrt_client_->MakeCrossHostReceiveBuffers(
281 shapes, device,
282 [&done, &recv_descriptors_or,
283 &mu](StatusOr<PjRtCrossHostRecvState> recv_state_or) {
284 absl::MutexLock l(&mu);
285 if (recv_state_or.ok()) {
286 py::gil_scoped_acquire gil;
287 recv_descriptors_or =
288 std::move(recv_state_or->descriptors);
289 } else {
290 recv_descriptors_or = recv_state_or.status();
291 }
292 done = true;
293 }));
294
295 {
296 py::gil_scoped_release gil_release;
297 absl::MutexLock l(&mu);
298 mu.Await(absl::Condition(&done));
299 }
300
301 TF_RETURN_IF_ERROR(recv_descriptors_or.status());
302 CHECK_EQ(buffers.size(), recv_descriptors_or->size());
303 std::vector<std::pair<pybind11::bytes, pybind11::object>> result;
304 result.reserve(buffers.size());
305 for (int i = 0; i < buffers.size(); ++i) {
306 auto& descriptors = recv_descriptors_or->at(i);
307 CHECK_EQ(descriptors.serialized_descriptors.size(), 1);
308 const std::string& desc = descriptors.serialized_descriptors[0];
309 pybind11::bytes py_desc = pybind11::bytes(desc);
310 auto traceback = Traceback::Get();
311 auto py_buf =
312 PyBuffer::Make(shared_from_this(), std::move(buffers[i]), traceback);
313 result.push_back(std::make_pair(std::move(py_desc), std::move(py_buf)));
314 }
315 return result;
316 }
317
Compile(const XlaComputation & computation,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)318 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
319 const XlaComputation& computation, CompileOptions options,
320 std::vector<pybind11::capsule> host_callbacks) {
321 std::unique_ptr<PjRtLoadedExecutable> executable;
322 std::optional<std::string> fingerprint;
323 {
324 py::gil_scoped_release gil_release;
325 TF_ASSIGN_OR_RETURN(executable,
326 pjrt_client_->Compile(computation, std::move(options)));
327 TF_ASSIGN_OR_RETURN(fingerprint,
328 pjrt_client_->ExecutableFingerprint(*executable));
329 }
330 auto traceback = Traceback::Get();
331 return std::make_shared<PyExecutable>(
332 shared_from_this(), std::move(executable), std::move(traceback),
333 std::move(fingerprint), std::move(host_callbacks));
334 }
335
CompileMlir(std::string mlir_module,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)336 StatusOr<std::shared_ptr<PyExecutable>> PyClient::CompileMlir(
337 std::string mlir_module, CompileOptions options,
338 std::vector<pybind11::capsule> host_callbacks) {
339 std::unique_ptr<PjRtLoadedExecutable> executable;
340 std::optional<std::string> fingerprint;
341 {
342 py::gil_scoped_release gil_release;
343 mlir::MLIRContext context;
344 TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
345 ParseMlirModuleString(mlir_module, context));
346 TF_ASSIGN_OR_RETURN(
347 executable, pjrt_client_->Compile(module.get(), std::move(options)));
348 TF_ASSIGN_OR_RETURN(fingerprint,
349 pjrt_client_->ExecutableFingerprint(*executable));
350 }
351 auto traceback = Traceback::Get();
352 return std::make_shared<PyExecutable>(
353 shared_from_this(), std::move(executable), std::move(traceback),
354 std::move(fingerprint), std::move(host_callbacks));
355 }
356
SerializeExecutable(const PyExecutable & executable) const357 StatusOr<py::bytes> PyClient::SerializeExecutable(
358 const PyExecutable& executable) const {
359 return pjrt_client_->SerializeExecutable(executable.pjrt_executable());
360 }
361
DeserializeExecutable(const std::string & serialized,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)362 StatusOr<std::shared_ptr<PyExecutable>> PyClient::DeserializeExecutable(
363 const std::string& serialized, CompileOptions options,
364 std::vector<pybind11::capsule> host_callbacks) {
365 std::unique_ptr<PjRtLoadedExecutable> executable;
366 std::optional<std::string> fingerprint;
367 {
368 py::gil_scoped_release gil_release;
369 TF_ASSIGN_OR_RETURN(executable, pjrt_client_->DeserializeExecutable(
370 serialized, std::move(options)));
371 TF_ASSIGN_OR_RETURN(fingerprint,
372 pjrt_client_->ExecutableFingerprint(*executable));
373 }
374 auto traceback = Traceback::Get();
375 return std::make_shared<PyExecutable>(
376 shared_from_this(), std::move(executable), std::move(traceback),
377 std::move(fingerprint), std::move(host_callbacks));
378 }
379
380 namespace {
381
382 struct HeapProfileKey {
383 Traceback* traceback;
384 int64_t size;
385 PjRtDevice* device;
386 bool operator==(const HeapProfileKey& other) const;
387 };
388
operator ==(const HeapProfileKey & other) const389 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
390 if (size != other.size || device != other.device) {
391 return false;
392 }
393 if ((traceback == nullptr) != (other.traceback == nullptr)) {
394 return false;
395 }
396 if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
397 return false;
398 }
399 return true;
400 }
401
402 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)403 H AbslHashValue(H h, const HeapProfileKey& key) {
404 if (key.traceback) {
405 h = H::combine(std::move(h), key.traceback->raw_frames());
406 }
407 h = H::combine(std::move(h), key.size, key.device);
408 return h;
409 }
410
411 } // namespace
412
HeapProfile()413 StatusOr<py::bytes> PyClient::HeapProfile() {
414 CHECK(PyGILState_Check());
415 absl::flat_hash_set<PjRtBuffer*> buffer_set;
416 absl::flat_hash_map<HeapProfileKey, int64_t> entries;
417 for (PyBuffer* device_buffers : buffers_) {
418 for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
419 // We only wish to count each PjRtBuffer once, even though they may be
420 // shared by multiple PyBuffers.
421 if (!buffer->is_deleted() && buffer_set.insert(buffer->buffer()).second) {
422 TF_ASSIGN_OR_RETURN(size_t size,
423 buffer->buffer()->GetOnDeviceSizeInBytes());
424 HeapProfileKey key{buffer->traceback().get(),
425 static_cast<int64_t>(size),
426 buffer->buffer()->device()};
427 ++entries[key];
428 }
429 }
430 }
431
432 for (PyExecutable* executable = executables_; executable;
433 executable = executable->next_) {
434 if (!executable->is_deleted()) {
435 HeapProfileKey key{executable->traceback(),
436 executable->SizeOfGeneratedCodeInBytes(), nullptr};
437 ++entries[key];
438 }
439 }
440
441 PprofProfileBuilder builder;
442 auto* allocations = builder.profile().add_sample_type();
443 allocations->set_type(builder.StringId("allocations"));
444 allocations->set_unit(builder.StringId("count"));
445 auto* space = builder.profile().add_sample_type();
446 space->set_type(builder.StringId("space"));
447 space->set_unit(builder.StringId("bytes"));
448
449 const int kind_string_id = builder.StringId("kind");
450 const int buffer_string_id = builder.StringId("buffer");
451 const int executable_string_id = builder.StringId("executable");
452 const int device_string_id = builder.StringId("device");
453 for (const auto& entry : entries) {
454 auto* sample = builder.profile().add_sample();
455 if (entry.first.traceback) {
456 for (const auto& frame : entry.first.traceback->raw_frames()) {
457 sample->add_location_id(builder.LocationId(frame.first, frame.second));
458 }
459 }
460 sample->add_value(entry.second);
461 sample->add_value(entry.first.size * entry.second);
462
463 auto* kind_label = sample->add_label();
464 kind_label->set_key(kind_string_id);
465 if (entry.first.device) {
466 kind_label->set_str(buffer_string_id);
467 auto* device_label = sample->add_label();
468 device_label->set_key(device_string_id);
469 device_label->set_str(
470 builder.StringId(std::string(entry.first.device->DebugString())));
471 } else {
472 kind_label->set_str(executable_string_id);
473 }
474 }
475 return py::bytes(builder.profile().SerializeAsString());
476 }
477
478 namespace {
479
CreateCallbackArgs(absl::Span<Shape const> operand_shapes)480 StatusOr<std::vector<CpuCallback::Arg>> CreateCallbackArgs(
481 absl::Span<Shape const> operand_shapes) {
482 std::vector<CpuCallback::Arg> callback_args(operand_shapes.size());
483 for (int i = 0; i < operand_shapes.size(); ++i) {
484 Shape shape = operand_shapes[i];
485
486 if (shape.IsArray()) {
487 Shape layout =
488 (shape.has_layout() ? shape
489 : LayoutUtil::GetWithDefaultLayout(shape));
490 callback_args[i].dims.resize(shape.dimensions_size());
491 absl::c_copy(shape.dimensions(), callback_args[i].dims.begin());
492 callback_args[i].strides = ByteStridesForShape(layout);
493 callback_args[i].type = shape.element_type();
494 callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout);
495 TF_ASSIGN_OR_RETURN(callback_args[i].dtype,
496 PrimitiveTypeToDtype(shape.element_type()));
497 } else if (shape.IsToken()) {
498 callback_args[i].type = TOKEN;
499 } else {
500 return InvalidArgument(
501 "Only array and token arguments to Python callbacks are supported, "
502 "got %s",
503 shape.ToString());
504 }
505 }
506 return callback_args;
507 }
508
CreateCallbackResults(absl::Span<Shape const> result_shapes)509 StatusOr<std::vector<CpuCallback::Result>> CreateCallbackResults(
510 absl::Span<Shape const> result_shapes) {
511 std::vector<CpuCallback::Result> callback_results(result_shapes.size());
512 for (int i = 0; i < result_shapes.size(); ++i) {
513 if (result_shapes[i].IsArray()) {
514 const Shape& shape =
515 result_shapes[i].has_layout()
516 ? result_shapes[i]
517 : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
518 callback_results[i].expected_dims.resize(shape.dimensions_size());
519 absl::c_copy(shape.dimensions(),
520 callback_results[i].expected_dims.begin());
521 callback_results[i].expected_strides = ByteStridesForShapeInt64(shape);
522 callback_results[i].type = shape.element_type();
523 callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape);
524 callback_results[i].reversed_layout.resize(shape.dimensions_size());
525 absl::c_reverse_copy(shape.layout().minor_to_major(),
526 callback_results[i].reversed_layout.begin());
527 } else if (result_shapes[i].IsToken()) {
528 callback_results[i].type = TOKEN;
529 } else {
530 return InvalidArgument(
531 "Only array and token return values from Python callbacks are "
532 "supported, got %s",
533 result_shapes[i].ToString());
534 }
535 }
536 return callback_results;
537 }
538
539 } // namespace
540
MakePythonCallbackUsingHostSendAndRecv(pybind11::function callable,absl::Span<Shape const> operand_shapes,absl::Span<Shape const> result_shapes,absl::Span<uint16_t const> send_channel_ids,absl::Span<uint16_t const> recv_channel_ids)541 StatusOr<pybind11::object> PyClient::MakePythonCallbackUsingHostSendAndRecv(
542 pybind11::function callable, absl::Span<Shape const> operand_shapes,
543 absl::Span<Shape const> result_shapes,
544 absl::Span<uint16_t const> send_channel_ids,
545 absl::Span<uint16_t const> recv_channel_ids) {
546 static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
547 "Expected 64-bit pointers");
548
549 TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes));
550 TF_ASSIGN_OR_RETURN(auto callback_results,
551 CreateCallbackResults(result_shapes));
552
553 auto callback = std::make_shared<CpuCallback>(
554 std::move(callable), callback_args, callback_results);
555
556 auto* host_callback = new HostCallback;
557
558 auto assign_arg_info = [](absl::Span<Shape const> shapes,
559 absl::Span<uint16_t const> channel_ids,
560 std::vector<HostCallbackArgInfo>& arg_infos) {
561 DCHECK_EQ(shapes.size(), channel_ids.size());
562 arg_infos.reserve(shapes.size());
563 for (int i = 0; i < shapes.size(); ++i) {
564 HostCallbackArgInfo host_callback_arg_info;
565 host_callback_arg_info.channel_id = channel_ids[i];
566 const auto& shape = shapes[i];
567 Shape layout =
568 (shape.has_layout() ? shape
569 : LayoutUtil::GetWithDefaultLayout(shape));
570 host_callback_arg_info.shape = layout;
571 arg_infos.push_back(std::move(host_callback_arg_info));
572 }
573 };
574
575 assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands);
576 assign_arg_info(result_shapes, recv_channel_ids, host_callback->results);
577
578 host_callback->callback = [callback = std::move(callback)](void** outputs,
579 void** inputs) {
580 return callback->PrepareAndCall(outputs, inputs);
581 };
582
583 py::capsule callback_capsule(
584 host_callback, [](void* ptr) { delete static_cast<HostCallback*>(ptr); });
585
586 return callback_capsule;
587 }
588
589 StatusOr<std::pair<uint64_t, pybind11::object>>
GetEmitPythonCallbackDescriptor(pybind11::function callable,absl::Span<Shape const> operand_shapes,absl::Span<Shape const> result_shapes)590 PyClient::GetEmitPythonCallbackDescriptor(
591 pybind11::function callable, absl::Span<Shape const> operand_shapes,
592 absl::Span<Shape const> result_shapes) {
593 PjRtPlatformId platform_id = pjrt_client_->platform_id();
594 if (platform_id != GpuId() && platform_id != CpuId()) {
595 return Unimplemented(
596 "EmitPythonCallback is only implemented on CPU and GPU");
597 }
598
599 static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
600 "Expected 64-bit pointers");
601
602 TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes));
603 TF_ASSIGN_OR_RETURN(auto callback_results,
604 CreateCallbackResults(result_shapes));
605
606 auto callback = std::make_unique<CpuCallback>(
607 std::move(callable), callback_args, callback_results);
608 uint64_t descriptor = absl::bit_cast<std::uint64_t>(callback.get());
609
610 py::capsule callback_capsule(callback.release(), [](void* ptr) {
611 delete reinterpret_cast<CpuCallback*>(ptr);
612 });
613 return std::make_pair(descriptor, py::object(std::move(callback_capsule)));
614 }
615
EmitPythonCallbackFromDescriptor(XlaBuilder & builder,uint64_t descriptor,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,std::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)616 StatusOr<XlaOp> PyClient::EmitPythonCallbackFromDescriptor(
617 XlaBuilder& builder, uint64_t descriptor, absl::Span<XlaOp const> operands,
618 absl::Span<Shape const> result_shapes,
619 std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
620 std::vector<Shape> custom_call_arg_layouts(operands.size() + 1);
621 custom_call_arg_layouts[0] =
622 ShapeUtil::MakeShapeWithDescendingLayout(U64, {});
623 std::vector<XlaOp> custom_call_args(operands.size() + 1);
624 custom_call_args[0] = ConstantR0<std::uint64_t>(&builder, descriptor);
625 absl::c_copy(operands, custom_call_args.begin() + 1);
626
627 if (operand_layouts && operand_layouts->size() != operands.size()) {
628 return InvalidArgument(
629 "Mismatched number of operands (%d) and operand_layouts (%d)",
630 operands.size(), operand_layouts->size());
631 }
632
633 for (int i = 0; i < operands.size(); ++i) {
634 TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
635 Shape layout = LayoutUtil::GetWithDefaultLayout(shape);
636 if (shape.IsArray() && operand_layouts) {
637 if (!(*operand_layouts)[i].has_layout()) {
638 return InvalidArgument(
639 "operand_layout shapes for callback must have "
640 "layouts, got %s",
641 (*operand_layouts)[i].ToString(/*print_layout=*/true));
642 }
643 if (!ShapeUtil::Compatible(shape, (*operand_layouts)[i])) {
644 return InvalidArgument(
645 "Incompatible shapes for Python callback argument %d: %s vs %s", i,
646 shape.ToString(),
647 (*operand_layouts)[i].ToString(/*print_layout=*/true));
648 }
649 layout = (*operand_layouts)[i];
650 }
651 custom_call_arg_layouts[i + 1] = layout;
652 }
653
654 std::vector<Shape> result_shapes_with_layout(result_shapes.size());
655 for (int i = 0; i < result_shapes.size(); ++i) {
656 if (result_shapes[i].IsArray()) {
657 result_shapes_with_layout[i] =
658 result_shapes[i].has_layout()
659 ? result_shapes[i]
660 : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
661 } else if (result_shapes[i].IsToken()) {
662 result_shapes_with_layout[i] = result_shapes[i];
663 } else {
664 return InvalidArgument(
665 "Only array and token return values from Python callbacks are "
666 "supported, got %s",
667 result_shapes[i].ToString());
668 }
669 }
670 custom_call_args[0] = ConstantR0<std::uint64_t>(&builder, descriptor);
671 Shape result_shape = ShapeUtil::MakeTupleShape(result_shapes_with_layout);
672 std::string callback_str = std::to_string(descriptor);
673 std::string callback_name = "xla_python_cpu_callback";
674 if (pjrt_client_->platform_id() == GpuId()) {
675 callback_name = "xla_python_gpu_callback";
676 }
677 XlaOp result =
678 CustomCallWithLayout(&builder, callback_name, custom_call_args,
679 result_shape, custom_call_arg_layouts,
680 /*opaque=*/callback_str.data(), has_side_effect,
681 /*output_operand_aliasing=*/{},
682 /*literal=*/nullptr,
683 /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
684 /*api_version=*/API_VERSION_STATUS_RETURNING);
685 return result;
686 }
687
EmitPythonCallback(pybind11::function callable,XlaBuilder & builder,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,std::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)688 StatusOr<std::pair<XlaOp, pybind11::object>> PyClient::EmitPythonCallback(
689 pybind11::function callable, XlaBuilder& builder,
690 absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
691 std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
692 std::vector<Shape> operand_shapes(operands.size());
693 for (int i = 0; i < operands.size(); ++i) {
694 TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
695 operand_shapes[i] =
696 (operand_layouts ? (*operand_layouts)[i]
697 : LayoutUtil::GetWithDefaultLayout(shape));
698 }
699 StatusOr<std::pair<uint64_t, pybind11::object>> result_sor =
700 GetEmitPythonCallbackDescriptor(callable, operand_shapes, result_shapes);
701 TF_ASSIGN_OR_RETURN(auto result, result_sor);
702 uint64_t descriptor = result.first;
703 pybind11::object keepalive = result.second;
704 TF_ASSIGN_OR_RETURN(XlaOp callback_op,
705 EmitPythonCallbackFromDescriptor(
706 builder, descriptor, operands, result_shapes,
707 operand_shapes, has_side_effect));
708 return std::make_pair(callback_op, keepalive);
709 }
710
711 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
712 &XlaPythonCpuCallback);
713
714 #if TENSORFLOW_USE_ROCM
715 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_gpu_callback",
716 &XlaPythonGpuCallback, "ROCM");
717 #elif defined(GOOGLE_CUDA)
718 XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_gpu_callback",
719 &XlaPythonGpuCallback, "CUDA");
720 #endif // TENSORFLOW_USE_ROCM
721
722 } // namespace xla
723