1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" 27 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h" 28 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 29 30 namespace xla { 31 32 class PjRtCApiClient; 33 34 class PjRtCApiDevice : public PjRtDevice { 35 public: 36 explicit PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client); 37 38 PjRtClient* client() const override; 39 40 bool IsAddressable() const override; 41 42 int id() const override; 43 44 int process_index() const override; 45 46 int local_hardware_id() const override; 47 48 absl::string_view device_kind() const override; 49 50 absl::string_view DebugString() const override; 51 52 absl::string_view ToString() const override; 53 TransferToInfeed(const LiteralSlice & literal)54 Status TransferToInfeed(const LiteralSlice& literal) override { 55 #ifdef PJRT_C_API_BYPASS 56 return wrapped_->TransferToInfeed(literal); 57 #endif // PJRT_C_API_BYPASS 58 return Unimplemented("PJRT C API does not support TransferToInfeed"); 59 } 60 TransferFromOutfeed(MutableBorrowingLiteral literal)61 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { 62 #ifdef PJRT_C_API_BYPASS 63 return wrapped_->TransferFromOutfeed(std::move(literal)); 64 #endif // PJRT_C_API_BYPASS 65 return Unimplemented("PJRT C API does not support TransferFromOutfeed"); 66 } 67 CreateAsyncTrackingEvent(absl::string_view description)68 std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent( 69 absl::string_view description) const override { 70 #ifdef PJRT_C_API_BYPASS 71 return wrapped_->CreateAsyncTrackingEvent(description); 72 #endif // PJRT_C_API_BYPASS 73 LOG(WARNING) << "PJRT C API does not support CreateAsyncTrackingEvent"; 74 return nullptr; 75 } 76 77 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes() 78 const override; 79 c_device()80 PJRT_Device* c_device() const { return device_; } 81 wrapped()82 PjRtDevice* wrapped() const { return wrapped_; } 83 GetWrapped(PjRtDevice * c_api_device)84 static PjRtDevice* GetWrapped(PjRtDevice* c_api_device) { 85 return tensorflow::down_cast<PjRtCApiDevice*>(c_api_device)->wrapped(); 86 } 87 88 private: 89 PjRtCApiClient* client_ = nullptr; 90 // `device_` is owned by the `PJRT_Client` wrapped by `client_` 91 PJRT_Device* device_; 92 // TODO(shahrokhi): wrapped_ is a non-C API pointer that was used to bypass 93 // the C API calls until all the C API's got implemented. Remove it when it's 94 // usage is reduced to zero. 95 PjRtDevice* wrapped_; 96 // Device specific attributes with corresponding values. 97 absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_; 98 99 // Initializes device specific attributes. 100 void InitAttributes(); 101 }; 102 103 class PjRtCApiClient : public PjRtClient { 104 public: 105 PjRtCApiClient(const PJRT_Api* c_api, PJRT_Client* c_client); 106 107 int process_index() const override; 108 109 int device_count() const override; 110 int addressable_device_count() const override; 111 112 absl::Span<PjRtDevice* const> devices() const override; 113 absl::Span<PjRtDevice* const> addressable_devices() const override; 114 115 StatusOr<PjRtDevice*> LookupDevice(int device_id) const override; 116 LookupAddressableDevice(int local_hardware_id)117 StatusOr<PjRtDevice*> LookupAddressableDevice( 118 int local_hardware_id) const override { 119 #ifdef PJRT_C_API_BYPASS 120 TF_ASSIGN_OR_RETURN(PjRtDevice * wrapped_device, 121 wrapped_->LookupAddressableDevice(local_hardware_id)); 122 return GetCApiDevice(wrapped_device); 123 #endif // PJRT_C_API_BYPASS 124 return Unimplemented("PJRT C API does not support LookupAddressableDevice"); 125 } 126 platform_id()127 PjRtPlatformId platform_id() const override { 128 #ifdef PJRT_C_API_BYPASS 129 return wrapped_->platform_id(); 130 #endif // PJRT_C_API_BYPASS 131 CHECK(false) << "PJRT C API does not support platform_id."; 132 } 133 134 absl::string_view platform_name() const override; 135 136 absl::string_view platform_version() const override; 137 runtime_type()138 PjRtRuntimeType runtime_type() const override { 139 #ifdef PJRT_C_API_BYPASS 140 return wrapped_->runtime_type(); 141 #endif // PJRT_C_API_BYPASS 142 CHECK(false) << "PJRT C API does not support runtime_type."; 143 } 144 GetDefaultDeviceAssignment(int num_replicas,int num_partitions)145 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 146 int num_replicas, int num_partitions) const override { 147 #ifdef PJRT_C_API_BYPASS 148 return wrapped_->GetDefaultDeviceAssignment(num_replicas, num_partitions); 149 #endif // PJRT_C_API_BYPASS 150 return Unimplemented( 151 "PJRT C API does not support GetDefaultDeviceAssignment"); 152 } 153 GetHloCostAnalysis()154 StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override { 155 #ifdef PJRT_C_API_BYPASS 156 return wrapped_->GetHloCostAnalysis(); 157 #endif // PJRT_C_API_BYPASS 158 return Unimplemented("PJRT C API does not support GetHloCostAnalysis"); 159 } 160 Compile(const XlaComputation & computation,CompileOptions options)161 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 162 const XlaComputation& computation, CompileOptions options) override { 163 #ifdef PJRT_C_API_BYPASS 164 return WrapExecutable(wrapped_->Compile(computation, options)); 165 #endif // PJRT_C_API_BYPASS 166 return Unimplemented( 167 "PJRT C API does not support Compile with XlaComputation"); 168 } 169 170 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 171 mlir::ModuleOp module, CompileOptions options) override; 172 173 StatusOr<std::optional<std::string>> ExecutableFingerprint( 174 const PjRtLoadedExecutable& executable) const override; 175 176 StatusOr<std::string> SerializeExecutable( 177 const PjRtLoadedExecutable& executable) const override; 178 179 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable( 180 absl::string_view serialized, CompileOptions options) override; 181 CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)182 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 183 const Shape& shape, PjRtDevice* device) override { 184 return Unimplemented( 185 "PJRT C API does not support CreateUninitializedBuffer"); 186 } 187 188 StatusOr<std::unique_ptr<AsyncBufferTransferManager>> CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)189 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes, 190 PjRtDevice* device) override { 191 return Unimplemented( 192 "PJRT C API does not support CreateBuffersForAsyncTransfer"); 193 } 194 BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)195 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 196 const void* data, PrimitiveType type, absl::Span<int64_t const> dims, 197 std::optional<absl::Span<int64_t const>> byte_strides, 198 HostBufferSemantics host_buffer_semantics, 199 std::function<void()> on_done_with_host_buffer, 200 PjRtDevice* device) override { 201 #ifdef PJRT_C_API_BYPASS 202 return WrapBuffer(wrapped_->BufferFromHostBuffer( 203 data, type, dims, byte_strides, host_buffer_semantics, 204 on_done_with_host_buffer, PjRtCApiDevice::GetWrapped(device))); 205 #endif // PJRT_C_API_BYPASS 206 return Unimplemented("PJRT C API does not support BufferFromHostBuffer"); 207 } 208 BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)209 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 210 const LiteralSlice& literal, PjRtDevice* device) override { 211 #ifdef PJRT_C_API_BYPASS 212 return WrapBuffer(wrapped_->BufferFromHostLiteral( 213 literal, PjRtCApiDevice::GetWrapped(device))); 214 #endif // PJRT_C_API_BYPASS 215 return Unimplemented("PJRT C API does not support BufferFromHostLiteral"); 216 } 217 CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)218 StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 219 void* device_ptr, const Shape& shape, PjRtDevice* device, 220 std::function<void()> on_delete_callback) override { 221 #ifdef PJRT_C_API_BYPASS 222 return WrapBuffer(wrapped_->CreateViewOfDeviceBuffer( 223 device_ptr, shape, PjRtCApiDevice::GetWrapped(device), 224 on_delete_callback)); 225 #endif // PJRT_C_API_BYPASS 226 return Unimplemented( 227 "PJRT C API does not support CreateViewOfDeviceBuffer"); 228 } 229 230 StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer) override; 231 232 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)233 MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes, 234 PjRtDevice* device, 235 PjRtCrossHostRecvNotifier notifier) override { 236 return Unimplemented( 237 "PJRT C API does not support MakeCrossHostReceiveBuffers"); 238 } 239 240 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)241 MakeCrossHostReceiveBuffersForGather( 242 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details, 243 PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { 244 return Unimplemented( 245 "PJRT C API does not support MakeCrossHostReceiveBuffers"); 246 } 247 CreateChannelHandle()248 StatusOr<ChannelHandle> CreateChannelHandle() override { 249 return Unimplemented("PJRT C API does not support CreateChannelHandle"); 250 } 251 CreateDeviceToHostChannelHandle()252 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { 253 return Unimplemented( 254 "PJRT C API does not support CreateDeviceToHostChannelHandle"); 255 } 256 CreateHostToDeviceChannelHandle()257 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { 258 return Unimplemented( 259 "PJRT C API does not support CreateHostToDeviceChannelHandle"); 260 } 261 Defragment()262 Status Defragment() override { return wrapped_->Defragment(); } 263 GetCApiDevice(PjRtDevice * wrapped_device)264 PjRtDevice* GetCApiDevice(PjRtDevice* wrapped_device) const { 265 auto it = wrapped_device_map_.find(wrapped_device); 266 CHECK(it != wrapped_device_map_.end()); 267 return it->second; 268 } 269 270 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> WrapExecutable( 271 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap); 272 273 StatusOr<std::unique_ptr<PjRtBuffer>> WrapBuffer( 274 StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap); 275 276 const PJRT_Api* pjrt_c_api() const; 277 pjrt_c_client()278 PJRT_Client* pjrt_c_client() { return c_client_.get(); } 279 GetCppDevice(PJRT_Device * c_device)280 PjRtCApiDevice* GetCppDevice(PJRT_Device* c_device) const { 281 auto it = c_to_cpp_device_map_.find(c_device); 282 CHECK(it != c_to_cpp_device_map_.end()); 283 return it->second; 284 } 285 286 private: 287 const PJRT_Api* c_api_; 288 std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter> c_client_; 289 290 std::vector<std::unique_ptr<PjRtCApiDevice>> owned_devices_; 291 std::vector<PjRtDevice*> devices_; 292 std::vector<PjRtDevice*> addressable_devices_; 293 absl::flat_hash_map<PJRT_Device*, PjRtCApiDevice*> c_to_cpp_device_map_; 294 295 // TODO(skyewm): this is a shim so we can run PjRtCApiClient code without the 296 // C API being fully implemented. All methods using wrapped_ should either be 297 // marked unimplemented or implemented in terms of the C API, at which point 298 // wrapped_ and related functionality should be removed. 299 PjRtClient* wrapped_; 300 absl::flat_hash_map<PjRtDevice*, PjRtCApiDevice*> wrapped_device_map_; 301 302 void InitDevices(); 303 }; 304 305 class PjRtCApiBuffer : public PjRtBuffer { 306 public: 307 PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer); 308 309 const Shape& on_device_shape() const override; 310 logical_on_device_shape()311 StatusOr<Shape> logical_on_device_shape() override { 312 #ifdef PJRT_C_API_BYPASS 313 return wrapped_->logical_on_device_shape(); 314 #endif // PJRT_C_API_BYPASS 315 return Unimplemented("PJRT C API does not support logical_on_device_shape"); 316 } 317 318 PjRtDevice* device() const override; 319 client()320 PjRtClient* client() const override { return client_; } 321 AcquireExternalReference()322 StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() 323 override { 324 #ifdef PJRT_C_API_BYPASS 325 return wrapped_->AcquireExternalReference(); 326 #endif // PJRT_C_API_BYPASS 327 return Unimplemented( 328 "PJRT C API does not support AcquireExternalReference"); 329 } 330 ToLiteral(MutableLiteralBase * literal)331 PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override { 332 #ifdef PJRT_C_API_BYPASS 333 return wrapped_->ToLiteral(literal); 334 #endif // PJRT_C_API_BYPASS 335 return PjRtFuture<Status>( 336 Unimplemented("PJRT C API does not support ToLiteral")); 337 } 338 339 StatusOr<size_t> GetOnDeviceSizeInBytes() const override; 340 CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)341 PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset, 342 int64_t transfer_size) override { 343 #ifdef PJRT_C_API_BYPASS 344 return wrapped_->CopyRawToHost(dst, offset, transfer_size); 345 #endif // PJRT_C_API_BYPASS 346 return PjRtFuture<Status>( 347 Unimplemented("PJRT C API does not support CopyRawToHost")); 348 } 349 350 void Delete() override; 351 ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)352 StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( 353 bool wait_for_operations_to_complete) override { 354 #ifdef PJRT_C_API_BYPASS 355 return wrapped_->ReleaseDeviceMemoryOwnership( 356 wait_for_operations_to_complete); 357 #endif // PJRT_C_API_BYPASS 358 return Unimplemented( 359 "PJRT C API does not support ReleaseDeviceMemoryOwnership"); 360 } 361 362 bool IsDeleted() override; 363 364 StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 365 PjRtDevice* dst_device) override; 366 CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)367 void CopyToRemoteDevice(absl::string_view serialized_descriptor, 368 RemoteSendCallback on_done) override { 369 LOG(ERROR) << "PJRT C API does not support CopyToRemoteDevice"; 370 } 371 CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)372 void CopyToRemoteDeviceScattered( 373 absl::Span<const std::pair<std::string, RemoteSendCallback>> 374 serialized_descriptors_and_callbacks, 375 const ScatterDetails& scatter_details) override { 376 LOG(ERROR) << "PJRT C API does not support CopyToRemoteDeviceScattered"; 377 } 378 GetReadyFuture()379 PjRtFuture<Status> GetReadyFuture() override { 380 #ifdef PJRT_C_API_BYPASS 381 return wrapped_->GetReadyFuture(); 382 #endif // PJRT_C_API_BYPASS 383 return PjRtFuture<Status>( 384 Unimplemented("PJRT C API does not support GetReadyFuture")); 385 } 386 387 bool IsOnCpu() const override; 388 wrapped()389 PjRtBuffer* wrapped() const { return wrapped_; } 390 c_buffer()391 PJRT_Buffer* c_buffer() const { return buffer_.get(); } 392 GetWrapped(PjRtBuffer * c_api_buffer)393 static PjRtBuffer* GetWrapped(PjRtBuffer* c_api_buffer) { 394 return tensorflow::down_cast<PjRtCApiBuffer*>(c_api_buffer)->wrapped(); 395 } 396 GetWrappedVector(absl::Span<PjRtBuffer * const> c_api_buffers)397 static std::vector<PjRtBuffer*> GetWrappedVector( 398 absl::Span<PjRtBuffer* const> c_api_buffers) { 399 std::vector<PjRtBuffer*> wrapped; 400 wrapped.reserve(c_api_buffers.size()); 401 for (PjRtBuffer* c_api_buf : c_api_buffers) { 402 wrapped.push_back(GetWrapped(c_api_buf)); 403 } 404 return wrapped; 405 } 406 pjrt_c_api()407 const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); } 408 409 private: 410 PjRtCApiClient* client_; 411 std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer_; 412 std::optional<xla::Shape> shape_; 413 414 // TODO(amangu): _wrapped is a non-C API pointer that was used to bypass the 415 // C API calls until all the C API's got implemented. Remove it when it's 416 // usage is reduced to zero. 417 PjRtBuffer* wrapped_; 418 419 // TODO(b/238999986): Refactor or Remove. 420 void set_shape(); 421 }; 422 423 class PjRtCApiExecutable : public PjRtLoadedExecutable { 424 public: 425 PjRtCApiExecutable(PjRtCApiClient* client, 426 std::unique_ptr<PjRtLoadedExecutable> wrapped); 427 PjRtCApiExecutable(PjRtCApiClient* client, PJRT_Executable* executable); 428 429 ~PjRtCApiExecutable() override; 430 client()431 PjRtClient* client() const override { return client_; } 432 absl::string_view name() const override; num_replicas()433 int num_replicas() const override { return wrapped()->num_replicas(); } num_partitions()434 int num_partitions() const override { return wrapped()->num_partitions(); } 435 SizeOfGeneratedCodeInBytes()436 int64_t SizeOfGeneratedCodeInBytes() const override { 437 #ifdef PJRT_C_API_BYPASS 438 return wrapped()->SizeOfGeneratedCodeInBytes(); 439 #endif // PJRT_C_API_BYPASS 440 CHECK(false) << "PJRT C API does not support SizeOfGeneratedCodeInBytes"; 441 } 442 device_assignment()443 const DeviceAssignment& device_assignment() const override { 444 #ifdef PJRT_C_API_BYPASS 445 return wrapped()->device_assignment(); 446 #endif // PJRT_C_API_BYPASS 447 CHECK(false) << "PJRT C API does not support device_assignment"; 448 } 449 addressable_device_logical_ids()450 absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 451 const override { 452 #ifdef PJRT_C_API_BYPASS 453 return wrapped()->addressable_device_logical_ids(); 454 #endif // PJRT_C_API_BYPASS 455 CHECK(false) 456 << "PJRT C API does not support addressable_device_logical_ids"; 457 } 458 addressable_devices()459 absl::Span<PjRtDevice* const> addressable_devices() const override { 460 return addressable_devices_; 461 } 462 GetHloModules()463 StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 464 const override { 465 #ifdef PJRT_C_API_BYPASS 466 return wrapped()->GetHloModules(); 467 #endif // PJRT_C_API_BYPASS 468 return Unimplemented("PJRT C API does not support GetHloModules"); 469 } 470 471 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( 472 absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 473 const ExecuteOptions& options, 474 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) 475 override; 476 477 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 478 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 479 const ExecuteOptions& options, 480 std::optional<PjRtFuture<Status>>& returned_future, 481 bool fill_future) override; 482 483 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 484 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 485 const ExecuteOptions& options, 486 std::optional<PjRtFuture<Status>>& returned_future, 487 bool fill_future) override; 488 489 void Delete() override; 490 bool IsDeleted() override; 491 492 PjRtLoadedExecutable* wrapped() const; 493 GetWrapped(const PjRtLoadedExecutable * c_api_executable)494 static PjRtLoadedExecutable* GetWrapped( 495 const PjRtLoadedExecutable* c_api_executable) { 496 return tensorflow::down_cast<const PjRtCApiExecutable*>(c_api_executable) 497 ->wrapped(); 498 } 499 pjrt_c_api()500 const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); } 501 502 private: 503 PjRtCApiClient* client_; 504 PJRT_Executable* executable_; 505 std::vector<PjRtDevice*> addressable_devices_; 506 507 void InitDevices(); 508 }; 509 510 // Takes ownership of wrapped. 511 StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient(); 512 513 } // namespace xla 514 515 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_ 516