1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/inlined_vector.h"
24 #include "absl/status/status.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
28 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
29 #include "tensorflow/compiler/xla/pjrt/utils.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/tpu/tpu_initializer_helper.h"
39 #include "tensorflow/stream_executor/device_memory.h"
40 #include "tensorflow/stream_executor/lib/statusor.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
45 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
46 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
47
48 namespace tf_tpu = tensorflow::tpu;
49
50 namespace xla {
51 namespace {
52
53 class TpuDeviceState : public LocalDeviceState {
54 public:
55 TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
56 int max_inflight_computations);
57
58 Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
59 se::Stream* dst_stream,
60 se::DeviceMemoryBase src_buffer,
61 se::DeviceMemoryBase dst_buffer) override;
62 };
63
TpuDeviceState(se::StreamExecutor * executor,LocalClient * client,int max_inflight_computations)64 TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
65 LocalClient* client,
66 int max_inflight_computations)
67 : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
68 max_inflight_computations,
69 /*allow_event_reuse=*/false,
70 /*use_callback_stream=*/true) {}
71
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)72 Status TpuDeviceState::ThenMemcpyDeviceToDevice(
73 se::Stream* transfer_stream, se::Stream* dst_stream,
74 se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
75 auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
76 transfer_stream->implementation());
77 TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
78 src_buffer, dst_buffer));
79 return OkStatus();
80 }
81
82 } // namespace
83
PjRtTpuClient(LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index)84 PjRtTpuClient::PjRtTpuClient(
85 LocalClient* client,
86 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
87 int process_index)
88 : PjRtStreamExecutorClient(TpuName(), client, std::move(devices),
89 process_index,
90 /*allocator=*/nullptr,
91 /*host_memory_allocator=*/nullptr,
92 /*should_stage_host_to_device_transfers=*/false,
93 /*gpu_run_options=*/nullptr),
94 platform_version_([]() {
95 // Example platform version string:
96 // libtpu version 0.0.1
97 // Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169
98 tf_tpu::TpuPlatformInterface* platform =
99 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
100 TpuRuntimeVersion version = platform->version();
101 return absl::StrCat(
102 "libtpu version ", absl::StrJoin(version.version, "."), "\n",
103 absl::string_view(version.metadata, version.metadata_size));
104 }()) {
105 // We always initialize the tpu client even if libtpu isn't linked in or
106 // initialized.
107 if (tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_InitFn !=
108 nullptr) {
109 tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_InitFn();
110 }
111 }
112
~PjRtTpuClient()113 PjRtTpuClient::~PjRtTpuClient() {
114 if (tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_ShutdownFn !=
115 nullptr) {
116 tf_tpu::ExecutorApiFn()->TpuAsyncCollectiveOffloadHelper_ShutdownFn();
117 }
118 }
119
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const120 StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
121 int num_replicas, int num_partitions) const {
122 tf_tpu::TpuPlatformInterface* platform =
123 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
124 tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
125 int num_local_devices = host.Cores(kTensorCore).size();
126 if (num_replicas * num_partitions <= num_local_devices) {
127 return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
128 num_partitions);
129 }
130 // Fallback to default global device assignment if we can't run locally.
131 return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
132 num_partitions);
133 }
134
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const135 StatusOr<std::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
136 const PjRtLoadedExecutable& executable) const {
137 if (executable.client() != this) {
138 return InvalidArgument(
139 "Passed executable from different client (platform '%s') to "
140 "PjRtTpuClient::ExecutableFingerprint",
141 executable.client()->platform_name());
142 }
143 if (executable.num_partitions() > 1) {
144 LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
145 "executables, fingerprint may not be unique.";
146 }
147 xla::TpuExecutableInterface* tpu_executable =
148 tensorflow::down_cast<xla::TpuExecutableInterface*>(
149 tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(
150 &executable)
151 ->executables()[0]
152 ->executable());
153 return std::optional<std::string>(tpu_executable->fingerprint());
154 }
155
SerializeExecutable(const PjRtLoadedExecutable & executable) const156 StatusOr<std::string> PjRtTpuClient::SerializeExecutable(
157 const PjRtLoadedExecutable& executable) const {
158 const PjRtStreamExecutorExecutable* se_executable =
159 tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(&executable);
160 if (se_executable->executables().size() > 1) {
161 return Unimplemented(
162 "PjRtTpuClient::SerializeExecutable unimplemented for MPMD "
163 "executables");
164 }
165 const TpuExecutable* tpu_executable =
166 tensorflow::down_cast<const TpuExecutable*>(
167 se_executable->executables()[0]->executable());
168 return tpu_executable->Serialize();
169 }
170
171 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,CompileOptions options)172 PjRtTpuClient::DeserializeExecutable(absl::string_view serialized,
173 CompileOptions options) {
174 TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuExecutable> tpu_executable,
175 TpuExecutable::Deserialize(serialized));
176
177 TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
178
179 // TODO(skyewm): can we streamline this? e.g. removing proto serialization
180 XlaComputation computation(tpu_executable->module().ToProto());
181 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
182 computation.GetProgramShape());
183 std::vector<const Shape*> unused_argument_layout_pointers;
184 TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
185 computation,
186 [local_client = client()](Shape shape) {
187 return local_client->backend()
188 .transfer_manager()
189 ->ChooseCompactLayoutForShape(shape);
190 },
191 options.argument_layouts, &options.executable_build_options,
192 &unused_argument_layout_pointers));
193
194 auto local_executable = std::make_unique<LocalExecutable>(
195 std::move(tpu_executable), client_->mutable_backend(),
196 options.executable_build_options);
197 std::vector<std::unique_ptr<LocalExecutable>> local_executables;
198 local_executables.emplace_back(std::move(local_executable));
199
200 auto pjrt_executable = std::make_unique<PjRtStreamExecutorExecutable>(
201 std::move(local_executables), options.parameter_is_tupled_arguments,
202 std::move(extras.device_assignment),
203 std::move(extras.addressable_device_logical_ids),
204 std::move(extras.addressable_devices), this);
205 TF_RETURN_IF_ERROR(
206 pjrt_executable->SetUpDonation(options.parameter_is_tupled_arguments));
207 return std::unique_ptr<PjRtLoadedExecutable>(std::move(pjrt_executable));
208 }
209
210 static StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>>
GetTpuDevices(LocalClient * client,std::vector<std::unique_ptr<LocalDeviceState>> local_device_states)211 GetTpuDevices(
212 LocalClient* client,
213 std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
214 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
215 tf_tpu::TpuTopologyExternal topology =
216 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
217
218 std::map<int, int> core_id_to_device_ordinal;
219 for (int i = 0; i < client->device_count(); ++i) {
220 se::StreamExecutor* executor =
221 client->backend().stream_executor(i).ValueOrDie();
222 tf_tpu::TpuExecutorInterface* tpu_executor =
223 tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
224 executor->implementation());
225 core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
226 }
227
228 for (const tf_tpu::TpuCoreLocationExternal& core :
229 topology.cores(TpuCoreTypeEnum::kTensorCore)) {
230 auto it = core_id_to_device_ordinal.find(core.Id());
231 int device_ordinal =
232 (it != core_id_to_device_ordinal.end()) ? it->second : -1;
233 int process_index = topology.IdForHost(core.host_coordinates());
234 const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
235 std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
236 std::unique_ptr<LocalDeviceState> local_device_state;
237 if (device_ordinal >= 0) {
238 local_device_state = std::move(local_device_states[device_ordinal]);
239 }
240 auto device = std::make_unique<PjRtTpuDevice>(
241 core, std::move(local_device_state), process_index, coords_array,
242 std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
243 devices.push_back(std::move(device));
244 }
245 return devices;
246 }
247
GetTpuClient(int max_inflight_computations,absl::Duration init_retry_timeout)248 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
249 int max_inflight_computations, absl::Duration init_retry_timeout) {
250 #if !defined(PLATFORM_GOOGLE) || defined(LIBTPU_STATIC)
251 TF_RETURN_IF_ERROR(tensorflow::tpu::FindAndLoadTpuLibrary());
252 #endif
253 tf_tpu::TpuPlatformInterface* platform =
254 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
255 /*initialize_platform=*/true, /*num_tries=*/1);
256 if (platform == nullptr) {
257 return InvalidArgument("TpuPlatform is not available.");
258 }
259 // NOTE: We retry in a loop since some pod failures are transient (e.g. some
260 // RPCs may timeout waiting for other hosts to come up, but will succeed
261 // at a later point if retried).
262 auto start = absl::Now();
263 while (true) {
264 Status status = platform->Initialize({});
265 if (status.ok()) {
266 break;
267 }
268 // TODO(b/165870356): refactor this loop to be
269 // while(!platform->Initialized()) once the Initialized() function works
270 // correctly, and remove this check. The platform may already be initialized
271 // when running internally.
272 if (status.code() == tensorflow::error::ALREADY_EXISTS) {
273 LOG(INFO) << "TpuPlatform already initialized, continuing...";
274 break;
275 }
276 LOG(INFO) << "TPU platform initialization failed: " << status;
277 if ((absl::Now() - start) >= init_retry_timeout) {
278 return status;
279 }
280 absl::SleepFor(absl::Microseconds(10));
281 }
282 CHECK(platform->Initialized());
283 if (platform->VisibleDeviceCount() <= 0) {
284 return InvalidArgument("No TPU devices found.");
285 }
286 LocalClientOptions options;
287 options.set_platform(platform);
288 TF_ASSIGN_OR_RETURN(LocalClient * client,
289 ClientLibrary::GetOrCreateLocalClient(options));
290
291 std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
292 local_device_states.reserve(client->device_count());
293 for (int i = 0; i < client->device_count(); ++i) {
294 se::StreamExecutor* executor =
295 client->backend().stream_executor(i).ValueOrDie();
296 local_device_states.push_back(std::make_unique<TpuDeviceState>(
297 executor, client, max_inflight_computations));
298 }
299
300 TF_ASSIGN_OR_RETURN(auto devices,
301 GetTpuDevices(client, std::move(local_device_states)));
302 int process_index = platform->GetTpuHostLocation().Id();
303
304 return std::shared_ptr<PjRtClient>(std::make_unique<PjRtTpuClient>(
305 client, std::move(devices), process_index));
306 }
307
308 } // namespace xla
309