xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
27 
28 struct PJRT_Error {
29   xla::Status status;
30 };
31 
32 struct PJRT_Client {
33   std::unique_ptr<xla::PjRtClient> client;
34   std::vector<PJRT_Device> owned_devices;
35   // `devices` contains the addresses of the contents of `owned_devices`.
36   std::vector<PJRT_Device*> devices;
37   // `addressable_devices` contains pointers to the `owned_devices` that the
38   // client can issue commands to.
39   std::vector<PJRT_Device*> addressable_devices;
40   // Map from wrapped C++ devices to C devices. The values are the same as
41   // `owned_devices`.
42   absl::flat_hash_map<xla::PjRtDevice*, PJRT_Device*> c_device_from_cpp_device;
43 };
44 
45 // PJRT_Devices are owned by their corresponding PJRT_Client.
46 struct PJRT_Device {
47   // The xla::PjRtDevice* is owned by the corresponding xla::PjRtClient.
48   xla::PjRtDevice* device;
49   // The device specific attributes which are initialized once per device.
50   std::vector<PJRT_Device_Attribute> attributes;
51 };
52 
53 struct PJRT_Executable {
54   std::unique_ptr<xla::PjRtLoadedExecutable> executable;
55   PJRT_Client* client;
56   // These pointers are a subset of `client`'s `addressable_devices`, i.e. those
57   // addressed by the compiled executable program. `client` owns the objects
58   // these point to.
59   std::vector<PJRT_Device*> addressable_devices;
60   // TODO(b/237545405): Remove `populated` once we implement creation methods
61   // for PJRT_Executable that can populate addressable_devices on instantiation.
62   bool populated = false;
63 };
64 
65 struct PJRT_Buffer {
66   std::unique_ptr<xla::PjRtBuffer> buffer;
67   PJRT_Client* client;
68 };
69 
70 struct PJRT_Event {
71   xla::PjRtFuture<xla::Status> future;
72   // Set and stored upon future.Await(), as PjRtFuture only allows its result to
73   // be queried through Await() and Await() can only safely be called once. This
74   // variable allows C API users to check for error status any time after
75   // Await() has been called.
76   std::optional<xla::Status> status;
77 };
78 
79 namespace pjrt {
80 
81 // C API definitions
82 
83 void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args);
84 void PJRT_Error_Message(PJRT_Error_Message_Args* args);
85 PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args);
86 
87 PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args);
88 PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args);
89 PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args);
90 PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args);
91 PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args);
92 
93 PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args);
94 PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args);
95 PJRT_Error* PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args* args);
96 PJRT_Error* PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args* args);
97 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args);
98 PJRT_Error* PJRT_Client_AddressableDevices(
99     PJRT_Client_AddressableDevices_Args* args);
100 PJRT_Error* PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args* args);
101 PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args);
102 
103 PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args);
104 PJRT_Error* PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args* args);
105 PJRT_Error* PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args* args);
106 PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args);
107 PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args);
108 PJRT_Error* PJRT_Device_LocalHardwareId(PJRT_Device_LocalHardwareId_Args* args);
109 PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args);
110 PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args);
111 
112 PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args);
113 PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args);
114 PJRT_Error* PJRT_Executable_AddressableDevices(
115     PJRT_Executable_AddressableDevices_Args* args);
116 PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args);
117 PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args);
118 PJRT_Error* PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args* args);
119 PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args);
120 
121 PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args);
122 PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
123     PJRT_Buffer_OnDeviceTrimmedShape_Args* args);
124 PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
125     PJRT_Buffer_OnDeviceSizeInBytes_Args* args);
126 PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args);
127 PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args);
128 PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args);
129 PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args);
130 PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args);
131 PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args);
132 
133 // Helper macros and functions
134 
135 #define PJRT_RETURN_IF_ERROR(expr)                                \
136   do {                                                            \
137     xla::Status _status = (expr);                                 \
138     if (!_status.ok()) {                                          \
139       PJRT_Error* _c_status = new PJRT_Error{std::move(_status)}; \
140       return _c_status;                                           \
141     }                                                             \
142   } while (false)
143 
144 #define PJRT_ASSIGN_OR_RETURN(lhs, rexpr)                                  \
145   _PJRT_ASSIGN_OR_RETURN_IMPL(_PJRT_CONCAT(_status_or_value, __COUNTER__), \
146                               lhs, rexpr,                                  \
147                               _PJRT_CONCAT(_c_status, __COUNTER__));
148 
149 #define _PJRT_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \
150   auto statusor = (rexpr);                                          \
151   if (!statusor.ok()) {                                             \
152     PJRT_Error* c_status = new PJRT_Error();                        \
153     c_status->status = statusor.status();                           \
154     return c_status;                                                \
155   }                                                                 \
156   lhs = std::move(*statusor)
157 
158 #define _PJRT_CONCAT(x, y) _PJRT_CONCAT_IMPL(x, y)
159 #define _PJRT_CONCAT_IMPL(x, y) x##y
160 
161 // Helper function for checking C API argument struct sizes. Returns a non-OK
162 // status if the expected and actual sizes aren't equal (i.e. no ABI
163 // compatibility guarantees).
164 xla::Status CheckMatchingStructSizes(absl::string_view struct_name,
165                                      size_t expected_size, size_t actual_size);
166 
167 // Helper function
168 std::string StructSizeErrorMsg(absl::string_view struct_name,
169                                size_t expected_size, size_t actual_size);
170 
171 }  // namespace pjrt
172 
173 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_
174