1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <functional>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "pybind11/attr.h"
26 #include "pybind11/cast.h"
27 #include "pybind11/detail/common.h"
28 #include "pybind11/numpy.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/pytypes.h"
31 #include "pybind11/stl_bind.h"
32 #include "tensorflow/compiler/xla/layout_util.h"
33 #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
34 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
35 #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
36 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
37 #include "tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h"
38 #ifdef XLA_PYTHON_ENABLE_GPU
39 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
40 #endif // XLA_PYTHON_ENABLE_GPU
41 #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
42 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
43 #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
44 #ifdef XLA_PYTHON_ENABLE_TPU
45 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
46 #endif // XLA_PYTHON_ENABLE_TPU
47 #include "tensorflow/compiler/xla/python/dlpack.h"
48 #include "tensorflow/compiler/xla/python/jax_jit.h"
49 #include "tensorflow/compiler/xla/python/mlir.h"
50 #include "tensorflow/compiler/xla/python/ops.h"
51 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
52 #include "tensorflow/compiler/xla/python/pmap_lib.h"
53 #include "tensorflow/compiler/xla/python/pprof_profile_builder.h"
54 #include "tensorflow/compiler/xla/python/profiler.h"
55 #include "tensorflow/compiler/xla/python/py_buffer.h"
56 #include "tensorflow/compiler/xla/python/py_executable.h"
57 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
58 #include "tensorflow/compiler/xla/python/pytree.h"
59 #include "tensorflow/compiler/xla/python/traceback.h"
60 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
61 #include "tensorflow/compiler/xla/python/types.h"
62 #include "tensorflow/compiler/xla/python/weakref_lru_cache.h"
63 #include "tensorflow/compiler/xla/python/xla_compiler.h"
64 #include "tensorflow/compiler/xla/shape.h"
65 #include "tensorflow/compiler/xla/shape_util.h"
66 #include "tensorflow/compiler/xla/statusor.h"
67 #include "tensorflow/compiler/xla/util.h"
68 #include "tensorflow/python/lib/core/bfloat16.h"
69
70 // TODO(phawkins): remove host_id properties after JAX is update to avoid them.
71
72 namespace xla {
73 namespace {
74
75 namespace py = pybind11;
76
IsOptimizedBuild()77 bool IsOptimizedBuild() {
78 #if NDEBUG
79 return true;
80 #else
81 return false;
82 #endif // NDEBUG
83 }
84
85 } // namespace
86
PYBIND11_MODULE(xla_extension,m)87 PYBIND11_MODULE(xla_extension, m) {
88 CHECK(tensorflow::RegisterNumpyBfloat16());
89
90 // Exceptions
91 py::register_exception<XlaRuntimeError>(m, "XlaRuntimeError",
92 PyExc_RuntimeError);
93
94 // Types
95 py::enum_<PrimitiveType>(m, "PrimitiveType")
96 .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
97 .value("PRED", PRED)
98 .value("S8", S8)
99 .value("S16", S16)
100 .value("S32", S32)
101 .value("S64", S64)
102 .value("U8", U8)
103 .value("U16", U16)
104 .value("U32", U32)
105 .value("U64", U64)
106 .value("F16", F16)
107 .value("BF16", BF16)
108 .value("F32", F32)
109 .value("F64", F64)
110 .value("C64", C64)
111 .value("C128", C128)
112 .value("TUPLE", TUPLE)
113 .value("OPAQUE_TYPE", OPAQUE_TYPE)
114 .value("TOKEN", TOKEN);
115
116 m.def("bfloat16_dtype",
117 []() { return py::handle(tensorflow::Bfloat16Dtype()); });
118
119 // Must be before PyClient.compile.
120 BuildXlaCompilerSubmodule(m);
121
122 py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
123 m, "Device",
124 "A descriptor of an available device.\n\nSubclasses are used to "
125 "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
126 "have additional properties specific to that device type.")
127 .def_property_readonly(
128 "id", &PjRtDevice::id,
129 "Integer ID of this device.\n\nUnique across all available devices "
130 "of this type, including remote devices on multi-host platforms.")
131 .def_property_readonly(
132 "process_index", &PjRtDevice::process_index,
133 "Integer index of this device's process.\n\n"
134 "This is always 0 except on multi-process platforms.")
135 .def_property_readonly("host_id", &PjRtDevice::process_index,
136 "Deprecated; please use process_index")
137 .def_property_readonly("task_id", &PjRtDevice::process_index,
138 "Deprecated; please use process_index")
139 .def_property_readonly("platform",
140 [](const PjRtDevice& device) {
141 return device.client()->platform_name();
142 })
143 .def_property_readonly("device_kind", &PjRtDevice::device_kind)
144 .def_property_readonly(
145 "client",
146 [](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
147 .def("__str__", &PjRtDevice::DebugString)
148 .def("__repr__", &PjRtDevice::ToString)
149 .def("transfer_to_infeed",
150 [](PjRtDevice& device, const LiteralSlice& literal) {
151 GlobalPyRefManager()->CollectGarbage();
152 py::gil_scoped_release gil_release;
153 return device.TransferToInfeed(literal);
154 })
155 .def("transfer_from_outfeed",
156 [](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
157 GlobalPyRefManager()->CollectGarbage();
158 std::shared_ptr<Literal> literal;
159 {
160 py::gil_scoped_release gil_release;
161 Shape shape_with_layout = shape;
162 ShapeUtil::ForEachMutableSubshape(
163 &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
164 if (!subshape->has_layout()) {
165 LayoutUtil::SetToDefaultLayout(subshape);
166 }
167 });
168 literal = std::make_shared<Literal>(shape_with_layout);
169 TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
170 }
171 return LiteralToPython(std::move(literal));
172 })
173 .def("live_buffers",
174 [](const ClientAndPtr<PjRtDevice>& device) {
175 return device.client->LiveBuffersOnDevice(device.get());
176 })
177 .def(
178 "__getattr__",
179 [](PjRtDevice& device, std::string name) -> py::object {
180 const auto& attrs = device.Attributes();
181 auto it = attrs.find(name);
182 if (it != attrs.end()) {
183 return std::visit([](auto&& v) { return py::cast(v); },
184 it->second);
185 }
186 throw py::attribute_error(absl::StrCat("Unknown attribute ", name));
187 });
188
189 // Local XLA client methods.
190
191 py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
192 .value("IMMUTABLE_ONLY_DURING_CALL",
193 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
194 .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
195 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
196 .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
197
198 jax::BuildWeakrefLRUCacheAPI(m);
199
200 py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
201 py_local_client.def_property_readonly("platform", &PyClient::platform_name)
202 .def_property_readonly("platform_version", &PyClient::platform_version)
203 .def_property_readonly("runtime_type", &PyClient::runtime_type)
204 .def("device_count", &PyClient::device_count)
205 .def("local_device_count", &PyClient::addressable_device_count)
206 .def("devices", &PyClient::Devices)
207 .def("local_devices", &PyClient::LocalDevices)
208 .def("live_buffers", &PyClient::LiveBuffers)
209 .def("live_executables", &PyClient::LiveExecutables)
210 .def("process_index", &PyClient::process_index)
211 .def("host_id", &PyClient::process_index)
212 .def("task_id", &PyClient::process_index)
213 .def("get_default_device_assignment",
214 &PyClient::GetDefaultDeviceAssignment)
215 // TODO(skye): delete after all callers can handle 2D output
216 .def("get_default_device_assignment",
217 &PyClient::GetDefaultDeviceAssignment1D)
218 .def("create_channel_handle", &PyClient::CreateChannelHandle)
219 .def("create_device_to_host_channel_handle",
220 &PyClient::CreateDeviceToHostChannelHandle)
221 .def("create_host_to_device_channel_handle",
222 &PyClient::CreateHostToDeviceChannelHandle)
223 .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
224 py::arg("device") = nullptr, py::arg("force_copy") = false,
225 py::arg("host_buffer_semantics") =
226 PjRtClient::HostBufferSemantics::kZeroCopy)
227 .def("make_cross_host_receive_buffers",
228 &PyClient::MakeCrossHostReceiveBuffers, py::arg("shapes"),
229 py::arg("device"))
230 .def("compile", &PyClient::Compile, py::arg("computation"),
231 py::arg("compile_options") = CompileOptions(),
232 py::arg("host_callbacks") = std::vector<py::capsule>())
233 .def("compile", &PyClient::CompileMlir, py::arg("computation"),
234 py::arg("compile_options") = CompileOptions(),
235 py::arg("host_callbacks") = std::vector<py::capsule>())
236 .def("serialize_executable", &PyClient::SerializeExecutable)
237 .def("deserialize_executable",
238 py::overload_cast<const std::string&, CompileOptions,
239 std::vector<py::capsule>>(
240 &PyClient::DeserializeExecutable),
241 py::arg("serialized"), py::arg("compile_options"),
242 py::arg("host_callbacks") = std::vector<py::capsule>())
243 // TODO(skyewm): remove when jax stop providing hlo_module
244 .def("deserialize_executable",
245 py::overload_cast<const std::string&, std::shared_ptr<HloModule>,
246 CompileOptions, std::vector<py::capsule>>(
247 &PyClient::DeserializeExecutable),
248 py::arg("serialized"), py::arg("hlo_module"),
249 py::arg("compile_options"),
250 py::arg("host_callbacks") = std::vector<py::capsule>())
251 .def("heap_profile", &PyClient::HeapProfile)
252 // TODO(zhangqiaorjc): Experimental.
253 .def("defragment", &PyClient::Defragment)
254 .def("get_emit_python_callback_descriptor",
255 &PyClient::GetEmitPythonCallbackDescriptor, py::arg("callable"),
256 py::arg("operand_shapes"), py::arg("result_shapes") = std::nullopt)
257 .def("make_python_callback_from_host_send_and_recv",
258 &PyClient::MakePythonCallbackUsingHostSendAndRecv,
259 py::arg("callable"), py::arg("operand_shapes"),
260 py::arg("result_shapes"), py::arg("send_channel_ids"),
261 py::arg("recv_channel_ids"))
262 // Deprecated: please use `get_emit_python_callback_descriptor` instead.
263 .def("emit_python_callback", &PyClient::EmitPythonCallback,
264 py::arg("callable"), py::arg("builder"), py::arg("operands"),
265 py::arg("result_shapes"), py::arg("operand_layouts") = std::nullopt,
266 py::arg("has_side_effects") = false);
267
268 m.def(
269 "get_cpu_client",
270 [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
271 py::gil_scoped_release gil_release;
272 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
273 GetCpuClient(asynchronous));
274 return std::make_shared<PyClient>(std::move(client));
275 },
276 py::arg("asynchronous") = true);
277 m.def(
278 "get_tfrt_cpu_client",
279 [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
280 py::gil_scoped_release gil_release;
281 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
282 GetTfrtCpuClient(asynchronous));
283 return std::make_shared<PyClient>(std::move(client));
284 },
285 py::arg("asynchronous") = true);
286 m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
287 py::gil_scoped_release gil_release;
288 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
289 GetInterpreterClient());
290 return std::make_shared<PyClient>(std::move(client));
291 });
292
293 #ifdef XLA_PYTHON_ENABLE_GPU
294 py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
295 alloc_config.def(py::init<>())
296 .def_readwrite("kind", &GpuAllocatorConfig::kind)
297 .def_readwrite("memory_fraction", &GpuAllocatorConfig::memory_fraction)
298 .def_readwrite("preallocate", &GpuAllocatorConfig::preallocate);
299 py::enum_<GpuAllocatorConfig::Kind>(alloc_config, "Kind")
300 .value("DEFAULT", GpuAllocatorConfig::Kind::kDefault)
301 .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
302 .value("BFC", GpuAllocatorConfig::Kind::kBFC)
303 .value("CUDA_ASYNC", GpuAllocatorConfig::Kind::kCudaAsync);
304
305 // TODO(tomhennigan): Remove this types.
306 py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>> gpu_device(
307 m, "GpuDevice");
308 m.def(
309 "get_gpu_client",
310 [](bool asynchronous, const GpuAllocatorConfig& allocator_config,
311 std::shared_ptr<DistributedRuntimeClient> distributed_client,
312 int node_id, std::optional<std::set<int>> allowed_devices,
313 std::optional<std::string> platform_name)
314 -> StatusOr<std::shared_ptr<PyClient>> {
315 py::gil_scoped_release gil_release;
316 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
317 GetGpuClient(asynchronous, allocator_config,
318 std::move(distributed_client), node_id,
319 allowed_devices, platform_name));
320 return std::make_shared<PyClient>(std::move(client));
321 },
322 py::arg("asynchronous") = true,
323 py::arg("allocator_config") = GpuAllocatorConfig(),
324 py::arg("distributed_client") = nullptr, py::arg("node_id") = 0,
325 py::arg("allowed_devices") = std::nullopt,
326 py::arg("platform_name") = std::nullopt);
327 #endif // XLA_PYTHON_ENABLE_GPU
328
329 #ifdef XLA_PYTHON_ENABLE_TPU
330 // TODO(tomhennigan): Remove this types.
331 py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>> tpu_device(
332 m, "TpuDevice");
333 m.def(
334 "get_tpu_client",
335 [](int max_inflight_computations) -> StatusOr<std::shared_ptr<PyClient>> {
336 py::gil_scoped_release gil_release;
337 TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
338 GetTpuClient(max_inflight_computations));
339 return std::make_shared<PyClient>(std::move(client));
340 },
341 py::arg("max_inflight_computations") = 32);
342 #endif // XLA_PYTHON_ENABLE_TPU
343
344 TF_CHECK_OK(PyBuffer::RegisterTypes(m));
345
346 py::class_<CompiledMemoryStats>(m, "CompiledMemoryStats")
347 .def_readwrite("generated_code_size_in_bytes",
348 &CompiledMemoryStats::generated_code_size_in_bytes)
349 .def_readwrite("argument_size_in_bytes",
350 &CompiledMemoryStats::argument_size_in_bytes)
351 .def_readwrite("output_size_in_bytes",
352 &CompiledMemoryStats::output_size_in_bytes)
353 .def_readwrite("alias_size_in_bytes",
354 &CompiledMemoryStats::alias_size_in_bytes)
355 .def_readwrite("temp_size_in_bytes",
356 &CompiledMemoryStats::temp_size_in_bytes)
357 .def("__str__", &CompiledMemoryStats::DebugString);
358
359 py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
360 m, "Executable");
361 executable.def_property_readonly("client", &PyExecutable::client)
362 .def("local_logical_device_ids",
363 [](PyExecutable* exec) {
364 auto span = exec->addressable_device_logical_ids();
365 // Not on dispatch critical path, so ok to have heap allocation.
366 std::vector<std::pair<int, int>> addressable_device_logic_ids;
367 addressable_device_logic_ids.reserve(span.size());
368 for (const auto& logical_device_id : span) {
369 addressable_device_logic_ids.push_back(std::make_pair(
370 logical_device_id.replica, logical_device_id.partition));
371 }
372 })
373 .def("local_devices", &PyExecutable::AddressableDevices)
374 .def("size_of_generated_code_in_bytes",
375 &PyExecutable::SizeOfGeneratedCodeInBytes)
376 .def("get_compiled_memory_stats", &PyExecutable::GetCompiledMemoryStats)
377 .def("delete", &PyExecutable::Delete)
378 .def("execute", &PyExecutable::Execute, py::arg("arguments"),
379 py::arg("device") = std::nullopt)
380 // TODO(chky): Change execute() to always return token rather than hanving
381 // two API entry points.
382 .def("execute_with_token", &PyExecutable::ExecuteWithToken,
383 py::arg("arguments"), py::arg("device") = std::nullopt)
384 .def("execute_sharded_on_local_devices",
385 &PyExecutable::ExecuteShardedOnLocalDevices, py::arg("arguments"))
386 .def("execute_sharded_on_local_devices_with_tokens",
387 &PyExecutable::ExecuteShardedOnLocalDevicesWithTokens,
388 py::arg("arguments"))
389 .def("hlo_modules", &PyExecutable::HloModules)
390 .def("keep_alive", &PyExecutable::KeepAlive)
391 .def_property_readonly("traceback", &PyExecutable::traceback)
392 .def_property_readonly("fingerprint",
393 [](PyExecutable* exec) -> py::object {
394 if (exec->fingerprint().has_value()) {
395 return py::bytes(*exec->fingerprint());
396 } else {
397 return py::none();
398 }
399 });
400 py::class_<PyToken> token(m, "Token");
401 token.def("block_until_ready", &PyToken::Await);
402 py::class_<PyShardedToken> sharded_token(m, "ShardedToken");
403 sharded_token.def("block_until_ready", &PyShardedToken::Await);
404 sharded_token.def("get_token", &PyShardedToken::GetPyToken);
405
406 m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor,
407 py::arg("buffer"), py::arg("take_ownership") = true);
408 m.def("dlpack_managed_tensor_to_buffer", DLPackManagedTensorToBuffer,
409 py::arg("dlpack"), py::arg("cpu_backend") = nullptr,
410 py::arg("gpu_backend") = nullptr);
411
412 BuildProfilerSubmodule(&m);
413 BuildOpsSubmodule(&m);
414 BuildOutfeedReceiverSubmodule(&m);
415 BuildPytreeSubmodule(m);
416 jax::BuildJaxjitSubmodule(m);
417 jax::BuildPmapSubmodule(m);
418 jax::BuildTransferGuardSubmodule(m);
419 BuildTracebackSubmodule(m);
420 BuildMlirSubmodule(m);
421
422 py::class_<tensorflow::PreemptionSyncManager,
423 std::unique_ptr<tensorflow::PreemptionSyncManager>>
424 preemption_sync_manager(m, "PreemptionSyncManager");
425 preemption_sync_manager
426 .def(
427 "initialize",
428 [](tensorflow::PreemptionSyncManager& manager,
429 DistributedRuntimeClient* client) { manager.Initialize(client); },
430 py::arg("distributed_client"))
431 .def("reached_sync_point",
432 [](tensorflow::PreemptionSyncManager& manager, int step_counter) {
433 return manager.ReachedSyncPoint(step_counter);
434 });
435 m.def("create_preemption_sync_manager",
436 []() { return tensorflow::CreatePreemptionSyncManager(); });
437
438 py::class_<DistributedRuntimeService,
439 std::unique_ptr<DistributedRuntimeService>>
440 distributed_runtime_service(m, "DistributedRuntimeService");
441 distributed_runtime_service.def("shutdown",
442 &DistributedRuntimeService::Shutdown,
443 py::call_guard<py::gil_scoped_release>());
444 py::class_<DistributedRuntimeClient,
445 std::shared_ptr<DistributedRuntimeClient>>
446 distributed_runtime_client(m, "DistributedRuntimeClient");
447 distributed_runtime_client
448 .def("connect", &DistributedRuntimeClient::Connect,
449 py::call_guard<py::gil_scoped_release>())
450 .def("shutdown", &DistributedRuntimeClient::Shutdown,
451 py::call_guard<py::gil_scoped_release>())
452 .def(
453 "blocking_key_value_get",
454 [](DistributedRuntimeClient& client, std::string key,
455 int64_t timeout_in_ms) {
456 py::gil_scoped_release gil_release;
457 return client.BlockingKeyValueGet(
458 key, absl::Milliseconds(timeout_in_ms));
459 },
460 py::arg("key"), py::arg("timeout_in_ms"))
461 .def(
462 "wait_at_barrier",
463 [](DistributedRuntimeClient& client, std::string barrier_id,
464 int64_t timeout_in_ms) {
465 py::gil_scoped_release gil_release;
466 return client.WaitAtBarrier(barrier_id,
467 absl::Milliseconds(timeout_in_ms));
468 },
469 py::arg("barrier_id"), py::arg("timeout_in_ms"))
470 .def(
471 "key_value_set",
472 [](DistributedRuntimeClient& client, std::string key,
473 std::string value) {
474 py::gil_scoped_release gil_release;
475 return client.KeyValueSet(key, value);
476 },
477 py::arg("key"), py::arg("value"));
478
479 m.def(
480 "get_distributed_runtime_service",
481 [](std::string address, int num_nodes, bool use_coordination_service,
482 std::optional<int> heartbeat_interval,
483 std::optional<int> max_missing_heartbeats,
484 std::optional<int> enumerate_devices_timeout,
485 std::optional<int> shutdown_timeout)
486 -> StatusOr<std::unique_ptr<DistributedRuntimeService>> {
487 DistributedRuntimeServiceImpl::Options options;
488 options.num_nodes = num_nodes;
489 if (heartbeat_interval.has_value()) {
490 options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
491 }
492 if (max_missing_heartbeats.has_value()) {
493 options.max_missing_heartbeats = *max_missing_heartbeats;
494 }
495 if (enumerate_devices_timeout.has_value()) {
496 options.enumerate_devices_timeout =
497 absl::Seconds(*enumerate_devices_timeout);
498 }
499 if (shutdown_timeout.has_value()) {
500 options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
501 }
502 TF_ASSIGN_OR_RETURN(std::unique_ptr<DistributedRuntimeService> service,
503 GetDistributedRuntimeService(
504 address, options, use_coordination_service));
505 return service;
506 },
507 py::arg("address"), py::arg("num_nodes"),
508 py::arg("use_coordination_service"), py::kw_only(),
509 py::arg("heartbeat_interval") = std::nullopt,
510 py::arg("max_missing_heartbeats") = std::nullopt,
511 py::arg("enumerate_devices_timeout") = std::nullopt,
512 py::arg("shutdown_timeout") = std::nullopt);
513
514 m.def(
515 "get_distributed_runtime_client",
516 [](std::string address, int node_id, bool use_coordination_service,
517 std::optional<int> rpc_timeout, std::optional<int> init_timeout,
518 std::optional<int> shutdown_timeout,
519 std::optional<int> heartbeat_interval,
520 std::optional<int> max_missing_heartbeats,
521 std::optional<std::function<void(xla::Status,
522 bool coordinator_reported_failure)>>
523 missed_heartbeat_callback,
524 std::optional<bool> shutdown_on_destruction)
525 -> StatusOr<std::shared_ptr<DistributedRuntimeClient>> {
526 DistributedRuntimeClient::Options options;
527 options.node_id = node_id;
528 if (rpc_timeout.has_value()) {
529 options.rpc_timeout = absl::Seconds(*rpc_timeout);
530 }
531 if (init_timeout.has_value()) {
532 options.init_timeout = absl::Seconds(*init_timeout);
533 }
534 if (shutdown_timeout.has_value()) {
535 options.shutdown_timeout = absl::Seconds(*shutdown_timeout);
536 }
537 if (heartbeat_interval.has_value()) {
538 options.heartbeat_interval = absl::Seconds(*heartbeat_interval);
539 }
540 if (max_missing_heartbeats.has_value()) {
541 options.max_missing_heartbeats = *max_missing_heartbeats;
542 }
543 if (missed_heartbeat_callback.has_value()) {
544 options.missed_heartbeat_callback =
545 std::move(*missed_heartbeat_callback);
546 }
547 if (shutdown_on_destruction.has_value()) {
548 options.shutdown_on_destruction = *shutdown_on_destruction;
549 }
550 return GetDistributedRuntimeClient(address, options,
551 use_coordination_service);
552 },
553 py::arg("address"), py::arg("node_id"),
554 py::arg("use_coordination_service"), py::kw_only(),
555 py::arg("rpc_timeout") = std::nullopt,
556 py::arg("init_timeout") = std::nullopt,
557 py::arg("shutdown_timeout") = std::nullopt,
558 py::arg("heartbeat_interval") = std::nullopt,
559 py::arg("max_missing_heartbeats") = std::nullopt,
560 py::arg("missed_heartbeat_callback") = std::nullopt,
561 py::arg("shutdown_on_destruction") = std::nullopt);
562
563 m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
564
565 m.def("is_optimized_build", &IsOptimizedBuild);
566
567 m.def("json_to_pprof_profile", &JsonToPprofProfile,
568 "Encodes the JSON representation of a pprof Profile into its binary "
569 "protocol buffer encoding.");
570 m.def("pprof_profile_to_json", &PprofProfileToJson,
571 "Decodes an uncompressed pprof Profile protocol buffer into a JSON "
572 "representation");
573 } // NOLINT(readability/fn_size)
574
575 } // namespace xla
576