xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "absl/cleanup/cleanup.h"
20 #include "absl/time/time.h"
21 #include "tensorflow/c/tf_status.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/framework/collective.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/op_requires.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
31 #include "tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
33 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
34 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_configuration.h"
37 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
38 #include "tensorflow/dtensor/cc/dstatus.h"
39 #include "tensorflow/dtensor/cc/tpu_system_interface.h"
40 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
41 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
42 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
43 
44 // Timeout for waiting for TPU devices to appear.
45 const absl::Duration dtensor_tpu_init_retry_timeout = absl::Seconds(30);
46 
47 namespace tensorflow {
48 namespace dtensor {
49 
50 // Attempt to delete resource_name from resource_manager's default_container.
51 // Returns OK if the deletion succeeded, or if the resource was not found. Else
52 // return the deletion error.
53 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)54 Status DeleteIfExists(ResourceMgr* resource_manager,
55                       const char* resource_name) {
56   VLOG(1) << "Removing resource " << resource_name << " if it exists";
57   Status status = resource_manager->Delete<ResourceT>(
58       resource_manager->default_container(), resource_name);
59   if (status.ok()) {
60     VLOG(1) << "Removed existing resource " << resource_name;
61     return OkStatus();
62   }
63   if (status.code() == error::NOT_FOUND) {
64     VLOG(1) << "No resource " << resource_name << " to remove";
65     return OkStatus();
66   }
67   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
68   return status;
69 }
70 
71 class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel {
72  public:
ConfigureAndInitializeGlobalTPUOpKernel(OpKernelConstruction * ctx)73   explicit ConfigureAndInitializeGlobalTPUOpKernel(OpKernelConstruction* ctx)
74       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)75   void Compute(OpKernelContext* ctx) override {
76     LOG(INFO) << "ConfigureAndInitializeGlobalTPUOpKernel op";
77 
78     ResourceMgr* rmgr = GetTPUConfigResourceMgr();
79     std::vector<int32> core_id_output_vec;
80     auto retry_timeout = dtensor_tpu_init_retry_timeout;
81 
82     TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
83     if (tpu_system == nullptr) {
84       VLOG(1) << "Initializing the default TPU system.";
85       OP_REQUIRES_OK(ctx, InitializeInternal(ctx, rmgr, retry_timeout,
86                                              &core_id_output_vec));
87     } else {
88       VLOG(1) << "Initializing a preferred TPU system.";
89       OP_REQUIRES_OK(ctx, tpu_system->Initialize(ctx, rmgr, retry_timeout,
90                                                  &core_id_output_vec));
91     }
92 
93     if (VLOG_IS_ON(1)) {
94       LOG(INFO) << "core_id_output_vec";
95       for (auto i : core_id_output_vec) {
96         LOG(INFO) << i;
97       }
98     }
99 
100     // Set output using local core ID vector.
101     Tensor* ctx_output;
102     auto core_id_output_vec_size = core_id_output_vec.size();
103     OP_REQUIRES_OK(
104         ctx,
105         ctx->allocate_output(
106             0, TensorShape({static_cast<long long>(core_id_output_vec_size)}),
107             &ctx_output));
108     for (size_t i = 0; i < core_id_output_vec_size; ++i) {
109       ctx_output->flat<int32>()(i) = core_id_output_vec[i];
110     }
111 
112     LOG(INFO) << "ConfigureAndInitializeGlobalTPUOpKernel done";
113   }
114 
~ConfigureAndInitializeGlobalTPUOpKernel()115   ~ConfigureAndInitializeGlobalTPUOpKernel() override {}
116 
117  private:
118   // ConfigureAndInitializeGlobalTPUOpKernel is neither copyable nor movable.
119   ConfigureAndInitializeGlobalTPUOpKernel(
120       const ConfigureAndInitializeGlobalTPUOpKernel&) = delete;
121   ConfigureAndInitializeGlobalTPUOpKernel& operator=(
122       const ConfigureAndInitializeGlobalTPUOpKernel&) = delete;
123 
InitializeInternal(OpKernelContext * ctx,ResourceMgr * rmgr,absl::Duration retry_timeout,std::vector<int32> * core_id_output_vec)124   static Status InitializeInternal(OpKernelContext* ctx, ResourceMgr* rmgr,
125                                    absl::Duration retry_timeout,
126                                    std::vector<int32>* core_id_output_vec) {
127     // Reset the TPU embedding engine interface if we are not the master.
128     // We need to reset the interface before initializing the host because the
129     // resetting process reset the TPU platform.
130     TF_RETURN_IF_ERROR(DeleteIfExists<tpu::TpuEmbeddingEngineStateInterface>(
131         rmgr, tpu::kTpuEmbeddingEngineStateInterfaceResourceName));
132 
133     // Create the subgraph compilation cache and put it in the local resource
134     // manager.
135     tpu::TpuCompilationCacheInterface* compilation_cache;
136     TF_RETURN_IF_ERROR(CreateTpuCompilationCache(rmgr, &compilation_cache));
137     core::ScopedUnref compilation_cache_ref(compilation_cache);
138 
139     // Initialize global tpu and set `TPUHostConfiguration` with TPU topology.
140     auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
141     if (tpu_platform == nullptr) {
142       return errors::Internal("Could not find registered TPU system.");
143     }
144 
145     auto start = absl::Now();
146     auto init_status = OkStatus();
147 
148     // Keep trying to initialize underlying TPU system until either TPU system
149     // is initialized or initialization times out.
150     while (!tpu_platform->Initialized() &&
151            (absl::Now() - start < retry_timeout)) {
152       VLOG(1) << "Initializaing global TPU system.";
153       init_status = tpu_platform->Initialize({});
154     }
155     if (!tpu_platform->Initialized()) {
156       return errors::Unavailable("Unable to initialize TPU system.");
157     }
158 
159     std::string host_config_serialized;
160     std::vector<int32> num_device_per_host;
161     const auto tpu_topology = tpu_platform->topology();
162     num_device_per_host.reserve(tpu_topology.HostCount());
163     for (int i = 0; i < tpu_topology.HostCount(); ++i) {
164       num_device_per_host.emplace_back(tpu_topology.ChipsPerHost());
165     }
166 
167     TF_RETURN_IF_ERROR(tensorflow::ConstructTpuPodState(
168         rmgr, num_device_per_host, compilation_cache, &host_config_serialized));
169 
170     // Turn `host_config_serialized` into `core_id_output_vec` by calling the
171     // guts of InitializeHostForDistributedTpuOp.
172     TF_Status* status = TF_NewStatus();
173     size_t device_id_output_size;
174     int32_t* device_id_output = nullptr;
175     auto cleanup = absl::MakeCleanup([&status, &device_id_output]() {
176       TF_DeleteStatus(status);
177       tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
178     });
179 
180     InitializeHostForDistributedTpuOp_DoWork_Params params;
181     params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
182     params.priv = nullptr;
183     params.tpu_host_config_size = host_config_serialized.size();
184     params.tpu_host_config = host_config_serialized.data();
185     params.enable_whole_mesh_compilations = false;
186     params.is_master_worker = true;
187     params.core_id_output_size = &device_id_output_size;
188     params.core_id_output = &device_id_output;
189     params.status = status;
190 
191     tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(&params);
192     TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
193     for (size_t i = 0; i < device_id_output_size; ++i) {
194       core_id_output_vec->push_back(device_id_output[i]);
195     }
196 
197     // Create resource containers used for storing TPU topology and HBM buffer
198     // configurations.
199     auto delete_status = rmgr->Delete<tpu::TpuMeshStateInterface>(
200         rmgr->default_container(), tpu::kTpuMeshStateInterfaceResourceName);
201     if (!delete_status.ok() && delete_status.code() != error::NOT_FOUND) {
202       return errors::InvalidArgument(
203           "Failed to delete mesh interface. Please try initializing "
204           "again once all TPU devices are allocated.");
205     }
206 
207     auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
208     TF_RETURN_IF_ERROR(rmgr->Create(rmgr->default_container(),
209                                     tpu::kTpuMeshStateInterfaceResourceName,
210                                     tpu_mesh));
211 
212     VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
213     Status resource_delete_status =
214         rmgr->Delete<tpu::TpuCompilationCacheLookup>(
215             rmgr->default_container(), tpu::kCompiledProtoCacheResourceName);
216 
217     tpu::TpuCompilationCacheInterface* local_compilation_cache;
218     TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
219                                     tpu::kCompilationCacheResourceName,
220                                     &local_compilation_cache));
221     local_compilation_cache->Unref();
222 
223     VLOG(1) << "Creating compilation proto cache resource";
224     tpu::TpuCompilationCacheLookup* proto_lookup;
225     proto_lookup =
226         new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
227     TF_RETURN_IF_ERROR(rmgr->Create(rmgr->default_container(),
228                                     tpu::kCompiledProtoCacheResourceName,
229                                     proto_lookup));
230     TF_RETURN_IF_ERROR(
231         rmgr->Create(rmgr->default_container(),
232                      tpu::kTpuEmbeddingEngineStateInterfaceResourceName,
233                      tpu::TpuEmbeddingEngineStateInterface::Create()));
234 
235     return OkStatus();
236   }
237 };
238 
239 class ShutdownTPUSystemOpKernel : public OpKernel {
240  public:
ShutdownTPUSystemOpKernel(OpKernelConstruction * ctx)241   explicit ShutdownTPUSystemOpKernel(OpKernelConstruction* ctx)
242       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)243   void Compute(OpKernelContext* ctx) override {
244     LOG(INFO) << "ShutdownTPUSystemOpKernel op";
245 
246     Status status;
247     TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
248     if (tpu_system == nullptr) {
249       VLOG(1) << "Shutting down the default TPU system.";
250       // In current runtime, we reset the TPU platform, which in turn shuts
251       // down the tpu::System.
252       auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
253       OP_REQUIRES(ctx, tpu_platform != nullptr,
254                   errors::Internal("Could not find registered TPU system."));
255 
256       status = tpu_platform->Reset(/*only_tear_down=*/true,
257                                    /*reason=*/"ShutdownSystem");
258     } else {
259       VLOG(1) << "Shutting down a preferred TPU system.";
260       status = tpu_system->Shutdown();
261     }
262 
263     Tensor* output_tensor;
264     OP_REQUIRES_OK(ctx,
265                    ctx->allocate_output(0, TensorShape({1}), &output_tensor));
266 
267     if (status.ok()) {
268       output_tensor->flat<bool>()(0) = true;
269     } else {
270       output_tensor->flat<bool>()(0) = false;
271     }
272   }
273 };
274 
275 class SetGlobalTPUArrayOpKernel : public OpKernel {
276  public:
SetGlobalTPUArrayOpKernel(OpKernelConstruction * ctx)277   explicit SetGlobalTPUArrayOpKernel(OpKernelConstruction* ctx)
278       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)279   void Compute(OpKernelContext* ctx) override {
280     VLOG(1) << "SetGlobalTPUArrayOpKernel op";
281     auto tpu_topology = ctx->input(0).scalar<tstring>()();
282     TF_Status* status = TF_NewStatus();
283 
284     tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
285                                                   tpu_topology.data(), status);
286     OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
287     TF_DeleteStatus(status);
288 
289     VLOG(1) << "SetGlobalTPUArrayOpKernel done";
290   }
291 };
292 
293 REGISTER_KERNEL_BUILDER(Name("ConfigureAndInitializeGlobalTPU")
294                             .Device(DEVICE_TPU_SYSTEM)
295                             .HostMemory("output"),
296                         ConfigureAndInitializeGlobalTPUOpKernel);
297 
298 REGISTER_KERNEL_BUILDER(Name("ShutdownTPUSystem").Device(DEVICE_TPU_SYSTEM),
299                         ShutdownTPUSystemOpKernel);
300 
301 REGISTER_KERNEL_BUILDER(Name("DTensorSetGlobalTPUArray")
302                             .Device(DEVICE_TPU_SYSTEM)
303                             .HostMemory("topology"),
304                         SetGlobalTPUArrayOpKernel);
305 
306 }  // namespace dtensor
307 }  // namespace tensorflow
308