xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
27 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
28 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
29 
30 namespace xla {
31 
32 class PjRtCApiClient;
33 
34 class PjRtCApiDevice : public PjRtDevice {
35  public:
36   explicit PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client);
37 
38   PjRtClient* client() const override;
39 
40   bool IsAddressable() const override;
41 
42   int id() const override;
43 
44   int process_index() const override;
45 
46   int local_hardware_id() const override;
47 
48   absl::string_view device_kind() const override;
49 
50   absl::string_view DebugString() const override;
51 
52   absl::string_view ToString() const override;
53 
TransferToInfeed(const LiteralSlice & literal)54   Status TransferToInfeed(const LiteralSlice& literal) override {
55 #ifdef PJRT_C_API_BYPASS
56     return wrapped_->TransferToInfeed(literal);
57 #endif  // PJRT_C_API_BYPASS
58     return Unimplemented("PJRT C API does not support TransferToInfeed");
59   }
60 
TransferFromOutfeed(MutableBorrowingLiteral literal)61   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
62 #ifdef PJRT_C_API_BYPASS
63     return wrapped_->TransferFromOutfeed(std::move(literal));
64 #endif  // PJRT_C_API_BYPASS
65     return Unimplemented("PJRT C API does not support TransferFromOutfeed");
66   }
67 
CreateAsyncTrackingEvent(absl::string_view description)68   std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
69       absl::string_view description) const override {
70 #ifdef PJRT_C_API_BYPASS
71     return wrapped_->CreateAsyncTrackingEvent(description);
72 #endif  // PJRT_C_API_BYPASS
73     LOG(WARNING) << "PJRT C API does not support CreateAsyncTrackingEvent";
74     return nullptr;
75   }
76 
77   const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
78       const override;
79 
c_device()80   PJRT_Device* c_device() const { return device_; }
81 
wrapped()82   PjRtDevice* wrapped() const { return wrapped_; }
83 
GetWrapped(PjRtDevice * c_api_device)84   static PjRtDevice* GetWrapped(PjRtDevice* c_api_device) {
85     return tensorflow::down_cast<PjRtCApiDevice*>(c_api_device)->wrapped();
86   }
87 
88  private:
89   PjRtCApiClient* client_ = nullptr;
90   // `device_` is owned by the `PJRT_Client` wrapped by `client_`
91   PJRT_Device* device_;
92   // TODO(shahrokhi): wrapped_ is a non-C API pointer that was used to bypass
93   // the C API calls until all the C API's got implemented. Remove it when it's
94   // usage is reduced to zero.
95   PjRtDevice* wrapped_;
96   // Device specific attributes with corresponding values.
97   absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
98 
99   // Initializes device specific attributes.
100   void InitAttributes();
101 };
102 
103 class PjRtCApiClient : public PjRtClient {
104  public:
105   PjRtCApiClient(const PJRT_Api* c_api, PJRT_Client* c_client);
106 
107   int process_index() const override;
108 
109   int device_count() const override;
110   int addressable_device_count() const override;
111 
112   absl::Span<PjRtDevice* const> devices() const override;
113   absl::Span<PjRtDevice* const> addressable_devices() const override;
114 
115   StatusOr<PjRtDevice*> LookupDevice(int device_id) const override;
116 
LookupAddressableDevice(int local_hardware_id)117   StatusOr<PjRtDevice*> LookupAddressableDevice(
118       int local_hardware_id) const override {
119 #ifdef PJRT_C_API_BYPASS
120     TF_ASSIGN_OR_RETURN(PjRtDevice * wrapped_device,
121                         wrapped_->LookupAddressableDevice(local_hardware_id));
122     return GetCApiDevice(wrapped_device);
123 #endif  // PJRT_C_API_BYPASS
124     return Unimplemented("PJRT C API does not support LookupAddressableDevice");
125   }
126 
platform_id()127   PjRtPlatformId platform_id() const override {
128 #ifdef PJRT_C_API_BYPASS
129     return wrapped_->platform_id();
130 #endif  // PJRT_C_API_BYPASS
131     CHECK(false) << "PJRT C API does not support platform_id.";
132   }
133 
134   absl::string_view platform_name() const override;
135 
136   absl::string_view platform_version() const override;
137 
runtime_type()138   PjRtRuntimeType runtime_type() const override {
139 #ifdef PJRT_C_API_BYPASS
140     return wrapped_->runtime_type();
141 #endif  // PJRT_C_API_BYPASS
142     CHECK(false) << "PJRT C API does not support runtime_type.";
143   }
144 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)145   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
146       int num_replicas, int num_partitions) const override {
147 #ifdef PJRT_C_API_BYPASS
148     return wrapped_->GetDefaultDeviceAssignment(num_replicas, num_partitions);
149 #endif  // PJRT_C_API_BYPASS
150     return Unimplemented(
151         "PJRT C API does not support GetDefaultDeviceAssignment");
152   }
153 
GetHloCostAnalysis()154   StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override {
155 #ifdef PJRT_C_API_BYPASS
156     return wrapped_->GetHloCostAnalysis();
157 #endif  // PJRT_C_API_BYPASS
158     return Unimplemented("PJRT C API does not support GetHloCostAnalysis");
159   }
160 
Compile(const XlaComputation & computation,CompileOptions options)161   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
162       const XlaComputation& computation, CompileOptions options) override {
163 #ifdef PJRT_C_API_BYPASS
164     return WrapExecutable(wrapped_->Compile(computation, options));
165 #endif  // PJRT_C_API_BYPASS
166     return Unimplemented(
167         "PJRT C API does not support Compile with XlaComputation");
168   }
169 
170   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
171       mlir::ModuleOp module, CompileOptions options) override;
172 
173   StatusOr<std::optional<std::string>> ExecutableFingerprint(
174       const PjRtLoadedExecutable& executable) const override;
175 
176   StatusOr<std::string> SerializeExecutable(
177       const PjRtLoadedExecutable& executable) const override;
178 
179   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
180       absl::string_view serialized, CompileOptions options) override;
181 
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)182   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
183       const Shape& shape, PjRtDevice* device) override {
184     return Unimplemented(
185         "PJRT C API does not support CreateUninitializedBuffer");
186   }
187 
188   StatusOr<std::unique_ptr<AsyncBufferTransferManager>>
CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)189   CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
190                                 PjRtDevice* device) override {
191     return Unimplemented(
192         "PJRT C API does not support CreateBuffersForAsyncTransfer");
193   }
194 
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)195   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
196       const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
197       std::optional<absl::Span<int64_t const>> byte_strides,
198       HostBufferSemantics host_buffer_semantics,
199       std::function<void()> on_done_with_host_buffer,
200       PjRtDevice* device) override {
201 #ifdef PJRT_C_API_BYPASS
202     return WrapBuffer(wrapped_->BufferFromHostBuffer(
203         data, type, dims, byte_strides, host_buffer_semantics,
204         on_done_with_host_buffer, PjRtCApiDevice::GetWrapped(device)));
205 #endif  // PJRT_C_API_BYPASS
206     return Unimplemented("PJRT C API does not support BufferFromHostBuffer");
207   }
208 
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)209   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
210       const LiteralSlice& literal, PjRtDevice* device) override {
211 #ifdef PJRT_C_API_BYPASS
212     return WrapBuffer(wrapped_->BufferFromHostLiteral(
213         literal, PjRtCApiDevice::GetWrapped(device)));
214 #endif  // PJRT_C_API_BYPASS
215     return Unimplemented("PJRT C API does not support BufferFromHostLiteral");
216   }
217 
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)218   StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
219       void* device_ptr, const Shape& shape, PjRtDevice* device,
220       std::function<void()> on_delete_callback) override {
221 #ifdef PJRT_C_API_BYPASS
222     return WrapBuffer(wrapped_->CreateViewOfDeviceBuffer(
223         device_ptr, shape, PjRtCApiDevice::GetWrapped(device),
224         on_delete_callback));
225 #endif  // PJRT_C_API_BYPASS
226     return Unimplemented(
227         "PJRT C API does not support CreateViewOfDeviceBuffer");
228   }
229 
230   StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer) override;
231 
232   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)233   MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
234                               PjRtDevice* device,
235                               PjRtCrossHostRecvNotifier notifier) override {
236     return Unimplemented(
237         "PJRT C API does not support MakeCrossHostReceiveBuffers");
238   }
239 
240   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)241   MakeCrossHostReceiveBuffersForGather(
242       absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
243       PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override {
244     return Unimplemented(
245         "PJRT C API does not support MakeCrossHostReceiveBuffers");
246   }
247 
CreateChannelHandle()248   StatusOr<ChannelHandle> CreateChannelHandle() override {
249     return Unimplemented("PJRT C API does not support CreateChannelHandle");
250   }
251 
CreateDeviceToHostChannelHandle()252   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
253     return Unimplemented(
254         "PJRT C API does not support CreateDeviceToHostChannelHandle");
255   }
256 
CreateHostToDeviceChannelHandle()257   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
258     return Unimplemented(
259         "PJRT C API does not support CreateHostToDeviceChannelHandle");
260   }
261 
Defragment()262   Status Defragment() override { return wrapped_->Defragment(); }
263 
GetCApiDevice(PjRtDevice * wrapped_device)264   PjRtDevice* GetCApiDevice(PjRtDevice* wrapped_device) const {
265     auto it = wrapped_device_map_.find(wrapped_device);
266     CHECK(it != wrapped_device_map_.end());
267     return it->second;
268   }
269 
270   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> WrapExecutable(
271       StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap);
272 
273   StatusOr<std::unique_ptr<PjRtBuffer>> WrapBuffer(
274       StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap);
275 
276   const PJRT_Api* pjrt_c_api() const;
277 
pjrt_c_client()278   PJRT_Client* pjrt_c_client() { return c_client_.get(); }
279 
GetCppDevice(PJRT_Device * c_device)280   PjRtCApiDevice* GetCppDevice(PJRT_Device* c_device) const {
281     auto it = c_to_cpp_device_map_.find(c_device);
282     CHECK(it != c_to_cpp_device_map_.end());
283     return it->second;
284   }
285 
286  private:
287   const PJRT_Api* c_api_;
288   std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter> c_client_;
289 
290   std::vector<std::unique_ptr<PjRtCApiDevice>> owned_devices_;
291   std::vector<PjRtDevice*> devices_;
292   std::vector<PjRtDevice*> addressable_devices_;
293   absl::flat_hash_map<PJRT_Device*, PjRtCApiDevice*> c_to_cpp_device_map_;
294 
295   // TODO(skyewm): this is a shim so we can run PjRtCApiClient code without the
296   // C API being fully implemented. All methods using wrapped_ should either be
297   // marked unimplemented or implemented in terms of the C API, at which point
298   // wrapped_ and related functionality should be removed.
299   PjRtClient* wrapped_;
300   absl::flat_hash_map<PjRtDevice*, PjRtCApiDevice*> wrapped_device_map_;
301 
302   void InitDevices();
303 };
304 
305 class PjRtCApiBuffer : public PjRtBuffer {
306  public:
307   PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer);
308 
309   const Shape& on_device_shape() const override;
310 
logical_on_device_shape()311   StatusOr<Shape> logical_on_device_shape() override {
312 #ifdef PJRT_C_API_BYPASS
313     return wrapped_->logical_on_device_shape();
314 #endif  // PJRT_C_API_BYPASS
315     return Unimplemented("PJRT C API does not support logical_on_device_shape");
316   }
317 
318   PjRtDevice* device() const override;
319 
client()320   PjRtClient* client() const override { return client_; }
321 
AcquireExternalReference()322   StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
323       override {
324 #ifdef PJRT_C_API_BYPASS
325     return wrapped_->AcquireExternalReference();
326 #endif  // PJRT_C_API_BYPASS
327     return Unimplemented(
328         "PJRT C API does not support AcquireExternalReference");
329   }
330 
ToLiteral(MutableLiteralBase * literal)331   PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override {
332 #ifdef PJRT_C_API_BYPASS
333     return wrapped_->ToLiteral(literal);
334 #endif  // PJRT_C_API_BYPASS
335     return PjRtFuture<Status>(
336         Unimplemented("PJRT C API does not support ToLiteral"));
337   }
338 
339   StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
340 
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)341   PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
342                                    int64_t transfer_size) override {
343 #ifdef PJRT_C_API_BYPASS
344     return wrapped_->CopyRawToHost(dst, offset, transfer_size);
345 #endif  // PJRT_C_API_BYPASS
346     return PjRtFuture<Status>(
347         Unimplemented("PJRT C API does not support CopyRawToHost"));
348   }
349 
350   void Delete() override;
351 
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)352   StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
353       bool wait_for_operations_to_complete) override {
354 #ifdef PJRT_C_API_BYPASS
355     return wrapped_->ReleaseDeviceMemoryOwnership(
356         wait_for_operations_to_complete);
357 #endif  // PJRT_C_API_BYPASS
358     return Unimplemented(
359         "PJRT C API does not support ReleaseDeviceMemoryOwnership");
360   }
361 
362   bool IsDeleted() override;
363 
364   StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
365       PjRtDevice* dst_device) override;
366 
CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)367   void CopyToRemoteDevice(absl::string_view serialized_descriptor,
368                           RemoteSendCallback on_done) override {
369     LOG(ERROR) << "PJRT C API does not support CopyToRemoteDevice";
370   }
371 
CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)372   void CopyToRemoteDeviceScattered(
373       absl::Span<const std::pair<std::string, RemoteSendCallback>>
374           serialized_descriptors_and_callbacks,
375       const ScatterDetails& scatter_details) override {
376     LOG(ERROR) << "PJRT C API does not support CopyToRemoteDeviceScattered";
377   }
378 
GetReadyFuture()379   PjRtFuture<Status> GetReadyFuture() override {
380 #ifdef PJRT_C_API_BYPASS
381     return wrapped_->GetReadyFuture();
382 #endif  // PJRT_C_API_BYPASS
383     return PjRtFuture<Status>(
384         Unimplemented("PJRT C API does not support GetReadyFuture"));
385   }
386 
387   bool IsOnCpu() const override;
388 
wrapped()389   PjRtBuffer* wrapped() const { return wrapped_; }
390 
c_buffer()391   PJRT_Buffer* c_buffer() const { return buffer_.get(); }
392 
GetWrapped(PjRtBuffer * c_api_buffer)393   static PjRtBuffer* GetWrapped(PjRtBuffer* c_api_buffer) {
394     return tensorflow::down_cast<PjRtCApiBuffer*>(c_api_buffer)->wrapped();
395   }
396 
GetWrappedVector(absl::Span<PjRtBuffer * const> c_api_buffers)397   static std::vector<PjRtBuffer*> GetWrappedVector(
398       absl::Span<PjRtBuffer* const> c_api_buffers) {
399     std::vector<PjRtBuffer*> wrapped;
400     wrapped.reserve(c_api_buffers.size());
401     for (PjRtBuffer* c_api_buf : c_api_buffers) {
402       wrapped.push_back(GetWrapped(c_api_buf));
403     }
404     return wrapped;
405   }
406 
pjrt_c_api()407   const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); }
408 
409  private:
410   PjRtCApiClient* client_;
411   std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer_;
412   std::optional<xla::Shape> shape_;
413 
414   // TODO(amangu): _wrapped is a non-C API pointer that was used to bypass the
415   // C API calls until all the C API's got implemented. Remove it when it's
416   // usage is reduced to zero.
417   PjRtBuffer* wrapped_;
418 
419   // TODO(b/238999986): Refactor or Remove.
420   void set_shape();
421 };
422 
423 class PjRtCApiExecutable : public PjRtLoadedExecutable {
424  public:
425   PjRtCApiExecutable(PjRtCApiClient* client,
426                      std::unique_ptr<PjRtLoadedExecutable> wrapped);
427   PjRtCApiExecutable(PjRtCApiClient* client, PJRT_Executable* executable);
428 
429   ~PjRtCApiExecutable() override;
430 
client()431   PjRtClient* client() const override { return client_; }
432   absl::string_view name() const override;
num_replicas()433   int num_replicas() const override { return wrapped()->num_replicas(); }
num_partitions()434   int num_partitions() const override { return wrapped()->num_partitions(); }
435 
SizeOfGeneratedCodeInBytes()436   int64_t SizeOfGeneratedCodeInBytes() const override {
437 #ifdef PJRT_C_API_BYPASS
438     return wrapped()->SizeOfGeneratedCodeInBytes();
439 #endif  // PJRT_C_API_BYPASS
440     CHECK(false) << "PJRT C API does not support SizeOfGeneratedCodeInBytes";
441   }
442 
device_assignment()443   const DeviceAssignment& device_assignment() const override {
444 #ifdef PJRT_C_API_BYPASS
445     return wrapped()->device_assignment();
446 #endif  // PJRT_C_API_BYPASS
447     CHECK(false) << "PJRT C API does not support device_assignment";
448   }
449 
addressable_device_logical_ids()450   absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
451       const override {
452 #ifdef PJRT_C_API_BYPASS
453     return wrapped()->addressable_device_logical_ids();
454 #endif  // PJRT_C_API_BYPASS
455     CHECK(false)
456         << "PJRT C API does not support addressable_device_logical_ids";
457   }
458 
addressable_devices()459   absl::Span<PjRtDevice* const> addressable_devices() const override {
460     return addressable_devices_;
461   }
462 
GetHloModules()463   StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
464       const override {
465 #ifdef PJRT_C_API_BYPASS
466     return wrapped()->GetHloModules();
467 #endif  // PJRT_C_API_BYPASS
468     return Unimplemented("PJRT C API does not support GetHloModules");
469   }
470 
471   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
472       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
473       const ExecuteOptions& options,
474       std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
475       override;
476 
477   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
478       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
479       const ExecuteOptions& options,
480       std::optional<PjRtFuture<Status>>& returned_future,
481       bool fill_future) override;
482 
483   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
484       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
485       const ExecuteOptions& options,
486       std::optional<PjRtFuture<Status>>& returned_future,
487       bool fill_future) override;
488 
489   void Delete() override;
490   bool IsDeleted() override;
491 
492   PjRtLoadedExecutable* wrapped() const;
493 
GetWrapped(const PjRtLoadedExecutable * c_api_executable)494   static PjRtLoadedExecutable* GetWrapped(
495       const PjRtLoadedExecutable* c_api_executable) {
496     return tensorflow::down_cast<const PjRtCApiExecutable*>(c_api_executable)
497         ->wrapped();
498   }
499 
pjrt_c_api()500   const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); }
501 
502  private:
503   PjRtCApiClient* client_;
504   PJRT_Executable* executable_;
505   std::vector<PjRtDevice*> addressable_devices_;
506 
507   void InitDevices();
508 };
509 
510 // Takes ownership of wrapped.
511 StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient();
512 
513 }  // namespace xla
514 
515 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_
516