xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 // TODO(b/238999986): Remove this.
23 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
24 
25 #define PJRT_STRUCT_SIZE(struct_type, last_field) \
26   offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)
27 
28 #ifdef __cplusplus
29 extern "C" {
30 #endif
31 
32 // ---------------------------------- Errors -----------------------------------
33 
34 // PJRT C API methods generally return a PJRT_Error*, which is nullptr if there
35 // is no error and set if there is. The implementation allocates any returned
36 // PJRT_Errors, but the caller is always responsible for freeing them via
37 // PJRT_Error_Destroy.
38 
39 typedef struct PJRT_Error PJRT_Error;
40 
41 typedef struct {
42   size_t struct_size;
43   void* priv;
44   PJRT_Error* error;
45 } PJRT_Error_Destroy_Args;
46 const size_t PJRT_Error_Destroy_Args_STRUCT_SIZE =
47     PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error);
48 
49 // Frees `error`. `error` can be nullptr.
50 typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args);
51 
52 typedef struct {
53   size_t struct_size;
54   void* priv;
55   const PJRT_Error* error;
56   // Has the lifetime of `error`.
57   const char* message;  // out
58   size_t message_size;  // out
59 } PJRT_Error_Message_Args;
60 const size_t PJRT_Error_Message_Args_STRUCT_SIZE =
61     PJRT_STRUCT_SIZE(PJRT_Error_Message_Args, message_size);
62 
63 // Gets the human-readable reason for `error`. `message` has the lifetime of
64 // `error`.
65 typedef void PJRT_Error_Message(PJRT_Error_Message_Args* args);
66 
67 // Codes are based on https://abseil.io/docs/cpp/guides/status-codes
68 typedef enum {
69   PJRT_Error_Code_CANCELLED = 1,
70   PJRT_Error_Code_UNKNOWN = 2,
71   PJRT_Error_Code_INVALID_ARGUMENT = 3,
72   PJRT_Error_Code_DEADLINE_EXCEEDED = 4,
73   PJRT_Error_Code_NOT_FOUND = 5,
74   PJRT_Error_Code_ALREADY_EXISTS = 6,
75   PJRT_Error_Code_PERMISSION_DENIED = 7,
76   PJRT_Error_Code_RESOURCE_EXHAUSTED = 8,
77   PJRT_Error_Code_FAILED_PRECONDITION = 9,
78   PJRT_Error_Code_ABORTED = 10,
79   PJRT_Error_Code_OUT_OF_RANGE = 11,
80   PJRT_Error_Code_UNIMPLEMENTED = 12,
81   PJRT_Error_Code_INTERNAL = 13,
82   PJRT_Error_Code_UNAVAILABLE = 14,
83   PJRT_Error_Code_DATA_LOSS = 15,
84   PJRT_Error_Code_UNAUTHENTICATED = 16
85 } PJRT_Error_Code;
86 
87 typedef struct {
88   size_t struct_size;
89   void* priv;
90   const PJRT_Error* error;
91   PJRT_Error_Code code;  // out
92 } PJRT_Error_GetCode_Args;
93 const size_t PJRT_Error_GetCode_Args_STRUCT_SIZE =
94     PJRT_STRUCT_SIZE(PJRT_Error_GetCode_Args, error);
95 
96 typedef PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args);
97 
98 // ---------------------------------- Events -----------------------------------
99 
100 // Represents a notifying event that is returned by PJRT APIs that enqueue
101 // asynchronous work, informing callers when the work is complete and reporting
102 // a value of type `PJRT_Error*` or `nullptr` as error status.
103 //
104 // Callers are always responsible for freeing `PJRT_Event`s by calling
105 // `PJRT_Event_Destroy`.
106 typedef struct PJRT_Event PJRT_Event;
107 
108 typedef struct {
109   size_t struct_size;
110   void* priv;
111   PJRT_Event* event;
112 } PJRT_Event_Destroy_Args;
113 const size_t PJRT_Event_Destroy_Args_STRUCT_SIZE =
114     PJRT_STRUCT_SIZE(PJRT_Event_Destroy_Args, event);
115 
116 // Frees `event`. `event` can be `nullptr`.
117 typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args);
118 
119 typedef struct {
120   size_t struct_size;
121   void* priv;
122   PJRT_Event* event;
123   bool is_ready;  // out
124 } PJRT_Event_IsReady_Args;
125 const size_t PJRT_Event_IsReady_Args_STRUCT_SIZE =
126     PJRT_STRUCT_SIZE(PJRT_Event_IsReady_Args, is_ready);
127 
128 // Returns true if this PJRT_Event has completed, including if an error has
129 // occurred.
130 typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args);
131 
132 typedef struct {
133   size_t struct_size;
134   void* priv;
135   PJRT_Event* event;
136 } PJRT_Event_Error_Args;
137 const size_t PJRT_Event_Error_Args_STRUCT_SIZE =
138     PJRT_STRUCT_SIZE(PJRT_Event_Error_Args, event);
139 
140 // Should only be called if PJRT_Event_IsReady returns true.
141 // Returns `nullptr` if there is no error.
142 // The returned error should be freed with `PJRT_Error_Destroy`.
143 //
144 // If `PJRT_Event_Await` has been called, this will return a pointer to an
145 // identical error status as that call, as will subsequent calls to
146 // `PJRT_Event_Error`. However, each of these `PJRT_Error *` pointers are
147 // independent of `PJRT_Error *`s returned by other function calls, so they must
148 // each be freed separately using `PJRT_Error_Destroy`.
149 typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args);
150 
151 typedef struct {
152   size_t struct_size;
153   void* priv;
154   PJRT_Event* event;
155 } PJRT_Event_Await_Args;
156 
157 const size_t PJRT_Event_Await_Args_STRUCT_SIZE =
158     PJRT_STRUCT_SIZE(PJRT_Event_Await_Args, event);
159 
160 // Blocks the calling thread until `event` is ready, then returns the error
161 // status (with `nullptr` indicating no error). The returned status should be
162 // freed with `PJRT_Error_Destroy`.
163 typedef PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args);
164 
165 // A callback to be performed once an event is ready. It will be called on the
166 // event's error state and a pointer to an object of the caller's choice.
167 // Ownership of `error` is passed to the callback. The callback must destroy
168 // `error` via `PJRT_Error_Destroy`. The caller retains ownership of `user_arg`.
169 typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg);
170 
171 typedef struct {
172   size_t struct_size;
173   void* priv;
174   PJRT_Event* event;
175   PJRT_Event_OnReadyCallback callback;
176   // `user_arg` allows `callback` to be called with arbitrary arguments (e.g.
177   // via pointers in a struct cast to void*).
178   void* user_arg;
179 } PJRT_Event_OnReady_Args;
180 const size_t PJRT_Event_OnReady_Args_STRUCT_SIZE =
181     PJRT_STRUCT_SIZE(PJRT_Event_OnReady_Args, user_arg);
182 
183 // Registers `callback` to be called once `event` is ready, with `event`'s
184 // error status and a pointer to an object of the caller's choice as arguments.
185 typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args);
186 
187 // ---------------------------------- Client -----------------------------------
188 
189 typedef struct PJRT_Client PJRT_Client;
190 typedef struct PJRT_Device PJRT_Device;
191 typedef struct PJRT_Executable PJRT_Executable;
192 
193 typedef struct {
194   size_t struct_size;
195   void* priv;
196   PJRT_Client* client;  // out
197 } PJRT_Client_Create_Args;
198 const size_t PJRT_Client_Create_Args_STRUCT_SIZE =
199     PJRT_STRUCT_SIZE(PJRT_Client_Create_Args, client);
200 
201 // Creates and initializes a new PJRT_Client and returns in `client`.
202 typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args);
203 
204 typedef struct {
205   size_t struct_size;
206   void* priv;
207   PJRT_Client* client;
208 } PJRT_Client_Destroy_Args;
209 const size_t PJRT_Client_Destroy_Args_STRUCT_SIZE =
210     PJRT_STRUCT_SIZE(PJRT_Client_Destroy_Args, client);
211 
212 // Shuts down and frees `client`. `client` can be nullptr.
213 typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args);
214 
215 typedef struct {
216   size_t struct_size;
217   void* priv;
218   PJRT_Client* client;
219   // `platform_name` has the same lifetime as `client`. It is owned by `client`.
220   const char* platform_name;  // out
221   size_t platform_name_size;  // out
222 } PJRT_Client_PlatformName_Args;
223 
224 const size_t PJRT_Client_PlatformName_Args_STRUCT_SIZE =
225     PJRT_STRUCT_SIZE(PJRT_Client_PlatformName_Args, platform_name_size);
226 
227 // Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu").
228 typedef PJRT_Error* PJRT_Client_PlatformName(
229     PJRT_Client_PlatformName_Args* args);
230 
231 typedef struct {
232   size_t struct_size;
233   void* priv;
234   PJRT_Client* client;
235   int process_index;  // out
236 } PJRT_Client_ProcessIndex_Args;
237 const size_t PJRT_Client_ProcessIndex_Args_STRUCT_SIZE =
238     PJRT_STRUCT_SIZE(PJRT_Client_ProcessIndex_Args, process_index);
239 
240 // Return the process index of this client. Always 0 in single-process
241 // settings.
242 typedef PJRT_Error* PJRT_Client_ProcessIndex(
243     PJRT_Client_ProcessIndex_Args* args);
244 
245 typedef struct {
246   size_t struct_size;
247   void* priv;
248   PJRT_Client* client;
249   // `platform_version` has the same lifetime as `client`. It's owned by
250   // `client`.
251   const char* platform_version;  // out
252   size_t platform_version_size;  // out
253 } PJRT_Client_PlatformVersion_Args;
254 
255 const size_t PJRT_Client_PlatformVersion_Args_STRUCT_SIZE =
256     PJRT_STRUCT_SIZE(PJRT_Client_PlatformVersion_Args, platform_version_size);
257 
258 // Returns a string containing human-readable, platform-specific version info
259 // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
260 typedef PJRT_Error* PJRT_Client_PlatformVersion(
261     PJRT_Client_PlatformVersion_Args* args);
262 
263 typedef struct {
264   size_t struct_size;
265   void* priv;
266   PJRT_Client* client;
267   PJRT_Device** devices;  // out
268   size_t num_devices;     // out
269 } PJRT_Client_Devices_Args;
270 const size_t PJRT_Client_Devices_Args_STRUCT_SIZE =
271     PJRT_STRUCT_SIZE(PJRT_Client_Devices_Args, num_devices);
272 
273 // Returns a list of all devices visible to the runtime, including addressable
274 // and non-addressable devices.
275 typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args);
276 
277 typedef struct {
278   size_t struct_size;
279   void* priv;
280   PJRT_Client* client;
281   PJRT_Device** addressable_devices;  // out
282   size_t num_addressable_devices;     // out
283 } PJRT_Client_AddressableDevices_Args;
284 const size_t PJRT_Client_AddressableDevices_Args_STRUCT_SIZE = PJRT_STRUCT_SIZE(
285     PJRT_Client_AddressableDevices_Args, num_addressable_devices);
286 
287 // Returns a list of devices that are addressable from the client.
288 // Addressable devices are those that the client can issue commands to.
289 // All devices are addressable in a single-process environment.
290 typedef PJRT_Error* PJRT_Client_AddressableDevices(
291     PJRT_Client_AddressableDevices_Args* args);
292 
293 typedef struct {
294   size_t struct_size;
295   void* priv;
296   PJRT_Client* client;
297   int id;
298   // `device` has the same lifetime as `client`. It is owned by `client`.
299   PJRT_Device* device;  // out
300 } PJRT_Client_LookupDevice_Args;
301 
302 const size_t PJRT_Client_LookupDevice_Args_STRUCT_SIZE =
303     PJRT_STRUCT_SIZE(PJRT_Client_LookupDevice_Args, device);
304 
305 // Returns a PJRT_Device* with the specified ID as returned by PJRT_Device_Id.
306 typedef PJRT_Error* PJRT_Client_LookupDevice(
307     PJRT_Client_LookupDevice_Args* args);
308 
309 // TODO(jieying): add debug_option.
310 // TODO(b/240560013): consider putting some of option fields in priv.
311 typedef struct {
312   size_t struct_size;
313   void* priv;
314   // If true, the supplied module expects its arguments to be wrapped in a
315   // tuple and passed as a single parameter.
316   bool parameter_is_tupled_arguments;
317   // If set, this is the device to build the computation for. A value of -1
318   // indicates this option has not been set.
319   int device_ordinal;
320   // The number of replicas of this computation that are to be executed.
321   int num_replicas;
322   // The number of partitions in this computation.
323   int num_partitions;
324   // Whether to use SPMD (true) or MPMD (false) partitioning when
325   // num_partitions > 1 and XLA is requested to partition the input program.
326   bool use_spmd_partitioning;
327   // Whether to allow sharding propagation to propagate to the outputs.
328   bool allow_spmd_sharding_propagation_to_output;
329   const char* device_assignment;  // Serialized device assignment.
330   size_t device_assignment_size;
331 } PJRT_CompileOptions;
332 const size_t PJRT_CompileOptions_STRUCT_SIZE =
333     PJRT_STRUCT_SIZE(PJRT_CompileOptions, device_assignment_size);
334 
335 typedef struct {
336   size_t struct_size;
337   void* priv;
338   PJRT_Client* client;
339   // Serialized MLIR module. Only needs to stay alive for the duration of the
340   // Compile call.
341   const char* module;
342   size_t module_size;
343   // Only needs to stay alive for the duration of the Compile call.
344   PJRT_CompileOptions* options;
345   PJRT_Executable* executable;  // out
346 } PJRT_Client_Compile_Args;
347 
348 const size_t PJRT_Client_Compile_Args_STRUCT_SIZE =
349     PJRT_STRUCT_SIZE(PJRT_Client_Compile_Args, executable);
350 
351 // Compiles an MLIR module with given `options`.
352 typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args);
353 
354 // --------------------------------- Devices -----------------------------------
355 
356 typedef struct {
357   size_t struct_size;
358   void* priv;
359   PJRT_Device* device;
360   int id;  // out
361 } PJRT_Device_Id_Args;
362 const size_t PJRT_Device_Id_Args_STRUCT_SIZE =
363     PJRT_STRUCT_SIZE(PJRT_Device_Id_Args, id);
364 
365 // The ID of this device. IDs are unique among devices of this type
366 // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
367 // hosts' devices.
368 typedef PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args);
369 
370 typedef struct {
371   size_t struct_size;
372   void* priv;
373   PJRT_Device* device;
374   int local_hardware_id;  // out
375 } PJRT_Device_LocalHardwareId_Args;
376 const size_t PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE =
377     PJRT_STRUCT_SIZE(PJRT_Device_LocalHardwareId_Args, local_hardware_id);
378 
379 // Opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed
380 // to be dense, and -1 if undefined.
381 typedef PJRT_Error* PJRT_Device_LocalHardwareId(
382     PJRT_Device_LocalHardwareId_Args* args);
383 
384 typedef struct {
385   size_t struct_size;
386   void* priv;
387   PJRT_Device* device;
388   int process_index;  // out
389 } PJRT_Device_ProcessIndex_Args;
390 const size_t PJRT_Device_ProcessIndex_Args_STRUCT_SIZE =
391     PJRT_STRUCT_SIZE(PJRT_Device_ProcessIndex_Args, process_index);
392 
393 // The index of the process that this device belongs to, i.e. is addressable
394 // from. This is not always identical to PJRT_Client_ProcessIndex in a
395 // multi-process setting, where each client can see devices from all
396 // processes, but only a subset of them are addressable and have the same
397 // process_index as the client.
398 typedef PJRT_Error* PJRT_Device_ProcessIndex(
399     PJRT_Device_ProcessIndex_Args* args);
400 
401 typedef struct {
402   size_t struct_size;
403   void* priv;
404   PJRT_Device* device;
405   bool is_addressable;  // out
406 } PJRT_Device_IsAddressable_Args;
407 const size_t PJRT_Device_IsAddressable_Args_STRUCT_SIZE =
408     PJRT_STRUCT_SIZE(PJRT_Device_IsAddressable_Args, is_addressable);
409 
410 // Whether client can issue command to this device.
411 typedef PJRT_Error* PJRT_Device_IsAddressable(
412     PJRT_Device_IsAddressable_Args* args);
413 
414 typedef struct {
415   size_t struct_size;
416   void* priv;
417   const char* name;
418   size_t name_size;
419   enum {
420     PJRT_Device_Attribute_kString = 0,
421     PJRT_Device_Attribute_kInt64,
422     PJRT_Device_Attribute_kInt64List
423   } type;
424   union {
425     int64_t int64_value;
426     const int64_t* int64_array_value;
427     const char* string_value;
428   };
429   // `value_size` is the number of elements for array/string and 1 for scalar
430   // values.
431   size_t value_size;
432 } PJRT_Device_Attribute;
433 const size_t PJRT_Device_Attribute_STRUCT_SIZE =
434     PJRT_STRUCT_SIZE(PJRT_Device_Attribute, value_size);
435 
436 typedef struct {
437   size_t struct_size;
438   void* priv;
439   PJRT_Device* device;
440   size_t num_attributes;              // out
441   PJRT_Device_Attribute* attributes;  // out
442 } PJRT_Device_Attributes_Args;
443 const size_t PJRT_Device_Attributes_Args_STRUCT_SIZE =
444     PJRT_STRUCT_SIZE(PJRT_Device_Attributes_Args, attributes);
445 
446 // Returns an array of device specific attributes with attribute name, value
447 // and value type.
448 typedef PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args);
449 
450 typedef struct {
451   size_t struct_size;
452   void* priv;
453   PJRT_Device* device;
454   // `device_kind` string is owned by `device` and has same lifetime as
455   // `device`.
456   const char* device_kind;  // out
457   size_t device_kind_size;  // out
458 } PJRT_Device_Kind_Args;
459 const size_t PJRT_Device_Kind_Args_STRUCT_SIZE =
460     PJRT_STRUCT_SIZE(PJRT_Device_Kind_Args, device_kind_size);
461 
462 // A vendor-dependent string that uniquely identifies the kind of device,
463 // e.g., "Tesla V100-SXM2-16GB".
464 typedef PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args);
465 
466 typedef struct {
467   size_t struct_size;
468   void* priv;
469   PJRT_Device* device;
470   const char* debug_string;  // out
471   size_t debug_string_size;  // out
472 } PJRT_Device_DebugString_Args;
473 const size_t PJRT_Device_DebugString_Args_STRUCT_SIZE =
474     PJRT_STRUCT_SIZE(PJRT_Device_DebugString_Args, debug_string_size);
475 
476 // Debug string suitable for logging when errors occur. Should be verbose
477 // enough to describe the current device unambiguously.
478 typedef PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args);
479 
480 typedef struct {
481   size_t struct_size;
482   void* priv;
483   PJRT_Device* device;
484   const char* to_string;  // out
485   size_t to_string_size;  // out
486 } PJRT_Device_ToString_Args;
487 const size_t PJRT_Device_ToString_Args_STRUCT_SIZE =
488     PJRT_STRUCT_SIZE(PJRT_Device_ToString_Args, to_string_size);
489 
490 // Debug string suitable for reading by end users, should be reasonably terse,
491 // for example: "CpuDevice(id=0)".
492 typedef PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args);
493 
494 // ------------------------------- Executables ---------------------------------
495 
496 typedef struct PJRT_Buffer PJRT_Buffer;
497 
498 typedef struct {
499   size_t struct_size;
500   void* priv;
501   PJRT_Executable* executable;
502 } PJRT_Executable_Destroy_Args;
503 const size_t PJRT_Executable_Destroy_Args_STRUCT_SIZE =
504     PJRT_STRUCT_SIZE(PJRT_Executable_Destroy_Args, executable);
505 
506 // Frees `executable` and deletes the underlying runtime object as if
507 // `PJRT_Executable_Delete` were called. `executable` can be nullptr.
508 typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args);
509 
510 typedef struct {
511   size_t struct_size;
512   void* priv;
513   PJRT_Executable* executable;
514   // `executable_name` has the same lifetime as `executable`. It is owned by
515   // `executable`.
516   const char* executable_name;  // out
517   size_t executable_name_size;  // out
518 } PJRT_Executable_Name_Args;
519 
520 const size_t PJRT_Executable_Name_Args_STRUCT_SIZE =
521     PJRT_STRUCT_SIZE(PJRT_Executable_Name_Args, executable_name_size);
522 
523 // Returns a string that identifies the executable.
524 typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args);
525 
526 typedef struct {
527   size_t struct_size;
528   void* priv;
529   PJRT_Executable* executable;
530   PJRT_Device** addressable_devices;  // out
531   size_t num_addressable_devices;     // out
532 } PJRT_Executable_AddressableDevices_Args;
533 
534 const size_t PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE =
535     PJRT_STRUCT_SIZE(PJRT_Executable_AddressableDevices_Args,
536                      num_addressable_devices);
537 
538 // Returns a list of devices this executable will run on.
539 typedef PJRT_Error* PJRT_Executable_AddressableDevices(
540     PJRT_Executable_AddressableDevices_Args* args);
541 
542 typedef struct {
543   size_t struct_size;
544   void* priv;
545   PJRT_Executable* executable;
546 } PJRT_Executable_Delete_Args;
547 const size_t PJRT_Executable_Delete_Args_STRUCT_SIZE =
548     PJRT_STRUCT_SIZE(PJRT_Executable_Delete_Args, executable);
549 
550 // Drops `executable`'s reference to the internal runtime object and
551 // associated resources, without freeing the `executable` object itself.
552 // `executable` can only be used with PJRT_Executable_IsDeleted and
553 // PJRT_Executable_Destroy after calling this method. The internal runtime
554 // executable will be freed after the last execution completes.
555 typedef PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args);
556 
557 typedef struct {
558   size_t struct_size;
559   void* priv;
560   PJRT_Executable* executable;
561   bool is_deleted;  // out
562 } PJRT_Executable_IsDeleted_Args;
563 const size_t PJRT_Executable_IsDeleted_Args_STRUCT_SIZE =
564     PJRT_STRUCT_SIZE(PJRT_Executable_IsDeleted_Args, is_deleted);
565 
566 // True if and only if PJRT_Executable_Delete has previously been called.
567 typedef PJRT_Error* PJRT_Executable_IsDeleted(
568     PJRT_Executable_IsDeleted_Args* args);
569 
570 typedef struct {
571   size_t struct_size;
572   void* priv;
573   // If non-zero, identifies this execution as part of a potentially
574   // multi-device launch. This can be used to detect scheduling errors, e.g. if
575   // multi-host programs are launched in different orders on different hosts,
576   // the launch IDs may be used by the runtime to detect the mismatch.
577   int launch_id;
578 } PJRT_ExecuteOptions;
579 const size_t PJRT_ExecuteOptions_STRUCT_SIZE =
580     PJRT_STRUCT_SIZE(PJRT_ExecuteOptions, launch_id);
581 
582 typedef struct {
583   size_t struct_size;
584   void* priv;
585   PJRT_Executable* executable;
586   // Only needs to stay alive for the duration of the Execute call.
587   PJRT_ExecuteOptions* options;
588   // Execution input of size [`num_devices`, `num_args`].
589   PJRT_Buffer*** argument_lists;
590   size_t num_devices;
591   size_t num_args;
592   // Execution output of size [`num_devices`, num_outputs`], where `num_outputs`
593   // is the number of outputs returned by this executable per device. Both the
594   // outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be
595   // allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called
596   // on the output PJRT_Buffer*.
597   PJRT_Buffer*** output_lists;  // in/out
598 } PJRT_Executable_Execute_Args;
599 const size_t PJRT_Executable_Execute_Args_STRUCT_SIZE =
600     PJRT_STRUCT_SIZE(PJRT_Executable_Execute_Args, output_lists);
601 
602 // Executes on devices addressable by the client.
603 typedef PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args);
604 
605 typedef struct {
606   size_t struct_size;
607   void* priv;
608   PJRT_Executable* executable;
609   size_t num_outputs;  // out
610 } PJRT_Executable_NumOutputs_Args;
611 const size_t PJRT_Executable_NumOutputs_Args_STRUCT_SIZE =
612     PJRT_STRUCT_SIZE(PJRT_Executable_NumOutputs_Args, num_outputs);
613 
614 // Gets the number of outputs per device produced by `executable`.
615 typedef PJRT_Error* PJRT_Executable_NumOutputs(
616     PJRT_Executable_NumOutputs_Args* args);
617 
618 // ---------------------------------- Buffers ----------------------------------
619 
620 typedef struct {
621   size_t struct_size;
622   void* priv;
623   PJRT_Buffer* buffer;
624 } PJRT_Buffer_Destroy_Args;
625 const size_t PJRT_Buffer_Destroy_Args_STRUCT_SIZE =
626     PJRT_STRUCT_SIZE(PJRT_Buffer_Destroy_Args, buffer);
627 
628 // Deletes the underlying runtime objects as if 'PJRT_Buffer_Delete' were
629 // called and frees `buffer`. `buffer` can be nullptr.
630 typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args);
631 
632 // This trimmed shape doesn't have any Tuple information. In case of Tuple,
633 // assert is triggered from the C API  Client.
634 // TODO(b/238999986): This is a temporary solution. Remove this later.
635 typedef struct {
636   size_t struct_size;
637   void* priv;
638   PJRT_Buffer* buffer;
639   int element_type;             // out
640   Int64List dimensions;         // out
641   BoolList dynamic_dimensions;  // out
642   bool has_layout;
643   XLA_Layout layout;            // out
644 } PJRT_Buffer_OnDeviceTrimmedShape_Args;
645 const size_t PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE =
646     PJRT_STRUCT_SIZE(PJRT_Buffer_OnDeviceTrimmedShape_Args, layout);
647 
648 // Return the trimmed shape from PjRtBuffer.
649 // TODO(b/238999986): Replace this with decomposed shape methods.
650 typedef PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
651     PJRT_Buffer_OnDeviceTrimmedShape_Args* args);
652 
653 typedef struct {
654   size_t struct_size;
655   void* priv;
656   PJRT_Buffer* buffer;
657   size_t on_device_size_in_bytes;  // out
658 } PJRT_Buffer_OnDeviceSizeInBytes_Args;
659 const size_t PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE =
660     PJRT_STRUCT_SIZE(PJRT_Buffer_OnDeviceSizeInBytes_Args,
661                      on_device_size_in_bytes);
662 
663 // Gets the number of bytes of the buffer storage on the device
664 typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
665     PJRT_Buffer_OnDeviceSizeInBytes_Args* args);
666 
667 typedef struct {
668   size_t struct_size;
669   void* priv;
670   PJRT_Buffer* buffer;
671 } PJRT_Buffer_Delete_Args;
672 const size_t PJRT_Buffer_Delete_Args_STRUCT_SIZE =
673     PJRT_STRUCT_SIZE(PJRT_Buffer_Delete_Args, buffer);
674 
675 // Drop the buffer's reference to its associated device memory, without freeing
676 // the `buffer` object itself. `buffer` can only be used with
677 // PJRT_Buffer_IsDeleted and PJRT_Buffer_Destroy after calling this method. The
678 // device memory will be freed when all async operations using the buffer have
679 // completed, according to the allocation semantics of the underlying platform.
680 typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args);
681 
682 typedef struct {
683   size_t struct_size;
684   void* priv;
685   PJRT_Buffer* buffer;
686   bool is_deleted;  // out
687 } PJRT_Buffer_IsDeleted_Args;
688 const size_t PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE =
689     PJRT_STRUCT_SIZE(PJRT_Buffer_IsDeleted_Args, is_deleted);
690 
691 // True if and only if PJRT_Buffer_Delete has previously been called.
692 typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args);
693 
694 typedef struct {
695   size_t struct_size;
696   void* priv;
697   PJRT_Buffer* buffer;
698   PJRT_Device* dst_device;
699   PJRT_Buffer* dst_buffer;  // out
700 } PJRT_Buffer_CopyToDevice_Args;
701 const size_t PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE =
702     PJRT_STRUCT_SIZE(PJRT_Buffer_CopyToDevice_Args, dst_buffer);
703 
704 // Copies the buffer to device `dst_device`. Caller is responsible for freeing
705 // returned `dst_buffer` with PJRT_Buffer_Destroy. Returns an error if the
706 // buffer is already on `dst_device`.
707 typedef PJRT_Error* PJRT_Buffer_CopyToDevice(
708     PJRT_Buffer_CopyToDevice_Args* args);
709 
710 typedef struct {
711   size_t struct_size;
712   void* priv;
713   PJRT_Buffer* buffer;
714   bool is_on_cpu;  // out
715 } PJRT_Buffer_IsOnCpu_Args;
716 const size_t PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE =
717     PJRT_STRUCT_SIZE(PJRT_Buffer_IsOnCpu_Args, is_on_cpu);
718 
719 // Whether this buffer is on CPU and thus allows for certain optimizations.
720 typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args);
721 
722 typedef struct {
723   size_t struct_size;
724   void* priv;
725   PJRT_Buffer* buffer;
726   PJRT_Device* device;  // out
727 } PJRT_Buffer_Device_Args;
728 const size_t PJRT_Buffer_Device_Args_STRUCT_SIZE =
729     PJRT_STRUCT_SIZE(PJRT_Buffer_Device_Args, device);
730 
731 // Returns this buffer's storage device.
732 typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args);
733 
734 typedef struct {
735   size_t struct_size;
736   void* priv;
737   PJRT_Buffer* buffer;
738   // The caller is responsible for calling PJRT_Event_Destroy on `event`.
739   PJRT_Event* event;  // out
740 } PJRT_Buffer_ReadyEvent_Args;
741 const size_t PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE =
742     PJRT_STRUCT_SIZE(PJRT_Buffer_ReadyEvent_Args, event);
743 
744 // Returns an event that is triggered when either of the following happens:
745 // * the data in the PJRT_Buffer becomes ready, or
746 // * an error has occurred.
747 //
748 // TODO(b/241967811): change these weird semantics
749 // If the buffer has been deleted or donated, the returned event will
750 // immediately indicate an error. However, if PJRT_Buffer_ReadyEvent() is
751 // called on the buffer before PJRT_Buffer_Delete() is, the returned event will
752 // not transition to an error state after PJRT_Buffer_Delete() is called.
753 typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args);
754 
755 // -------------------------------- API access ---------------------------------
756 
757 #define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type
758 
759 // Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed.
760 typedef struct {
761   size_t struct_size;
762   void* priv;
763 
764   _PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy);
765   _PJRT_API_STRUCT_FIELD(PJRT_Error_Message);
766   _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode);
767 
768   _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy);
769   _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady);
770   _PJRT_API_STRUCT_FIELD(PJRT_Event_Error);
771   _PJRT_API_STRUCT_FIELD(PJRT_Event_Await);
772   _PJRT_API_STRUCT_FIELD(PJRT_Event_OnReady);
773 
774   _PJRT_API_STRUCT_FIELD(PJRT_Client_Create);
775   _PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy);
776   _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName);
777   _PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex);
778   _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion);
779   _PJRT_API_STRUCT_FIELD(PJRT_Client_Devices);
780   _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices);
781   _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupDevice);
782   _PJRT_API_STRUCT_FIELD(PJRT_Client_Compile);
783 
784   _PJRT_API_STRUCT_FIELD(PJRT_Device_Id);
785   _PJRT_API_STRUCT_FIELD(PJRT_Device_ProcessIndex);
786   _PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable);
787   _PJRT_API_STRUCT_FIELD(PJRT_Device_Attributes);
788   _PJRT_API_STRUCT_FIELD(PJRT_Device_Kind);
789   _PJRT_API_STRUCT_FIELD(PJRT_Device_LocalHardwareId);
790   _PJRT_API_STRUCT_FIELD(PJRT_Device_DebugString);
791   _PJRT_API_STRUCT_FIELD(PJRT_Device_ToString);
792 
793   _PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy);
794   _PJRT_API_STRUCT_FIELD(PJRT_Executable_Name);
795   _PJRT_API_STRUCT_FIELD(PJRT_Executable_AddressableDevices);
796   _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumOutputs);
797   _PJRT_API_STRUCT_FIELD(PJRT_Executable_Delete);
798   _PJRT_API_STRUCT_FIELD(PJRT_Executable_IsDeleted);
799   _PJRT_API_STRUCT_FIELD(PJRT_Executable_Execute);
800 
801   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Destroy);
802   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceTrimmedShape);
803   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceSizeInBytes);
804   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Device);
805   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete);
806   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted);
807   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToDevice);
808   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu);
809   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ReadyEvent);
810 } PJRT_Api;
811 
812 const size_t PJRT_Api_STRUCT_SIZE =
813     PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_ReadyEvent);
814 
815 #undef _PJRT_API_STRUCT_FIELD
816 
817 #ifdef __cplusplus
818 }
819 #endif
820 
821 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_
822