1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
25 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
26 // TODO(skyewm): remove when everything goes through C API
27 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
28 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/tpu/pjrt_api.h"
33
34 // TODO(b/238999986): Remove this when we have decomposed shape.
35 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
36
37 namespace xla {
38
39 // Helper macros
40
41 // Return error status if not success and frees the PJRT_Error returned by
42 // `expr`.
43 #define RETURN_STATUS_IF_ERROR(expr, c_api) \
44 do { \
45 PJRT_Error* error = (expr); \
46 std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> _error( \
47 error, pjrt::MakeErrorDeleter(c_api)); \
48 xla::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \
49 if (!_status.ok()) { \
50 return _status; \
51 } \
52 } while (false)
53
54 // ---------------------------------- Client -----------------------------------
55
PjRtCApiClient(const PJRT_Api * c_api,PJRT_Client * c_client)56 PjRtCApiClient::PjRtCApiClient(const PJRT_Api* c_api, PJRT_Client* c_client)
57 : c_api_(c_api),
58 c_client_(std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter>(
59 c_client, ::pjrt::MakeClientDeleter(c_api))) {
60 wrapped_ = c_client_->client.get();
61
62 InitDevices();
63 }
64
InitDevices()65 void PjRtCApiClient::InitDevices() {
66 PJRT_Client_Devices_Args devices_args;
67 devices_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE;
68 devices_args.priv = nullptr;
69 devices_args.client = c_client_.get();
70
71 pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_Devices(&devices_args), c_api_);
72
73 const size_t n = devices_args.num_devices;
74 wrapped_device_map_.reserve(n);
75 c_to_cpp_device_map_.reserve(n);
76 owned_devices_.reserve(n);
77 devices_.reserve(n);
78
79 for (size_t i = 0; i < n; ++i) {
80 PJRT_Device* device = devices_args.devices[i];
81 std::unique_ptr<PjRtCApiDevice>& cpp_device = owned_devices_.emplace_back(
82 std::make_unique<PjRtCApiDevice>(device, this));
83 devices_.push_back(cpp_device.get());
84 c_to_cpp_device_map_[device] = cpp_device.get();
85 // Map the wrapped PjRtDevice* to the PjRtCApiDevice* that wraps it.
86 // TODO(b/237017893): remove `wrapped_device_map_` and replace it with
87 // `c_api_device_map_`
88 wrapped_device_map_[device->device] = cpp_device.get();
89 }
90
91 PJRT_Client_AddressableDevices_Args address_args;
92 address_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE;
93 address_args.priv = nullptr;
94 address_args.client = c_client_.get();
95
96 pjrt::LogFatalIfPjrtError(
97 c_api_->PJRT_Client_AddressableDevices(&address_args), c_api_);
98
99 const size_t m = address_args.num_addressable_devices;
100 addressable_devices_.reserve(m);
101
102 for (size_t i = 0; i < m; ++i) {
103 PJRT_Device* c_device = address_args.addressable_devices[i];
104 addressable_devices_.push_back(GetCppDevice(c_device));
105 }
106 }
107
device_count() const108 int PjRtCApiClient::device_count() const { return devices_.size(); }
109
addressable_device_count() const110 int PjRtCApiClient::addressable_device_count() const {
111 return addressable_devices_.size();
112 }
113
devices() const114 absl::Span<PjRtDevice* const> PjRtCApiClient::devices() const {
115 return devices_;
116 }
117
addressable_devices() const118 absl::Span<PjRtDevice* const> PjRtCApiClient::addressable_devices() const {
119 return addressable_devices_;
120 }
121
platform_name() const122 absl::string_view PjRtCApiClient::platform_name() const {
123 PJRT_Client_PlatformName_Args args;
124 args.client = c_client_.get();
125 args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE;
126 args.priv = nullptr;
127 pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_PlatformName(&args), c_api_);
128
129 absl::string_view platform_name(args.platform_name, args.platform_name_size);
130 return platform_name;
131 }
132
process_index() const133 int PjRtCApiClient::process_index() const {
134 PJRT_Client_ProcessIndex_Args process_index_args;
135 process_index_args.struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE;
136 process_index_args.priv = nullptr;
137 process_index_args.client = c_client_.get();
138 pjrt::LogFatalIfPjrtError(
139 c_api_->PJRT_Client_ProcessIndex(&process_index_args), c_api_);
140
141 return process_index_args.process_index;
142 }
143
platform_version() const144 absl::string_view PjRtCApiClient::platform_version() const {
145 PJRT_Client_PlatformVersion_Args args;
146 args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE;
147 args.priv = nullptr;
148 args.client = c_client_.get();
149 pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_PlatformVersion(&args), c_api_);
150
151 absl::string_view platform_version(args.platform_version,
152 args.platform_version_size);
153 return platform_version;
154 }
155
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const156 StatusOr<std::optional<std::string>> PjRtCApiClient::ExecutableFingerprint(
157 const PjRtLoadedExecutable& executable) const {
158 return {std::nullopt};
159 }
160
LookupDevice(int device_id) const161 StatusOr<PjRtDevice*> PjRtCApiClient::LookupDevice(int device_id) const {
162 PJRT_Client_LookupDevice_Args args;
163 args.struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE;
164 args.priv = nullptr;
165 args.client = c_client_.get();
166 args.id = device_id;
167 RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_);
168 return GetCppDevice(args.device);
169 }
170
ValidateCompileOption(CompileOptions options)171 static Status ValidateCompileOption(CompileOptions options) {
172 if (options.argument_layouts.has_value()) {
173 return xla::Unimplemented(
174 "argument_layouts in CompileOptions is not supported.");
175 }
176 if (options.compile_portable_executable) {
177 return xla::Unimplemented(
178 "compile_portable_executable in CompileOptions is not supported.");
179 }
180 if (options.profile_version != 0) {
181 return xla::Unimplemented(
182 "profile_version in CompileOptions is not supported.");
183 }
184 if (options.multi_slice_config != nullptr) {
185 return xla::Unimplemented(
186 "multi_slice_config in CompileOptions is not supported.");
187 }
188 return xla::OkStatus();
189 }
190
191 // Convert `CompileOptions` to `PJRT_CompileOptions`. `device_assignment_str`
192 // will be used for serialized DeviceAssignment storage.
ConvertCppCompileOptionsToCCompileOptions(CompileOptions options,std::string * device_assignment_str)193 static StatusOr<PJRT_CompileOptions> ConvertCppCompileOptionsToCCompileOptions(
194 CompileOptions options, std::string* device_assignment_str) {
195 PJRT_CompileOptions c_options;
196 c_options.struct_size = PJRT_CompileOptions_STRUCT_SIZE;
197 c_options.parameter_is_tupled_arguments =
198 options.parameter_is_tupled_arguments;
199 c_options.device_ordinal = options.executable_build_options.device_ordinal();
200 c_options.num_replicas = options.executable_build_options.num_replicas();
201 c_options.num_partitions = options.executable_build_options.num_partitions();
202 c_options.use_spmd_partitioning =
203 options.executable_build_options.use_spmd_partitioning();
204 c_options.allow_spmd_sharding_propagation_to_output =
205 options.executable_build_options
206 .allow_spmd_sharding_propagation_to_output();
207
208 if (options.executable_build_options.has_device_assignment()) {
209 DeviceAssignmentProto device_assignment_proto;
210 TF_RETURN_IF_ERROR(
211 options.executable_build_options.device_assignment().Serialize(
212 &device_assignment_proto));
213 *device_assignment_str = device_assignment_proto.SerializeAsString();
214 c_options.device_assignment = device_assignment_str->c_str();
215 c_options.device_assignment_size = device_assignment_str->size();
216 } else {
217 c_options.device_assignment_size = 0;
218 c_options.device_assignment = nullptr;
219 }
220 return c_options;
221 }
222
Compile(mlir::ModuleOp module,CompileOptions options)223 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
224 mlir::ModuleOp module, CompileOptions options) {
225 TF_RETURN_IF_ERROR(ValidateCompileOption(options));
226 PJRT_Client_Compile_Args args;
227 args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
228 args.priv = nullptr;
229 args.client = c_client_.get();
230 std::string device_assignment_str;
231 TF_ASSIGN_OR_RETURN(PJRT_CompileOptions c_options,
232 ConvertCppCompileOptionsToCCompileOptions(
233 options, &device_assignment_str));
234 args.options = &c_options;
235 std::string module_str = tensorflow::SerializeMlirModule(module);
236 args.module = module_str.c_str();
237 args.module_size = module_str.size();
238
239 RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_Compile(&args), c_api_);
240 std::unique_ptr<PjRtLoadedExecutable> ret =
241 std::make_unique<PjRtCApiExecutable>(this, args.executable);
242 return ret;
243 }
244
SerializeExecutable(const PjRtLoadedExecutable & executable) const245 StatusOr<std::string> PjRtCApiClient::SerializeExecutable(
246 const PjRtLoadedExecutable& executable) const {
247 #ifdef PJRT_C_API_BYPASS
248 return wrapped_->SerializeExecutable(
249 *PjRtCApiExecutable::GetWrapped(&executable));
250 #endif // PJRT_C_API_BYPASS
251 return Unimplemented("PJRT C API does not support SerializeExecutable");
252 }
253
254 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,CompileOptions options)255 PjRtCApiClient::DeserializeExecutable(absl::string_view serialized,
256 CompileOptions options) {
257 #ifdef PJRT_C_API_BYPASS
258 return WrapExecutable(wrapped_->DeserializeExecutable(serialized, options));
259 #endif // PJRT_C_API_BYPASS
260 return Unimplemented("PJRT C API does not support DeserializeExecutable");
261 }
262
UnsafeBufferPointer(PjRtBuffer * buffer)263 StatusOr<std::uintptr_t> PjRtCApiClient::UnsafeBufferPointer(
264 PjRtBuffer* buffer) {
265 #ifdef PJRT_C_API_BYPASS
266 return wrapped_->UnsafeBufferPointer(PjRtCApiBuffer::GetWrapped(buffer));
267 #endif // PJRT_C_API_BYPASS
268 return Unimplemented("PJRT C API does not support UnsafeBufferPointer");
269 }
270
WrapExecutable(StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap)271 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::WrapExecutable(
272 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap) {
273 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
274 std::move(to_wrap));
275 return std::unique_ptr<PjRtLoadedExecutable>(
276 std::make_unique<PjRtCApiExecutable>(this, std::move(executable)));
277 }
278
WrapBuffer(StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap)279 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiClient::WrapBuffer(
280 StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap) {
281 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer, std::move(to_wrap));
282 return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtCApiBuffer>(
283 this, new PJRT_Buffer{std::move(buffer), pjrt_c_client()}));
284 }
285
pjrt_c_api() const286 const PJRT_Api* PjRtCApiClient::pjrt_c_api() const { return c_api_; }
287
288 // --------------------------------- Devices -----------------------------------
289
PjRtCApiDevice(PJRT_Device * device,PjRtCApiClient * client)290 PjRtCApiDevice::PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client)
291 : client_(client), device_(device) {
292 wrapped_ = device_->device;
293 InitAttributes();
294 }
295
client() const296 PjRtClient* PjRtCApiDevice::client() const { return client_; }
297
id() const298 int PjRtCApiDevice::id() const {
299 PJRT_Device_Id_Args args;
300 args.struct_size = PJRT_Device_Id_Args_STRUCT_SIZE;
301 args.priv = nullptr;
302 args.device = device_;
303 const PJRT_Api* api = client_->pjrt_c_api();
304 pjrt::LogFatalIfPjrtError(api->PJRT_Device_Id(&args), api);
305 return args.id;
306 }
307
process_index() const308 int PjRtCApiDevice::process_index() const {
309 PJRT_Device_ProcessIndex_Args args;
310 args.struct_size = PJRT_Device_ProcessIndex_Args_STRUCT_SIZE;
311 args.priv = nullptr;
312 args.device = device_;
313 const PJRT_Api* api = client_->pjrt_c_api();
314 pjrt::LogFatalIfPjrtError(api->PJRT_Device_ProcessIndex(&args), api);
315 return args.process_index;
316 }
317
IsAddressable() const318 bool PjRtCApiDevice::IsAddressable() const {
319 PJRT_Device_IsAddressable_Args args;
320 args.struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE;
321 args.priv = nullptr;
322 args.device = device_;
323 const PJRT_Api* api = client_->pjrt_c_api();
324 pjrt::LogFatalIfPjrtError(api->PJRT_Device_IsAddressable(&args), api);
325 return args.is_addressable;
326 }
327
InitAttributes()328 void PjRtCApiDevice::InitAttributes() {
329 attributes_ = {};
330 PJRT_Device_Attributes_Args args;
331 args.struct_size = PJRT_Device_Attributes_Args_STRUCT_SIZE;
332 args.priv = nullptr;
333 args.device = device_;
334 const PJRT_Api* api = client_->pjrt_c_api();
335 pjrt::LogFatalIfPjrtError(api->PJRT_Device_Attributes(&args), api);
336
337 for (int i = 0; i < args.num_attributes; ++i) {
338 const auto& attribute = args.attributes[i];
339 std::string attribute_name(attribute.name, attribute.name_size);
340 switch (attribute.type) {
341 case PJRT_Device_Attribute::PJRT_Device_Attribute_kString: {
342 std::string string_value(attribute.string_value, attribute.value_size);
343 attributes_[attribute_name] = PjRtDeviceAttribute(string_value);
344 break;
345 }
346 case PJRT_Device_Attribute::PJRT_Device_Attribute_kInt64: {
347 attributes_[attribute_name] =
348 PjRtDeviceAttribute(attribute.int64_value);
349 break;
350 }
351 case PJRT_Device_Attribute::PJRT_Device_Attribute_kInt64List: {
352 const int64_t* array_ptr(attribute.int64_array_value);
353 std::vector<int64_t> int64_array(array_ptr,
354 array_ptr + attribute.value_size);
355 attributes_[attribute_name] = PjRtDeviceAttribute(int64_array);
356 break;
357 }
358 }
359 }
360 }
361
362 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const363 PjRtCApiDevice::Attributes() const {
364 return attributes_;
365 }
366
device_kind() const367 absl::string_view PjRtCApiDevice::device_kind() const {
368 PJRT_Device_Kind_Args args;
369 args.struct_size = PJRT_Device_Kind_Args_STRUCT_SIZE;
370 args.priv = nullptr;
371 args.device = device_;
372
373 const PJRT_Api* c_api = client_->pjrt_c_api();
374 pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_Kind(&args), c_api);
375
376 absl::string_view device_kind(args.device_kind, args.device_kind_size);
377 return device_kind;
378 }
379
local_hardware_id() const380 int PjRtCApiDevice::local_hardware_id() const {
381 PJRT_Device_LocalHardwareId_Args args;
382 args.struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE;
383 args.priv = nullptr;
384 args.device = device_;
385 const PJRT_Api* api = client_->pjrt_c_api();
386 pjrt::LogFatalIfPjrtError(api->PJRT_Device_LocalHardwareId(&args), api);
387 return args.local_hardware_id;
388 }
389
DebugString() const390 absl::string_view PjRtCApiDevice::DebugString() const {
391 PJRT_Device_DebugString_Args args;
392 args.struct_size = PJRT_Device_DebugString_Args_STRUCT_SIZE;
393 args.priv = nullptr;
394 args.device = device_;
395 const PJRT_Api* c_api = client_->pjrt_c_api();
396 pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_DebugString(&args), c_api);
397 absl::string_view debug_string(args.debug_string, args.debug_string_size);
398 return debug_string;
399 }
400
ToString() const401 absl::string_view PjRtCApiDevice::ToString() const {
402 PJRT_Device_ToString_Args args;
403 args.struct_size = PJRT_Device_ToString_Args_STRUCT_SIZE;
404 args.priv = nullptr;
405 args.device = device_;
406 const PJRT_Api* c_api = client_->pjrt_c_api();
407 pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_ToString(&args), c_api);
408 absl::string_view to_string(args.to_string, args.to_string_size);
409 return to_string;
410 }
411
412 // ------------------------------- Executables ---------------------------------
413
PjRtCApiExecutable(PjRtCApiClient * client,std::unique_ptr<PjRtLoadedExecutable> wrapped)414 PjRtCApiExecutable::PjRtCApiExecutable(
415 PjRtCApiClient* client, std::unique_ptr<PjRtLoadedExecutable> wrapped)
416 : client_(client),
417 executable_(
418 new PJRT_Executable{std::move(wrapped), client->pjrt_c_client()}) {
419 InitDevices();
420 }
421
PjRtCApiExecutable(PjRtCApiClient * client,PJRT_Executable * executable)422 PjRtCApiExecutable::PjRtCApiExecutable(PjRtCApiClient* client,
423 PJRT_Executable* executable)
424 : client_(client), executable_(executable) {
425 InitDevices();
426 }
427
InitDevices()428 void PjRtCApiExecutable::InitDevices() {
429 PJRT_Executable_AddressableDevices_Args args;
430 args.struct_size = PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE;
431 args.priv = nullptr;
432 args.executable = executable_;
433 args.addressable_devices = nullptr;
434 args.num_addressable_devices = 0;
435
436 const PJRT_Api* api = pjrt_c_api();
437 pjrt::LogFatalIfPjrtError(api->PJRT_Executable_AddressableDevices(&args),
438 api);
439
440 const size_t num_addressable_devices = args.num_addressable_devices;
441 addressable_devices_.reserve(num_addressable_devices);
442
443 for (size_t i = 0; i < num_addressable_devices; ++i) {
444 PJRT_Device* device = args.addressable_devices[i];
445 PjRtCApiDevice* c_api_device = client_->GetCppDevice(device);
446 addressable_devices_.push_back(c_api_device);
447 }
448 }
449
~PjRtCApiExecutable()450 PjRtCApiExecutable::~PjRtCApiExecutable() {
451 PJRT_Executable_Destroy_Args args;
452 args.struct_size = PJRT_Executable_Destroy_Args_STRUCT_SIZE;
453 args.priv = nullptr;
454 args.executable = executable_;
455 const PJRT_Api* api = pjrt_c_api();
456 pjrt::LogFatalIfPjrtError(api->PJRT_Executable_Destroy(&args), api);
457 }
458
Convert2DCppBuffersToCBuffers(absl::Span<const std::vector<PjRtBuffer * >> cpp_lists)459 static std::vector<std::vector<PJRT_Buffer*>> Convert2DCppBuffersToCBuffers(
460 absl::Span<const std::vector<PjRtBuffer*>> cpp_lists) {
461 std::vector<std::vector<PJRT_Buffer*>> c_lists;
462 c_lists.reserve(cpp_lists.size());
463 for (const auto& cpp_list : cpp_lists) {
464 auto& c_list = c_lists.emplace_back();
465 c_list.reserve(cpp_list.size());
466 for (PjRtBuffer* buffer : cpp_list) {
467 auto* c_api_argument = tensorflow::down_cast<PjRtCApiBuffer*>(buffer);
468 c_list.push_back(c_api_argument->c_buffer());
469 }
470 }
471 return c_lists;
472 }
473
474 static std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>
Convert2DCBuffersToCppBuffers(PJRT_Buffer *** c_lists,size_t outer_size,int inner_size,xla::PjRtCApiClient * client)475 Convert2DCBuffersToCppBuffers(PJRT_Buffer*** c_lists, size_t outer_size,
476 int inner_size, xla::PjRtCApiClient* client) {
477 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> ret;
478 for (size_t i = 0; i < outer_size; ++i) {
479 auto& output_list = ret.emplace_back();
480 output_list.reserve(inner_size);
481 for (size_t j = 0; j < inner_size; ++j) {
482 output_list.push_back(
483 std::make_unique<PjRtCApiBuffer>(client, c_lists[i][j]));
484 }
485 }
486 return ret;
487 }
488
489 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)490 PjRtCApiExecutable::Execute(
491 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
492 const ExecuteOptions& options,
493 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
494 PJRT_Executable_Execute_Args args;
495 args.struct_size = PJRT_Executable_Execute_Args_STRUCT_SIZE;
496 args.priv = nullptr;
497 args.executable = executable_;
498 PJRT_ExecuteOptions c_options;
499 args.options = &c_options;
500 args.options->struct_size = PJRT_ExecuteOptions_STRUCT_SIZE;
501 args.options->launch_id = options.launch_id;
502 args.num_devices = argument_handles.size();
503 CHECK_GT(args.num_devices, 0);
504 args.num_args = argument_handles[0].size();
505
506 // Populates `args.argument_lists` from `argument_handles`.
507 std::vector<std::vector<PJRT_Buffer*>> c_argument_lists =
508 Convert2DCppBuffersToCBuffers(argument_handles);
509 std::vector<PJRT_Buffer**> c_arguments;
510 c_arguments.reserve(c_argument_lists.size());
511 for (auto& argument_list : c_argument_lists) {
512 c_arguments.push_back(argument_list.data());
513 }
514 args.argument_lists = c_arguments.data();
515
516 // Allocates memory for output. `c_buffer_lists_holder` and `c_buffer_lists`
517 // needs to stay alive during the call of `PJRT_Executable_Execute`.
518 PJRT_Executable_NumOutputs_Args numoutputs_args;
519 numoutputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE;
520 numoutputs_args.priv = nullptr;
521 numoutputs_args.executable = executable_;
522 RETURN_STATUS_IF_ERROR(
523 pjrt_c_api()->PJRT_Executable_NumOutputs(&numoutputs_args), pjrt_c_api());
524 size_t outer_size = args.num_devices;
525 size_t inner_size = numoutputs_args.num_outputs;
526 std::vector<std::vector<PJRT_Buffer*>> c_buffer_lists_holder(outer_size);
527 auto c_buffer_lists = std::vector<PJRT_Buffer**>(outer_size);
528 for (int i = 0; i < outer_size; ++i) {
529 c_buffer_lists_holder[i].resize(inner_size);
530 c_buffer_lists[i] = c_buffer_lists_holder[i].data();
531 }
532 args.output_lists = c_buffer_lists.data();
533
534 RETURN_STATUS_IF_ERROR(pjrt_c_api()->PJRT_Executable_Execute(&args),
535 pjrt_c_api());
536
537 return Convert2DCBuffersToCppBuffers(args.output_lists, args.num_devices,
538 numoutputs_args.num_outputs, client_);
539 }
540
541 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)542 PjRtCApiExecutable::ExecuteSharded(
543 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
544 const ExecuteOptions& options,
545 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
546 #ifdef PJRT_C_API_BYPASS
547 std::vector<PjRtBuffer*> wrapped_args =
548 PjRtCApiBuffer::GetWrappedVector(argument_handles);
549
550 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> out,
551 wrapped()->ExecuteSharded(
552 wrapped_args, PjRtCApiDevice::GetWrapped(device),
553 options, returned_future, fill_future));
554
555 for (std::unique_ptr<PjRtBuffer>& buffer : out) {
556 buffer = std::make_unique<PjRtCApiBuffer>(
557 client_, new PJRT_Buffer{std::move(buffer), client_->pjrt_c_client()});
558 }
559 return out;
560 #endif // PJRT_C_API_BYPASS
561 return Unimplemented("PJRT C API does not support ExecuteSharded");
562 }
563
564 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)565 PjRtCApiExecutable::ExecutePortable(
566 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
567 const ExecuteOptions& options,
568 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
569 #ifdef PJRT_C_API_BYPASS
570 std::vector<PjRtBuffer*> wrapped_args =
571 PjRtCApiBuffer::GetWrappedVector(argument_handles);
572
573 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> out,
574 wrapped()->ExecutePortable(
575 wrapped_args, PjRtCApiDevice::GetWrapped(device),
576 options, returned_future, fill_future));
577
578 for (std::unique_ptr<PjRtBuffer>& buffer : out) {
579 buffer = std::make_unique<PjRtCApiBuffer>(
580 client_, new PJRT_Buffer{std::move(buffer), client_->pjrt_c_client()});
581 }
582 return out;
583 #endif // PJRT_C_API_BYPASS
584 return Unimplemented("PJRT C API does not support ExecutePortable");
585 }
586
wrapped() const587 PjRtLoadedExecutable* PjRtCApiExecutable::wrapped() const {
588 return executable_->executable.get();
589 }
590
name() const591 absl::string_view PjRtCApiExecutable::name() const {
592 const PJRT_Api* c_api = pjrt_c_api();
593 PJRT_Executable_Name_Args args;
594 args.executable = executable_;
595 args.struct_size = PJRT_Executable_Name_Args_STRUCT_SIZE;
596 args.priv = nullptr;
597 pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Name(&args), c_api);
598
599 absl::string_view executable_name(args.executable_name,
600 args.executable_name_size);
601 return executable_name;
602 }
603
Delete()604 void PjRtCApiExecutable::Delete() {
605 PJRT_Executable_Delete_Args args;
606 args.struct_size = PJRT_Executable_Delete_Args_STRUCT_SIZE;
607 args.priv = nullptr;
608 args.executable = executable_;
609 const PJRT_Api* c_api = pjrt_c_api();
610 pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Delete(&args), c_api);
611 }
612
IsDeleted()613 bool PjRtCApiExecutable::IsDeleted() {
614 PJRT_Executable_IsDeleted_Args args;
615 args.struct_size = PJRT_Executable_IsDeleted_Args_STRUCT_SIZE;
616 args.priv = nullptr;
617 args.executable = executable_;
618
619 const PJRT_Api* c_api = pjrt_c_api();
620 pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_IsDeleted(&args), c_api);
621 return args.is_deleted;
622 }
623
624 // ---------------------------------- Buffers ----------------------------------
625
PjRtCApiBuffer(PjRtCApiClient * client,PJRT_Buffer * buffer)626 PjRtCApiBuffer::PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer)
627 : client_(client),
628 buffer_(buffer, ::pjrt::MakeBufferDeleter(client->pjrt_c_api())),
629 wrapped_(buffer_->buffer.get()) {
630 set_shape();
631 }
632
on_device_shape() const633 const Shape& PjRtCApiBuffer::on_device_shape() const {
634 CHECK(shape_.has_value())
635 << "Shape should be initialized in PjRtCApiBuffer constructor.";
636 return shape_.value();
637 }
638
set_shape()639 void PjRtCApiBuffer::set_shape() {
640 PJRT_Buffer_OnDeviceTrimmedShape_Args args;
641 args.struct_size = PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE;
642 args.priv = nullptr;
643 args.buffer = buffer_.get();
644
645 pjrt::LogFatalIfPjrtError(
646 client_->pjrt_c_api()->PJRT_Buffer_OnDeviceTrimmedShape(&args),
647 client_->pjrt_c_api());
648
649 xla::PrimitiveType element_type =
650 static_cast<xla::PrimitiveType>(args.element_type);
651
652 CHECK_NE(element_type, xla::PrimitiveType::TUPLE);
653
654 absl::Span<const int64_t> dims = ApiConverter::MakeSpan(args.dimensions);
655 absl::Span<const bool> dynamic_dims =
656 ApiConverter::MakeSpan(args.dynamic_dimensions);
657
658 Shape trimmed_shape = Shape(element_type, dims, dynamic_dims, {});
659
660 if (args.has_layout) {
661 *(trimmed_shape.mutable_layout()) = ApiConverter::FromC(&args.layout);
662 }
663
664 shape_ = trimmed_shape;
665
666 // TODO(amangu): Refactor the deletion.
667 if (args.dimensions.size > TPU_C_API_MAX_INLINED) {
668 delete[] args.dimensions.heap;
669 }
670
671 if (args.dynamic_dimensions.size > TPU_C_API_MAX_INLINED) {
672 delete[] args.dynamic_dimensions.heap;
673 }
674
675 if (args.has_layout) {
676 if (args.layout.minor_to_major.size > TPU_C_API_MAX_INLINED) {
677 delete[] args.layout.minor_to_major.heap;
678 }
679
680 if (args.layout.tiles.size > TPU_C_API_MAX_INLINED) {
681 delete[] args.layout.tiles.heap;
682 }
683 }
684 }
685
GetOnDeviceSizeInBytes() const686 StatusOr<size_t> PjRtCApiBuffer::GetOnDeviceSizeInBytes() const {
687 PJRT_Buffer_OnDeviceSizeInBytes_Args args;
688 args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE;
689 args.priv = nullptr;
690 args.buffer = buffer_.get();
691 RETURN_STATUS_IF_ERROR(
692 client_->pjrt_c_api()->PJRT_Buffer_OnDeviceSizeInBytes(&args),
693 client_->pjrt_c_api());
694
695 return args.on_device_size_in_bytes;
696 }
697
device() const698 PjRtDevice* PjRtCApiBuffer::device() const {
699 PJRT_Buffer_Device_Args args;
700 args.struct_size = PJRT_Buffer_Device_Args_STRUCT_SIZE;
701 args.priv = nullptr;
702 args.buffer = buffer_.get();
703 const PJRT_Api* api = pjrt_c_api();
704 pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Device(&args), api);
705 return client_->GetCppDevice(args.device);
706 }
707
Delete()708 void PjRtCApiBuffer::Delete() {
709 PJRT_Buffer_Delete_Args args;
710 args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE;
711 args.priv = nullptr;
712 args.buffer = buffer_.get();
713 const PJRT_Api* api = pjrt_c_api();
714 pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Delete(&args), api);
715 }
716
IsDeleted()717 bool PjRtCApiBuffer::IsDeleted() {
718 PJRT_Buffer_IsDeleted_Args args;
719 args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE;
720 args.priv = nullptr;
721 args.buffer = buffer_.get();
722 const PJRT_Api* api = pjrt_c_api();
723 pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsDeleted(&args), api);
724 return args.is_deleted;
725 }
726
CopyToDevice(PjRtDevice * dst_device)727 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiBuffer::CopyToDevice(
728 PjRtDevice* dst_device) {
729 if (dst_device->client() == client_) {
730 PJRT_Buffer_CopyToDevice_Args args;
731 args.struct_size = PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE;
732 args.priv = nullptr;
733 args.buffer = buffer_.get();
734 args.dst_device =
735 tensorflow::down_cast<PjRtCApiDevice*>(dst_device)->c_device();
736 const PJRT_Api* api = pjrt_c_api();
737 RETURN_STATUS_IF_ERROR(api->PJRT_Buffer_CopyToDevice(&args), api);
738 return std::unique_ptr<PjRtBuffer>(
739 std::make_unique<PjRtCApiBuffer>(client_, args.dst_buffer));
740 } else {
741 // TODO(b/239735405) Copying across different clients where `dst_device` is
742 // not a PjRtCApiDevice raises an error.
743 return wrapped_->CopyToDevice(dst_device);
744 }
745 }
746
IsOnCpu() const747 bool PjRtCApiBuffer::IsOnCpu() const {
748 PJRT_Buffer_IsOnCpu_Args args;
749 args.struct_size = PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE;
750 args.priv = nullptr;
751 args.buffer = buffer_.get();
752 const PJRT_Api* api = pjrt_c_api();
753 pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsOnCpu(&args), api);
754 return args.is_on_cpu;
755 }
756
757 // -------------------------------- API access ---------------------------------
758
GetCApiClient()759 StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient() {
760 const PJRT_Api* c_api = tensorflow::tpu::PjrtApi();
761 // TODO(skyewm): make status
762 CHECK(c_api != nullptr);
763
764 PJRT_Client_Create_Args init_args;
765 init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
766 init_args.priv = nullptr;
767 RETURN_STATUS_IF_ERROR(c_api->PJRT_Client_Create(&init_args), c_api);
768 PJRT_Client* c_client = init_args.client;
769
770 return std::unique_ptr<PjRtClient>(
771 std::make_unique<PjRtCApiClient>(c_api, c_client));
772 }
773
774 } // namespace xla
775