xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tpu_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
18 
19 #include <array>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
28 
29 namespace xla {
30 
31 class PjRtTpuDevice : public PjRtStreamExecutorDevice {
32  public:
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,std::unique_ptr<LocalDeviceState> local_device_state,int process_index,const std::array<int,3> & coords,std::string device_kind)33   PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
34                 std::unique_ptr<LocalDeviceState> local_device_state,
35                 int process_index, const std::array<int, 3>& coords,
36                 std::string device_kind)
37       : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
38                                  std::move(device_kind), process_index),
39         core_(core),
40         coords_(coords) {
41     std::vector<int64_t> v_coords(coords_.begin(), coords_.end());
42     int64_t core_index = core_on_chip();
43     attributes_ = {
44         {"coords", xla::PjRtDeviceAttribute(v_coords)},
45         {"core_on_chip", xla::PjRtDeviceAttribute(core_index)},
46     };
47     debug_string_ = absl::StrFormat("TPU_%i(process=%i,(%i,%i,%i,%i))",
48                                     core_.Id(), process_index, coords_[0],
49                                     coords_[1], coords_[2], core_.index());
50     to_string_ = absl::StrFormat(
51         "TpuDevice(id=%i, process_index=%i, coords=(%s), core_on_chip=%i)",
52         id(), process_index, absl::StrJoin(coords_, ","), core_on_chip());
53   }
54 
coords()55   const std::array<int, 3>& coords() const { return coords_; }
core_on_chip()56   int core_on_chip() const { return core_.index(); }
core()57   const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; }
58 
ToString()59   absl::string_view ToString() const override { return to_string_; }
60 
DebugString()61   absl::string_view DebugString() const override { return debug_string_; }
62 
63  private:
64   const tensorflow::tpu::TpuCoreLocationExternal core_;
65   const std::array<int, 3> coords_;
66   std::string debug_string_;
67   std::string to_string_;
68 };
69 
70 class PjRtTpuClient : public PjRtStreamExecutorClient {
71  public:
72   PjRtTpuClient(LocalClient* client,
73                 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
74                 int process_index);
75   ~PjRtTpuClient() override;
76 
platform_version()77   absl::string_view platform_version() const override {
78     return platform_version_;
79   }
80 
81   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
82       int num_replicas, int num_partitions) const override;
83 
EnqueueD2DTransfersOnSrcStream()84   bool EnqueueD2DTransfersOnSrcStream() const override { return false; }
85 
86   StatusOr<std::optional<std::string>> ExecutableFingerprint(
87       const PjRtLoadedExecutable& executable) const override;
88 
89   StatusOr<std::string> SerializeExecutable(
90       const PjRtLoadedExecutable& executable) const override;
91 
92   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
93       absl::string_view serialized, CompileOptions options) override;
94 
95  private:
96   const std::string platform_version_;
97 };
98 
99 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
100     int max_inflight_computations,
101     absl::Duration init_retry_timeout = absl::ZeroDuration());
102 
103 }  // namespace xla
104 
105 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
106