1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" 25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 26 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" 27 28 struct PJRT_Error { 29 xla::Status status; 30 }; 31 32 struct PJRT_Client { 33 std::unique_ptr<xla::PjRtClient> client; 34 std::vector<PJRT_Device> owned_devices; 35 // `devices` contains the addresses of the contents of `owned_devices`. 36 std::vector<PJRT_Device*> devices; 37 // `addressable_devices` contains pointers to the `owned_devices` that the 38 // client can issue commands to. 39 std::vector<PJRT_Device*> addressable_devices; 40 // Map from wrapped C++ devices to C devices. The values are the same as 41 // `owned_devices`. 42 absl::flat_hash_map<xla::PjRtDevice*, PJRT_Device*> c_device_from_cpp_device; 43 }; 44 45 // PJRT_Devices are owned by their corresponding PJRT_Client. 46 struct PJRT_Device { 47 // The xla::PjRtDevice* is owned by the corresponding xla::PjRtClient. 48 xla::PjRtDevice* device; 49 // The device specific attributes which are initialized once per device. 50 std::vector<PJRT_Device_Attribute> attributes; 51 }; 52 53 struct PJRT_Executable { 54 std::unique_ptr<xla::PjRtLoadedExecutable> executable; 55 PJRT_Client* client; 56 // These pointers are a subset of `client`'s `addressable_devices`, i.e. those 57 // addressed by the compiled executable program. `client` owns the objects 58 // these point to. 59 std::vector<PJRT_Device*> addressable_devices; 60 // TODO(b/237545405): Remove `populated` once we implement creation methods 61 // for PJRT_Executable that can populate addressable_devices on instantiation. 62 bool populated = false; 63 }; 64 65 struct PJRT_Buffer { 66 std::unique_ptr<xla::PjRtBuffer> buffer; 67 PJRT_Client* client; 68 }; 69 70 struct PJRT_Event { 71 xla::PjRtFuture<xla::Status> future; 72 // Set and stored upon future.Await(), as PjRtFuture only allows its result to 73 // be queried through Await() and Await() can only safely be called once. This 74 // variable allows C API users to check for error status any time after 75 // Await() has been called. 76 std::optional<xla::Status> status; 77 }; 78 79 namespace pjrt { 80 81 // C API definitions 82 83 void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); 84 void PJRT_Error_Message(PJRT_Error_Message_Args* args); 85 PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); 86 87 PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); 88 PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); 89 PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); 90 PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args); 91 PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); 92 93 PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args); 94 PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args); 95 PJRT_Error* PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args* args); 96 PJRT_Error* PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args* args); 97 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args); 98 PJRT_Error* PJRT_Client_AddressableDevices( 99 PJRT_Client_AddressableDevices_Args* args); 100 PJRT_Error* PJRT_Client_LookupDevice(PJRT_Client_LookupDevice_Args* args); 101 PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args); 102 103 PJRT_Error* PJRT_Device_Id(PJRT_Device_Id_Args* args); 104 PJRT_Error* PJRT_Device_ProcessIndex(PJRT_Device_ProcessIndex_Args* args); 105 PJRT_Error* PJRT_Device_IsAddressable(PJRT_Device_IsAddressable_Args* args); 106 PJRT_Error* PJRT_Device_Attributes(PJRT_Device_Attributes_Args* args); 107 PJRT_Error* PJRT_Device_Kind(PJRT_Device_Kind_Args* args); 108 PJRT_Error* PJRT_Device_LocalHardwareId(PJRT_Device_LocalHardwareId_Args* args); 109 PJRT_Error* PJRT_Device_DebugString(PJRT_Device_DebugString_Args* args); 110 PJRT_Error* PJRT_Device_ToString(PJRT_Device_ToString_Args* args); 111 112 PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args); 113 PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args); 114 PJRT_Error* PJRT_Executable_AddressableDevices( 115 PJRT_Executable_AddressableDevices_Args* args); 116 PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args); 117 PJRT_Error* PJRT_Executable_Delete(PJRT_Executable_Delete_Args* args); 118 PJRT_Error* PJRT_Executable_IsDeleted(PJRT_Executable_IsDeleted_Args* args); 119 PJRT_Error* PJRT_Executable_Execute(PJRT_Executable_Execute_Args* args); 120 121 PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args); 122 PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape( 123 PJRT_Buffer_OnDeviceTrimmedShape_Args* args); 124 PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes( 125 PJRT_Buffer_OnDeviceSizeInBytes_Args* args); 126 PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args); 127 PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args); 128 PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args); 129 PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args); 130 PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); 131 PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args); 132 133 // Helper macros and functions 134 135 #define PJRT_RETURN_IF_ERROR(expr) \ 136 do { \ 137 xla::Status _status = (expr); \ 138 if (!_status.ok()) { \ 139 PJRT_Error* _c_status = new PJRT_Error{std::move(_status)}; \ 140 return _c_status; \ 141 } \ 142 } while (false) 143 144 #define PJRT_ASSIGN_OR_RETURN(lhs, rexpr) \ 145 _PJRT_ASSIGN_OR_RETURN_IMPL(_PJRT_CONCAT(_status_or_value, __COUNTER__), \ 146 lhs, rexpr, \ 147 _PJRT_CONCAT(_c_status, __COUNTER__)); 148 149 #define _PJRT_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \ 150 auto statusor = (rexpr); \ 151 if (!statusor.ok()) { \ 152 PJRT_Error* c_status = new PJRT_Error(); \ 153 c_status->status = statusor.status(); \ 154 return c_status; \ 155 } \ 156 lhs = std::move(*statusor) 157 158 #define _PJRT_CONCAT(x, y) _PJRT_CONCAT_IMPL(x, y) 159 #define _PJRT_CONCAT_IMPL(x, y) x##y 160 161 // Helper function for checking C API argument struct sizes. Returns a non-OK 162 // status if the expected and actual sizes aren't equal (i.e. no ABI 163 // compatibility guarantees). 164 xla::Status CheckMatchingStructSizes(absl::string_view struct_name, 165 size_t expected_size, size_t actual_size); 166 167 // Helper function 168 std::string StructSizeErrorMsg(absl::string_view struct_name, 169 size_t expected_size, size_t actual_size); 170 171 } // namespace pjrt 172 173 #endif // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_WRAPPER_IMPL_H_ 174