xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_pod_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
16 
17 #include "absl/cleanup/cleanup.h"
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/c/tf_status_helper.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21 
22 #if defined(LIBTPU_ON_GCE)
23 #include "tensorflow/core/tpu/kernels/tpu_util.h"
24 #else
25 #include "tensorflow/core/tpu/kernels/tpu_util.h"  // copybara"
26 #endif
27 
28 namespace tensorflow {
29 const char kTpuPodStateResourceName[] = "tpu_pod_state";
30 
31 namespace {
32 
33 // Attempt to delete resource_name from resource_manager's default_container.
34 // Returns OK if the deletion succeeded, or if the resource was not found. Else
35 // return the deletion error.
36 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)37 Status DeleteIfExists(ResourceMgr* resource_manager,
38                       const char* resource_name) {
39   VLOG(1) << "Removing resource " << resource_name << " if it exists";
40   Status status = resource_manager->Delete<ResourceT>(
41       resource_manager->default_container(), resource_name);
42   if (status.ok()) {
43     VLOG(1) << "Removed existing resource " << resource_name;
44     return OkStatus();
45   }
46   if (status.code() == error::NOT_FOUND) {
47     VLOG(1) << "No resource " << resource_name << " to remove";
48     return OkStatus();
49   }
50   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
51   return status;
52 }
53 
54 xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
ConstructCacheService(ResourceMgr * rmgr,int serving_port,tpu::TpuCompilationCacheInterface * compilation_cache)55 ConstructCacheService(ResourceMgr* rmgr, int serving_port,
56                       tpu::TpuCompilationCacheInterface* compilation_cache) {
57   xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
58 #if defined(LIBTPU_ON_GCE)
59   server_builder = tpu::CreateServerBuilder(serving_port);
60 #else
61   server_builder = tpu::CreateServerBuilderGoogle(serving_port);
62 #endif
63   TF_RETURN_IF_ERROR(server_builder.status());
64 
65   auto cache_service = absl::make_unique<TpuCompilationCacheService>(
66       server_builder.ValueOrDie().get(), compilation_cache);
67   cache_service->SetMemoryQuota(1ul << 31);  // 2GB
68   cache_service->Start();
69   return cache_service;
70 }
71 }  // namespace
72 
GetServerAddressAndPort(std::string * server_address,int * serving_port)73 Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
74   TF_Status* status = TF_NewStatus();
75   char* server_address_output = nullptr;
76   auto cleanup = absl::MakeCleanup([&status, &server_address_output]() {
77     TF_DeleteStatus(status);
78     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(server_address_output);
79   });
80   size_t server_address_output_size;
81   *serving_port = -1;
82 
83   TpuConfigurationApi_GetServerAddressAndPort_Params params;
84   params.struct_size = TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE;
85   params.priv = nullptr;
86   params.server_address_output_size = &server_address_output_size;
87   params.server_address_output = &server_address_output;
88   params.port_output = serving_port;
89   params.status = status;
90 
91   tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(&params);
92   TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
93   *server_address =
94       std::string(server_address_output, server_address_output_size);
95   CHECK_NE(*serving_port, -1);
96   return OkStatus();
97 }
98 
TpuPodState(int service_port,std::unique_ptr<TpuCompilationCacheService> cache_service)99 TpuPodState::TpuPodState(
100     int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
101     : cache_service_(std::move(cache_service)), service_port_(service_port) {}
102 
~TpuPodState()103 TpuPodState::~TpuPodState() {
104   if (cache_service_) {
105     VLOG(1) << "Shutting down Compilation Cache Service.";
106     if (cache_service_->Shutdown(20)) {
107       if (service_port_ >= 0) {
108         tpu::OpsApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
109       }
110     } else {
111       LOG(ERROR)
112           << "Failed to shutdown Compilation Cache Service within timeout.";
113     }
114   }
115   VLOG(1) << "Shutting down Compilation Cache Service done.";
116 }
117 
DebugString() const118 string TpuPodState::DebugString() const {
119   return "Wrapper for distributed TPU state";
120 }
121 
GetTPUPodState(const ResourceMgr * rmgr,TpuPodState ** pod_state)122 Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state) {
123   if (!rmgr) {
124     return errors::Internal("No resource manager.");
125   }
126   if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
127                     pod_state)
128            .ok()) {
129     return errors::FailedPrecondition(
130         "The TPU system has not been initialized.");
131   }
132   return OkStatus();
133 }
134 
HasTPUPodState(const ResourceMgr * rmgr)135 bool HasTPUPodState(const ResourceMgr* rmgr) {
136   TpuPodState* pod_state;
137   if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
138                     &pod_state)
139            .ok()) {
140     return false;
141   }
142   pod_state->Unref();
143   return true;
144 }
145 
ConstructTpuPodState(ResourceMgr * rmgr,const std::vector<int32_t> & num_devices_per_host,tpu::TpuCompilationCacheInterface * compilation_cache,std::string * host_config_proto)146 Status ConstructTpuPodState(
147     ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
148     tpu::TpuCompilationCacheInterface* compilation_cache,
149     std::string* host_config_proto) {
150   TF_Status* status = TF_NewStatus();
151   auto status_cleanup =
152       absl::MakeCleanup([&status]() { TF_DeleteStatus(status); });
153 
154   int serving_port;
155   std::string server_address;
156   TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port));
157 
158   char* host_config_output = nullptr;
159   auto host_config_cleanup = absl::MakeCleanup([&host_config_output]() {
160     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
161   });
162   size_t host_config_output_size;
163 
164   ConfigureDistributedTpuOp_DoWork_Params params;
165   params.struct_size = ConfigureDistributedTpuOp_DoWork_Params_SIZE;
166   params.priv = nullptr;
167   params.num_cores_per_host_size = num_devices_per_host.size();
168   params.num_cores_per_host = num_devices_per_host.data();
169   params.server_address_size = server_address.size();
170   params.server_address = server_address.data();
171   params.host_config_output_size = &host_config_output_size;
172   params.host_config_output = &host_config_output;
173   params.status = status;
174 
175   tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(&params);
176   TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
177   *host_config_proto = std::string(host_config_output, host_config_output_size);
178 
179   TF_ASSIGN_OR_RETURN(
180       std::unique_ptr<TpuCompilationCacheService> cache_service,
181       ConstructCacheService(rmgr, serving_port, compilation_cache));
182 
183   // Delete TpuPodState if it exists, and recreate below.
184   TF_RETURN_IF_ERROR(
185       DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
186   return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName,
187                       new TpuPodState(serving_port, std::move(cache_service)));
188 }
189 }  // namespace tensorflow
190