xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <variant>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
25 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 // TODO(b/238999986): Remove this.
29 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
30 
31 namespace pjrt {
32 
CheckMatchingStructSizes(absl::string_view struct_name,size_t expected_size,size_t actual_size)33 xla::Status CheckMatchingStructSizes(absl::string_view struct_name,
34                                      size_t expected_size, size_t actual_size) {
35   if (expected_size != actual_size) {
36     return tensorflow::errors::InvalidArgument(
37         StructSizeErrorMsg(struct_name, expected_size, actual_size));
38   }
39   return tensorflow::OkStatus();
40 }
41 
StructSizeErrorMsg(absl::string_view struct_name,size_t expected_size,size_t actual_size)42 std::string StructSizeErrorMsg(absl::string_view struct_name,
43                                size_t expected_size, size_t actual_size) {
44   return absl::StrCat("Unexpected ", struct_name, " size: expected ",
45                       expected_size, ", got ", actual_size,
46                       ". Check installed software versions.");
47 }
48 
49 // Returns C device from wrapped C++ device.
GetCDevice(const PJRT_Client * client,const xla::PjRtDevice * device)50 static PJRT_Device* GetCDevice(const PJRT_Client* client,
51                                const xla::PjRtDevice* device) {
52   auto c_device_map = client->c_device_from_cpp_device;
53   auto iter = c_device_map.find(device);
54   CHECK(iter != c_device_map.end());
55   return iter->second;
56 }
57 
58 // ---------------------------------- Errors -----------------------------------
59 
PJRT_Error_Destroy(PJRT_Error_Destroy_Args * args)60 void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args) {
61   xla::Status struct_size_check = CheckMatchingStructSizes(
62       "PJRT_Error_Destroy_Args", PJRT_Error_Destroy_Args_STRUCT_SIZE,
63       args->struct_size);
64   if (!struct_size_check.ok()) {
65     LOG(ERROR) << struct_size_check.error_message();
66   }
67   if (args->struct_size >= PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error)) {
68     delete args->error;
69   }
70 }
71 
PJRT_Error_Message(PJRT_Error_Message_Args * args)72 void PJRT_Error_Message(PJRT_Error_Message_Args* args) {
73   xla::Status struct_size_check = CheckMatchingStructSizes(
74       "PJRT_Error_Message_Args", PJRT_Error_Message_Args_STRUCT_SIZE,
75       args->struct_size);
76   if (!struct_size_check.ok()) {
77     LOG(ERROR) << struct_size_check.error_message();
78   }
79   if (args->struct_size >= PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error)) {
80     const xla::Status* status = &args->error->status;
81     args->message = status->error_message().data();
82     args->message_size = status->error_message().size();
83   }
84 }
85 
PJRT_Error_GetCode(PJRT_Error_GetCode_Args * args)86 PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) {
87   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
88       "PJRT_Error_GetCode_Args", PJRT_Error_GetCode_Args_STRUCT_SIZE,
89       args->struct_size));
90   args->code = StatusCodeToPjrtErrorCode(args->error->status.code());
91   return nullptr;
92 }
93 
94 // ---------------------------------- Client -----------------------------------
95 
PJRT_Client_Destroy(PJRT_Client_Destroy_Args * args)96 PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args) {
97   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
98       "PJRT_Client_Destroy_Args", PJRT_Client_Destroy_Args_STRUCT_SIZE,
99       args->struct_size));
100   delete args->client;
101   return nullptr;
102 }
103 
PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args * args)104 PJRT_Error* PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args* args) {
105   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
106       "PJRT_CLient_ProcessIndex_Args",
107       PJRT_Client_ProcessIndex_Args_STRUCT_SIZE, args->struct_size));
108   args->process_index = args->client->client->process_index();
109   return nullptr;
110 }
111 
PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args * args)112 PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args) {
113   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
114       "PJRT_Client_PlatformName_Args",
115       PJRT_Client_PlatformName_Args_STRUCT_SIZE, args->struct_size));
116   absl::string_view platform_name = args->client->client->platform_name();
117   args->platform_name = platform_name.data();
118   args->platform_name_size = platform_name.size();
119   return nullptr;
120 }
121 
PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args * args)122 PJRT_Error* PJRT_Client_PlatformVersion(
123     PJRT_Client_PlatformVersion_Args* args) {
124   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
125       "PJRT_CLient_PlatformVersion_Args",
126       PJRT_Client_PlatformVersion_Args_STRUCT_SIZE, args->struct_size));
127   absl::string_view platform_version = args->client->client->platform_version();
128   args->platform_version = platform_version.data();
129   args->platform_version_size = platform_version.size();
130   return nullptr;
131 }
132 
PJRT_Client_Devices(PJRT_Client_Devices_Args * args)133 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args) {
134   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
135       "PJRT_Client_Devices_Args", PJRT_Client_Devices_Args_STRUCT_SIZE,
136       args->struct_size));
137   args->num_devices = args->client->devices.size();
138   args->devices = args->client->devices.data();
139   return nullptr;
140 }
141 
PJRT_Client_AddressableDevices(PJRT_Client_AddressableDevices_Args * args)142 PJRT_Error* PJRT_Client_AddressableDevices(
143     PJRT_Client_AddressableDevices_Args* args) {
144   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
145       "PJRT_Client_AddressableDevices_Args",
146       PJRT_Client_AddressableDevices_Args_STRUCT_SIZE, args->struct_size));
147   args->num_addressable_devices = args->client->addressable_devices.size();
148   args->addressable_devices = args->client->addressable_devices.data();
149   return nullptr;
150 }
151 
PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args * args)152 PJRT_Error* PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args* args) {
153   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
154       "PJRT_Client_LookupDevice_Args",
155       PJRT_Client_LookupDevice_Args_STRUCT_SIZE, args->struct_size));
156   PJRT_ASSIGN_OR_RETURN(xla::PjRtDevice * device,
157                         args->client->client->LookupDevice(args->id));
158   args->device = GetCDevice(args->client, device);
159   return nullptr;
160 }
161 
162 // Searches `device_list` for a PJRT_Device* that wraps a provided
163 // `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* is
164 // returned. Otherwise, returns nullptr.
FindDeviceWrapper(xla::PjRtDevice * cpp_device,absl::Span<PJRT_Device * const> device_list)165 static PJRT_Device* FindDeviceWrapper(
166     xla::PjRtDevice* cpp_device, absl::Span<PJRT_Device* const> device_list) {
167   for (PJRT_Device* device : device_list) {
168     if (device->device == cpp_device) {
169       return device;
170     }
171   }
172   return nullptr;
173 }
174 
PopulatePjrtExecutableAddressableDevices(PJRT_Executable * executable)175 static void PopulatePjrtExecutableAddressableDevices(
176     PJRT_Executable* executable) {
177   CHECK(executable->client != nullptr) << ": client was null";
178   absl::Span<xla::PjRtDevice* const> cpp_devices =
179       executable->executable->addressable_devices();
180   const size_t num_addressable_devices = cpp_devices.size();
181   std::vector<PJRT_Device*>& exec_devices = executable->addressable_devices;
182   exec_devices.reserve(num_addressable_devices);
183 
184   const std::vector<PJRT_Device*>& client_devices =
185       executable->client->addressable_devices;
186 
187   CHECK(client_devices.size() >= num_addressable_devices)
188       << ": client->addressable_devices is not bigger than "
189          "executable->addressable_devices()";
190 
191   for (int i = 0; i < num_addressable_devices; ++i) {
192     xla::PjRtDevice* cpp_device = cpp_devices[i];
193     PJRT_Device* device = FindDeviceWrapper(cpp_device, client_devices);
194     CHECK(device != nullptr)
195         << ": No PJRT_Device* found in client->addressable_devices"
196         << " that wraps executable->addressable_devices()[" << i << "] ("
197         << cpp_devices[i] << ")";
198     exec_devices.push_back(device);
199   }
200 }
201 
202 static xla::StatusOr<xla::CompileOptions>
ConvertCCompileOptionstoCppCompileOptions(PJRT_CompileOptions * c_option)203 ConvertCCompileOptionstoCppCompileOptions(PJRT_CompileOptions* c_option) {
204   xla::CompileOptions ret;
205   ret.parameter_is_tupled_arguments = c_option->parameter_is_tupled_arguments;
206   if (c_option->device_ordinal != -1) {
207     ret.executable_build_options.set_device_ordinal(c_option->device_ordinal);
208   }
209   ret.executable_build_options.set_num_replicas(c_option->num_replicas);
210   ret.executable_build_options.set_num_partitions(c_option->num_partitions);
211   ret.executable_build_options.set_use_spmd_partitioning(
212       c_option->use_spmd_partitioning);
213   ret.executable_build_options.set_allow_spmd_sharding_propagation_to_output(
214       c_option->allow_spmd_sharding_propagation_to_output);
215   if (c_option->device_assignment_size > 0) {
216     absl::string_view device_assignment_sv(c_option->device_assignment,
217                                            c_option->device_assignment_size);
218     std::string device_assignment_str(device_assignment_sv);
219     xla::DeviceAssignmentProto proto;
220     proto.ParseFromString(device_assignment_str);
221     TF_ASSIGN_OR_RETURN(
222         std::unique_ptr<xla::DeviceAssignment> device_assignment,
223         xla::DeviceAssignment::Deserialize(proto));
224     ret.executable_build_options.set_device_assignment(*device_assignment);
225   }
226   return ret;
227 }
228 
PJRT_Client_Compile(PJRT_Client_Compile_Args * args)229 PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args) {
230   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
231       "PJRT_Client_Compile_Args", PJRT_Client_Compile_Args_STRUCT_SIZE,
232       args->struct_size));
233   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_CompileOptions",
234                                                 PJRT_CompileOptions_STRUCT_SIZE,
235                                                 args->options->struct_size));
236   PJRT_ASSIGN_OR_RETURN(
237       xla::CompileOptions options,
238       ConvertCCompileOptionstoCppCompileOptions(args->options));
239   absl::string_view module_str(args->module, args->module_size);
240   mlir::MLIRContext context;
241   PJRT_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
242                         xla::ParseMlirModuleString(module_str, context));
243 
244   PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtLoadedExecutable> executable,
245                         args->client->client->Compile(*module, options));
246   // TODO(b/237545405): Implement creation methods for PJRT_Executable.
247   args->executable = new PJRT_Executable{std::move(executable), args->client};
248   PopulatePjrtExecutableAddressableDevices(args->executable);
249   args->executable->populated = true;
250   return nullptr;
251 }
252 
253 // --------------------------------- Devices -----------------------------------
254 
PJRT_Device_Id(PJRT_Device_Id_Args * args)255 PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args) {
256   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_Device_Id_Args",
257                                                 PJRT_Device_Id_Args_STRUCT_SIZE,
258                                                 args->struct_size));
259 
260   args->id = args->device->device->id();
261   return nullptr;
262 }
263 
PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args * args)264 PJRT_Error* PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args* args) {
265   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
266       "PJRT_Device_ProcessIndex_Args",
267       PJRT_Device_ProcessIndex_Args_STRUCT_SIZE, args->struct_size));
268   args->process_index = args->device->device->process_index();
269   return nullptr;
270 }
271 
PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args * args)272 PJRT_Error* PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args* args) {
273   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
274       "PJRT_Device_IsAddressable_Args",
275       PJRT_Device_IsAddressable_Args_STRUCT_SIZE, args->struct_size));
276   args->is_addressable = args->device->device->IsAddressable();
277   return nullptr;
278 }
279 
PJRT_Device_Attributes(PJRT_Device_Attributes_Args * args)280 PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args) {
281   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
282       "PJRT_Device_Attributes_Args", PJRT_Device_Attributes_Args_STRUCT_SIZE,
283       args->struct_size));
284 
285   // Returns the attributes that were initialized during PJRT_Device creation.
286   args->num_attributes = args->device->attributes.size();
287   args->attributes = args->device->attributes.data();
288 
289   return nullptr;
290 }
291 
PJRT_Device_Kind(PJRT_Device_Kind_Args * args)292 PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args) {
293   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
294       "PJRT_Device_Kind_Args", PJRT_Device_Kind_Args_STRUCT_SIZE,
295       args->struct_size));
296 
297   args->device_kind = args->device->device->device_kind().data();
298   args->device_kind_size = args->device->device->device_kind().size();
299   return nullptr;
300 }
301 
PJRT_Device_LocalHardwareId(PJRT_Device_LocalHardwareId_Args * args)302 PJRT_Error* PJRT_Device_LocalHardwareId(
303     PJRT_Device_LocalHardwareId_Args* args) {
304   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
305       "PJRT_Device_LocalHardwareId_Args",
306       PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, args->struct_size));
307   args->local_hardware_id = args->device->device->local_hardware_id();
308   return nullptr;
309 }
310 
PJRT_Device_DebugString(PJRT_Device_DebugString_Args * args)311 PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args) {
312   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
313       "PJRT_Device_DebugString_Args", PJRT_Device_DebugString_Args_STRUCT_SIZE,
314       args->struct_size));
315 
316   args->debug_string = args->device->device->DebugString().data();
317   args->debug_string_size = args->device->device->DebugString().size();
318   return nullptr;
319 }
320 
PJRT_Device_ToString(PJRT_Device_ToString_Args * args)321 PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args) {
322   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
323       "PJRT_Device_ToString_Args", PJRT_Device_ToString_Args_STRUCT_SIZE,
324       args->struct_size));
325   args->to_string = args->device->device->ToString().data();
326   args->to_string_size = args->device->device->ToString().size();
327   return nullptr;
328 }
329 
330 // ------------------------------- Executables ---------------------------------
331 
PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args * args)332 PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args) {
333   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
334       "PJRT_Executable_Destroy_Args", PJRT_Executable_Destroy_Args_STRUCT_SIZE,
335       args->struct_size));
336   delete args->executable;
337   return nullptr;
338 }
339 
PJRT_Executable_Name(PJRT_Executable_Name_Args * args)340 PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args) {
341   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
342       "PJRT_Executable_Name_Args", PJRT_Executable_Name_Args_STRUCT_SIZE,
343       args->struct_size));
344   absl::string_view executable_name = args->executable->executable->name();
345   args->executable_name = executable_name.data();
346   args->executable_name_size = executable_name.size();
347   return nullptr;
348 }
349 
PJRT_Executable_AddressableDevices(PJRT_Executable_AddressableDevices_Args * args)350 PJRT_Error* PJRT_Executable_AddressableDevices(
351     PJRT_Executable_AddressableDevices_Args* args) {
352   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
353       "PJRT_Executable_AddressableDevices_Args",
354       PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE, args->struct_size));
355 
356   // TODO(b/237545405): Implement creation methods for PJRT_Executable that can
357   // populate addressable_devices on instantiation,  and use this logic there
358   if (!args->executable->populated) {
359     PopulatePjrtExecutableAddressableDevices(args->executable);
360     args->executable->populated = true;
361   }
362 
363   args->num_addressable_devices = args->executable->addressable_devices.size();
364   args->addressable_devices = args->executable->addressable_devices.data();
365   return nullptr;
366 }
367 
PJRT_Executable_Delete(PJRT_Executable_Delete_Args * args)368 PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args) {
369   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
370       "PJRT_Executable_Delete_Args", PJRT_Executable_Delete_Args_STRUCT_SIZE,
371       args->struct_size));
372   args->executable->executable->Delete();
373   return nullptr;
374 }
375 
PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args * args)376 PJRT_Error* PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args* args) {
377   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
378       "PJRT_Executable_IsDeleted_Args",
379       PJRT_Executable_IsDeleted_Args_STRUCT_SIZE, args->struct_size));
380   args->is_deleted = args->executable->executable->IsDeleted();
381   return nullptr;
382 }
383 
Convert2DCBuffersToCppBuffers(PJRT_Buffer *** c_lists,size_t outer_size,size_t inner_size)384 static std::vector<std::vector<xla::PjRtBuffer*>> Convert2DCBuffersToCppBuffers(
385     PJRT_Buffer*** c_lists, size_t outer_size, size_t inner_size) {
386   std::vector<std::vector<xla::PjRtBuffer*>> cpp_lists;
387   cpp_lists.reserve(outer_size);
388   for (int i = 0; i < outer_size; ++i) {
389     auto& cpp_list = cpp_lists.emplace_back();
390     cpp_list.reserve(inner_size);
391     for (int j = 0; j < inner_size; ++j) {
392       cpp_list.push_back(c_lists[i][j]->buffer.get());
393     }
394   }
395   return cpp_lists;
396 }
397 
PJRT_Executable_Execute(PJRT_Executable_Execute_Args * args)398 PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args) {
399   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
400       "PJRT_Executable_Execute_Args", PJRT_Executable_Execute_Args_STRUCT_SIZE,
401       args->struct_size));
402   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_ExecuteOptions",
403                                                 PJRT_ExecuteOptions_STRUCT_SIZE,
404                                                 args->options->struct_size));
405   xla::ExecuteOptions options;
406   options.launch_id = args->options->launch_id;
407   options.strict_shape_checking = true;
408   options.arguments_are_tupled = false;
409   options.untuple_result = true;
410   options.context = nullptr;
411   options.multi_slice_config = nullptr;
412   std::vector<std::vector<xla::PjRtBuffer*>> cpp_argument_lists =
413       Convert2DCBuffersToCppBuffers(args->argument_lists, args->num_devices,
414                                     args->num_args);
415 
416   PJRT_ASSIGN_OR_RETURN(
417       std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>
418           cpp_buffer_lists,
419       args->executable->executable->Execute(cpp_argument_lists, options));
420 
421   for (int i = 0; i < cpp_buffer_lists.size(); ++i) {
422     for (int j = 0; j < cpp_buffer_lists[i].size(); ++j) {
423       args->output_lists[i][j] = new PJRT_Buffer{
424           std::move(cpp_buffer_lists[i][j]), args->executable->client};
425     }
426   }
427   return nullptr;
428 }
429 
PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args * args)430 PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args) {
431   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
432       "PJRT_Executable_NumOutputs_Args",
433       PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size));
434   PJRT_ASSIGN_OR_RETURN(
435       std::vector<std::shared_ptr<xla::HloModule>> hlo_modules,
436       args->executable->executable->GetHloModules());
437   if (hlo_modules.empty()) {
438     return new PJRT_Error{
439         xla::InvalidArgument("Can't get number of executable outputs, Hlo "
440                              "modules is empty for executable %s.",
441                              args->executable->executable->name())};
442   }
443   if (hlo_modules.size() != 1) {
444     return new PJRT_Error{
445         xla::Unimplemented("MPMD execution not supported by PJRT C API (in "
446                            "function PJRT_Executable_NumOutputs).")};
447   }
448   xla::Shape shape = hlo_modules[0].get()->result_shape();
449   if (shape.IsTuple()) {
450     args->num_outputs = shape.tuple_shapes_size();
451   } else {
452     // The output size is 1 is it is not a tuple.
453     args->num_outputs = 1;
454   }
455   return nullptr;
456 }
457 
458 // ---------------------------------- Buffers ----------------------------------
459 // TODO(b/238999986): Replace this with decomposed shape methods.
PJRT_Buffer_OnDeviceTrimmedShape(PJRT_Buffer_OnDeviceTrimmedShape_Args * args)460 PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
461     PJRT_Buffer_OnDeviceTrimmedShape_Args* args) {
462   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
463       "PJRT_Buffer_OnDeviceTrimmedShape_Args",
464       PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE, args->struct_size));
465 
466   const xla::Shape& shape = args->buffer->buffer->on_device_shape();
467   args->element_type = shape.element_type();
468   ApiConverter::CreateVector(shape.dimensions(), &args->dimensions);
469   ApiConverter::CreateVector(shape.dynamic_dimensions(),
470                              &args->dynamic_dimensions);
471 
472   if (shape.has_layout()) {
473     args->has_layout = true;
474     ApiConverter::ToC(shape.layout(), &args->layout);
475   } else {
476     args->has_layout = false;
477   }
478 
479   return nullptr;
480 }
481 
PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args * args)482 PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args) {
483   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
484       "PJRT_Buffer_Destroy_Args", PJRT_Buffer_Destroy_Args_STRUCT_SIZE,
485       args->struct_size));
486   delete args->buffer;
487   return nullptr;
488 }
489 
PJRT_Buffer_OnDeviceSizeInBytes(PJRT_Buffer_OnDeviceSizeInBytes_Args * args)490 PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
491     PJRT_Buffer_OnDeviceSizeInBytes_Args* args) {
492   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
493       "PJRT_Buffer_OnDeviceSizeInBytes_Args",
494       PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE, args->struct_size));
495   PJRT_ASSIGN_OR_RETURN(args->on_device_size_in_bytes,
496                         args->buffer->buffer->GetOnDeviceSizeInBytes());
497   return nullptr;
498 }
499 
PJRT_Buffer_Device(PJRT_Buffer_Device_Args * args)500 PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args) {
501   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
502       "PJRT_Buffer_Device_Args", PJRT_Buffer_Device_Args_STRUCT_SIZE,
503       args->struct_size));
504   args->device = FindDeviceWrapper(args->buffer->buffer->device(),
505                                    args->buffer->client->addressable_devices);
506   CHECK(args->device != nullptr)
507       << "No PJRT_Device* found in the client's `addressable_devices` that "
508          "wraps this "
509       << args->buffer->buffer->device()->DebugString();
510   return nullptr;
511 }
512 
PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args * args)513 PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args) {
514   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
515       "PJRT_Buffer_Delete_Args", PJRT_Buffer_Delete_Args_STRUCT_SIZE,
516       args->struct_size));
517   args->buffer->buffer->Delete();
518   return nullptr;
519 }
520 
PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args * args)521 PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args) {
522   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
523       "PJRT_Buffer_IsDeleted_Args", PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE,
524       args->struct_size));
525   args->is_deleted = args->buffer->buffer->IsDeleted();
526   return nullptr;
527 }
528 
PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args * args)529 PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args) {
530   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
531       "PJRT_Buffer_CopyToDevice_Args",
532       PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE, args->struct_size));
533   PJRT_ASSIGN_OR_RETURN(
534       std::unique_ptr<xla::PjRtBuffer> dst_buffer,
535       args->buffer->buffer->CopyToDevice(args->dst_device->device));
536   args->dst_buffer =
537       new PJRT_Buffer{std::move(dst_buffer), args->buffer->client};
538   return nullptr;
539 }
540 
PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args * args)541 PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args) {
542   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
543       "PJRT_Buffer_IsOnCpu_Args", PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE,
544       args->struct_size));
545   args->is_on_cpu = args->buffer->buffer->IsOnCpu();
546   return nullptr;
547 }
548 
PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args * args)549 PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args) {
550   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
551       "PJRT_Buffer_ReadyEvent_Args", PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE,
552       args->struct_size));
553   xla::PjRtFuture<xla::Status> wrapped_promise =
554       args->buffer->buffer->GetReadyFuture();
555   args->event = new PJRT_Event{std::move(wrapped_promise)};
556   return nullptr;
557 }
558 
559 // -------------------------------- Events -------------------------------------
560 
PJRT_Event_Destroy(PJRT_Event_Destroy_Args * args)561 PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args) {
562   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
563       "PJRT_Event_Destroy", PJRT_Event_Destroy_Args_STRUCT_SIZE,
564       args->struct_size));
565 
566   delete args->event;
567   return nullptr;
568 }
569 
PJRT_Event_IsReady(PJRT_Event_IsReady_Args * args)570 PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args) {
571   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
572       "PJRT_Event_IsReady", PJRT_Event_IsReady_Args_STRUCT_SIZE,
573       args->struct_size));
574 
575   args->is_ready = args->event->future.IsReady();
576   return nullptr;
577 }
578 
PJRT_Event_Await(PJRT_Event_Await_Args * args)579 PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args) {
580   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
581       "PJRT_Event_Await", PJRT_Event_Await_Args_STRUCT_SIZE,
582       args->struct_size));
583 
584   PJRT_Event* event = args->event;
585   event->status.emplace(event->future.Await());
586   PJRT_RETURN_IF_ERROR(event->status.value());
587   return nullptr;
588 }
589 
PJRT_Event_Error(PJRT_Event_Error_Args * args)590 PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args) {
591   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
592       "PJRT_Event_Error", PJRT_Event_Error_Args_STRUCT_SIZE,
593       args->struct_size));
594 
595   PJRT_Event* event = args->event;
596   CHECK(event->future.IsReady());
597   if (!event->status.has_value()) {
598     PJRT_Event_Await_Args await_args;
599     await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
600     await_args.priv = nullptr;
601     await_args.event = event;
602     return PJRT_Event_Await(&await_args);
603   }
604   PJRT_RETURN_IF_ERROR(event->status.value());
605   return nullptr;
606 }
607 
PJRT_Event_OnReady(PJRT_Event_OnReady_Args * args)608 PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args) {
609   PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
610       "PJRT_Event_OnReady", PJRT_Event_OnReady_Args_STRUCT_SIZE,
611       args->struct_size));
612 
613   PJRT_Event_OnReadyCallback callback = args->callback;
614   void* user_arg = args->user_arg;
615   auto impl_callback = [callback, user_arg](xla::Status status) -> void {
616     PJRT_Error* error = new PJRT_Error{status};
617     callback(error, user_arg);
618   };
619   args->event->future.OnReady(impl_callback);
620   return nullptr;
621 }
622 
623 }  // namespace pjrt
624