1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <variant>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
25 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 // TODO(b/238999986): Remove this.
29 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
30
31 namespace pjrt {
32
CheckMatchingStructSizes(absl::string_view struct_name,size_t expected_size,size_t actual_size)33 xla::Status CheckMatchingStructSizes(absl::string_view struct_name,
34 size_t expected_size, size_t actual_size) {
35 if (expected_size != actual_size) {
36 return tensorflow::errors::InvalidArgument(
37 StructSizeErrorMsg(struct_name, expected_size, actual_size));
38 }
39 return tensorflow::OkStatus();
40 }
41
StructSizeErrorMsg(absl::string_view struct_name,size_t expected_size,size_t actual_size)42 std::string StructSizeErrorMsg(absl::string_view struct_name,
43 size_t expected_size, size_t actual_size) {
44 return absl::StrCat("Unexpected ", struct_name, " size: expected ",
45 expected_size, ", got ", actual_size,
46 ". Check installed software versions.");
47 }
48
49 // Returns C device from wrapped C++ device.
GetCDevice(const PJRT_Client * client,const xla::PjRtDevice * device)50 static PJRT_Device* GetCDevice(const PJRT_Client* client,
51 const xla::PjRtDevice* device) {
52 auto c_device_map = client->c_device_from_cpp_device;
53 auto iter = c_device_map.find(device);
54 CHECK(iter != c_device_map.end());
55 return iter->second;
56 }
57
58 // ---------------------------------- Errors -----------------------------------
59
PJRT_Error_Destroy(PJRT_Error_Destroy_Args * args)60 void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args) {
61 xla::Status struct_size_check = CheckMatchingStructSizes(
62 "PJRT_Error_Destroy_Args", PJRT_Error_Destroy_Args_STRUCT_SIZE,
63 args->struct_size);
64 if (!struct_size_check.ok()) {
65 LOG(ERROR) << struct_size_check.error_message();
66 }
67 if (args->struct_size >= PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error)) {
68 delete args->error;
69 }
70 }
71
PJRT_Error_Message(PJRT_Error_Message_Args * args)72 void PJRT_Error_Message(PJRT_Error_Message_Args* args) {
73 xla::Status struct_size_check = CheckMatchingStructSizes(
74 "PJRT_Error_Message_Args", PJRT_Error_Message_Args_STRUCT_SIZE,
75 args->struct_size);
76 if (!struct_size_check.ok()) {
77 LOG(ERROR) << struct_size_check.error_message();
78 }
79 if (args->struct_size >= PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error)) {
80 const xla::Status* status = &args->error->status;
81 args->message = status->error_message().data();
82 args->message_size = status->error_message().size();
83 }
84 }
85
PJRT_Error_GetCode(PJRT_Error_GetCode_Args * args)86 PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) {
87 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
88 "PJRT_Error_GetCode_Args", PJRT_Error_GetCode_Args_STRUCT_SIZE,
89 args->struct_size));
90 args->code = StatusCodeToPjrtErrorCode(args->error->status.code());
91 return nullptr;
92 }
93
94 // ---------------------------------- Client -----------------------------------
95
PJRT_Client_Destroy(PJRT_Client_Destroy_Args * args)96 PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args) {
97 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
98 "PJRT_Client_Destroy_Args", PJRT_Client_Destroy_Args_STRUCT_SIZE,
99 args->struct_size));
100 delete args->client;
101 return nullptr;
102 }
103
PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args * args)104 PJRT_Error* PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args* args) {
105 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
106 "PJRT_CLient_ProcessIndex_Args",
107 PJRT_Client_ProcessIndex_Args_STRUCT_SIZE, args->struct_size));
108 args->process_index = args->client->client->process_index();
109 return nullptr;
110 }
111
PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args * args)112 PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args) {
113 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
114 "PJRT_Client_PlatformName_Args",
115 PJRT_Client_PlatformName_Args_STRUCT_SIZE, args->struct_size));
116 absl::string_view platform_name = args->client->client->platform_name();
117 args->platform_name = platform_name.data();
118 args->platform_name_size = platform_name.size();
119 return nullptr;
120 }
121
PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args * args)122 PJRT_Error* PJRT_Client_PlatformVersion(
123 PJRT_Client_PlatformVersion_Args* args) {
124 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
125 "PJRT_CLient_PlatformVersion_Args",
126 PJRT_Client_PlatformVersion_Args_STRUCT_SIZE, args->struct_size));
127 absl::string_view platform_version = args->client->client->platform_version();
128 args->platform_version = platform_version.data();
129 args->platform_version_size = platform_version.size();
130 return nullptr;
131 }
132
PJRT_Client_Devices(PJRT_Client_Devices_Args * args)133 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args) {
134 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
135 "PJRT_Client_Devices_Args", PJRT_Client_Devices_Args_STRUCT_SIZE,
136 args->struct_size));
137 args->num_devices = args->client->devices.size();
138 args->devices = args->client->devices.data();
139 return nullptr;
140 }
141
PJRT_Client_AddressableDevices(PJRT_Client_AddressableDevices_Args * args)142 PJRT_Error* PJRT_Client_AddressableDevices(
143 PJRT_Client_AddressableDevices_Args* args) {
144 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
145 "PJRT_Client_AddressableDevices_Args",
146 PJRT_Client_AddressableDevices_Args_STRUCT_SIZE, args->struct_size));
147 args->num_addressable_devices = args->client->addressable_devices.size();
148 args->addressable_devices = args->client->addressable_devices.data();
149 return nullptr;
150 }
151
PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args * args)152 PJRT_Error* PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args* args) {
153 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
154 "PJRT_Client_LookupDevice_Args",
155 PJRT_Client_LookupDevice_Args_STRUCT_SIZE, args->struct_size));
156 PJRT_ASSIGN_OR_RETURN(xla::PjRtDevice * device,
157 args->client->client->LookupDevice(args->id));
158 args->device = GetCDevice(args->client, device);
159 return nullptr;
160 }
161
162 // Searches `device_list` for a PJRT_Device* that wraps a provided
163 // `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* is
164 // returned. Otherwise, returns nullptr.
FindDeviceWrapper(xla::PjRtDevice * cpp_device,absl::Span<PJRT_Device * const> device_list)165 static PJRT_Device* FindDeviceWrapper(
166 xla::PjRtDevice* cpp_device, absl::Span<PJRT_Device* const> device_list) {
167 for (PJRT_Device* device : device_list) {
168 if (device->device == cpp_device) {
169 return device;
170 }
171 }
172 return nullptr;
173 }
174
PopulatePjrtExecutableAddressableDevices(PJRT_Executable * executable)175 static void PopulatePjrtExecutableAddressableDevices(
176 PJRT_Executable* executable) {
177 CHECK(executable->client != nullptr) << ": client was null";
178 absl::Span<xla::PjRtDevice* const> cpp_devices =
179 executable->executable->addressable_devices();
180 const size_t num_addressable_devices = cpp_devices.size();
181 std::vector<PJRT_Device*>& exec_devices = executable->addressable_devices;
182 exec_devices.reserve(num_addressable_devices);
183
184 const std::vector<PJRT_Device*>& client_devices =
185 executable->client->addressable_devices;
186
187 CHECK(client_devices.size() >= num_addressable_devices)
188 << ": client->addressable_devices is not bigger than "
189 "executable->addressable_devices()";
190
191 for (int i = 0; i < num_addressable_devices; ++i) {
192 xla::PjRtDevice* cpp_device = cpp_devices[i];
193 PJRT_Device* device = FindDeviceWrapper(cpp_device, client_devices);
194 CHECK(device != nullptr)
195 << ": No PJRT_Device* found in client->addressable_devices"
196 << " that wraps executable->addressable_devices()[" << i << "] ("
197 << cpp_devices[i] << ")";
198 exec_devices.push_back(device);
199 }
200 }
201
202 static xla::StatusOr<xla::CompileOptions>
ConvertCCompileOptionstoCppCompileOptions(PJRT_CompileOptions * c_option)203 ConvertCCompileOptionstoCppCompileOptions(PJRT_CompileOptions* c_option) {
204 xla::CompileOptions ret;
205 ret.parameter_is_tupled_arguments = c_option->parameter_is_tupled_arguments;
206 if (c_option->device_ordinal != -1) {
207 ret.executable_build_options.set_device_ordinal(c_option->device_ordinal);
208 }
209 ret.executable_build_options.set_num_replicas(c_option->num_replicas);
210 ret.executable_build_options.set_num_partitions(c_option->num_partitions);
211 ret.executable_build_options.set_use_spmd_partitioning(
212 c_option->use_spmd_partitioning);
213 ret.executable_build_options.set_allow_spmd_sharding_propagation_to_output(
214 c_option->allow_spmd_sharding_propagation_to_output);
215 if (c_option->device_assignment_size > 0) {
216 absl::string_view device_assignment_sv(c_option->device_assignment,
217 c_option->device_assignment_size);
218 std::string device_assignment_str(device_assignment_sv);
219 xla::DeviceAssignmentProto proto;
220 proto.ParseFromString(device_assignment_str);
221 TF_ASSIGN_OR_RETURN(
222 std::unique_ptr<xla::DeviceAssignment> device_assignment,
223 xla::DeviceAssignment::Deserialize(proto));
224 ret.executable_build_options.set_device_assignment(*device_assignment);
225 }
226 return ret;
227 }
228
PJRT_Client_Compile(PJRT_Client_Compile_Args * args)229 PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args) {
230 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
231 "PJRT_Client_Compile_Args", PJRT_Client_Compile_Args_STRUCT_SIZE,
232 args->struct_size));
233 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_CompileOptions",
234 PJRT_CompileOptions_STRUCT_SIZE,
235 args->options->struct_size));
236 PJRT_ASSIGN_OR_RETURN(
237 xla::CompileOptions options,
238 ConvertCCompileOptionstoCppCompileOptions(args->options));
239 absl::string_view module_str(args->module, args->module_size);
240 mlir::MLIRContext context;
241 PJRT_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
242 xla::ParseMlirModuleString(module_str, context));
243
244 PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtLoadedExecutable> executable,
245 args->client->client->Compile(*module, options));
246 // TODO(b/237545405): Implement creation methods for PJRT_Executable.
247 args->executable = new PJRT_Executable{std::move(executable), args->client};
248 PopulatePjrtExecutableAddressableDevices(args->executable);
249 args->executable->populated = true;
250 return nullptr;
251 }
252
253 // --------------------------------- Devices -----------------------------------
254
PJRT_Device_Id(PJRT_Device_Id_Args * args)255 PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args) {
256 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_Device_Id_Args",
257 PJRT_Device_Id_Args_STRUCT_SIZE,
258 args->struct_size));
259
260 args->id = args->device->device->id();
261 return nullptr;
262 }
263
PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args * args)264 PJRT_Error* PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args* args) {
265 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
266 "PJRT_Device_ProcessIndex_Args",
267 PJRT_Device_ProcessIndex_Args_STRUCT_SIZE, args->struct_size));
268 args->process_index = args->device->device->process_index();
269 return nullptr;
270 }
271
PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args * args)272 PJRT_Error* PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args* args) {
273 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
274 "PJRT_Device_IsAddressable_Args",
275 PJRT_Device_IsAddressable_Args_STRUCT_SIZE, args->struct_size));
276 args->is_addressable = args->device->device->IsAddressable();
277 return nullptr;
278 }
279
PJRT_Device_Attributes(PJRT_Device_Attributes_Args * args)280 PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args) {
281 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
282 "PJRT_Device_Attributes_Args", PJRT_Device_Attributes_Args_STRUCT_SIZE,
283 args->struct_size));
284
285 // Returns the attributes that were initialized during PJRT_Device creation.
286 args->num_attributes = args->device->attributes.size();
287 args->attributes = args->device->attributes.data();
288
289 return nullptr;
290 }
291
PJRT_Device_Kind(PJRT_Device_Kind_Args * args)292 PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args) {
293 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
294 "PJRT_Device_Kind_Args", PJRT_Device_Kind_Args_STRUCT_SIZE,
295 args->struct_size));
296
297 args->device_kind = args->device->device->device_kind().data();
298 args->device_kind_size = args->device->device->device_kind().size();
299 return nullptr;
300 }
301
PJRT_Device_LocalHardwareId(PJRT_Device_LocalHardwareId_Args * args)302 PJRT_Error* PJRT_Device_LocalHardwareId(
303 PJRT_Device_LocalHardwareId_Args* args) {
304 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
305 "PJRT_Device_LocalHardwareId_Args",
306 PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, args->struct_size));
307 args->local_hardware_id = args->device->device->local_hardware_id();
308 return nullptr;
309 }
310
PJRT_Device_DebugString(PJRT_Device_DebugString_Args * args)311 PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args) {
312 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
313 "PJRT_Device_DebugString_Args", PJRT_Device_DebugString_Args_STRUCT_SIZE,
314 args->struct_size));
315
316 args->debug_string = args->device->device->DebugString().data();
317 args->debug_string_size = args->device->device->DebugString().size();
318 return nullptr;
319 }
320
PJRT_Device_ToString(PJRT_Device_ToString_Args * args)321 PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args) {
322 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
323 "PJRT_Device_ToString_Args", PJRT_Device_ToString_Args_STRUCT_SIZE,
324 args->struct_size));
325 args->to_string = args->device->device->ToString().data();
326 args->to_string_size = args->device->device->ToString().size();
327 return nullptr;
328 }
329
330 // ------------------------------- Executables ---------------------------------
331
PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args * args)332 PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args) {
333 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
334 "PJRT_Executable_Destroy_Args", PJRT_Executable_Destroy_Args_STRUCT_SIZE,
335 args->struct_size));
336 delete args->executable;
337 return nullptr;
338 }
339
PJRT_Executable_Name(PJRT_Executable_Name_Args * args)340 PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args) {
341 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
342 "PJRT_Executable_Name_Args", PJRT_Executable_Name_Args_STRUCT_SIZE,
343 args->struct_size));
344 absl::string_view executable_name = args->executable->executable->name();
345 args->executable_name = executable_name.data();
346 args->executable_name_size = executable_name.size();
347 return nullptr;
348 }
349
PJRT_Executable_AddressableDevices(PJRT_Executable_AddressableDevices_Args * args)350 PJRT_Error* PJRT_Executable_AddressableDevices(
351 PJRT_Executable_AddressableDevices_Args* args) {
352 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
353 "PJRT_Executable_AddressableDevices_Args",
354 PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE, args->struct_size));
355
356 // TODO(b/237545405): Implement creation methods for PJRT_Executable that can
357 // populate addressable_devices on instantiation, and use this logic there
358 if (!args->executable->populated) {
359 PopulatePjrtExecutableAddressableDevices(args->executable);
360 args->executable->populated = true;
361 }
362
363 args->num_addressable_devices = args->executable->addressable_devices.size();
364 args->addressable_devices = args->executable->addressable_devices.data();
365 return nullptr;
366 }
367
PJRT_Executable_Delete(PJRT_Executable_Delete_Args * args)368 PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args) {
369 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
370 "PJRT_Executable_Delete_Args", PJRT_Executable_Delete_Args_STRUCT_SIZE,
371 args->struct_size));
372 args->executable->executable->Delete();
373 return nullptr;
374 }
375
PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args * args)376 PJRT_Error* PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args* args) {
377 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
378 "PJRT_Executable_IsDeleted_Args",
379 PJRT_Executable_IsDeleted_Args_STRUCT_SIZE, args->struct_size));
380 args->is_deleted = args->executable->executable->IsDeleted();
381 return nullptr;
382 }
383
Convert2DCBuffersToCppBuffers(PJRT_Buffer *** c_lists,size_t outer_size,size_t inner_size)384 static std::vector<std::vector<xla::PjRtBuffer*>> Convert2DCBuffersToCppBuffers(
385 PJRT_Buffer*** c_lists, size_t outer_size, size_t inner_size) {
386 std::vector<std::vector<xla::PjRtBuffer*>> cpp_lists;
387 cpp_lists.reserve(outer_size);
388 for (int i = 0; i < outer_size; ++i) {
389 auto& cpp_list = cpp_lists.emplace_back();
390 cpp_list.reserve(inner_size);
391 for (int j = 0; j < inner_size; ++j) {
392 cpp_list.push_back(c_lists[i][j]->buffer.get());
393 }
394 }
395 return cpp_lists;
396 }
397
PJRT_Executable_Execute(PJRT_Executable_Execute_Args * args)398 PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args) {
399 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
400 "PJRT_Executable_Execute_Args", PJRT_Executable_Execute_Args_STRUCT_SIZE,
401 args->struct_size));
402 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes("PJRT_ExecuteOptions",
403 PJRT_ExecuteOptions_STRUCT_SIZE,
404 args->options->struct_size));
405 xla::ExecuteOptions options;
406 options.launch_id = args->options->launch_id;
407 options.strict_shape_checking = true;
408 options.arguments_are_tupled = false;
409 options.untuple_result = true;
410 options.context = nullptr;
411 options.multi_slice_config = nullptr;
412 std::vector<std::vector<xla::PjRtBuffer*>> cpp_argument_lists =
413 Convert2DCBuffersToCppBuffers(args->argument_lists, args->num_devices,
414 args->num_args);
415
416 PJRT_ASSIGN_OR_RETURN(
417 std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>
418 cpp_buffer_lists,
419 args->executable->executable->Execute(cpp_argument_lists, options));
420
421 for (int i = 0; i < cpp_buffer_lists.size(); ++i) {
422 for (int j = 0; j < cpp_buffer_lists[i].size(); ++j) {
423 args->output_lists[i][j] = new PJRT_Buffer{
424 std::move(cpp_buffer_lists[i][j]), args->executable->client};
425 }
426 }
427 return nullptr;
428 }
429
PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args * args)430 PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args) {
431 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
432 "PJRT_Executable_NumOutputs_Args",
433 PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size));
434 PJRT_ASSIGN_OR_RETURN(
435 std::vector<std::shared_ptr<xla::HloModule>> hlo_modules,
436 args->executable->executable->GetHloModules());
437 if (hlo_modules.empty()) {
438 return new PJRT_Error{
439 xla::InvalidArgument("Can't get number of executable outputs, Hlo "
440 "modules is empty for executable %s.",
441 args->executable->executable->name())};
442 }
443 if (hlo_modules.size() != 1) {
444 return new PJRT_Error{
445 xla::Unimplemented("MPMD execution not supported by PJRT C API (in "
446 "function PJRT_Executable_NumOutputs).")};
447 }
448 xla::Shape shape = hlo_modules[0].get()->result_shape();
449 if (shape.IsTuple()) {
450 args->num_outputs = shape.tuple_shapes_size();
451 } else {
452 // The output size is 1 is it is not a tuple.
453 args->num_outputs = 1;
454 }
455 return nullptr;
456 }
457
458 // ---------------------------------- Buffers ----------------------------------
459 // TODO(b/238999986): Replace this with decomposed shape methods.
PJRT_Buffer_OnDeviceTrimmedShape(PJRT_Buffer_OnDeviceTrimmedShape_Args * args)460 PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
461 PJRT_Buffer_OnDeviceTrimmedShape_Args* args) {
462 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
463 "PJRT_Buffer_OnDeviceTrimmedShape_Args",
464 PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE, args->struct_size));
465
466 const xla::Shape& shape = args->buffer->buffer->on_device_shape();
467 args->element_type = shape.element_type();
468 ApiConverter::CreateVector(shape.dimensions(), &args->dimensions);
469 ApiConverter::CreateVector(shape.dynamic_dimensions(),
470 &args->dynamic_dimensions);
471
472 if (shape.has_layout()) {
473 args->has_layout = true;
474 ApiConverter::ToC(shape.layout(), &args->layout);
475 } else {
476 args->has_layout = false;
477 }
478
479 return nullptr;
480 }
481
PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args * args)482 PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args) {
483 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
484 "PJRT_Buffer_Destroy_Args", PJRT_Buffer_Destroy_Args_STRUCT_SIZE,
485 args->struct_size));
486 delete args->buffer;
487 return nullptr;
488 }
489
PJRT_Buffer_OnDeviceSizeInBytes(PJRT_Buffer_OnDeviceSizeInBytes_Args * args)490 PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
491 PJRT_Buffer_OnDeviceSizeInBytes_Args* args) {
492 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
493 "PJRT_Buffer_OnDeviceSizeInBytes_Args",
494 PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE, args->struct_size));
495 PJRT_ASSIGN_OR_RETURN(args->on_device_size_in_bytes,
496 args->buffer->buffer->GetOnDeviceSizeInBytes());
497 return nullptr;
498 }
499
PJRT_Buffer_Device(PJRT_Buffer_Device_Args * args)500 PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args) {
501 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
502 "PJRT_Buffer_Device_Args", PJRT_Buffer_Device_Args_STRUCT_SIZE,
503 args->struct_size));
504 args->device = FindDeviceWrapper(args->buffer->buffer->device(),
505 args->buffer->client->addressable_devices);
506 CHECK(args->device != nullptr)
507 << "No PJRT_Device* found in the client's `addressable_devices` that "
508 "wraps this "
509 << args->buffer->buffer->device()->DebugString();
510 return nullptr;
511 }
512
PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args * args)513 PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args) {
514 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
515 "PJRT_Buffer_Delete_Args", PJRT_Buffer_Delete_Args_STRUCT_SIZE,
516 args->struct_size));
517 args->buffer->buffer->Delete();
518 return nullptr;
519 }
520
PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args * args)521 PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args) {
522 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
523 "PJRT_Buffer_IsDeleted_Args", PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE,
524 args->struct_size));
525 args->is_deleted = args->buffer->buffer->IsDeleted();
526 return nullptr;
527 }
528
PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args * args)529 PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args) {
530 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
531 "PJRT_Buffer_CopyToDevice_Args",
532 PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE, args->struct_size));
533 PJRT_ASSIGN_OR_RETURN(
534 std::unique_ptr<xla::PjRtBuffer> dst_buffer,
535 args->buffer->buffer->CopyToDevice(args->dst_device->device));
536 args->dst_buffer =
537 new PJRT_Buffer{std::move(dst_buffer), args->buffer->client};
538 return nullptr;
539 }
540
PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args * args)541 PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args) {
542 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
543 "PJRT_Buffer_IsOnCpu_Args", PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE,
544 args->struct_size));
545 args->is_on_cpu = args->buffer->buffer->IsOnCpu();
546 return nullptr;
547 }
548
PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args * args)549 PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args) {
550 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
551 "PJRT_Buffer_ReadyEvent_Args", PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE,
552 args->struct_size));
553 xla::PjRtFuture<xla::Status> wrapped_promise =
554 args->buffer->buffer->GetReadyFuture();
555 args->event = new PJRT_Event{std::move(wrapped_promise)};
556 return nullptr;
557 }
558
559 // -------------------------------- Events -------------------------------------
560
PJRT_Event_Destroy(PJRT_Event_Destroy_Args * args)561 PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args) {
562 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
563 "PJRT_Event_Destroy", PJRT_Event_Destroy_Args_STRUCT_SIZE,
564 args->struct_size));
565
566 delete args->event;
567 return nullptr;
568 }
569
PJRT_Event_IsReady(PJRT_Event_IsReady_Args * args)570 PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args) {
571 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
572 "PJRT_Event_IsReady", PJRT_Event_IsReady_Args_STRUCT_SIZE,
573 args->struct_size));
574
575 args->is_ready = args->event->future.IsReady();
576 return nullptr;
577 }
578
PJRT_Event_Await(PJRT_Event_Await_Args * args)579 PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args) {
580 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
581 "PJRT_Event_Await", PJRT_Event_Await_Args_STRUCT_SIZE,
582 args->struct_size));
583
584 PJRT_Event* event = args->event;
585 event->status.emplace(event->future.Await());
586 PJRT_RETURN_IF_ERROR(event->status.value());
587 return nullptr;
588 }
589
PJRT_Event_Error(PJRT_Event_Error_Args * args)590 PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args) {
591 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
592 "PJRT_Event_Error", PJRT_Event_Error_Args_STRUCT_SIZE,
593 args->struct_size));
594
595 PJRT_Event* event = args->event;
596 CHECK(event->future.IsReady());
597 if (!event->status.has_value()) {
598 PJRT_Event_Await_Args await_args;
599 await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
600 await_args.priv = nullptr;
601 await_args.event = event;
602 return PJRT_Event_Await(&await_args);
603 }
604 PJRT_RETURN_IF_ERROR(event->status.value());
605 return nullptr;
606 }
607
PJRT_Event_OnReady(PJRT_Event_OnReady_Args * args)608 PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args) {
609 PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
610 "PJRT_Event_OnReady", PJRT_Event_OnReady_Args_STRUCT_SIZE,
611 args->struct_size));
612
613 PJRT_Event_OnReadyCallback callback = args->callback;
614 void* user_arg = args->user_arg;
615 auto impl_callback = [callback, user_arg](xla::Status status) -> void {
616 PJRT_Error* error = new PJRT_Error{status};
617 callback(error, user_arg);
618 };
619 args->event->future.OnReady(impl_callback);
620 return nullptr;
621 }
622
623 } // namespace pjrt
624