xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
25 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
26 // TODO(skyewm): remove when everything goes through C API
27 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
28 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/tpu/pjrt_api.h"
33 
34 // TODO(b/238999986): Remove this when we have decomposed shape.
35 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
36 
37 namespace xla {
38 
39 // Helper macros
40 
41 // Return error status if not success and frees the PJRT_Error returned by
42 // `expr`.
43 #define RETURN_STATUS_IF_ERROR(expr, c_api)                             \
44   do {                                                                  \
45     PJRT_Error* error = (expr);                                         \
46     std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> _error(        \
47         error, pjrt::MakeErrorDeleter(c_api));                          \
48     xla::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \
49     if (!_status.ok()) {                                                \
50       return _status;                                                   \
51     }                                                                   \
52   } while (false)
53 
54 // ---------------------------------- Client -----------------------------------
55 
PjRtCApiClient(const PJRT_Api * c_api,PJRT_Client * c_client)56 PjRtCApiClient::PjRtCApiClient(const PJRT_Api* c_api, PJRT_Client* c_client)
57     : c_api_(c_api),
58       c_client_(std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter>(
59           c_client, ::pjrt::MakeClientDeleter(c_api))) {
60   wrapped_ = c_client_->client.get();
61 
62   InitDevices();
63 }
64 
InitDevices()65 void PjRtCApiClient::InitDevices() {
66   PJRT_Client_Devices_Args devices_args;
67   devices_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE;
68   devices_args.priv = nullptr;
69   devices_args.client = c_client_.get();
70 
71   pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_Devices(&devices_args), c_api_);
72 
73   const size_t n = devices_args.num_devices;
74   wrapped_device_map_.reserve(n);
75   c_to_cpp_device_map_.reserve(n);
76   owned_devices_.reserve(n);
77   devices_.reserve(n);
78 
79   for (size_t i = 0; i < n; ++i) {
80     PJRT_Device* device = devices_args.devices[i];
81     std::unique_ptr<PjRtCApiDevice>& cpp_device = owned_devices_.emplace_back(
82         std::make_unique<PjRtCApiDevice>(device, this));
83     devices_.push_back(cpp_device.get());
84     c_to_cpp_device_map_[device] = cpp_device.get();
85     // Map the wrapped PjRtDevice* to the PjRtCApiDevice* that wraps it.
86     // TODO(b/237017893): remove `wrapped_device_map_` and replace it with
87     // `c_api_device_map_`
88     wrapped_device_map_[device->device] = cpp_device.get();
89   }
90 
91   PJRT_Client_AddressableDevices_Args address_args;
92   address_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE;
93   address_args.priv = nullptr;
94   address_args.client = c_client_.get();
95 
96   pjrt::LogFatalIfPjrtError(
97       c_api_->PJRT_Client_AddressableDevices(&address_args), c_api_);
98 
99   const size_t m = address_args.num_addressable_devices;
100   addressable_devices_.reserve(m);
101 
102   for (size_t i = 0; i < m; ++i) {
103     PJRT_Device* c_device = address_args.addressable_devices[i];
104     addressable_devices_.push_back(GetCppDevice(c_device));
105   }
106 }
107 
device_count() const108 int PjRtCApiClient::device_count() const { return devices_.size(); }
109 
addressable_device_count() const110 int PjRtCApiClient::addressable_device_count() const {
111   return addressable_devices_.size();
112 }
113 
devices() const114 absl::Span<PjRtDevice* const> PjRtCApiClient::devices() const {
115   return devices_;
116 }
117 
addressable_devices() const118 absl::Span<PjRtDevice* const> PjRtCApiClient::addressable_devices() const {
119   return addressable_devices_;
120 }
121 
platform_name() const122 absl::string_view PjRtCApiClient::platform_name() const {
123   PJRT_Client_PlatformName_Args args;
124   args.client = c_client_.get();
125   args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE;
126   args.priv = nullptr;
127   pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_PlatformName(&args), c_api_);
128 
129   absl::string_view platform_name(args.platform_name, args.platform_name_size);
130   return platform_name;
131 }
132 
process_index() const133 int PjRtCApiClient::process_index() const {
134   PJRT_Client_ProcessIndex_Args process_index_args;
135   process_index_args.struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE;
136   process_index_args.priv = nullptr;
137   process_index_args.client = c_client_.get();
138   pjrt::LogFatalIfPjrtError(
139       c_api_->PJRT_Client_ProcessIndex(&process_index_args), c_api_);
140 
141   return process_index_args.process_index;
142 }
143 
platform_version() const144 absl::string_view PjRtCApiClient::platform_version() const {
145   PJRT_Client_PlatformVersion_Args args;
146   args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE;
147   args.priv = nullptr;
148   args.client = c_client_.get();
149   pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_PlatformVersion(&args), c_api_);
150 
151   absl::string_view platform_version(args.platform_version,
152                                      args.platform_version_size);
153   return platform_version;
154 }
155 
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const156 StatusOr<std::optional<std::string>> PjRtCApiClient::ExecutableFingerprint(
157     const PjRtLoadedExecutable& executable) const {
158   return {std::nullopt};
159 }
160 
LookupDevice(int device_id) const161 StatusOr<PjRtDevice*> PjRtCApiClient::LookupDevice(int device_id) const {
162   PJRT_Client_LookupDevice_Args args;
163   args.struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE;
164   args.priv = nullptr;
165   args.client = c_client_.get();
166   args.id = device_id;
167   RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_);
168   return GetCppDevice(args.device);
169 }
170 
ValidateCompileOption(CompileOptions options)171 static Status ValidateCompileOption(CompileOptions options) {
172   if (options.argument_layouts.has_value()) {
173     return xla::Unimplemented(
174         "argument_layouts in CompileOptions is not supported.");
175   }
176   if (options.compile_portable_executable) {
177     return xla::Unimplemented(
178         "compile_portable_executable in CompileOptions is not supported.");
179   }
180   if (options.profile_version != 0) {
181     return xla::Unimplemented(
182         "profile_version in CompileOptions is not supported.");
183   }
184   if (options.multi_slice_config != nullptr) {
185     return xla::Unimplemented(
186         "multi_slice_config in CompileOptions is not supported.");
187   }
188   return xla::OkStatus();
189 }
190 
191 // Convert `CompileOptions` to `PJRT_CompileOptions`. `device_assignment_str`
192 // will be used for serialized DeviceAssignment storage.
ConvertCppCompileOptionsToCCompileOptions(CompileOptions options,std::string * device_assignment_str)193 static StatusOr<PJRT_CompileOptions> ConvertCppCompileOptionsToCCompileOptions(
194     CompileOptions options, std::string* device_assignment_str) {
195   PJRT_CompileOptions c_options;
196   c_options.struct_size = PJRT_CompileOptions_STRUCT_SIZE;
197   c_options.parameter_is_tupled_arguments =
198       options.parameter_is_tupled_arguments;
199   c_options.device_ordinal = options.executable_build_options.device_ordinal();
200   c_options.num_replicas = options.executable_build_options.num_replicas();
201   c_options.num_partitions = options.executable_build_options.num_partitions();
202   c_options.use_spmd_partitioning =
203       options.executable_build_options.use_spmd_partitioning();
204   c_options.allow_spmd_sharding_propagation_to_output =
205       options.executable_build_options
206           .allow_spmd_sharding_propagation_to_output();
207 
208   if (options.executable_build_options.has_device_assignment()) {
209     DeviceAssignmentProto device_assignment_proto;
210     TF_RETURN_IF_ERROR(
211         options.executable_build_options.device_assignment().Serialize(
212             &device_assignment_proto));
213     *device_assignment_str = device_assignment_proto.SerializeAsString();
214     c_options.device_assignment = device_assignment_str->c_str();
215     c_options.device_assignment_size = device_assignment_str->size();
216   } else {
217     c_options.device_assignment_size = 0;
218     c_options.device_assignment = nullptr;
219   }
220   return c_options;
221 }
222 
Compile(mlir::ModuleOp module,CompileOptions options)223 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
224     mlir::ModuleOp module, CompileOptions options) {
225   TF_RETURN_IF_ERROR(ValidateCompileOption(options));
226   PJRT_Client_Compile_Args args;
227   args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
228   args.priv = nullptr;
229   args.client = c_client_.get();
230   std::string device_assignment_str;
231   TF_ASSIGN_OR_RETURN(PJRT_CompileOptions c_options,
232                       ConvertCppCompileOptionsToCCompileOptions(
233                           options, &device_assignment_str));
234   args.options = &c_options;
235   std::string module_str = tensorflow::SerializeMlirModule(module);
236   args.module = module_str.c_str();
237   args.module_size = module_str.size();
238 
239   RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_Compile(&args), c_api_);
240   std::unique_ptr<PjRtLoadedExecutable> ret =
241       std::make_unique<PjRtCApiExecutable>(this, args.executable);
242   return ret;
243 }
244 
SerializeExecutable(const PjRtLoadedExecutable & executable) const245 StatusOr<std::string> PjRtCApiClient::SerializeExecutable(
246     const PjRtLoadedExecutable& executable) const {
247 #ifdef PJRT_C_API_BYPASS
248   return wrapped_->SerializeExecutable(
249       *PjRtCApiExecutable::GetWrapped(&executable));
250 #endif  // PJRT_C_API_BYPASS
251   return Unimplemented("PJRT C API does not support SerializeExecutable");
252 }
253 
254 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,CompileOptions options)255 PjRtCApiClient::DeserializeExecutable(absl::string_view serialized,
256                                       CompileOptions options) {
257 #ifdef PJRT_C_API_BYPASS
258   return WrapExecutable(wrapped_->DeserializeExecutable(serialized, options));
259 #endif  // PJRT_C_API_BYPASS
260   return Unimplemented("PJRT C API does not support DeserializeExecutable");
261 }
262 
UnsafeBufferPointer(PjRtBuffer * buffer)263 StatusOr<std::uintptr_t> PjRtCApiClient::UnsafeBufferPointer(
264     PjRtBuffer* buffer) {
265 #ifdef PJRT_C_API_BYPASS
266   return wrapped_->UnsafeBufferPointer(PjRtCApiBuffer::GetWrapped(buffer));
267 #endif  // PJRT_C_API_BYPASS
268   return Unimplemented("PJRT C API does not support UnsafeBufferPointer");
269 }
270 
WrapExecutable(StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap)271 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::WrapExecutable(
272     StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap) {
273   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
274                       std::move(to_wrap));
275   return std::unique_ptr<PjRtLoadedExecutable>(
276       std::make_unique<PjRtCApiExecutable>(this, std::move(executable)));
277 }
278 
WrapBuffer(StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap)279 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiClient::WrapBuffer(
280     StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap) {
281   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer, std::move(to_wrap));
282   return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtCApiBuffer>(
283       this, new PJRT_Buffer{std::move(buffer), pjrt_c_client()}));
284 }
285 
pjrt_c_api() const286 const PJRT_Api* PjRtCApiClient::pjrt_c_api() const { return c_api_; }
287 
288 // --------------------------------- Devices -----------------------------------
289 
PjRtCApiDevice(PJRT_Device * device,PjRtCApiClient * client)290 PjRtCApiDevice::PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client)
291     : client_(client), device_(device) {
292   wrapped_ = device_->device;
293   InitAttributes();
294 }
295 
client() const296 PjRtClient* PjRtCApiDevice::client() const { return client_; }
297 
id() const298 int PjRtCApiDevice::id() const {
299   PJRT_Device_Id_Args args;
300   args.struct_size = PJRT_Device_Id_Args_STRUCT_SIZE;
301   args.priv = nullptr;
302   args.device = device_;
303   const PJRT_Api* api = client_->pjrt_c_api();
304   pjrt::LogFatalIfPjrtError(api->PJRT_Device_Id(&args), api);
305   return args.id;
306 }
307 
process_index() const308 int PjRtCApiDevice::process_index() const {
309   PJRT_Device_ProcessIndex_Args args;
310   args.struct_size = PJRT_Device_ProcessIndex_Args_STRUCT_SIZE;
311   args.priv = nullptr;
312   args.device = device_;
313   const PJRT_Api* api = client_->pjrt_c_api();
314   pjrt::LogFatalIfPjrtError(api->PJRT_Device_ProcessIndex(&args), api);
315   return args.process_index;
316 }
317 
IsAddressable() const318 bool PjRtCApiDevice::IsAddressable() const {
319   PJRT_Device_IsAddressable_Args args;
320   args.struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE;
321   args.priv = nullptr;
322   args.device = device_;
323   const PJRT_Api* api = client_->pjrt_c_api();
324   pjrt::LogFatalIfPjrtError(api->PJRT_Device_IsAddressable(&args), api);
325   return args.is_addressable;
326 }
327 
InitAttributes()328 void PjRtCApiDevice::InitAttributes() {
329   attributes_ = {};
330   PJRT_Device_Attributes_Args args;
331   args.struct_size = PJRT_Device_Attributes_Args_STRUCT_SIZE;
332   args.priv = nullptr;
333   args.device = device_;
334   const PJRT_Api* api = client_->pjrt_c_api();
335   pjrt::LogFatalIfPjrtError(api->PJRT_Device_Attributes(&args), api);
336 
337   for (int i = 0; i < args.num_attributes; ++i) {
338     const auto& attribute = args.attributes[i];
339     std::string attribute_name(attribute.name, attribute.name_size);
340     switch (attribute.type) {
341       case PJRT_Device_Attribute::PJRT_Device_Attribute_kString: {
342         std::string string_value(attribute.string_value, attribute.value_size);
343         attributes_[attribute_name] = PjRtDeviceAttribute(string_value);
344         break;
345       }
346       case PJRT_Device_Attribute::PJRT_Device_Attribute_kInt64: {
347         attributes_[attribute_name] =
348             PjRtDeviceAttribute(attribute.int64_value);
349         break;
350       }
351       case PJRT_Device_Attribute::PJRT_Device_Attribute_kInt64List: {
352         const int64_t* array_ptr(attribute.int64_array_value);
353         std::vector<int64_t> int64_array(array_ptr,
354                                          array_ptr + attribute.value_size);
355         attributes_[attribute_name] = PjRtDeviceAttribute(int64_array);
356         break;
357       }
358     }
359   }
360 }
361 
362 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
Attributes() const363 PjRtCApiDevice::Attributes() const {
364   return attributes_;
365 }
366 
device_kind() const367 absl::string_view PjRtCApiDevice::device_kind() const {
368   PJRT_Device_Kind_Args args;
369   args.struct_size = PJRT_Device_Kind_Args_STRUCT_SIZE;
370   args.priv = nullptr;
371   args.device = device_;
372 
373   const PJRT_Api* c_api = client_->pjrt_c_api();
374   pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_Kind(&args), c_api);
375 
376   absl::string_view device_kind(args.device_kind, args.device_kind_size);
377   return device_kind;
378 }
379 
local_hardware_id() const380 int PjRtCApiDevice::local_hardware_id() const {
381   PJRT_Device_LocalHardwareId_Args args;
382   args.struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE;
383   args.priv = nullptr;
384   args.device = device_;
385   const PJRT_Api* api = client_->pjrt_c_api();
386   pjrt::LogFatalIfPjrtError(api->PJRT_Device_LocalHardwareId(&args), api);
387   return args.local_hardware_id;
388 }
389 
DebugString() const390 absl::string_view PjRtCApiDevice::DebugString() const {
391   PJRT_Device_DebugString_Args args;
392   args.struct_size = PJRT_Device_DebugString_Args_STRUCT_SIZE;
393   args.priv = nullptr;
394   args.device = device_;
395   const PJRT_Api* c_api = client_->pjrt_c_api();
396   pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_DebugString(&args), c_api);
397   absl::string_view debug_string(args.debug_string, args.debug_string_size);
398   return debug_string;
399 }
400 
ToString() const401 absl::string_view PjRtCApiDevice::ToString() const {
402   PJRT_Device_ToString_Args args;
403   args.struct_size = PJRT_Device_ToString_Args_STRUCT_SIZE;
404   args.priv = nullptr;
405   args.device = device_;
406   const PJRT_Api* c_api = client_->pjrt_c_api();
407   pjrt::LogFatalIfPjrtError(c_api->PJRT_Device_ToString(&args), c_api);
408   absl::string_view to_string(args.to_string, args.to_string_size);
409   return to_string;
410 }
411 
412 // ------------------------------- Executables ---------------------------------
413 
PjRtCApiExecutable(PjRtCApiClient * client,std::unique_ptr<PjRtLoadedExecutable> wrapped)414 PjRtCApiExecutable::PjRtCApiExecutable(
415     PjRtCApiClient* client, std::unique_ptr<PjRtLoadedExecutable> wrapped)
416     : client_(client),
417       executable_(
418           new PJRT_Executable{std::move(wrapped), client->pjrt_c_client()}) {
419   InitDevices();
420 }
421 
PjRtCApiExecutable(PjRtCApiClient * client,PJRT_Executable * executable)422 PjRtCApiExecutable::PjRtCApiExecutable(PjRtCApiClient* client,
423                                        PJRT_Executable* executable)
424     : client_(client), executable_(executable) {
425   InitDevices();
426 }
427 
InitDevices()428 void PjRtCApiExecutable::InitDevices() {
429   PJRT_Executable_AddressableDevices_Args args;
430   args.struct_size = PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE;
431   args.priv = nullptr;
432   args.executable = executable_;
433   args.addressable_devices = nullptr;
434   args.num_addressable_devices = 0;
435 
436   const PJRT_Api* api = pjrt_c_api();
437   pjrt::LogFatalIfPjrtError(api->PJRT_Executable_AddressableDevices(&args),
438                             api);
439 
440   const size_t num_addressable_devices = args.num_addressable_devices;
441   addressable_devices_.reserve(num_addressable_devices);
442 
443   for (size_t i = 0; i < num_addressable_devices; ++i) {
444     PJRT_Device* device = args.addressable_devices[i];
445     PjRtCApiDevice* c_api_device = client_->GetCppDevice(device);
446     addressable_devices_.push_back(c_api_device);
447   }
448 }
449 
~PjRtCApiExecutable()450 PjRtCApiExecutable::~PjRtCApiExecutable() {
451   PJRT_Executable_Destroy_Args args;
452   args.struct_size = PJRT_Executable_Destroy_Args_STRUCT_SIZE;
453   args.priv = nullptr;
454   args.executable = executable_;
455   const PJRT_Api* api = pjrt_c_api();
456   pjrt::LogFatalIfPjrtError(api->PJRT_Executable_Destroy(&args), api);
457 }
458 
Convert2DCppBuffersToCBuffers(absl::Span<const std::vector<PjRtBuffer * >> cpp_lists)459 static std::vector<std::vector<PJRT_Buffer*>> Convert2DCppBuffersToCBuffers(
460     absl::Span<const std::vector<PjRtBuffer*>> cpp_lists) {
461   std::vector<std::vector<PJRT_Buffer*>> c_lists;
462   c_lists.reserve(cpp_lists.size());
463   for (const auto& cpp_list : cpp_lists) {
464     auto& c_list = c_lists.emplace_back();
465     c_list.reserve(cpp_list.size());
466     for (PjRtBuffer* buffer : cpp_list) {
467       auto* c_api_argument = tensorflow::down_cast<PjRtCApiBuffer*>(buffer);
468       c_list.push_back(c_api_argument->c_buffer());
469     }
470   }
471   return c_lists;
472 }
473 
474 static std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>
Convert2DCBuffersToCppBuffers(PJRT_Buffer *** c_lists,size_t outer_size,int inner_size,xla::PjRtCApiClient * client)475 Convert2DCBuffersToCppBuffers(PJRT_Buffer*** c_lists, size_t outer_size,
476                               int inner_size, xla::PjRtCApiClient* client) {
477   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> ret;
478   for (size_t i = 0; i < outer_size; ++i) {
479     auto& output_list = ret.emplace_back();
480     output_list.reserve(inner_size);
481     for (size_t j = 0; j < inner_size; ++j) {
482       output_list.push_back(
483           std::make_unique<PjRtCApiBuffer>(client, c_lists[i][j]));
484     }
485   }
486   return ret;
487 }
488 
489 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)490 PjRtCApiExecutable::Execute(
491     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
492     const ExecuteOptions& options,
493     std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
494   PJRT_Executable_Execute_Args args;
495   args.struct_size = PJRT_Executable_Execute_Args_STRUCT_SIZE;
496   args.priv = nullptr;
497   args.executable = executable_;
498   PJRT_ExecuteOptions c_options;
499   args.options = &c_options;
500   args.options->struct_size = PJRT_ExecuteOptions_STRUCT_SIZE;
501   args.options->launch_id = options.launch_id;
502   args.num_devices = argument_handles.size();
503   CHECK_GT(args.num_devices, 0);
504   args.num_args = argument_handles[0].size();
505 
506   // Populates `args.argument_lists` from `argument_handles`.
507   std::vector<std::vector<PJRT_Buffer*>> c_argument_lists =
508       Convert2DCppBuffersToCBuffers(argument_handles);
509   std::vector<PJRT_Buffer**> c_arguments;
510   c_arguments.reserve(c_argument_lists.size());
511   for (auto& argument_list : c_argument_lists) {
512     c_arguments.push_back(argument_list.data());
513   }
514   args.argument_lists = c_arguments.data();
515 
516   // Allocates memory for output. `c_buffer_lists_holder` and `c_buffer_lists`
517   // needs to stay alive during the call of `PJRT_Executable_Execute`.
518   PJRT_Executable_NumOutputs_Args numoutputs_args;
519   numoutputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE;
520   numoutputs_args.priv = nullptr;
521   numoutputs_args.executable = executable_;
522   RETURN_STATUS_IF_ERROR(
523       pjrt_c_api()->PJRT_Executable_NumOutputs(&numoutputs_args), pjrt_c_api());
524   size_t outer_size = args.num_devices;
525   size_t inner_size = numoutputs_args.num_outputs;
526   std::vector<std::vector<PJRT_Buffer*>> c_buffer_lists_holder(outer_size);
527   auto c_buffer_lists = std::vector<PJRT_Buffer**>(outer_size);
528   for (int i = 0; i < outer_size; ++i) {
529     c_buffer_lists_holder[i].resize(inner_size);
530     c_buffer_lists[i] = c_buffer_lists_holder[i].data();
531   }
532   args.output_lists = c_buffer_lists.data();
533 
534   RETURN_STATUS_IF_ERROR(pjrt_c_api()->PJRT_Executable_Execute(&args),
535                          pjrt_c_api());
536 
537   return Convert2DCBuffersToCppBuffers(args.output_lists, args.num_devices,
538                                        numoutputs_args.num_outputs, client_);
539 }
540 
541 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)542 PjRtCApiExecutable::ExecuteSharded(
543     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
544     const ExecuteOptions& options,
545     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
546 #ifdef PJRT_C_API_BYPASS
547   std::vector<PjRtBuffer*> wrapped_args =
548       PjRtCApiBuffer::GetWrappedVector(argument_handles);
549 
550   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> out,
551                       wrapped()->ExecuteSharded(
552                           wrapped_args, PjRtCApiDevice::GetWrapped(device),
553                           options, returned_future, fill_future));
554 
555   for (std::unique_ptr<PjRtBuffer>& buffer : out) {
556     buffer = std::make_unique<PjRtCApiBuffer>(
557         client_, new PJRT_Buffer{std::move(buffer), client_->pjrt_c_client()});
558   }
559   return out;
560 #endif  // PJRT_C_API_BYPASS
561   return Unimplemented("PJRT C API does not support ExecuteSharded");
562 }
563 
564 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)565 PjRtCApiExecutable::ExecutePortable(
566     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
567     const ExecuteOptions& options,
568     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
569 #ifdef PJRT_C_API_BYPASS
570   std::vector<PjRtBuffer*> wrapped_args =
571       PjRtCApiBuffer::GetWrappedVector(argument_handles);
572 
573   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> out,
574                       wrapped()->ExecutePortable(
575                           wrapped_args, PjRtCApiDevice::GetWrapped(device),
576                           options, returned_future, fill_future));
577 
578   for (std::unique_ptr<PjRtBuffer>& buffer : out) {
579     buffer = std::make_unique<PjRtCApiBuffer>(
580         client_, new PJRT_Buffer{std::move(buffer), client_->pjrt_c_client()});
581   }
582   return out;
583 #endif  // PJRT_C_API_BYPASS
584   return Unimplemented("PJRT C API does not support ExecutePortable");
585 }
586 
wrapped() const587 PjRtLoadedExecutable* PjRtCApiExecutable::wrapped() const {
588   return executable_->executable.get();
589 }
590 
name() const591 absl::string_view PjRtCApiExecutable::name() const {
592   const PJRT_Api* c_api = pjrt_c_api();
593   PJRT_Executable_Name_Args args;
594   args.executable = executable_;
595   args.struct_size = PJRT_Executable_Name_Args_STRUCT_SIZE;
596   args.priv = nullptr;
597   pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Name(&args), c_api);
598 
599   absl::string_view executable_name(args.executable_name,
600                                     args.executable_name_size);
601   return executable_name;
602 }
603 
Delete()604 void PjRtCApiExecutable::Delete() {
605   PJRT_Executable_Delete_Args args;
606   args.struct_size = PJRT_Executable_Delete_Args_STRUCT_SIZE;
607   args.priv = nullptr;
608   args.executable = executable_;
609   const PJRT_Api* c_api = pjrt_c_api();
610   pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Delete(&args), c_api);
611 }
612 
IsDeleted()613 bool PjRtCApiExecutable::IsDeleted() {
614   PJRT_Executable_IsDeleted_Args args;
615   args.struct_size = PJRT_Executable_IsDeleted_Args_STRUCT_SIZE;
616   args.priv = nullptr;
617   args.executable = executable_;
618 
619   const PJRT_Api* c_api = pjrt_c_api();
620   pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_IsDeleted(&args), c_api);
621   return args.is_deleted;
622 }
623 
624 // ---------------------------------- Buffers ----------------------------------
625 
PjRtCApiBuffer(PjRtCApiClient * client,PJRT_Buffer * buffer)626 PjRtCApiBuffer::PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer)
627     : client_(client),
628       buffer_(buffer, ::pjrt::MakeBufferDeleter(client->pjrt_c_api())),
629       wrapped_(buffer_->buffer.get()) {
630   set_shape();
631 }
632 
on_device_shape() const633 const Shape& PjRtCApiBuffer::on_device_shape() const {
634   CHECK(shape_.has_value())
635       << "Shape should be initialized in PjRtCApiBuffer constructor.";
636   return shape_.value();
637 }
638 
set_shape()639 void PjRtCApiBuffer::set_shape() {
640   PJRT_Buffer_OnDeviceTrimmedShape_Args args;
641   args.struct_size = PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE;
642   args.priv = nullptr;
643   args.buffer = buffer_.get();
644 
645   pjrt::LogFatalIfPjrtError(
646       client_->pjrt_c_api()->PJRT_Buffer_OnDeviceTrimmedShape(&args),
647       client_->pjrt_c_api());
648 
649   xla::PrimitiveType element_type =
650       static_cast<xla::PrimitiveType>(args.element_type);
651 
652   CHECK_NE(element_type, xla::PrimitiveType::TUPLE);
653 
654   absl::Span<const int64_t> dims = ApiConverter::MakeSpan(args.dimensions);
655   absl::Span<const bool> dynamic_dims =
656       ApiConverter::MakeSpan(args.dynamic_dimensions);
657 
658   Shape trimmed_shape = Shape(element_type, dims, dynamic_dims, {});
659 
660   if (args.has_layout) {
661     *(trimmed_shape.mutable_layout()) = ApiConverter::FromC(&args.layout);
662   }
663 
664   shape_ = trimmed_shape;
665 
666   // TODO(amangu): Refactor the deletion.
667   if (args.dimensions.size > TPU_C_API_MAX_INLINED) {
668     delete[] args.dimensions.heap;
669   }
670 
671   if (args.dynamic_dimensions.size > TPU_C_API_MAX_INLINED) {
672     delete[] args.dynamic_dimensions.heap;
673   }
674 
675   if (args.has_layout) {
676     if (args.layout.minor_to_major.size > TPU_C_API_MAX_INLINED) {
677       delete[] args.layout.minor_to_major.heap;
678     }
679 
680     if (args.layout.tiles.size > TPU_C_API_MAX_INLINED) {
681       delete[] args.layout.tiles.heap;
682     }
683   }
684 }
685 
GetOnDeviceSizeInBytes() const686 StatusOr<size_t> PjRtCApiBuffer::GetOnDeviceSizeInBytes() const {
687   PJRT_Buffer_OnDeviceSizeInBytes_Args args;
688   args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE;
689   args.priv = nullptr;
690   args.buffer = buffer_.get();
691   RETURN_STATUS_IF_ERROR(
692       client_->pjrt_c_api()->PJRT_Buffer_OnDeviceSizeInBytes(&args),
693       client_->pjrt_c_api());
694 
695   return args.on_device_size_in_bytes;
696 }
697 
device() const698 PjRtDevice* PjRtCApiBuffer::device() const {
699   PJRT_Buffer_Device_Args args;
700   args.struct_size = PJRT_Buffer_Device_Args_STRUCT_SIZE;
701   args.priv = nullptr;
702   args.buffer = buffer_.get();
703   const PJRT_Api* api = pjrt_c_api();
704   pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Device(&args), api);
705   return client_->GetCppDevice(args.device);
706 }
707 
Delete()708 void PjRtCApiBuffer::Delete() {
709   PJRT_Buffer_Delete_Args args;
710   args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE;
711   args.priv = nullptr;
712   args.buffer = buffer_.get();
713   const PJRT_Api* api = pjrt_c_api();
714   pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Delete(&args), api);
715 }
716 
IsDeleted()717 bool PjRtCApiBuffer::IsDeleted() {
718   PJRT_Buffer_IsDeleted_Args args;
719   args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE;
720   args.priv = nullptr;
721   args.buffer = buffer_.get();
722   const PJRT_Api* api = pjrt_c_api();
723   pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsDeleted(&args), api);
724   return args.is_deleted;
725 }
726 
CopyToDevice(PjRtDevice * dst_device)727 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiBuffer::CopyToDevice(
728     PjRtDevice* dst_device) {
729   if (dst_device->client() == client_) {
730     PJRT_Buffer_CopyToDevice_Args args;
731     args.struct_size = PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE;
732     args.priv = nullptr;
733     args.buffer = buffer_.get();
734     args.dst_device =
735         tensorflow::down_cast<PjRtCApiDevice*>(dst_device)->c_device();
736     const PJRT_Api* api = pjrt_c_api();
737     RETURN_STATUS_IF_ERROR(api->PJRT_Buffer_CopyToDevice(&args), api);
738     return std::unique_ptr<PjRtBuffer>(
739         std::make_unique<PjRtCApiBuffer>(client_, args.dst_buffer));
740   } else {
741     // TODO(b/239735405) Copying across different clients where `dst_device` is
742     // not a PjRtCApiDevice raises an error.
743     return wrapped_->CopyToDevice(dst_device);
744   }
745 }
746 
IsOnCpu() const747 bool PjRtCApiBuffer::IsOnCpu() const {
748   PJRT_Buffer_IsOnCpu_Args args;
749   args.struct_size = PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE;
750   args.priv = nullptr;
751   args.buffer = buffer_.get();
752   const PJRT_Api* api = pjrt_c_api();
753   pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsOnCpu(&args), api);
754   return args.is_on_cpu;
755 }
756 
757 // -------------------------------- API access ---------------------------------
758 
GetCApiClient()759 StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient() {
760   const PJRT_Api* c_api = tensorflow::tpu::PjrtApi();
761   // TODO(skyewm): make status
762   CHECK(c_api != nullptr);
763 
764   PJRT_Client_Create_Args init_args;
765   init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
766   init_args.priv = nullptr;
767   RETURN_STATUS_IF_ERROR(c_api->PJRT_Client_Create(&init_args), c_api);
768   PJRT_Client* c_client = init_args.client;
769 
770   return std::unique_ptr<PjRtClient>(
771       std::make_unique<PjRtCApiClient>(c_api, c_client));
772 }
773 
774 }  // namespace xla
775