1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_ 18 19 #include <stddef.h> 20 #include <stdint.h> 21 22 // TODO(b/238999986): Remove this. 23 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" 24 25 #define PJRT_STRUCT_SIZE(struct_type, last_field) \ 26 offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field) 27 28 #ifdef __cplusplus 29 extern "C" { 30 #endif 31 32 // ---------------------------------- Errors ----------------------------------- 33 34 // PJRT C API methods generally return a PJRT_Error*, which is nullptr if there 35 // is no error and set if there is. The implementation allocates any returned 36 // PJRT_Errors, but the caller is always responsible for freeing them via 37 // PJRT_Error_Destroy. 38 39 typedef struct PJRT_Error PJRT_Error; 40 41 typedef struct { 42 size_t struct_size; 43 void* priv; 44 PJRT_Error* error; 45 } PJRT_Error_Destroy_Args; 46 const size_t PJRT_Error_Destroy_Args_STRUCT_SIZE = 47 PJRT_STRUCT_SIZE(PJRT_Error_Destroy_Args, error); 48 49 // Frees `error`. `error` can be nullptr. 50 typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); 51 52 typedef struct { 53 size_t struct_size; 54 void* priv; 55 const PJRT_Error* error; 56 // Has the lifetime of `error`. 57 const char* message; // out 58 size_t message_size; // out 59 } PJRT_Error_Message_Args; 60 const size_t PJRT_Error_Message_Args_STRUCT_SIZE = 61 PJRT_STRUCT_SIZE(PJRT_Error_Message_Args, message_size); 62 63 // Gets the human-readable reason for `error`. `message` has the lifetime of 64 // `error`. 65 typedef void PJRT_Error_Message(PJRT_Error_Message_Args* args); 66 67 // Codes are based on https://abseil.io/docs/cpp/guides/status-codes 68 typedef enum { 69 PJRT_Error_Code_CANCELLED = 1, 70 PJRT_Error_Code_UNKNOWN = 2, 71 PJRT_Error_Code_INVALID_ARGUMENT = 3, 72 PJRT_Error_Code_DEADLINE_EXCEEDED = 4, 73 PJRT_Error_Code_NOT_FOUND = 5, 74 PJRT_Error_Code_ALREADY_EXISTS = 6, 75 PJRT_Error_Code_PERMISSION_DENIED = 7, 76 PJRT_Error_Code_RESOURCE_EXHAUSTED = 8, 77 PJRT_Error_Code_FAILED_PRECONDITION = 9, 78 PJRT_Error_Code_ABORTED = 10, 79 PJRT_Error_Code_OUT_OF_RANGE = 11, 80 PJRT_Error_Code_UNIMPLEMENTED = 12, 81 PJRT_Error_Code_INTERNAL = 13, 82 PJRT_Error_Code_UNAVAILABLE = 14, 83 PJRT_Error_Code_DATA_LOSS = 15, 84 PJRT_Error_Code_UNAUTHENTICATED = 16 85 } PJRT_Error_Code; 86 87 typedef struct { 88 size_t struct_size; 89 void* priv; 90 const PJRT_Error* error; 91 PJRT_Error_Code code; // out 92 } PJRT_Error_GetCode_Args; 93 const size_t PJRT_Error_GetCode_Args_STRUCT_SIZE = 94 PJRT_STRUCT_SIZE(PJRT_Error_GetCode_Args, error); 95 96 typedef PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); 97 98 // ---------------------------------- Events ----------------------------------- 99 100 // Represents a notifying event that is returned by PJRT APIs that enqueue 101 // asynchronous work, informing callers when the work is complete and reporting 102 // a value of type `PJRT_Error*` or `nullptr` as error status. 103 // 104 // Callers are always responsible for freeing `PJRT_Event`s by calling 105 // `PJRT_Event_Destroy`. 106 typedef struct PJRT_Event PJRT_Event; 107 108 typedef struct { 109 size_t struct_size; 110 void* priv; 111 PJRT_Event* event; 112 } PJRT_Event_Destroy_Args; 113 const size_t PJRT_Event_Destroy_Args_STRUCT_SIZE = 114 PJRT_STRUCT_SIZE(PJRT_Event_Destroy_Args, event); 115 116 // Frees `event`. `event` can be `nullptr`. 117 typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); 118 119 typedef struct { 120 size_t struct_size; 121 void* priv; 122 PJRT_Event* event; 123 bool is_ready; // out 124 } PJRT_Event_IsReady_Args; 125 const size_t PJRT_Event_IsReady_Args_STRUCT_SIZE = 126 PJRT_STRUCT_SIZE(PJRT_Event_IsReady_Args, is_ready); 127 128 // Returns true if this PJRT_Event has completed, including if an error has 129 // occurred. 130 typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); 131 132 typedef struct { 133 size_t struct_size; 134 void* priv; 135 PJRT_Event* event; 136 } PJRT_Event_Error_Args; 137 const size_t PJRT_Event_Error_Args_STRUCT_SIZE = 138 PJRT_STRUCT_SIZE(PJRT_Event_Error_Args, event); 139 140 // Should only be called if PJRT_Event_IsReady returns true. 141 // Returns `nullptr` if there is no error. 142 // The returned error should be freed with `PJRT_Error_Destroy`. 143 // 144 // If `PJRT_Event_Await` has been called, this will return a pointer to an 145 // identical error status as that call, as will subsequent calls to 146 // `PJRT_Event_Error`. However, each of these `PJRT_Error *` pointers are 147 // independent of `PJRT_Error *`s returned by other function calls, so they must 148 // each be freed separately using `PJRT_Error_Destroy`. 149 typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); 150 151 typedef struct { 152 size_t struct_size; 153 void* priv; 154 PJRT_Event* event; 155 } PJRT_Event_Await_Args; 156 157 const size_t PJRT_Event_Await_Args_STRUCT_SIZE = 158 PJRT_STRUCT_SIZE(PJRT_Event_Await_Args, event); 159 160 // Blocks the calling thread until `event` is ready, then returns the error 161 // status (with `nullptr` indicating no error). The returned status should be 162 // freed with `PJRT_Error_Destroy`. 163 typedef PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args); 164 165 // A callback to be performed once an event is ready. It will be called on the 166 // event's error state and a pointer to an object of the caller's choice. 167 // Ownership of `error` is passed to the callback. The callback must destroy 168 // `error` via `PJRT_Error_Destroy`. The caller retains ownership of `user_arg`. 169 typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg); 170 171 typedef struct { 172 size_t struct_size; 173 void* priv; 174 PJRT_Event* event; 175 PJRT_Event_OnReadyCallback callback; 176 // `user_arg` allows `callback` to be called with arbitrary arguments (e.g. 177 // via pointers in a struct cast to void*). 178 void* user_arg; 179 } PJRT_Event_OnReady_Args; 180 const size_t PJRT_Event_OnReady_Args_STRUCT_SIZE = 181 PJRT_STRUCT_SIZE(PJRT_Event_OnReady_Args, user_arg); 182 183 // Registers `callback` to be called once `event` is ready, with `event`'s 184 // error status and a pointer to an object of the caller's choice as arguments. 185 typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); 186 187 // ---------------------------------- Client ----------------------------------- 188 189 typedef struct PJRT_Client PJRT_Client; 190 typedef struct PJRT_Device PJRT_Device; 191 typedef struct PJRT_Executable PJRT_Executable; 192 193 typedef struct { 194 size_t struct_size; 195 void* priv; 196 PJRT_Client* client; // out 197 } PJRT_Client_Create_Args; 198 const size_t PJRT_Client_Create_Args_STRUCT_SIZE = 199 PJRT_STRUCT_SIZE(PJRT_Client_Create_Args, client); 200 201 // Creates and initializes a new PJRT_Client and returns in `client`. 202 typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); 203 204 typedef struct { 205 size_t struct_size; 206 void* priv; 207 PJRT_Client* client; 208 } PJRT_Client_Destroy_Args; 209 const size_t PJRT_Client_Destroy_Args_STRUCT_SIZE = 210 PJRT_STRUCT_SIZE(PJRT_Client_Destroy_Args, client); 211 212 // Shuts down and frees `client`. `client` can be nullptr. 213 typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args); 214 215 typedef struct { 216 size_t struct_size; 217 void* priv; 218 PJRT_Client* client; 219 // `platform_name` has the same lifetime as `client`. It is owned by `client`. 220 const char* platform_name; // out 221 size_t platform_name_size; // out 222 } PJRT_Client_PlatformName_Args; 223 224 const size_t PJRT_Client_PlatformName_Args_STRUCT_SIZE = 225 PJRT_STRUCT_SIZE(PJRT_Client_PlatformName_Args, platform_name_size); 226 227 // Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu"). 228 typedef PJRT_Error* PJRT_Client_PlatformName( 229 PJRT_Client_PlatformName_Args* args); 230 231 typedef struct { 232 size_t struct_size; 233 void* priv; 234 PJRT_Client* client; 235 int process_index; // out 236 } PJRT_Client_ProcessIndex_Args; 237 const size_t PJRT_Client_ProcessIndex_Args_STRUCT_SIZE = 238 PJRT_STRUCT_SIZE(PJRT_Client_ProcessIndex_Args, process_index); 239 240 // Return the process index of this client. Always 0 in single-process 241 // settings. 242 typedef PJRT_Error* PJRT_Client_ProcessIndex( 243 PJRT_Client_ProcessIndex_Args* args); 244 245 typedef struct { 246 size_t struct_size; 247 void* priv; 248 PJRT_Client* client; 249 // `platform_version` has the same lifetime as `client`. It's owned by 250 // `client`. 251 const char* platform_version; // out 252 size_t platform_version_size; // out 253 } PJRT_Client_PlatformVersion_Args; 254 255 const size_t PJRT_Client_PlatformVersion_Args_STRUCT_SIZE = 256 PJRT_STRUCT_SIZE(PJRT_Client_PlatformVersion_Args, platform_version_size); 257 258 // Returns a string containing human-readable, platform-specific version info 259 // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). 260 typedef PJRT_Error* PJRT_Client_PlatformVersion( 261 PJRT_Client_PlatformVersion_Args* args); 262 263 typedef struct { 264 size_t struct_size; 265 void* priv; 266 PJRT_Client* client; 267 PJRT_Device** devices; // out 268 size_t num_devices; // out 269 } PJRT_Client_Devices_Args; 270 const size_t PJRT_Client_Devices_Args_STRUCT_SIZE = 271 PJRT_STRUCT_SIZE(PJRT_Client_Devices_Args, num_devices); 272 273 // Returns a list of all devices visible to the runtime, including addressable 274 // and non-addressable devices. 275 typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args); 276 277 typedef struct { 278 size_t struct_size; 279 void* priv; 280 PJRT_Client* client; 281 PJRT_Device** addressable_devices; // out 282 size_t num_addressable_devices; // out 283 } PJRT_Client_AddressableDevices_Args; 284 const size_t PJRT_Client_AddressableDevices_Args_STRUCT_SIZE = PJRT_STRUCT_SIZE( 285 PJRT_Client_AddressableDevices_Args, num_addressable_devices); 286 287 // Returns a list of devices that are addressable from the client. 288 // Addressable devices are those that the client can issue commands to. 289 // All devices are addressable in a single-process environment. 290 typedef PJRT_Error* PJRT_Client_AddressableDevices( 291 PJRT_Client_AddressableDevices_Args* args); 292 293 typedef struct { 294 size_t struct_size; 295 void* priv; 296 PJRT_Client* client; 297 int id; 298 // `device` has the same lifetime as `client`. It is owned by `client`. 299 PJRT_Device* device; // out 300 } PJRT_Client_LookupDevice_Args; 301 302 const size_t PJRT_Client_LookupDevice_Args_STRUCT_SIZE = 303 PJRT_STRUCT_SIZE(PJRT_Client_LookupDevice_Args, device); 304 305 // Returns a PJRT_Device* with the specified ID as returned by PJRT_Device_Id. 306 typedef PJRT_Error* PJRT_Client_LookupDevice( 307 PJRT_Client_LookupDevice_Args* args); 308 309 // TODO(jieying): add debug_option. 310 // TODO(b/240560013): consider putting some of option fields in priv. 311 typedef struct { 312 size_t struct_size; 313 void* priv; 314 // If true, the supplied module expects its arguments to be wrapped in a 315 // tuple and passed as a single parameter. 316 bool parameter_is_tupled_arguments; 317 // If set, this is the device to build the computation for. A value of -1 318 // indicates this option has not been set. 319 int device_ordinal; 320 // The number of replicas of this computation that are to be executed. 321 int num_replicas; 322 // The number of partitions in this computation. 323 int num_partitions; 324 // Whether to use SPMD (true) or MPMD (false) partitioning when 325 // num_partitions > 1 and XLA is requested to partition the input program. 326 bool use_spmd_partitioning; 327 // Whether to allow sharding propagation to propagate to the outputs. 328 bool allow_spmd_sharding_propagation_to_output; 329 const char* device_assignment; // Serialized device assignment. 330 size_t device_assignment_size; 331 } PJRT_CompileOptions; 332 const size_t PJRT_CompileOptions_STRUCT_SIZE = 333 PJRT_STRUCT_SIZE(PJRT_CompileOptions, device_assignment_size); 334 335 typedef struct { 336 size_t struct_size; 337 void* priv; 338 PJRT_Client* client; 339 // Serialized MLIR module. Only needs to stay alive for the duration of the 340 // Compile call. 341 const char* module; 342 size_t module_size; 343 // Only needs to stay alive for the duration of the Compile call. 344 PJRT_CompileOptions* options; 345 PJRT_Executable* executable; // out 346 } PJRT_Client_Compile_Args; 347 348 const size_t PJRT_Client_Compile_Args_STRUCT_SIZE = 349 PJRT_STRUCT_SIZE(PJRT_Client_Compile_Args, executable); 350 351 // Compiles an MLIR module with given `options`. 352 typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args); 353 354 // --------------------------------- Devices ----------------------------------- 355 356 typedef struct { 357 size_t struct_size; 358 void* priv; 359 PJRT_Device* device; 360 int id; // out 361 } PJRT_Device_Id_Args; 362 const size_t PJRT_Device_Id_Args_STRUCT_SIZE = 363 PJRT_STRUCT_SIZE(PJRT_Device_Id_Args, id); 364 365 // The ID of this device. IDs are unique among devices of this type 366 // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all 367 // hosts' devices. 368 typedef PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args); 369 370 typedef struct { 371 size_t struct_size; 372 void* priv; 373 PJRT_Device* device; 374 int local_hardware_id; // out 375 } PJRT_Device_LocalHardwareId_Args; 376 const size_t PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE = 377 PJRT_STRUCT_SIZE(PJRT_Device_LocalHardwareId_Args, local_hardware_id); 378 379 // Opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed 380 // to be dense, and -1 if undefined. 381 typedef PJRT_Error* PJRT_Device_LocalHardwareId( 382 PJRT_Device_LocalHardwareId_Args* args); 383 384 typedef struct { 385 size_t struct_size; 386 void* priv; 387 PJRT_Device* device; 388 int process_index; // out 389 } PJRT_Device_ProcessIndex_Args; 390 const size_t PJRT_Device_ProcessIndex_Args_STRUCT_SIZE = 391 PJRT_STRUCT_SIZE(PJRT_Device_ProcessIndex_Args, process_index); 392 393 // The index of the process that this device belongs to, i.e. is addressable 394 // from. This is not always identical to PJRT_Client_ProcessIndex in a 395 // multi-process setting, where each client can see devices from all 396 // processes, but only a subset of them are addressable and have the same 397 // process_index as the client. 398 typedef PJRT_Error* PJRT_Device_ProcessIndex( 399 PJRT_Device_ProcessIndex_Args* args); 400 401 typedef struct { 402 size_t struct_size; 403 void* priv; 404 PJRT_Device* device; 405 bool is_addressable; // out 406 } PJRT_Device_IsAddressable_Args; 407 const size_t PJRT_Device_IsAddressable_Args_STRUCT_SIZE = 408 PJRT_STRUCT_SIZE(PJRT_Device_IsAddressable_Args, is_addressable); 409 410 // Whether client can issue command to this device. 411 typedef PJRT_Error* PJRT_Device_IsAddressable( 412 PJRT_Device_IsAddressable_Args* args); 413 414 typedef struct { 415 size_t struct_size; 416 void* priv; 417 const char* name; 418 size_t name_size; 419 enum { 420 PJRT_Device_Attribute_kString = 0, 421 PJRT_Device_Attribute_kInt64, 422 PJRT_Device_Attribute_kInt64List 423 } type; 424 union { 425 int64_t int64_value; 426 const int64_t* int64_array_value; 427 const char* string_value; 428 }; 429 // `value_size` is the number of elements for array/string and 1 for scalar 430 // values. 431 size_t value_size; 432 } PJRT_Device_Attribute; 433 const size_t PJRT_Device_Attribute_STRUCT_SIZE = 434 PJRT_STRUCT_SIZE(PJRT_Device_Attribute, value_size); 435 436 typedef struct { 437 size_t struct_size; 438 void* priv; 439 PJRT_Device* device; 440 size_t num_attributes; // out 441 PJRT_Device_Attribute* attributes; // out 442 } PJRT_Device_Attributes_Args; 443 const size_t PJRT_Device_Attributes_Args_STRUCT_SIZE = 444 PJRT_STRUCT_SIZE(PJRT_Device_Attributes_Args, attributes); 445 446 // Returns an array of device specific attributes with attribute name, value 447 // and value type. 448 typedef PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args); 449 450 typedef struct { 451 size_t struct_size; 452 void* priv; 453 PJRT_Device* device; 454 // `device_kind` string is owned by `device` and has same lifetime as 455 // `device`. 456 const char* device_kind; // out 457 size_t device_kind_size; // out 458 } PJRT_Device_Kind_Args; 459 const size_t PJRT_Device_Kind_Args_STRUCT_SIZE = 460 PJRT_STRUCT_SIZE(PJRT_Device_Kind_Args, device_kind_size); 461 462 // A vendor-dependent string that uniquely identifies the kind of device, 463 // e.g., "Tesla V100-SXM2-16GB". 464 typedef PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args); 465 466 typedef struct { 467 size_t struct_size; 468 void* priv; 469 PJRT_Device* device; 470 const char* debug_string; // out 471 size_t debug_string_size; // out 472 } PJRT_Device_DebugString_Args; 473 const size_t PJRT_Device_DebugString_Args_STRUCT_SIZE = 474 PJRT_STRUCT_SIZE(PJRT_Device_DebugString_Args, debug_string_size); 475 476 // Debug string suitable for logging when errors occur. Should be verbose 477 // enough to describe the current device unambiguously. 478 typedef PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args); 479 480 typedef struct { 481 size_t struct_size; 482 void* priv; 483 PJRT_Device* device; 484 const char* to_string; // out 485 size_t to_string_size; // out 486 } PJRT_Device_ToString_Args; 487 const size_t PJRT_Device_ToString_Args_STRUCT_SIZE = 488 PJRT_STRUCT_SIZE(PJRT_Device_ToString_Args, to_string_size); 489 490 // Debug string suitable for reading by end users, should be reasonably terse, 491 // for example: "CpuDevice(id=0)". 492 typedef PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args); 493 494 // ------------------------------- Executables --------------------------------- 495 496 typedef struct PJRT_Buffer PJRT_Buffer; 497 498 typedef struct { 499 size_t struct_size; 500 void* priv; 501 PJRT_Executable* executable; 502 } PJRT_Executable_Destroy_Args; 503 const size_t PJRT_Executable_Destroy_Args_STRUCT_SIZE = 504 PJRT_STRUCT_SIZE(PJRT_Executable_Destroy_Args, executable); 505 506 // Frees `executable` and deletes the underlying runtime object as if 507 // `PJRT_Executable_Delete` were called. `executable` can be nullptr. 508 typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args); 509 510 typedef struct { 511 size_t struct_size; 512 void* priv; 513 PJRT_Executable* executable; 514 // `executable_name` has the same lifetime as `executable`. It is owned by 515 // `executable`. 516 const char* executable_name; // out 517 size_t executable_name_size; // out 518 } PJRT_Executable_Name_Args; 519 520 const size_t PJRT_Executable_Name_Args_STRUCT_SIZE = 521 PJRT_STRUCT_SIZE(PJRT_Executable_Name_Args, executable_name_size); 522 523 // Returns a string that identifies the executable. 524 typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args); 525 526 typedef struct { 527 size_t struct_size; 528 void* priv; 529 PJRT_Executable* executable; 530 PJRT_Device** addressable_devices; // out 531 size_t num_addressable_devices; // out 532 } PJRT_Executable_AddressableDevices_Args; 533 534 const size_t PJRT_Executable_AddressableDevices_Args_STRUCT_SIZE = 535 PJRT_STRUCT_SIZE(PJRT_Executable_AddressableDevices_Args, 536 num_addressable_devices); 537 538 // Returns a list of devices this executable will run on. 539 typedef PJRT_Error* PJRT_Executable_AddressableDevices( 540 PJRT_Executable_AddressableDevices_Args* args); 541 542 typedef struct { 543 size_t struct_size; 544 void* priv; 545 PJRT_Executable* executable; 546 } PJRT_Executable_Delete_Args; 547 const size_t PJRT_Executable_Delete_Args_STRUCT_SIZE = 548 PJRT_STRUCT_SIZE(PJRT_Executable_Delete_Args, executable); 549 550 // Drops `executable`'s reference to the internal runtime object and 551 // associated resources, without freeing the `executable` object itself. 552 // `executable` can only be used with PJRT_Executable_IsDeleted and 553 // PJRT_Executable_Destroy after calling this method. The internal runtime 554 // executable will be freed after the last execution completes. 555 typedef PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args); 556 557 typedef struct { 558 size_t struct_size; 559 void* priv; 560 PJRT_Executable* executable; 561 bool is_deleted; // out 562 } PJRT_Executable_IsDeleted_Args; 563 const size_t PJRT_Executable_IsDeleted_Args_STRUCT_SIZE = 564 PJRT_STRUCT_SIZE(PJRT_Executable_IsDeleted_Args, is_deleted); 565 566 // True if and only if PJRT_Executable_Delete has previously been called. 567 typedef PJRT_Error* PJRT_Executable_IsDeleted( 568 PJRT_Executable_IsDeleted_Args* args); 569 570 typedef struct { 571 size_t struct_size; 572 void* priv; 573 // If non-zero, identifies this execution as part of a potentially 574 // multi-device launch. This can be used to detect scheduling errors, e.g. if 575 // multi-host programs are launched in different orders on different hosts, 576 // the launch IDs may be used by the runtime to detect the mismatch. 577 int launch_id; 578 } PJRT_ExecuteOptions; 579 const size_t PJRT_ExecuteOptions_STRUCT_SIZE = 580 PJRT_STRUCT_SIZE(PJRT_ExecuteOptions, launch_id); 581 582 typedef struct { 583 size_t struct_size; 584 void* priv; 585 PJRT_Executable* executable; 586 // Only needs to stay alive for the duration of the Execute call. 587 PJRT_ExecuteOptions* options; 588 // Execution input of size [`num_devices`, `num_args`]. 589 PJRT_Buffer*** argument_lists; 590 size_t num_devices; 591 size_t num_args; 592 // Execution output of size [`num_devices`, num_outputs`], where `num_outputs` 593 // is the number of outputs returned by this executable per device. Both the 594 // outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be 595 // allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called 596 // on the output PJRT_Buffer*. 597 PJRT_Buffer*** output_lists; // in/out 598 } PJRT_Executable_Execute_Args; 599 const size_t PJRT_Executable_Execute_Args_STRUCT_SIZE = 600 PJRT_STRUCT_SIZE(PJRT_Executable_Execute_Args, output_lists); 601 602 // Executes on devices addressable by the client. 603 typedef PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args); 604 605 typedef struct { 606 size_t struct_size; 607 void* priv; 608 PJRT_Executable* executable; 609 size_t num_outputs; // out 610 } PJRT_Executable_NumOutputs_Args; 611 const size_t PJRT_Executable_NumOutputs_Args_STRUCT_SIZE = 612 PJRT_STRUCT_SIZE(PJRT_Executable_NumOutputs_Args, num_outputs); 613 614 // Gets the number of outputs per device produced by `executable`. 615 typedef PJRT_Error* PJRT_Executable_NumOutputs( 616 PJRT_Executable_NumOutputs_Args* args); 617 618 // ---------------------------------- Buffers ---------------------------------- 619 620 typedef struct { 621 size_t struct_size; 622 void* priv; 623 PJRT_Buffer* buffer; 624 } PJRT_Buffer_Destroy_Args; 625 const size_t PJRT_Buffer_Destroy_Args_STRUCT_SIZE = 626 PJRT_STRUCT_SIZE(PJRT_Buffer_Destroy_Args, buffer); 627 628 // Deletes the underlying runtime objects as if 'PJRT_Buffer_Delete' were 629 // called and frees `buffer`. `buffer` can be nullptr. 630 typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args); 631 632 // This trimmed shape doesn't have any Tuple information. In case of Tuple, 633 // assert is triggered from the C API Client. 634 // TODO(b/238999986): This is a temporary solution. Remove this later. 635 typedef struct { 636 size_t struct_size; 637 void* priv; 638 PJRT_Buffer* buffer; 639 int element_type; // out 640 Int64List dimensions; // out 641 BoolList dynamic_dimensions; // out 642 bool has_layout; 643 XLA_Layout layout; // out 644 } PJRT_Buffer_OnDeviceTrimmedShape_Args; 645 const size_t PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE = 646 PJRT_STRUCT_SIZE(PJRT_Buffer_OnDeviceTrimmedShape_Args, layout); 647 648 // Return the trimmed shape from PjRtBuffer. 649 // TODO(b/238999986): Replace this with decomposed shape methods. 650 typedef PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape( 651 PJRT_Buffer_OnDeviceTrimmedShape_Args* args); 652 653 typedef struct { 654 size_t struct_size; 655 void* priv; 656 PJRT_Buffer* buffer; 657 size_t on_device_size_in_bytes; // out 658 } PJRT_Buffer_OnDeviceSizeInBytes_Args; 659 const size_t PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE = 660 PJRT_STRUCT_SIZE(PJRT_Buffer_OnDeviceSizeInBytes_Args, 661 on_device_size_in_bytes); 662 663 // Gets the number of bytes of the buffer storage on the device 664 typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes( 665 PJRT_Buffer_OnDeviceSizeInBytes_Args* args); 666 667 typedef struct { 668 size_t struct_size; 669 void* priv; 670 PJRT_Buffer* buffer; 671 } PJRT_Buffer_Delete_Args; 672 const size_t PJRT_Buffer_Delete_Args_STRUCT_SIZE = 673 PJRT_STRUCT_SIZE(PJRT_Buffer_Delete_Args, buffer); 674 675 // Drop the buffer's reference to its associated device memory, without freeing 676 // the `buffer` object itself. `buffer` can only be used with 677 // PJRT_Buffer_IsDeleted and PJRT_Buffer_Destroy after calling this method. The 678 // device memory will be freed when all async operations using the buffer have 679 // completed, according to the allocation semantics of the underlying platform. 680 typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args); 681 682 typedef struct { 683 size_t struct_size; 684 void* priv; 685 PJRT_Buffer* buffer; 686 bool is_deleted; // out 687 } PJRT_Buffer_IsDeleted_Args; 688 const size_t PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE = 689 PJRT_STRUCT_SIZE(PJRT_Buffer_IsDeleted_Args, is_deleted); 690 691 // True if and only if PJRT_Buffer_Delete has previously been called. 692 typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args); 693 694 typedef struct { 695 size_t struct_size; 696 void* priv; 697 PJRT_Buffer* buffer; 698 PJRT_Device* dst_device; 699 PJRT_Buffer* dst_buffer; // out 700 } PJRT_Buffer_CopyToDevice_Args; 701 const size_t PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE = 702 PJRT_STRUCT_SIZE(PJRT_Buffer_CopyToDevice_Args, dst_buffer); 703 704 // Copies the buffer to device `dst_device`. Caller is responsible for freeing 705 // returned `dst_buffer` with PJRT_Buffer_Destroy. Returns an error if the 706 // buffer is already on `dst_device`. 707 typedef PJRT_Error* PJRT_Buffer_CopyToDevice( 708 PJRT_Buffer_CopyToDevice_Args* args); 709 710 typedef struct { 711 size_t struct_size; 712 void* priv; 713 PJRT_Buffer* buffer; 714 bool is_on_cpu; // out 715 } PJRT_Buffer_IsOnCpu_Args; 716 const size_t PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE = 717 PJRT_STRUCT_SIZE(PJRT_Buffer_IsOnCpu_Args, is_on_cpu); 718 719 // Whether this buffer is on CPU and thus allows for certain optimizations. 720 typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); 721 722 typedef struct { 723 size_t struct_size; 724 void* priv; 725 PJRT_Buffer* buffer; 726 PJRT_Device* device; // out 727 } PJRT_Buffer_Device_Args; 728 const size_t PJRT_Buffer_Device_Args_STRUCT_SIZE = 729 PJRT_STRUCT_SIZE(PJRT_Buffer_Device_Args, device); 730 731 // Returns this buffer's storage device. 732 typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args); 733 734 typedef struct { 735 size_t struct_size; 736 void* priv; 737 PJRT_Buffer* buffer; 738 // The caller is responsible for calling PJRT_Event_Destroy on `event`. 739 PJRT_Event* event; // out 740 } PJRT_Buffer_ReadyEvent_Args; 741 const size_t PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE = 742 PJRT_STRUCT_SIZE(PJRT_Buffer_ReadyEvent_Args, event); 743 744 // Returns an event that is triggered when either of the following happens: 745 // * the data in the PJRT_Buffer becomes ready, or 746 // * an error has occurred. 747 // 748 // TODO(b/241967811): change these weird semantics 749 // If the buffer has been deleted or donated, the returned event will 750 // immediately indicate an error. However, if PJRT_Buffer_ReadyEvent() is 751 // called on the buffer before PJRT_Buffer_Delete() is, the returned event will 752 // not transition to an error state after PJRT_Buffer_Delete() is called. 753 typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args); 754 755 // -------------------------------- API access --------------------------------- 756 757 #define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type 758 759 // Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed. 760 typedef struct { 761 size_t struct_size; 762 void* priv; 763 764 _PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy); 765 _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); 766 _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode); 767 768 _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy); 769 _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady); 770 _PJRT_API_STRUCT_FIELD(PJRT_Event_Error); 771 _PJRT_API_STRUCT_FIELD(PJRT_Event_Await); 772 _PJRT_API_STRUCT_FIELD(PJRT_Event_OnReady); 773 774 _PJRT_API_STRUCT_FIELD(PJRT_Client_Create); 775 _PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy); 776 _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName); 777 _PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex); 778 _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion); 779 _PJRT_API_STRUCT_FIELD(PJRT_Client_Devices); 780 _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices); 781 _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupDevice); 782 _PJRT_API_STRUCT_FIELD(PJRT_Client_Compile); 783 784 _PJRT_API_STRUCT_FIELD(PJRT_Device_Id); 785 _PJRT_API_STRUCT_FIELD(PJRT_Device_ProcessIndex); 786 _PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable); 787 _PJRT_API_STRUCT_FIELD(PJRT_Device_Attributes); 788 _PJRT_API_STRUCT_FIELD(PJRT_Device_Kind); 789 _PJRT_API_STRUCT_FIELD(PJRT_Device_LocalHardwareId); 790 _PJRT_API_STRUCT_FIELD(PJRT_Device_DebugString); 791 _PJRT_API_STRUCT_FIELD(PJRT_Device_ToString); 792 793 _PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy); 794 _PJRT_API_STRUCT_FIELD(PJRT_Executable_Name); 795 _PJRT_API_STRUCT_FIELD(PJRT_Executable_AddressableDevices); 796 _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumOutputs); 797 _PJRT_API_STRUCT_FIELD(PJRT_Executable_Delete); 798 _PJRT_API_STRUCT_FIELD(PJRT_Executable_IsDeleted); 799 _PJRT_API_STRUCT_FIELD(PJRT_Executable_Execute); 800 801 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Destroy); 802 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceTrimmedShape); 803 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceSizeInBytes); 804 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Device); 805 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete); 806 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted); 807 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToDevice); 808 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu); 809 _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ReadyEvent); 810 } PJRT_Api; 811 812 const size_t PJRT_Api_STRUCT_SIZE = 813 PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_ReadyEvent); 814 815 #undef _PJRT_API_STRUCT_FIELD 816 817 #ifdef __cplusplus 818 } 819 #endif 820 821 #endif // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_ 822