1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "absl/cleanup/cleanup.h"
20 #include "absl/time/time.h"
21 #include "tensorflow/c/tf_status.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/collective.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/op_requires.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
31 #include "tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
33 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
34 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_configuration.h"
37 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
38 #include "tensorflow/dtensor/cc/dstatus.h"
39 #include "tensorflow/dtensor/cc/tpu_system_interface.h"
40 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
41 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
42 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
43
44 // Timeout for waiting for TPU devices to appear.
45 const absl::Duration dtensor_tpu_init_retry_timeout = absl::Seconds(30);
46
47 namespace tensorflow {
48 namespace dtensor {
49
50 // Attempt to delete resource_name from resource_manager's default_container.
51 // Returns OK if the deletion succeeded, or if the resource was not found. Else
52 // return the deletion error.
53 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)54 Status DeleteIfExists(ResourceMgr* resource_manager,
55 const char* resource_name) {
56 VLOG(1) << "Removing resource " << resource_name << " if it exists";
57 Status status = resource_manager->Delete<ResourceT>(
58 resource_manager->default_container(), resource_name);
59 if (status.ok()) {
60 VLOG(1) << "Removed existing resource " << resource_name;
61 return OkStatus();
62 }
63 if (status.code() == error::NOT_FOUND) {
64 VLOG(1) << "No resource " << resource_name << " to remove";
65 return OkStatus();
66 }
67 VLOG(1) << "Error removing resource " << resource_name << " : " << status;
68 return status;
69 }
70
71 class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel {
72 public:
ConfigureAndInitializeGlobalTPUOpKernel(OpKernelConstruction * ctx)73 explicit ConfigureAndInitializeGlobalTPUOpKernel(OpKernelConstruction* ctx)
74 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)75 void Compute(OpKernelContext* ctx) override {
76 LOG(INFO) << "ConfigureAndInitializeGlobalTPUOpKernel op";
77
78 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
79 std::vector<int32> core_id_output_vec;
80 auto retry_timeout = dtensor_tpu_init_retry_timeout;
81
82 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
83 if (tpu_system == nullptr) {
84 VLOG(1) << "Initializing the default TPU system.";
85 OP_REQUIRES_OK(ctx, InitializeInternal(ctx, rmgr, retry_timeout,
86 &core_id_output_vec));
87 } else {
88 VLOG(1) << "Initializing a preferred TPU system.";
89 OP_REQUIRES_OK(ctx, tpu_system->Initialize(ctx, rmgr, retry_timeout,
90 &core_id_output_vec));
91 }
92
93 if (VLOG_IS_ON(1)) {
94 LOG(INFO) << "core_id_output_vec";
95 for (auto i : core_id_output_vec) {
96 LOG(INFO) << i;
97 }
98 }
99
100 // Set output using local core ID vector.
101 Tensor* ctx_output;
102 auto core_id_output_vec_size = core_id_output_vec.size();
103 OP_REQUIRES_OK(
104 ctx,
105 ctx->allocate_output(
106 0, TensorShape({static_cast<long long>(core_id_output_vec_size)}),
107 &ctx_output));
108 for (size_t i = 0; i < core_id_output_vec_size; ++i) {
109 ctx_output->flat<int32>()(i) = core_id_output_vec[i];
110 }
111
112 LOG(INFO) << "ConfigureAndInitializeGlobalTPUOpKernel done";
113 }
114
~ConfigureAndInitializeGlobalTPUOpKernel()115 ~ConfigureAndInitializeGlobalTPUOpKernel() override {}
116
117 private:
118 // ConfigureAndInitializeGlobalTPUOpKernel is neither copyable nor movable.
119 ConfigureAndInitializeGlobalTPUOpKernel(
120 const ConfigureAndInitializeGlobalTPUOpKernel&) = delete;
121 ConfigureAndInitializeGlobalTPUOpKernel& operator=(
122 const ConfigureAndInitializeGlobalTPUOpKernel&) = delete;
123
InitializeInternal(OpKernelContext * ctx,ResourceMgr * rmgr,absl::Duration retry_timeout,std::vector<int32> * core_id_output_vec)124 static Status InitializeInternal(OpKernelContext* ctx, ResourceMgr* rmgr,
125 absl::Duration retry_timeout,
126 std::vector<int32>* core_id_output_vec) {
127 // Reset the TPU embedding engine interface if we are not the master.
128 // We need to reset the interface before initializing the host because the
129 // resetting process reset the TPU platform.
130 TF_RETURN_IF_ERROR(DeleteIfExists<tpu::TpuEmbeddingEngineStateInterface>(
131 rmgr, tpu::kTpuEmbeddingEngineStateInterfaceResourceName));
132
133 // Create the subgraph compilation cache and put it in the local resource
134 // manager.
135 tpu::TpuCompilationCacheInterface* compilation_cache;
136 TF_RETURN_IF_ERROR(CreateTpuCompilationCache(rmgr, &compilation_cache));
137 core::ScopedUnref compilation_cache_ref(compilation_cache);
138
139 // Initialize global tpu and set `TPUHostConfiguration` with TPU topology.
140 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
141 if (tpu_platform == nullptr) {
142 return errors::Internal("Could not find registered TPU system.");
143 }
144
145 auto start = absl::Now();
146 auto init_status = OkStatus();
147
148 // Keep trying to initialize underlying TPU system until either TPU system
149 // is initialized or initialization times out.
150 while (!tpu_platform->Initialized() &&
151 (absl::Now() - start < retry_timeout)) {
152 VLOG(1) << "Initializaing global TPU system.";
153 init_status = tpu_platform->Initialize({});
154 }
155 if (!tpu_platform->Initialized()) {
156 return errors::Unavailable("Unable to initialize TPU system.");
157 }
158
159 std::string host_config_serialized;
160 std::vector<int32> num_device_per_host;
161 const auto tpu_topology = tpu_platform->topology();
162 num_device_per_host.reserve(tpu_topology.HostCount());
163 for (int i = 0; i < tpu_topology.HostCount(); ++i) {
164 num_device_per_host.emplace_back(tpu_topology.ChipsPerHost());
165 }
166
167 TF_RETURN_IF_ERROR(tensorflow::ConstructTpuPodState(
168 rmgr, num_device_per_host, compilation_cache, &host_config_serialized));
169
170 // Turn `host_config_serialized` into `core_id_output_vec` by calling the
171 // guts of InitializeHostForDistributedTpuOp.
172 TF_Status* status = TF_NewStatus();
173 size_t device_id_output_size;
174 int32_t* device_id_output = nullptr;
175 auto cleanup = absl::MakeCleanup([&status, &device_id_output]() {
176 TF_DeleteStatus(status);
177 tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
178 });
179
180 InitializeHostForDistributedTpuOp_DoWork_Params params;
181 params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
182 params.priv = nullptr;
183 params.tpu_host_config_size = host_config_serialized.size();
184 params.tpu_host_config = host_config_serialized.data();
185 params.enable_whole_mesh_compilations = false;
186 params.is_master_worker = true;
187 params.core_id_output_size = &device_id_output_size;
188 params.core_id_output = &device_id_output;
189 params.status = status;
190
191 tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
192 TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
193 for (size_t i = 0; i < device_id_output_size; ++i) {
194 core_id_output_vec->push_back(device_id_output[i]);
195 }
196
197 // Create resource containers used for storing TPU topology and HBM buffer
198 // configurations.
199 auto delete_status = rmgr->Delete<tpu::TpuMeshStateInterface>(
200 rmgr->default_container(), tpu::kTpuMeshStateInterfaceResourceName);
201 if (!delete_status.ok() && delete_status.code() != error::NOT_FOUND) {
202 return errors::InvalidArgument(
203 "Failed to delete mesh interface. Please try initializing "
204 "again once all TPU devices are allocated.");
205 }
206
207 auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
208 TF_RETURN_IF_ERROR(rmgr->Create(rmgr->default_container(),
209 tpu::kTpuMeshStateInterfaceResourceName,
210 tpu_mesh));
211
212 VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
213 Status resource_delete_status =
214 rmgr->Delete<tpu::TpuCompilationCacheLookup>(
215 rmgr->default_container(), tpu::kCompiledProtoCacheResourceName);
216
217 tpu::TpuCompilationCacheInterface* local_compilation_cache;
218 TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
219 tpu::kCompilationCacheResourceName,
220 &local_compilation_cache));
221 local_compilation_cache->Unref();
222
223 VLOG(1) << "Creating compilation proto cache resource";
224 tpu::TpuCompilationCacheLookup* proto_lookup;
225 proto_lookup =
226 new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
227 TF_RETURN_IF_ERROR(rmgr->Create(rmgr->default_container(),
228 tpu::kCompiledProtoCacheResourceName,
229 proto_lookup));
230 TF_RETURN_IF_ERROR(
231 rmgr->Create(rmgr->default_container(),
232 tpu::kTpuEmbeddingEngineStateInterfaceResourceName,
233 tpu::TpuEmbeddingEngineStateInterface::Create()));
234
235 return OkStatus();
236 }
237 };
238
239 class ShutdownTPUSystemOpKernel : public OpKernel {
240 public:
ShutdownTPUSystemOpKernel(OpKernelConstruction * ctx)241 explicit ShutdownTPUSystemOpKernel(OpKernelConstruction* ctx)
242 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)243 void Compute(OpKernelContext* ctx) override {
244 LOG(INFO) << "ShutdownTPUSystemOpKernel op";
245
246 Status status;
247 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
248 if (tpu_system == nullptr) {
249 VLOG(1) << "Shutting down the default TPU system.";
250 // In current runtime, we reset the TPU platform, which in turn shuts
251 // down the tpu::System.
252 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
253 OP_REQUIRES(ctx, tpu_platform != nullptr,
254 errors::Internal("Could not find registered TPU system."));
255
256 status = tpu_platform->Reset(/*only_tear_down=*/true,
257 /*reason=*/"ShutdownSystem");
258 } else {
259 VLOG(1) << "Shutting down a preferred TPU system.";
260 status = tpu_system->Shutdown();
261 }
262
263 Tensor* output_tensor;
264 OP_REQUIRES_OK(ctx,
265 ctx->allocate_output(0, TensorShape({1}), &output_tensor));
266
267 if (status.ok()) {
268 output_tensor->flat<bool>()(0) = true;
269 } else {
270 output_tensor->flat<bool>()(0) = false;
271 }
272 }
273 };
274
275 class SetGlobalTPUArrayOpKernel : public OpKernel {
276 public:
SetGlobalTPUArrayOpKernel(OpKernelConstruction * ctx)277 explicit SetGlobalTPUArrayOpKernel(OpKernelConstruction* ctx)
278 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)279 void Compute(OpKernelContext* ctx) override {
280 VLOG(1) << "SetGlobalTPUArrayOpKernel op";
281 auto tpu_topology = ctx->input(0).scalar<tstring>()();
282 TF_Status* status = TF_NewStatus();
283
284 tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
285 tpu_topology.data(), status);
286 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
287 TF_DeleteStatus(status);
288
289 VLOG(1) << "SetGlobalTPUArrayOpKernel done";
290 }
291 };
292
293 REGISTER_KERNEL_BUILDER(Name("ConfigureAndInitializeGlobalTPU")
294 .Device(DEVICE_TPU_SYSTEM)
295 .HostMemory("output"),
296 ConfigureAndInitializeGlobalTPUOpKernel);
297
298 REGISTER_KERNEL_BUILDER(Name("ShutdownTPUSystem").Device(DEVICE_TPU_SYSTEM),
299 ShutdownTPUSystemOpKernel);
300
301 REGISTER_KERNEL_BUILDER(Name("DTensorSetGlobalTPUArray")
302 .Device(DEVICE_TPU_SYSTEM)
303 .HostMemory("topology"),
304 SetGlobalTPUArrayOpKernel);
305
306 } // namespace dtensor
307 } // namespace tensorflow
308