xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/xla.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <functional>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "pybind11/attr.h"
26 #include "pybind11/cast.h"
27 #include "pybind11/detail/common.h"
28 #include "pybind11/numpy.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/pytypes.h"
31 #include "pybind11/stl_bind.h"
32 #include "tensorflow/compiler/xla/layout_util.h"
33 #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
34 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
35 #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
36 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
37 #include "tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h"
38 #ifdef XLA_PYTHON_ENABLE_GPU
39 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
40 #endif  // XLA_PYTHON_ENABLE_GPU
41 #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
42 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
43 #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
44 #ifdef XLA_PYTHON_ENABLE_TPU
45 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
46 #endif  // XLA_PYTHON_ENABLE_TPU
47 #include "tensorflow/compiler/xla/python/dlpack.h"
48 #include "tensorflow/compiler/xla/python/jax_jit.h"
49 #include "tensorflow/compiler/xla/python/mlir.h"
50 #include "tensorflow/compiler/xla/python/ops.h"
51 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
52 #include "tensorflow/compiler/xla/python/pmap_lib.h"
53 #include "tensorflow/compiler/xla/python/pprof_profile_builder.h"
54 #include "tensorflow/compiler/xla/python/profiler.h"
55 #include "tensorflow/compiler/xla/python/py_buffer.h"
56 #include "tensorflow/compiler/xla/python/py_executable.h"
57 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
58 #include "tensorflow/compiler/xla/python/pytree.h"
59 #include "tensorflow/compiler/xla/python/traceback.h"
60 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
61 #include "tensorflow/compiler/xla/python/types.h"
62 #include "tensorflow/compiler/xla/python/weakref_lru_cache.h"
63 #include "tensorflow/compiler/xla/python/xla_compiler.h"
64 #include "tensorflow/compiler/xla/shape.h"
65 #include "tensorflow/compiler/xla/shape_util.h"
66 #include "tensorflow/compiler/xla/statusor.h"
67 #include "tensorflow/compiler/xla/util.h"
68 #include "tensorflow/python/lib/core/bfloat16.h"
69 
70 // TODO(phawkins): remove host_id properties after JAX is update to avoid them.
71 
72 namespace xla {
73 namespace {
74 
75 namespace py = pybind11;
76 
IsOptimizedBuild()77 bool IsOptimizedBuild() {
78 #if NDEBUG
79   return true;
80 #else
81   return false;
82 #endif  // NDEBUG
83 }
84 
85 }  // namespace
86 
PYBIND11_MODULE(xla_extension,m)87 PYBIND11_MODULE(xla_extension, m) {
88   CHECK(tensorflow::RegisterNumpyBfloat16());
89 
90   // Exceptions
91   py::register_exception<XlaRuntimeError>(m, "XlaRuntimeError",
92                                           PyExc_RuntimeError);
93 
94   // Types
95   py::enum_<PrimitiveType>(m, "PrimitiveType")
96       .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
97       .value("PRED", PRED)
98       .value("S8", S8)
99       .value("S16", S16)
100       .value("S32", S32)
101       .value("S64", S64)
102       .value("U8", U8)
103       .value("U16", U16)
104       .value("U32", U32)
105       .value("U64", U64)
106       .value("F16", F16)
107       .value("BF16", BF16)
108       .value("F32", F32)
109       .value("F64", F64)
110       .value("C64", C64)
111       .value("C128", C128)
112       .value("TUPLE", TUPLE)
113       .value("OPAQUE_TYPE", OPAQUE_TYPE)
114       .value("TOKEN", TOKEN);
115 
116   m.def("bfloat16_dtype",
117         []() { return py::handle(tensorflow::Bfloat16Dtype()); });
118 
119   // Must be before PyClient.compile.
120   BuildXlaCompilerSubmodule(m);
121 
122   py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
123       m, "Device",
124       "A descriptor of an available device.\n\nSubclasses are used to "
125       "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
126       "have additional properties specific to that device type.")
127       .def_property_readonly(
128           "id", &PjRtDevice::id,
129           "Integer ID of this device.\n\nUnique across all available devices "
130           "of this type, including remote devices on multi-host platforms.")
131       .def_property_readonly(
132           "process_index", &PjRtDevice::process_index,
133           "Integer index of this device's process.\n\n"
134           "This is always 0 except on multi-process platforms.")
135       .def_property_readonly("host_id", &PjRtDevice::process_index,
136                              "Deprecated; please use process_index")
137       .def_property_readonly("task_id", &PjRtDevice::process_index,
138                              "Deprecated; please use process_index")
139       .def_property_readonly("platform",
140                              [](const PjRtDevice& device) {
141                                return device.client()->platform_name();
142                              })
143       .def_property_readonly("device_kind", &PjRtDevice::device_kind)
144       .def_property_readonly(
145           "client",
146           [](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
147       .def("__str__", &PjRtDevice::DebugString)
148       .def("__repr__", &PjRtDevice::ToString)
149       .def("transfer_to_infeed",
150            [](PjRtDevice& device, const LiteralSlice& literal) {
151              GlobalPyRefManager()->CollectGarbage();
152              py::gil_scoped_release gil_release;
153              return device.TransferToInfeed(literal);
154            })
155       .def("transfer_from_outfeed",
156            [](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
157              GlobalPyRefManager()->CollectGarbage();
158              std::shared_ptr<Literal> literal;
159              {
160                py::gil_scoped_release gil_release;
161                Shape shape_with_layout = shape;
162                ShapeUtil::ForEachMutableSubshape(
163                    &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
164                      if (!subshape->has_layout()) {
165                        LayoutUtil::SetToDefaultLayout(subshape);
166                      }
167                    });
168                literal = std::make_shared<Literal>(shape_with_layout);
169                TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
170              }
171              return LiteralToPython(std::move(literal));
172            })
173       .def("live_buffers",
174            [](const ClientAndPtr<PjRtDevice>& device) {
175              return device.client->LiveBuffersOnDevice(device.get());
176            })
177       .def(
178           "__getattr__",
179           [](PjRtDevice& device, std::string name) -> py::object {
180             const auto& attrs = device.Attributes();
181             auto it = attrs.find(name);
182             if (it != attrs.end()) {
183               return std::visit([](auto&& v) { return py::cast(v); },
184                                 it->second);
185             }
186             throw py::attribute_error(absl::StrCat("Unknown attribute ", name));
187           });
188 
189   // Local XLA client methods.
190 
191   py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
192       .value("IMMUTABLE_ONLY_DURING_CALL",
193              PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
194       .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
195              PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
196       .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
197 
198   jax::BuildWeakrefLRUCacheAPI(m);
199 
200   py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
201   py_local_client.def_property_readonly("platform", &PyClient::platform_name)
202       .def_property_readonly("platform_version", &PyClient::platform_version)
203       .def_property_readonly("runtime_type", &PyClient::runtime_type)
204       .def("device_count", &PyClient::device_count)
205       .def("local_device_count", &PyClient::addressable_device_count)
206       .def("devices", &PyClient::Devices)
207       .def("local_devices", &PyClient::LocalDevices)
208       .def("live_buffers", &PyClient::LiveBuffers)
209       .def("live_executables", &PyClient::LiveExecutables)
210       .def("process_index", &PyClient::process_index)
211       .def("host_id", &PyClient::process_index)
212       .def("task_id", &PyClient::process_index)
213       .def("get_default_device_assignment",
214            &PyClient::GetDefaultDeviceAssignment)
215       // TODO(skye): delete after all callers can handle 2D output
216       .def("get_default_device_assignment",
217            &PyClient::GetDefaultDeviceAssignment1D)
218       .def("create_channel_handle", &PyClient::CreateChannelHandle)
219       .def("create_device_to_host_channel_handle",
220            &PyClient::CreateDeviceToHostChannelHandle)
221       .def("create_host_to_device_channel_handle",
222            &PyClient::CreateHostToDeviceChannelHandle)
223       .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
224            py::arg("device") = nullptr, py::arg("force_copy") = false,
225            py::arg("host_buffer_semantics") =
226                PjRtClient::HostBufferSemantics::kZeroCopy)
227       .def("make_cross_host_receive_buffers",
228            &PyClient::MakeCrossHostReceiveBuffers, py::arg("shapes"),
229            py::arg("device"))
230       .def("compile", &PyClient::Compile, py::arg("computation"),
231            py::arg("compile_options") = CompileOptions(),
232            py::arg("host_callbacks") = std::vector<py::capsule>())
233       .def("compile", &PyClient::CompileMlir, py::arg("computation"),
234            py::arg("compile_options") = CompileOptions(),
235            py::arg("host_callbacks") = std::vector<py::capsule>())
236       .def("serialize_executable", &PyClient::SerializeExecutable)
237       .def("deserialize_executable",
238            py::overload_cast<const std::string&, CompileOptions,
239                              std::vector<py::capsule>>(
240                &PyClient::DeserializeExecutable),
241            py::arg("serialized"), py::arg("compile_options"),
242            py::arg("host_callbacks") = std::vector<py::capsule>())
243       // TODO(skyewm): remove when jax stop providing hlo_module
244       .def("deserialize_executable",
245            py::overload_cast<const std::string&, std::shared_ptr<HloModule>,
246                              CompileOptions, std::vector<py::capsule>>(
247                &PyClient::DeserializeExecutable),
248            py::arg("serialized"), py::arg("hlo_module"),
249            py::arg("compile_options"),
250            py::arg("host_callbacks") = std::vector<py::capsule>())
251       .def("heap_profile", &PyClient::HeapProfile)
252       // TODO(zhangqiaorjc): Experimental.
253       .def("defragment", &PyClient::Defragment)
254       .def("get_emit_python_callback_descriptor",
255            &PyClient::GetEmitPythonCallbackDescriptor, py::arg("callable"),
256            py::arg("operand_shapes"), py::arg("result_shapes") = std::nullopt)
257       .def("make_python_callback_from_host_send_and_recv",
258            &PyClient::MakePythonCallbackUsingHostSendAndRecv,
259            py::arg("callable"), py::arg("operand_shapes"),
260            py::arg("result_shapes"), py::arg("send_channel_ids"),
261            py::arg("recv_channel_ids"))
262       // Deprecated: please use `get_emit_python_callback_descriptor` instead.
263       .def("emit_python_callback", &PyClient::EmitPythonCallback,
264            py::arg("callable"), py::arg("builder"), py::arg("operands"),
265            py::arg("result_shapes"), py::arg("operand_layouts") = std::nullopt,
266            py::arg("has_side_effects") = false);
267 
268   m.def(
269       "get_cpu_client",
270       [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
271         py::gil_scoped_release gil_release;
272         TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
273                             GetCpuClient(asynchronous));
274         return std::make_shared<PyClient>(std::move(client));
275       },
276       py::arg("asynchronous") = true);
277   m.def(
278       "get_tfrt_cpu_client",
279       [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
280         py::gil_scoped_release gil_release;
281         TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
282                             GetTfrtCpuClient(asynchronous));
283         return std::make_shared<PyClient>(std::move(client));
284       },
285       py::arg("asynchronous") = true);
286   m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
287     py::gil_scoped_release gil_release;
288     TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
289                         GetInterpreterClient());
290     return std::make_shared<PyClient>(std::move(client));
291   });
292 
293 #ifdef XLA_PYTHON_ENABLE_GPU
294   py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
295   alloc_config.def(py::init<>())
296       .def_readwrite("kind", &GpuAllocatorConfig::kind)
297       .def_readwrite("memory_fraction", &GpuAllocatorConfig::memory_fraction)
298       .def_readwrite("preallocate", &GpuAllocatorConfig::preallocate);
299   py::enum_<GpuAllocatorConfig::Kind>(alloc_config, "Kind")
300       .value("DEFAULT", GpuAllocatorConfig::Kind::kDefault)
301       .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
302       .value("BFC", GpuAllocatorConfig::Kind::kBFC)
303       .value("CUDA_ASYNC", GpuAllocatorConfig::Kind::kCudaAsync);
304 
305   // TODO(tomhennigan): Remove this types.
306   py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>> gpu_device(
307       m, "GpuDevice");
308   m.def(
309       "get_gpu_client",
310       [](bool asynchronous, const GpuAllocatorConfig& allocator_config,
311          std::shared_ptr<DistributedRuntimeClient> distributed_client,
312          int node_id, std::optional<std::set<int>> allowed_devices,
313          std::optional<std::string> platform_name)
314           -> StatusOr<std::shared_ptr<PyClient>> {
315         py::gil_scoped_release gil_release;
316         TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
317                             GetGpuClient(asynchronous, allocator_config,
318                                          std::move(distributed_client), node_id,
319                                          allowed_devices, platform_name));
320         return std::make_shared<PyClient>(std::move(client));
321       },
322       py::arg("asynchronous") = true,
323       py::arg("allocator_config") = GpuAllocatorConfig(),
324       py::arg("distributed_client") = nullptr, py::arg("node_id") = 0,
325       py::arg("allowed_devices") = std::nullopt,
326       py::arg("platform_name") = std::nullopt);
327 #endif  // XLA_PYTHON_ENABLE_GPU
328 
329 #ifdef XLA_PYTHON_ENABLE_TPU
330   // TODO(tomhennigan): Remove this types.
331   py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>> tpu_device(
332       m, "TpuDevice");
333   m.def(
334       "get_tpu_client",
335       [](int max_inflight_computations) -> StatusOr<std::shared_ptr<PyClient>> {
336         py::gil_scoped_release gil_release;
337         TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
338                             GetTpuClient(max_inflight_computations));
339         return std::make_shared<PyClient>(std::move(client));
340       },
341       py::arg("max_inflight_computations") = 32);
342 #endif  // XLA_PYTHON_ENABLE_TPU
343 
344   TF_CHECK_OK(PyBuffer::RegisterTypes(m));
345 
346   py::class_<CompiledMemoryStats>(m, "CompiledMemoryStats")
347       .def_readwrite("generated_code_size_in_bytes",
348                      &CompiledMemoryStats::generated_code_size_in_bytes)
349       .def_readwrite("argument_size_in_bytes",
350                      &CompiledMemoryStats::argument_size_in_bytes)
351       .def_readwrite("output_size_in_bytes",
352                      &CompiledMemoryStats::output_size_in_bytes)
353       .def_readwrite("alias_size_in_bytes",
354                      &CompiledMemoryStats::alias_size_in_bytes)
355       .def_readwrite("temp_size_in_bytes",
356                      &CompiledMemoryStats::temp_size_in_bytes)
357       .def("__str__", &CompiledMemoryStats::DebugString);
358 
359   py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
360       m, "Executable");
361   executable.def_property_readonly("client", &PyExecutable::client)
362       .def("local_logical_device_ids",
363            [](PyExecutable* exec) {
364              auto span = exec->addressable_device_logical_ids();
365              // Not on dispatch critical path, so ok to have heap allocation.
366              std::vector<std::pair<int, int>> addressable_device_logic_ids;
367              addressable_device_logic_ids.reserve(span.size());
368              for (const auto& logical_device_id : span) {
369                addressable_device_logic_ids.push_back(std::make_pair(
370                    logical_device_id.replica, logical_device_id.partition));
371              }
372            })
373       .def("local_devices", &PyExecutable::AddressableDevices)
374       .def("size_of_generated_code_in_bytes",
375            &PyExecutable::SizeOfGeneratedCodeInBytes)
376       .def("get_compiled_memory_stats", &PyExecutable::GetCompiledMemoryStats)
377       .def("delete", &PyExecutable::Delete)
378       .def("execute", &PyExecutable::Execute, py::arg("arguments"),
379            py::arg("device") = std::nullopt)
380       // TODO(chky): Change execute() to always return token rather than hanving
381       // two API entry points.
382       .def("execute_with_token", &PyExecutable::ExecuteWithToken,
383            py::arg("arguments"), py::arg("device") = std::nullopt)
384       .def("execute_sharded_on_local_devices",
385            &PyExecutable::ExecuteShardedOnLocalDevices, py::arg("arguments"))
386       .def("execute_sharded_on_local_devices_with_tokens",
387            &PyExecutable::ExecuteShardedOnLocalDevicesWithTokens,
388            py::arg("arguments"))
389       .def("hlo_modules", &PyExecutable::HloModules)
390       .def("keep_alive", &PyExecutable::KeepAlive)
391       .def_property_readonly("traceback", &PyExecutable::traceback)
392       .def_property_readonly("fingerprint",
393                              [](PyExecutable* exec) -> py::object {
394                                if (exec->fingerprint().has_value()) {
395                                  return py::bytes(*exec->fingerprint());
396                                } else {
397                                  return py::none();
398                                }
399                              });
400   py::class_<PyToken> token(m, "Token");
401   token.def("block_until_ready", &PyToken::Await);
402   py::class_<PyShardedToken> sharded_token(m, "ShardedToken");
403   sharded_token.def("block_until_ready", &PyShardedToken::Await);
404   sharded_token.def("get_token", &PyShardedToken::GetPyToken);
405 
406   m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor,
407         py::arg("buffer"), py::arg("take_ownership") = true);
408   m.def("dlpack_managed_tensor_to_buffer", DLPackManagedTensorToBuffer,
409         py::arg("dlpack"), py::arg("cpu_backend") = nullptr,
410         py::arg("gpu_backend") = nullptr);
411 
412   BuildProfilerSubmodule(&m);
413   BuildOpsSubmodule(&m);
414   BuildOutfeedReceiverSubmodule(&m);
415   BuildPytreeSubmodule(m);
416   jax::BuildJaxjitSubmodule(m);
417   jax::BuildPmapSubmodule(m);
418   jax::BuildTransferGuardSubmodule(m);
419   BuildTracebackSubmodule(m);
420   BuildMlirSubmodule(m);
421 
422   py::class_<tensorflow::PreemptionSyncManager,
423              std::unique_ptr<tensorflow::PreemptionSyncManager>>
424       preemption_sync_manager(m, "PreemptionSyncManager");
425   preemption_sync_manager
426       .def(
427           "initialize",
428           [](tensorflow::PreemptionSyncManager& manager,
429              DistributedRuntimeClient* client) { manager.Initialize(client); },
430           py::arg("distributed_client"))
431       .def("reached_sync_point",
432            [](tensorflow::PreemptionSyncManager& manager, int step_counter) {
433              return manager.ReachedSyncPoint(step_counter);
434            });
435   m.def("create_preemption_sync_manager",
436         []() { return tensorflow::CreatePreemptionSyncManager(); });
437 
438   py::class_<DistributedRuntimeService,
439              std::unique_ptr<DistributedRuntimeService>>
440       distributed_runtime_service(m, "DistributedRuntimeService");
441   distributed_runtime_service.def("shutdown",
442                                   &DistributedRuntimeService::Shutdown,
443                                   py::call_guard<py::gil_scoped_release>());
444   py::class_<DistributedRuntimeClient,
445              std::shared_ptr<DistributedRuntimeClient>>
446       distributed_runtime_client(m, "DistributedRuntimeClient");
447   distributed_runtime_client
448       .def("connect", &DistributedRuntimeClient::Connect,
449            py::call_guard<py::gil_scoped_release>())
450       .def("shutdown", &DistributedRuntimeClient::Shutdown,
451            py::call_guard<py::gil_scoped_release>())
452       .def(
453           "blocking_key_value_get",
454           [](DistributedRuntimeClient& client, std::string key,
455              int64_t timeout_in_ms) {
456             py::gil_scoped_release gil_release;
457             return client.BlockingKeyValueGet(
458                 key, absl::Milliseconds(timeout_in_ms));
459           },
460           py::arg("key"), py::arg("timeout_in_ms"))
461       .def(
462           "wait_at_barrier",
463           [](DistributedRuntimeClient& client, std::string barrier_id,
464              int64_t timeout_in_ms) {
465             py::gil_scoped_release gil_release;
466             return client.WaitAtBarrier(barrier_id,
467                                         absl::Milliseconds(timeout_in_ms));
468           },
469           py::arg("barrier_id"), py::arg("timeout_in_ms"))
470       .def(
471           "key_value_set",
472           [](DistributedRuntimeClient& client, std::string key,
473              std::string value) {
474             py::gil_scoped_release gil_release;
475             return client.KeyValueSet(key, value);
476           },
477           py::arg("key"), py::arg("value"));
478 
479   m.def(
480       "get_distributed_runtime_service",
481       [](std::string address, int num_nodes, bool use_coordination_service,
482          std::optional<int> heartbeat_interval,
483          std::optional<int> max_missing_heartbeats,
484          std::optional<int> enumerate_devices_timeout,
485          std::optional<int> shutdown_timeout)
486           -> StatusOr<std::unique_ptr<DistributedRuntimeService>> {
487         DistributedRuntimeServiceImpl::Options options;
488         options.num_nodes = num_nodes;
489         if (heartbeat_interval.has_value()) {
490           options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
491         }
492         if (max_missing_heartbeats.has_value()) {
493           options.max_missing_heartbeats = *max_missing_heartbeats;
494         }
495         if (enumerate_devices_timeout.has_value()) {
496           options.enumerate_devices_timeout =
497               absl::Seconds(*enumerate_devices_timeout);
498         }
499         if (shutdown_timeout.has_value()) {
500           options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
501         }
502         TF_ASSIGN_OR_RETURN(std::unique_ptr<DistributedRuntimeService> service,
503                             GetDistributedRuntimeService(
504                                 address, options, use_coordination_service));
505         return service;
506       },
507       py::arg("address"), py::arg("num_nodes"),
508       py::arg("use_coordination_service"), py::kw_only(),
509       py::arg("heartbeat_interval") = std::nullopt,
510       py::arg("max_missing_heartbeats") = std::nullopt,
511       py::arg("enumerate_devices_timeout") = std::nullopt,
512       py::arg("shutdown_timeout") = std::nullopt);
513 
514   m.def(
515       "get_distributed_runtime_client",
516       [](std::string address, int node_id, bool use_coordination_service,
517          std::optional<int> rpc_timeout, std::optional<int> init_timeout,
518          std::optional<int> shutdown_timeout,
519          std::optional<int> heartbeat_interval,
520          std::optional<int> max_missing_heartbeats,
521          std::optional<std::function<void(xla::Status,
522                                           bool coordinator_reported_failure)>>
523              missed_heartbeat_callback,
524          std::optional<bool> shutdown_on_destruction)
525           -> StatusOr<std::shared_ptr<DistributedRuntimeClient>> {
526         DistributedRuntimeClient::Options options;
527         options.node_id = node_id;
528         if (rpc_timeout.has_value()) {
529           options.rpc_timeout = absl::Seconds(*rpc_timeout);
530         }
531         if (init_timeout.has_value()) {
532           options.init_timeout = absl::Seconds(*init_timeout);
533         }
534         if (shutdown_timeout.has_value()) {
535           options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
536         }
537         if (heartbeat_interval.has_value()) {
538           options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
539         }
540         if (max_missing_heartbeats.has_value()) {
541           options.max_missing_heartbeats = *max_missing_heartbeats;
542         }
543         if (missed_heartbeat_callback.has_value()) {
544           options.missed_heartbeat_callback =
545               std::move(*missed_heartbeat_callback);
546         }
547         if (shutdown_on_destruction.has_value()) {
548           options.shutdown_on_destruction = *shutdown_on_destruction;
549         }
550         return GetDistributedRuntimeClient(address, options,
551                                            use_coordination_service);
552       },
553       py::arg("address"), py::arg("node_id"),
554       py::arg("use_coordination_service"), py::kw_only(),
555       py::arg("rpc_timeout") = std::nullopt,
556       py::arg("init_timeout") = std::nullopt,
557       py::arg("shutdown_timeout") = std::nullopt,
558       py::arg("heartbeat_interval") = std::nullopt,
559       py::arg("max_missing_heartbeats") = std::nullopt,
560       py::arg("missed_heartbeat_callback") = std::nullopt,
561       py::arg("shutdown_on_destruction") = std::nullopt);
562 
563   m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
564 
565   m.def("is_optimized_build", &IsOptimizedBuild);
566 
567   m.def("json_to_pprof_profile", &JsonToPprofProfile,
568         "Encodes the JSON representation of a pprof Profile into its binary "
569         "protocol buffer encoding.");
570   m.def("pprof_profile_to_json", &PprofProfileToJson,
571         "Decodes an uncompressed pprof Profile protocol buffer into a JSON "
572         "representation");
573 }  // NOLINT(readability/fn_size)
574 
575 }  // namespace xla
576