xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tpu_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "absl/status/status.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
28 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
29 #include "tensorflow/compiler/xla/pjrt/utils.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/tpu/tpu_initializer_helper.h"
39 #include "tensorflow/stream_executor/device_memory.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
45 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
46 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
47 
48 namespace tf_tpu = tensorflow::tpu;
49 
50 namespace xla {
51 namespace {
52 
53 class TpuDeviceState : public LocalDeviceState {
54  public:
55   TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
56                  int max_inflight_computations);
57 
58   Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
59                                   se::Stream* dst_stream,
60                                   se::DeviceMemoryBase src_buffer,
61                                   se::DeviceMemoryBase dst_buffer) override;
62 };
63 
TpuDeviceState(se::StreamExecutor * executor,LocalClient * client,int max_inflight_computations)64 TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
65                                LocalClient* client,
66                                int max_inflight_computations)
67     : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
68                        max_inflight_computations,
69                        /*allow_event_reuse=*/false,
70                        /*use_callback_stream=*/true) {}
71 
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)72 Status TpuDeviceState::ThenMemcpyDeviceToDevice(
73     se::Stream* transfer_stream, se::Stream* dst_stream,
74     se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
75   auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
76       transfer_stream->implementation());
77   TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
78       src_buffer, dst_buffer));
79   return OkStatus();
80 }
81 
82 }  // namespace
83 
PjRtTpuClient(LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index)84 PjRtTpuClient::PjRtTpuClient(
85     LocalClient* client,
86     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
87     int process_index)
88     : PjRtStreamExecutorClient(TpuName(), client, std::move(devices),
89                                process_index,
90                                /*allocator=*/nullptr,
91                                /*host_memory_allocator=*/nullptr,
92                                /*should_stage_host_to_device_transfers=*/false,
93                                /*gpu_run_options=*/nullptr),
94       platform_version_([]() {
95         // Example platform version string:
96         //   libtpu version 0.0.1
97         //   Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169
98         tf_tpu::TpuPlatformInterface* platform =
99             tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
100         TpuRuntimeVersion version = platform->version();
101         return absl::StrCat(
102             "libtpu version ", absl::StrJoin(version.version, "."), "\n",
103             absl::string_view(version.metadata, version.metadata_size));
104       }()) {
105   // We always initialize the tpu client even if libtpu isn't linked in or
106   // initialized.
107   if (tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_InitFn !=
108       nullptr) {
109     tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_InitFn();
110   }
111 }
112 
~PjRtTpuClient()113 PjRtTpuClient::~PjRtTpuClient() {
114   if (tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_ShutdownFn !=
115       nullptr) {
116     tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_ShutdownFn();
117   }
118 }
119 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const120 StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
121     int num_replicas, int num_partitions) const {
122   tf_tpu::TpuPlatformInterface* platform =
123       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
124   tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
125   int num_local_devices = host.Cores(kTensorCore).size();
126   if (num_replicas * num_partitions <= num_local_devices) {
127     return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
128                                                             num_partitions);
129   }
130   // Fallback to default global device assignment if we can't run locally.
131   return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
132                                                               num_partitions);
133 }
134 
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const135 StatusOr<std::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
136     const PjRtLoadedExecutable& executable) const {
137   if (executable.client() != this) {
138     return InvalidArgument(
139         "Passed executable from different client (platform '%s') to "
140         "PjRtTpuClient::ExecutableFingerprint",
141         executable.client()->platform_name());
142   }
143   if (executable.num_partitions() > 1) {
144     LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
145                  "executables, fingerprint may not be unique.";
146   }
147   xla::TpuExecutableInterface* tpu_executable =
148       tensorflow::down_cast<xla::TpuExecutableInterface*>(
149           tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(
150               &executable)
151               ->executables()[0]
152               ->executable());
153   return std::optional<std::string>(tpu_executable->fingerprint());
154 }
155 
SerializeExecutable(const PjRtLoadedExecutable & executable) const156 StatusOr<std::string> PjRtTpuClient::SerializeExecutable(
157     const PjRtLoadedExecutable& executable) const {
158   const PjRtStreamExecutorExecutable* se_executable =
159       tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(&executable);
160   if (se_executable->executables().size() > 1) {
161     return Unimplemented(
162         "PjRtTpuClient::SerializeExecutable unimplemented for MPMD "
163         "executables");
164   }
165   const TpuExecutable* tpu_executable =
166       tensorflow::down_cast<const TpuExecutable*>(
167           se_executable->executables()[0]->executable());
168   return tpu_executable->Serialize();
169 }
170 
171 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,CompileOptions options)172 PjRtTpuClient::DeserializeExecutable(absl::string_view serialized,
173                                      CompileOptions options) {
174   TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuExecutable> tpu_executable,
175                       TpuExecutable::Deserialize(serialized));
176 
177   TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
178 
179   // TODO(skyewm): can we streamline this? e.g. removing proto serialization
180   XlaComputation computation(tpu_executable->module().ToProto());
181   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
182                       computation.GetProgramShape());
183   std::vector<const Shape*> unused_argument_layout_pointers;
184   TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
185       computation,
186       [local_client = client()](Shape shape) {
187         return local_client->backend()
188             .transfer_manager()
189             ->ChooseCompactLayoutForShape(shape);
190       },
191       options.argument_layouts, &options.executable_build_options,
192       &unused_argument_layout_pointers));
193 
194   auto local_executable = std::make_unique<LocalExecutable>(
195       std::move(tpu_executable), client_->mutable_backend(),
196       options.executable_build_options);
197   std::vector<std::unique_ptr<LocalExecutable>> local_executables;
198   local_executables.emplace_back(std::move(local_executable));
199 
200   auto pjrt_executable = std::make_unique<PjRtStreamExecutorExecutable>(
201       std::move(local_executables), options.parameter_is_tupled_arguments,
202       std::move(extras.device_assignment),
203       std::move(extras.addressable_device_logical_ids),
204       std::move(extras.addressable_devices), this);
205   TF_RETURN_IF_ERROR(
206       pjrt_executable->SetUpDonation(options.parameter_is_tupled_arguments));
207   return std::unique_ptr<PjRtLoadedExecutable>(std::move(pjrt_executable));
208 }
209 
210 static StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>>
GetTpuDevices(LocalClient * client,std::vector<std::unique_ptr<LocalDeviceState>> local_device_states)211 GetTpuDevices(
212     LocalClient* client,
213     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
214   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
215   tf_tpu::TpuTopologyExternal topology =
216       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
217 
218   std::map<int, int> core_id_to_device_ordinal;
219   for (int i = 0; i < client->device_count(); ++i) {
220     se::StreamExecutor* executor =
221         client->backend().stream_executor(i).ValueOrDie();
222     tf_tpu::TpuExecutorInterface* tpu_executor =
223         tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
224             executor->implementation());
225     core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
226   }
227 
228   for (const tf_tpu::TpuCoreLocationExternal& core :
229        topology.cores(TpuCoreTypeEnum::kTensorCore)) {
230     auto it = core_id_to_device_ordinal.find(core.Id());
231     int device_ordinal =
232         (it != core_id_to_device_ordinal.end()) ? it->second : -1;
233     int process_index = topology.IdForHost(core.host_coordinates());
234     const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
235     std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
236     std::unique_ptr<LocalDeviceState> local_device_state;
237     if (device_ordinal >= 0) {
238       local_device_state = std::move(local_device_states[device_ordinal]);
239     }
240     auto device = std::make_unique<PjRtTpuDevice>(
241         core, std::move(local_device_state), process_index, coords_array,
242         std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
243     devices.push_back(std::move(device));
244   }
245   return devices;
246 }
247 
GetTpuClient(int max_inflight_computations,absl::Duration init_retry_timeout)248 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
249     int max_inflight_computations, absl::Duration init_retry_timeout) {
250 #if !defined(PLATFORM_GOOGLE) || defined(LIBTPU_STATIC)
251   TF_RETURN_IF_ERROR(tensorflow::tpu::FindAndLoadTpuLibrary());
252 #endif
253   tf_tpu::TpuPlatformInterface* platform =
254       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
255           /*initialize_platform=*/true, /*num_tries=*/1);
256   if (platform == nullptr) {
257     return InvalidArgument("TpuPlatform is not available.");
258   }
259   // NOTE: We retry in a loop since some pod failures are transient (e.g. some
260   // RPCs may timeout waiting for other hosts to come up, but will succeed
261   // at a later point if retried).
262   auto start = absl::Now();
263   while (true) {
264     Status status = platform->Initialize({});
265     if (status.ok()) {
266       break;
267     }
268     // TODO(b/165870356): refactor this loop to be
269     // while(!platform->Initialized()) once the Initialized() function works
270     // correctly, and remove this check. The platform may already be initialized
271     // when running internally.
272     if (status.code() == tensorflow::error::ALREADY_EXISTS) {
273       LOG(INFO) << "TpuPlatform already initialized, continuing...";
274       break;
275     }
276     LOG(INFO) << "TPU platform initialization failed: " << status;
277     if ((absl::Now() - start) >= init_retry_timeout) {
278       return status;
279     }
280     absl::SleepFor(absl::Microseconds(10));
281   }
282   CHECK(platform->Initialized());
283   if (platform->VisibleDeviceCount() <= 0) {
284     return InvalidArgument("No TPU devices found.");
285   }
286   LocalClientOptions options;
287   options.set_platform(platform);
288   TF_ASSIGN_OR_RETURN(LocalClient * client,
289                       ClientLibrary::GetOrCreateLocalClient(options));
290 
291   std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
292   local_device_states.reserve(client->device_count());
293   for (int i = 0; i < client->device_count(); ++i) {
294     se::StreamExecutor* executor =
295         client->backend().stream_executor(i).ValueOrDie();
296     local_device_states.push_back(std::make_unique<TpuDeviceState>(
297         executor, client, max_inflight_computations));
298   }
299 
300   TF_ASSIGN_OR_RETURN(auto devices,
301                       GetTpuDevices(client, std::move(local_device_states)));
302   int process_index = platform->GetTpuHostLocation().Id();
303 
304   return std::shared_ptr<PjRtClient>(std::make_unique<PjRtTpuClient>(
305       client, std::move(devices), process_index));
306 }
307 
308 }  // namespace xla
309