xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17 
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22 
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel_common.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/thread_annotations.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 
40 namespace tensorflow {
41 
42 namespace {
43 
MakeAddress(const string & job,int task)44 string MakeAddress(const string& job, int task) {
45   return strings::StrCat("/job:", job, "/replica:0/task:", task);
46 }
47 
48 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)49 Status ValidateHostPortPair(const string& host_port) {
50   string bns_prefix = "/bns/";
51   if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
52     return OkStatus();
53   }
54   uint32 port;
55   auto colon_index = host_port.find_last_of(':');
56   if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
57       host_port.substr(0, colon_index).find('/') != string::npos) {
58     return errors::InvalidArgument("Could not interpret \"", host_port,
59                                    "\" as a host-port pair.");
60   }
61   return OkStatus();
62 }
63 
CreateDefaultChannelArguments()64 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
65   ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
66   const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
67   if (env != nullptr) {
68     for (auto& grpc_option : absl::StrSplit(env, ',')) {
69       std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
70       if (name_value.size() != 2) {
71         LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
72         continue;
73       }
74       VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
75               << name_value[1] << "'";
76       if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
77         string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
78         string value;
79         string error;
80         if (!absl::CUnescape(ue_value, &value, &error)) {
81           LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
82                      << ": " << error;
83         } else {
84           args->SetString(name_value[0], value);
85         }
86       } else {
87         int64_t value;
88         if (strings::safe_strto64(name_value[1], &value)) {
89           args->SetInt(name_value[0], value);
90         } else {
91           LOG(ERROR) << "Invalid integer value: " << grpc_option;
92         }
93       }
94     }
95   }
96   return args;
97 }
98 
GetDefaultChannelArguments()99 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
100   static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
101   return args;
102 }
103 
104 }  // namespace
105 
GetChannelArguments(const RPCOptions * rpc_options)106 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
107   // TODO(mrry): Implement secure channels.
108   ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
109   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
110   // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
111   // on connection failure, which makes our tests time out.
112   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
113   if (rpc_options != nullptr) {
114     if (rpc_options->compression_algorithm() == "deflate") {
115       args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
116       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
117                   rpc_options->compression_level());
118       VLOG(5) << "Setting GRPC compression : algo='"
119               << rpc_options->compression_algorithm()
120               << "' level=" << rpc_options->compression_level();
121     } else if (rpc_options->compression_algorithm() == "gzip") {
122       args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
123       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
124                   rpc_options->compression_level());
125       VLOG(5) << "Setting GRPC compression : algo='"
126               << rpc_options->compression_algorithm()
127               << "' level=" << rpc_options->compression_level();
128     } else if (!rpc_options->compression_algorithm().empty()) {
129       LOG(ERROR) << "Invalid compression algorithm: "
130                  << rpc_options->compression_algorithm();
131     }
132     if (rpc_options->disable_session_connection_sharing()) {
133       VLOG(5) << "Disabling TCP connection sharing";
134       args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
135     }
136   }
137   return args;
138 }
139 
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)140 Status NewHostPortGrpcChannel(const string& target,
141                               const RPCOptions* rpc_options,
142                               SharedGrpcChannelPtr* channel_pointer) {
143   // Minimally ensure that the target is valid
144   TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
145 
146   ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
147   *channel_pointer = ::grpc::CreateCustomChannel(
148       "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
149   return OkStatus();
150 }
151 
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)152 ChannelCreationFunction ConvertToChannelCreationFunction(
153     const std::function<Status(string, const RPCOptions*,
154                                SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
155   return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
156     SharedGrpcChannelPtr channel_ptr;
157     if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
158             .ok()) {
159       return channel_ptr;
160     } else {
161       return nullptr;
162     }
163   };
164 }
165 
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)166 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
167                                         const std::vector<string>& host_ports) {
168   std::map<int, string> host_ports_map;
169   for (size_t i = 0; i < host_ports.size(); ++i) {
170     host_ports_map[i] = host_ports[i];
171   }
172   return AddHostPortsJob(job_id, host_ports_map);
173 }
174 
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)175 Status GrpcChannelSpec::AddHostPortsJob(
176     const string& job_id, const std::map<int, string>& host_ports) {
177   if (!job_ids_.insert(job_id).second) {
178     return errors::InvalidArgument(
179         "Duplicate job ID in cluster specification: ", job_id);
180   }
181   for (const auto& id_host_port : host_ports) {
182     TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
183   }
184   host_ports_jobs_.emplace_back(job_id, host_ports);
185   return OkStatus();
186 }
187 
188 namespace {
189 
190 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
191 using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;
192 
193 // A ChannelCache that is the union of multiple ChannelCaches.
194 // Takes ownership of the caches passed to the constructor.
195 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
196  public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches,int num_channels_per_target)197   explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
198                                  int num_channels_per_target)
199       : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}
200 
~MultiGrpcChannelCache()201   ~MultiGrpcChannelCache() override {
202     for (GrpcChannelCache* cache : caches_) {
203       delete cache;
204     }
205   }
206 
ListWorkers(std::vector<string> * workers)207   void ListWorkers(std::vector<string>* workers) override {
208     for (GrpcChannelCache* cache : caches_) {
209       cache->ListWorkers(workers);
210     }
211   }
212 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)213   void ListWorkersInJob(const string& job_name,
214                         std::vector<string>* workers) override {
215     for (GrpcChannelCache* cache : caches_) {
216       cache->ListWorkersInJob(job_name, workers);
217     }
218   }
219 
TranslateTask(const string & target)220   string TranslateTask(const string& target) override {
221     mutex_lock l(mu_);  // could use reader lock
222     GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
223     if (cache == nullptr) {
224       for (GrpcChannelCache* c : caches_) {
225         string r = c->TranslateTask(target);
226         if (!r.empty()) {
227           target_caches_.insert({target, c});
228           cache = c;
229           break;
230         }
231       }
232     }
233     CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
234                  << target;
235     return cache->TranslateTask(target);
236   }
237 
238  protected:
FindChannelOnce(const string & target)239   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
240     for (GrpcChannelCache* cache : caches_) {
241       SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
242       if (ch) {
243         mutex_lock l(mu_);
244         target_caches_.insert({target, cache});
245         return ch;
246       }
247     }
248     return nullptr;
249   }
250 
251  private:
252   // List of channels used by this MultiGrpcChannelCache.
253   const std::vector<GrpcChannelCache*> caches_;
254 
255   mutex mu_;
256   // Cache of channels keyed by the target they are handling.
257   // The same GrpcChannelCache can appear multiple times in the cache.
258   std::unordered_map<string, GrpcChannelCache*> target_caches_
259       TF_GUARDED_BY(mu_);
260 };
261 
262 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
263  public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func,int num_channels_per_target)264   SparseGrpcChannelCache(const string& job_id,
265                          const std::map<int, string>& host_ports,
266                          ChannelCreationFunction channel_func,
267                          int num_channels_per_target)
268       : CachingGrpcChannelCache(num_channels_per_target),
269         job_id_(job_id),
270         host_ports_(host_ports),
271         channel_func_(std::move(channel_func)) {
272     VLOG(2) << "Initialize GrpcChannelCache for job " << ToString();
273   }
~SparseGrpcChannelCache()274   ~SparseGrpcChannelCache() override {}
275 
ListWorkers(std::vector<string> * workers)276   void ListWorkers(std::vector<string>* workers) override {
277     workers->reserve(workers->size() + host_ports_.size());
278     for (const auto& id_host_port : host_ports_) {
279       workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
280     }
281   }
282 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)283   void ListWorkersInJob(const string& job_name,
284                         std::vector<string>* workers) override {
285     if (job_name == job_id_) {
286       ListWorkers(workers);
287     }
288   }
289 
TranslateTask(const string & target)290   string TranslateTask(const string& target) override {
291     DeviceNameUtils::ParsedName parsed;
292     if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
293       LOG(WARNING) << "Invalid target: " << target;
294       return "";
295     }
296 
297     if (!parsed.has_job || parsed.job != job_id_) {
298       return "";
299     }
300     if (!parsed.has_replica || parsed.replica != 0) {
301       LOG(WARNING) << "Replica ID must be 0 in target: " << target;
302       return "";
303     }
304     int32_t task = parsed.has_task ? parsed.task : -1;
305     auto iter = host_ports_.find(task);
306     if (iter == host_ports_.end()) {
307       LOG(WARNING) << "Task " << task << " was not defined in sparse job "
308                    << job_id_ << ": " << target;
309       return "";
310     }
311     return iter->second;
312   }
313 
314  protected:
FindChannelOnce(const string & target)315   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
316     const string host_port = TranslateTask(target);
317     if (host_port.empty()) {
318       return nullptr;
319     }
320     auto chan_ptr = channel_func_(host_port);
321     VLOG(5) << "Channel created for: job: " << job_id_
322             << " host_port: " << host_port << " target : " << target
323             << " Ptr: " << chan_ptr.get();
324     return chan_ptr;
325   }
326 
327  private:
ToString()328   string ToString() {
329     std::vector<string> task_strings;
330     task_strings.reserve(host_ports_.size());
331     for (const auto& id_host_port : host_ports_) {
332       task_strings.emplace_back(
333           strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
334     }
335     return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
336                            "}");
337   }
338 
339   const string job_id_;
340   const std::map<int, string> host_ports_;
341   const ChannelCreationFunction channel_func_;
342   TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
343 };
344 
345 }  // namespace
346 
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func,const RPCOptions & options)347 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
348                                       ChannelCreationFunction channel_func,
349                                       const RPCOptions& options) {
350   const int num_jobs = spec.host_ports_jobs().size();
351   if (!num_jobs) {
352     LOG(ERROR) << "Empty channel spec.";
353     return nullptr;
354   }
355   std::vector<GrpcChannelCache*> caches;
356   caches.reserve(num_jobs);
357   for (auto& job : spec.host_ports_jobs()) {
358     VLOG(2) << "Creating Grpc Channel Cache for: " << job.job_id;
359     caches.push_back(
360         new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
361                                    options.num_channels_per_target()));
362   }
363   return caches.size() == 1 ? caches[0]
364                             : new MultiGrpcChannelCache(
365                                   caches, options.num_channels_per_target());
366 }
367 
368 }  // end namespace tensorflow
369