1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel_common.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/thread_annotations.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39
40 namespace tensorflow {
41
42 namespace {
43
MakeAddress(const string & job,int task)44 string MakeAddress(const string& job, int task) {
45 return strings::StrCat("/job:", job, "/replica:0/task:", task);
46 }
47
48 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)49 Status ValidateHostPortPair(const string& host_port) {
50 string bns_prefix = "/bns/";
51 if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
52 return OkStatus();
53 }
54 uint32 port;
55 auto colon_index = host_port.find_last_of(':');
56 if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
57 host_port.substr(0, colon_index).find('/') != string::npos) {
58 return errors::InvalidArgument("Could not interpret \"", host_port,
59 "\" as a host-port pair.");
60 }
61 return OkStatus();
62 }
63
CreateDefaultChannelArguments()64 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
65 ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
66 const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
67 if (env != nullptr) {
68 for (auto& grpc_option : absl::StrSplit(env, ',')) {
69 std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
70 if (name_value.size() != 2) {
71 LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
72 continue;
73 }
74 VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
75 << name_value[1] << "'";
76 if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
77 string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
78 string value;
79 string error;
80 if (!absl::CUnescape(ue_value, &value, &error)) {
81 LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
82 << ": " << error;
83 } else {
84 args->SetString(name_value[0], value);
85 }
86 } else {
87 int64_t value;
88 if (strings::safe_strto64(name_value[1], &value)) {
89 args->SetInt(name_value[0], value);
90 } else {
91 LOG(ERROR) << "Invalid integer value: " << grpc_option;
92 }
93 }
94 }
95 }
96 return args;
97 }
98
GetDefaultChannelArguments()99 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
100 static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
101 return args;
102 }
103
104 } // namespace
105
GetChannelArguments(const RPCOptions * rpc_options)106 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
107 // TODO(mrry): Implement secure channels.
108 ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
109 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
110 // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
111 // on connection failure, which makes our tests time out.
112 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
113 if (rpc_options != nullptr) {
114 if (rpc_options->compression_algorithm() == "deflate") {
115 args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
116 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
117 rpc_options->compression_level());
118 VLOG(5) << "Setting GRPC compression : algo='"
119 << rpc_options->compression_algorithm()
120 << "' level=" << rpc_options->compression_level();
121 } else if (rpc_options->compression_algorithm() == "gzip") {
122 args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
123 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
124 rpc_options->compression_level());
125 VLOG(5) << "Setting GRPC compression : algo='"
126 << rpc_options->compression_algorithm()
127 << "' level=" << rpc_options->compression_level();
128 } else if (!rpc_options->compression_algorithm().empty()) {
129 LOG(ERROR) << "Invalid compression algorithm: "
130 << rpc_options->compression_algorithm();
131 }
132 if (rpc_options->disable_session_connection_sharing()) {
133 VLOG(5) << "Disabling TCP connection sharing";
134 args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
135 }
136 }
137 return args;
138 }
139
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)140 Status NewHostPortGrpcChannel(const string& target,
141 const RPCOptions* rpc_options,
142 SharedGrpcChannelPtr* channel_pointer) {
143 // Minimally ensure that the target is valid
144 TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
145
146 ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
147 *channel_pointer = ::grpc::CreateCustomChannel(
148 "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
149 return OkStatus();
150 }
151
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)152 ChannelCreationFunction ConvertToChannelCreationFunction(
153 const std::function<Status(string, const RPCOptions*,
154 SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
155 return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
156 SharedGrpcChannelPtr channel_ptr;
157 if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
158 .ok()) {
159 return channel_ptr;
160 } else {
161 return nullptr;
162 }
163 };
164 }
165
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)166 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
167 const std::vector<string>& host_ports) {
168 std::map<int, string> host_ports_map;
169 for (size_t i = 0; i < host_ports.size(); ++i) {
170 host_ports_map[i] = host_ports[i];
171 }
172 return AddHostPortsJob(job_id, host_ports_map);
173 }
174
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)175 Status GrpcChannelSpec::AddHostPortsJob(
176 const string& job_id, const std::map<int, string>& host_ports) {
177 if (!job_ids_.insert(job_id).second) {
178 return errors::InvalidArgument(
179 "Duplicate job ID in cluster specification: ", job_id);
180 }
181 for (const auto& id_host_port : host_ports) {
182 TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
183 }
184 host_ports_jobs_.emplace_back(job_id, host_ports);
185 return OkStatus();
186 }
187
188 namespace {
189
190 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
191 using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;
192
193 // A ChannelCache that is the union of multiple ChannelCaches.
194 // Takes ownership of the caches passed to the constructor.
195 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
196 public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches,int num_channels_per_target)197 explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
198 int num_channels_per_target)
199 : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}
200
~MultiGrpcChannelCache()201 ~MultiGrpcChannelCache() override {
202 for (GrpcChannelCache* cache : caches_) {
203 delete cache;
204 }
205 }
206
ListWorkers(std::vector<string> * workers)207 void ListWorkers(std::vector<string>* workers) override {
208 for (GrpcChannelCache* cache : caches_) {
209 cache->ListWorkers(workers);
210 }
211 }
212
ListWorkersInJob(const string & job_name,std::vector<string> * workers)213 void ListWorkersInJob(const string& job_name,
214 std::vector<string>* workers) override {
215 for (GrpcChannelCache* cache : caches_) {
216 cache->ListWorkersInJob(job_name, workers);
217 }
218 }
219
TranslateTask(const string & target)220 string TranslateTask(const string& target) override {
221 mutex_lock l(mu_); // could use reader lock
222 GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
223 if (cache == nullptr) {
224 for (GrpcChannelCache* c : caches_) {
225 string r = c->TranslateTask(target);
226 if (!r.empty()) {
227 target_caches_.insert({target, c});
228 cache = c;
229 break;
230 }
231 }
232 }
233 CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
234 << target;
235 return cache->TranslateTask(target);
236 }
237
238 protected:
FindChannelOnce(const string & target)239 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
240 for (GrpcChannelCache* cache : caches_) {
241 SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
242 if (ch) {
243 mutex_lock l(mu_);
244 target_caches_.insert({target, cache});
245 return ch;
246 }
247 }
248 return nullptr;
249 }
250
251 private:
252 // List of channels used by this MultiGrpcChannelCache.
253 const std::vector<GrpcChannelCache*> caches_;
254
255 mutex mu_;
256 // Cache of channels keyed by the target they are handling.
257 // The same GrpcChannelCache can appear multiple times in the cache.
258 std::unordered_map<string, GrpcChannelCache*> target_caches_
259 TF_GUARDED_BY(mu_);
260 };
261
262 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
263 public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func,int num_channels_per_target)264 SparseGrpcChannelCache(const string& job_id,
265 const std::map<int, string>& host_ports,
266 ChannelCreationFunction channel_func,
267 int num_channels_per_target)
268 : CachingGrpcChannelCache(num_channels_per_target),
269 job_id_(job_id),
270 host_ports_(host_ports),
271 channel_func_(std::move(channel_func)) {
272 VLOG(2) << "Initialize GrpcChannelCache for job " << ToString();
273 }
~SparseGrpcChannelCache()274 ~SparseGrpcChannelCache() override {}
275
ListWorkers(std::vector<string> * workers)276 void ListWorkers(std::vector<string>* workers) override {
277 workers->reserve(workers->size() + host_ports_.size());
278 for (const auto& id_host_port : host_ports_) {
279 workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
280 }
281 }
282
ListWorkersInJob(const string & job_name,std::vector<string> * workers)283 void ListWorkersInJob(const string& job_name,
284 std::vector<string>* workers) override {
285 if (job_name == job_id_) {
286 ListWorkers(workers);
287 }
288 }
289
TranslateTask(const string & target)290 string TranslateTask(const string& target) override {
291 DeviceNameUtils::ParsedName parsed;
292 if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
293 LOG(WARNING) << "Invalid target: " << target;
294 return "";
295 }
296
297 if (!parsed.has_job || parsed.job != job_id_) {
298 return "";
299 }
300 if (!parsed.has_replica || parsed.replica != 0) {
301 LOG(WARNING) << "Replica ID must be 0 in target: " << target;
302 return "";
303 }
304 int32_t task = parsed.has_task ? parsed.task : -1;
305 auto iter = host_ports_.find(task);
306 if (iter == host_ports_.end()) {
307 LOG(WARNING) << "Task " << task << " was not defined in sparse job "
308 << job_id_ << ": " << target;
309 return "";
310 }
311 return iter->second;
312 }
313
314 protected:
FindChannelOnce(const string & target)315 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
316 const string host_port = TranslateTask(target);
317 if (host_port.empty()) {
318 return nullptr;
319 }
320 auto chan_ptr = channel_func_(host_port);
321 VLOG(5) << "Channel created for: job: " << job_id_
322 << " host_port: " << host_port << " target : " << target
323 << " Ptr: " << chan_ptr.get();
324 return chan_ptr;
325 }
326
327 private:
ToString()328 string ToString() {
329 std::vector<string> task_strings;
330 task_strings.reserve(host_ports_.size());
331 for (const auto& id_host_port : host_ports_) {
332 task_strings.emplace_back(
333 strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
334 }
335 return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
336 "}");
337 }
338
339 const string job_id_;
340 const std::map<int, string> host_ports_;
341 const ChannelCreationFunction channel_func_;
342 TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
343 };
344
345 } // namespace
346
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func,const RPCOptions & options)347 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
348 ChannelCreationFunction channel_func,
349 const RPCOptions& options) {
350 const int num_jobs = spec.host_ports_jobs().size();
351 if (!num_jobs) {
352 LOG(ERROR) << "Empty channel spec.";
353 return nullptr;
354 }
355 std::vector<GrpcChannelCache*> caches;
356 caches.reserve(num_jobs);
357 for (auto& job : spec.host_ports_jobs()) {
358 VLOG(2) << "Creating Grpc Channel Cache for: " << job.job_id;
359 caches.push_back(
360 new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
361 options.num_channels_per_target()));
362 }
363 return caches.size() == 1 ? caches[0]
364 : new MultiGrpcChannelCache(
365 caches, options.num_channels_per_target());
366 }
367
368 } // end namespace tensorflow
369