xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17 
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "grpcpp/grpcpp.h"
26 #include "grpcpp/security/credentials.h"
27 #include "grpcpp/server_builder.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
32 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
33 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
34 #include "tensorflow/core/distributed_runtime/local_master.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/master_env.h"
37 #include "tensorflow/core/distributed_runtime/master_session.h"
38 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
39 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
45 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
46 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
47 #include "tensorflow/core/distributed_runtime/server_lib.h"
48 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
49 #include "tensorflow/core/distributed_runtime/worker_env.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/nccl/collective_communicator.h"
54 #include "tensorflow/core/platform/cpu_info.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/mem.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/threadpool.h"
60 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
61 #include "tensorflow/core/public/session_options.h"
62 #include "tensorflow/core/util/env_var.h"
63 
64 namespace tensorflow {
65 
66 namespace {
67 
68 // Define an option subclass in order to disable SO_REUSEPORT for the
69 // server socket.
70 class NoReusePortOption : public ::grpc::ServerBuilderOption {
71  public:
UpdateArguments(::grpc::ChannelArguments * args)72   void UpdateArguments(::grpc::ChannelArguments* args) override {
73     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
74   }
75 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)76   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
77                          plugins) override {}
78 };
79 
80 // Define an option subclass in order to enable SO_REUSEPORT for the
81 // server socket.
82 class ReusePortOption : public ::grpc::ServerBuilderOption {
83  public:
UpdateArguments(::grpc::ChannelArguments * args)84   void UpdateArguments(::grpc::ChannelArguments* args) override {
85     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
86   }
87 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)88   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
89                          plugins) override {}
90 };
91 
92 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)93 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
94   return new RpcRendezvousMgr(env);
95 }
96 
97 }  // namespace
98 
GrpcServer(const ServerDef & server_def,Env * env)99 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
100     : env_(env), state_(NEW), server_def_(server_def) {}
101 
~GrpcServer()102 GrpcServer::~GrpcServer() {
103   TF_CHECK_OK(Stop());
104   TF_CHECK_OK(Join());
105 
106   delete master_service_;
107   delete worker_service_;
108   delete eager_service_;
109 
110   for (auto& kv : extra_services_) {
111     AsyncServiceInterface* service = kv.second;
112     delete service;
113   }
114 
115   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
116   // to destroy them.
117 
118   // Shut down all outstanding rendezvous.
119   delete worker_env_.rendezvous_mgr;
120 
121   // We must delete graph_mgr before device_mgr, due to shared
122   // ownership of OpKernels in the executors. (The graph_mgr will
123   // free all stateless OpKernels, and pass over borrowed stateful
124   // OpKernels, which are also held in their respective devices'
125   // OpSegments.)
126   if (worker_env_.session_mgr != nullptr) {
127     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
128   }
129 
130   // Do not delete (as these are not owned by the server):
131   // - master_env_.env
132   // - worker_env_.env
133   // - worker_env_.compute_pool
134 }
135 
136 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const137 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
138                                   string* host_name, int* port) const {
139   *port = -1;
140   *host_name = "localhost";
141   for (const auto& job : server_def.cluster().job()) {
142     if (job.name() == server_def.job_name()) {
143       auto iter = job.tasks().find(server_def.task_index());
144       if (iter == job.tasks().end()) {
145         return errors::Internal("Task ", server_def.task_index(),
146                                 " was not defined in job \"",
147                                 server_def.job_name(), "\"");
148       }
149 
150       if (server_def.port() != 0) {
151         *port = server_def.port();
152       } else {
153         auto colon_index = iter->second.find_last_of(':');
154         if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
155                                    port)) {
156           return errors::InvalidArgument(
157               "Could not parse port for local server from \"", iter->second,
158               "\".");
159         }
160 
161         if (colon_index != string::npos &&
162             !iter->second.substr(0, colon_index).empty()) {
163           *host_name = iter->second.substr(0, colon_index);
164         }
165       }
166       break;
167     }
168   }
169   if (*port == -1) {
170     return errors::Internal("Job \"", server_def.job_name(),
171                             "\" was not defined in cluster");
172   }
173 
174   return OkStatus();
175 }
176 
Init(const GrpcServerOptions & opts)177 Status GrpcServer::Init(const GrpcServerOptions& opts) {
178   mutex_lock l(mu_);
179   CHECK_EQ(state_, NEW);
180   master_env_.env = env_;
181   worker_env_.env = env_;
182 
183   // Check parameters before DeviceFactory::AddDevices,
184   // otherwise if 'task_index=-1' the program will abort.
185 
186   int requested_port;
187   TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
188 
189   SessionOptions sess_opts;
190   VLOG(3) << "Grpc Server Init Definition: " << server_def_.DebugString();
191   ConfigProto config = server_def_.default_session_config();
192   sess_opts.config = config;
193 
194   // Configure shared devices between master and worker.
195   string name_prefix =
196       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
197                       "/task:", server_def_.task_index());
198   if (opts.local_device_mgr == nullptr) {
199     std::vector<std::unique_ptr<Device>> devices;
200     TF_RETURN_IF_ERROR(
201         DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
202     worker_env_.device_mgr = new DynamicDeviceMgr(std::move(devices));
203     owned_device_manager_.reset(worker_env_.device_mgr);
204   } else {
205     worker_env_.device_mgr = opts.local_device_mgr;
206     owned_device_manager_.reset(nullptr);
207   }
208   worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
209   master_env_.local_devices = worker_env_.device_mgr->ListDevices();
210   worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
211                                    ? new RpcRendezvousMgr(&worker_env_)
212                                    : opts.rendezvous_mgr_func(&worker_env_);
213   string unused;
214   string default_worker_name;
215   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
216                                         &default_worker_name, &unused)) {
217     return errors::Internal("Could not parse worker name.");
218   }
219 
220   // N.B. The order of initialization here is intricate, because we
221   // wish to allow `requested_port == 0` (for choosing any port,
222   // mostly for testing). Therefore, the construction of the channel
223   // and worker caches depends on `bound_port_`, which is not set
224   // until we call `builder.BuildAndStart()`. We must create the
225   // service objects before calling `builder.BuildAndStart()`, but
226   // `master_env_` and `worker_env_` are only partially
227   // configured. However, this is not dangerous, because we do not
228   // start serving requests until `this->Start()` is called, which
229   // happens after this method returns.
230   //
231   // TODO(mrry): Provide a general mechanism for dynamically setting
232   // the identities of tasks in the worker pool after the service is
233   // running.
234   ::grpc::ServerBuilder builder;
235   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
236                            GetServerCredentials(server_def_), &bound_port_);
237   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
238 
239   bool reuse_port = false;
240   const Status status =
241       ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
242   if (!status.ok()) {
243     LOG(ERROR) << status.error_message();
244   }
245   auto server_build_option =
246       reuse_port
247           ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
248           : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
249   builder.SetOption(std::move(server_build_option));
250 
251   // Allow subclasses to specify more args to pass to the gRPC server.
252   MaybeMutateBuilder(&builder, requested_port);
253   master_impl_ = CreateMaster(&master_env_);
254   master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
255   worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
256                                   : NewGrpcWorker(&worker_env_, config);
257   worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
258                                          opts.worker_service_options)
259                         .release();
260   eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
261   thread::ThreadPool* compute_pool = ComputePool(sess_opts);
262   coordination_service_ =
263       new GrpcCoordinationServiceImpl(compute_pool, &builder);
264 
265   profiler_service_ = profiler::CreateProfilerService();
266   builder.RegisterService(profiler_service_.get());
267 
268   // Add any extra services to be started.
269   extra_services_ = ExtraServices(&builder);
270 
271   // extra service:
272   if (opts.service_func != nullptr) {
273     opts.service_func(&worker_env_, &builder);
274   }
275   server_ = builder.BuildAndStart();
276 
277   if (!server_) {
278     return errors::Unknown("Could not start gRPC server");
279   }
280   // Create the execution environment for the GRPC workers cache.
281   grpc_worker_env_.reset(CreateGrpcWorkerEnv());
282 
283   WorkerCacheInterface* worker_cache;
284   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
285   TF_RETURN_IF_ERROR(
286       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
287   CHECK_NE(nullptr, worker_cache);
288 
289   if (opts.collective_mgr_func) {
290     worker_env_.collective_executor_mgr.reset(
291         opts.collective_mgr_func(config, &worker_env_, worker_cache));
292     if (worker_env_.collective_executor_mgr == nullptr) {
293       return errors::Internal(
294           "collective_mgr_func did not return CollectiveExecutorMgr");
295     }
296   } else {
297     worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
298         config, worker_env_.device_mgr, MaybeCreateNcclCommunicator(config),
299         worker_cache, default_worker_name);
300   }
301 
302   // Set up worker environment.
303   worker_env_.session_mgr = new SessionMgr(
304       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
305       std::unique_ptr<WorkerCacheInterface>(worker_cache),
306       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
307         WorkerCacheFactoryOptions options(server_def);
308         return WorkerCacheFactory(options, worker_cache);
309       });
310   worker_env_.compute_pool = compute_pool;
311 
312   // Finish setting up master environment.
313   master_env_.ops = OpRegistry::Global();
314   master_env_.worker_cache = worker_cache;
315   master_env_.collective_executor_mgr =
316       worker_env_.collective_executor_mgr.get();
317   StatsPublisherFactory stats_factory = opts.stats_factory;
318   master_env_.master_session_factory =
319       [config, stats_factory](
320           SessionOptions options, const MasterEnv* env,
321           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
322           std::unique_ptr<WorkerCacheInterface> worker_cache,
323           std::unique_ptr<DeviceSet> device_set,
324           std::vector<string> filtered_worker_list) {
325         options.config.MergeFrom(config);
326         return new MasterSession(options, env, std::move(remote_devs),
327                                  std::move(worker_cache), std::move(device_set),
328                                  std::move(filtered_worker_list),
329                                  stats_factory);
330       };
331   master_env_.worker_cache_factory =
332       [this](const WorkerCacheFactoryOptions& options,
333              WorkerCacheInterface** worker_cache) {
334         return WorkerCacheFactory(options, worker_cache);
335       };
336 
337   // Provide direct access to the master from in-process clients.
338   LocalMaster::Register(target(), master_impl_.get(),
339                         config.operation_timeout_in_ms());
340 
341   return OkStatus();
342 }
343 
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)344 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
345                                     GrpcChannelSpec* channel_spec) {
346   for (const auto& job : options.cluster_def->job()) {
347     std::map<int, string> host_ports;
348     for (const auto& task : job.tasks()) {
349       string& host_port = host_ports[task.first];
350       if (!host_port.empty()) {
351         return errors::InvalidArgument("JobDef for job \"", job.name(),
352                                        "\" specified two addresses for task \"",
353                                        task.first, "\": ", host_port, " and ",
354                                        task.second);
355       }
356       if (job.name() == *options.job_name && task.first == options.task_index) {
357         host_port = strings::StrCat(host_name_, ":", bound_port_);
358       } else {
359         host_port = task.second;
360       }
361     }
362     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
363   }
364   return OkStatus();
365 }
366 
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)367 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
368                                       WorkerCacheInterface** worker_cache) {
369   if (options.job_name == nullptr || options.job_name->empty()) {
370     Status s = errors::InvalidArgument(
371         "The master (current machine) is not included in the provided "
372         "cluster_def. ",
373         options.cluster_def->DebugString());
374     LOG(WARNING) << s;
375     return s;
376   }
377 
378   GrpcChannelSpec channel_spec;
379   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
380 
381   if (options.rpc_options == nullptr) {
382     return errors::InvalidArgument(
383         "rpc_options not set in WorkerCacheFactoryOptions");
384   }
385   std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
386       channel_spec, GetChannelCreationFunction(), *options.rpc_options));
387 
388   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
389                                        "/task:", options.task_index);
390 
391   const string host_port = channel_cache->TranslateTask(name_prefix);
392   int requested_port;
393 
394   auto colon_index = host_port.find_last_of(':');
395   if (!strings::safe_strto32(host_port.substr(colon_index + 1),
396                              &requested_port)) {
397     return errors::Internal("Could not parse port for local server from \"",
398                             host_port, "\".");
399   }
400   if (requested_port != bound_port_) {
401     return errors::InvalidArgument("Requested port ", requested_port,
402                                    " differs from expected port ", bound_port_);
403   }
404   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
405       channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
406   return OkStatus();
407 }
408 
Start()409 Status GrpcServer::Start() {
410   mutex_lock l(mu_);
411   switch (state_) {
412     case NEW: {
413       master_thread_.reset(
414           env_->StartThread(ThreadOptions(), "TF_master_service",
415                             [this] { master_service_->HandleRPCsLoop(); }));
416       worker_thread_.reset(
417           env_->StartThread(ThreadOptions(), "TF_worker_service",
418                             [this] { worker_service_->HandleRPCsLoop(); }));
419       eager_thread_.reset(
420           env_->StartThread(ThreadOptions(), "TF_eager_service",
421                             [this] { eager_service_->HandleRPCsLoop(); }));
422       coordination_thread_.reset(env_->StartThread(
423           ThreadOptions(), "TF_coordination_service",
424           [this] { coordination_service_->HandleRPCsLoop(); }));
425 
426       for (const auto& kv : extra_services_) {
427         const std::string& service_name = kv.first;
428         AsyncServiceInterface* service = kv.second;
429         std::unique_ptr<Thread> extra_service_thread;
430         extra_service_thread.reset(env_->StartThread(
431             ThreadOptions(), service_name,
432             [service = service] { service->HandleRPCsLoop(); }));
433         extra_service_threads_.push_back(std::move(extra_service_thread));
434         VLOG(3) << "Started extra service: " << service_name;
435       }
436 
437       state_ = STARTED;
438       LOG(INFO) << "Started server with target: " << target();
439       return OkStatus();
440     }
441     case STARTED:
442       LOG(INFO) << "Server already started (target: " << target() << ")";
443       return OkStatus();
444     case STOPPED:
445       return errors::FailedPrecondition("Server has stopped.");
446     default:
447       LOG(FATAL);
448   }
449 }
450 
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)451 Status GrpcServer::AddMasterEagerContextToEagerService(
452     const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
453   auto* eager_service =
454       static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
455   return eager_service->CreateMasterContext(context_id, context);
456 }
457 
UpdateServerDef(const ServerDef & server_def)458 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
459   mutex_lock l(mu_);
460   server_def_ = server_def;
461   WorkerCacheInterface* worker_cache;
462   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
463   TF_RETURN_IF_ERROR(
464       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
465   if (worker_cache == nullptr) {
466     return errors::InvalidArgument(
467         "Failed to build worker cache with the provided server def.");
468   }
469   // Transfer ownership of worker_cache to worker_env_.session_mgr.
470   worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
471 
472   string default_worker_name;
473   string unused;
474   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
475                                         &default_worker_name, &unused)) {
476     return errors::Internal("Could not parse worker name.");
477   }
478   worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
479       server_def_.default_session_config(), worker_env_.device_mgr,
480       MaybeCreateNcclCommunicator(server_def_.default_session_config()),
481       worker_cache, default_worker_name);
482 
483   master_env_.worker_cache = worker_cache;
484   master_env_.collective_executor_mgr =
485       worker_env_.collective_executor_mgr.get();
486   return OkStatus();
487 }
488 
489 // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set
490 // field inside the RPC coordination service handler.
SetCoordinationServiceAgentInstance(CoordinationServiceAgent * agent)491 Status GrpcServer::SetCoordinationServiceAgentInstance(
492     CoordinationServiceAgent* agent) {
493   auto* coord_service =
494       static_cast<GrpcCoordinationServiceImpl*>(coordination_service_);
495   coord_service->SetCoordinationServiceAgentInstance(agent);
496   return OkStatus();
497 }
498 
StopCoordinationService()499 Status GrpcServer::StopCoordinationService() {
500   // Note: the sequence of events is important here.
501   // 1. Agent must be torn down before the service as it needs to notify the
502   // service.
503   // 2. Remove RPC handlers' access to agent/service first before destructing
504   // them within the session manager to prevent data races.
505   TF_RETURN_IF_ERROR(SetCoordinationServiceAgentInstance(nullptr));
506   worker_env()->session_mgr->TeardownCoordinationServiceAgent();
507   coordination_service_->Shutdown();
508   worker_env()->session_mgr->TeardownCoordinationService();
509   return OkStatus();
510 }
511 
Stop()512 Status GrpcServer::Stop() {
513   mutex_lock l(mu_);
514   switch (state_) {
515     case NEW:
516       state_ = STOPPED;
517       return OkStatus();
518     case STARTED:
519       return errors::Unimplemented(
520           "Clean shutdown is not currently implemented");
521     case STOPPED:
522       LOG(INFO) << "Server already stopped (target: " << target() << ")";
523       return OkStatus();
524     default:
525       LOG(FATAL);
526   }
527 }
528 
Join()529 Status GrpcServer::Join() {
530   mutex_lock l(mu_);
531   switch (state_) {
532     case NEW:
533       // Prevent the server from being started subsequently.
534       state_ = STOPPED;
535       return OkStatus();
536     case STARTED:
537     case STOPPED:
538       master_thread_.reset();
539       worker_thread_.reset();
540       eager_thread_.reset();
541       for (auto& thread : extra_service_threads_) {
542         thread.reset();
543       }
544       return OkStatus();
545     default:
546       LOG(FATAL);
547   }
548 }
549 
target() const550 const string GrpcServer::target() const {
551   return strings::StrCat("grpc://", host_name_, ":", bound_port_);
552 }
553 
GetServerCredentials(const ServerDef & server_def) const554 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
555     const ServerDef& server_def) const {
556   return ::grpc::InsecureServerCredentials();
557 }
558 
GetChannelCreationFunction() const559 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
560   // We can do this because SparseGrpcChannelCache is robust to nullptr being
561   // returned by the channel creation function
562   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
563 }
564 
CreateMaster(MasterEnv * master_env)565 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
566   return std::unique_ptr<Master>(new Master(master_env, 0.0));
567 }
568 
569 /* static */
Create(const ServerDef & server_def,Env * env,DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)570 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
571                           DeviceMgr* local_device_mgr,
572                           std::unique_ptr<ServerInterface>* out_server) {
573   std::unique_ptr<GrpcServer> ret(
574       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
575   GrpcServerOptions options;
576   options.rendezvous_mgr_func = NewRpcRendezvousMgr;
577   options.local_device_mgr = local_device_mgr;
578   Status s = ret->Init(options);
579   if (!s.ok()) {
580     LOG(ERROR) << s;
581     return s;
582   }
583   *out_server = std::move(ret);
584   return OkStatus();
585 }
586 
587 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)588 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
589                           std::unique_ptr<ServerInterface>* out_server) {
590   return Create(server_def, env, nullptr, out_server);
591 }
592 
593 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)594 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
595                           std::unique_ptr<GrpcServer>* out_server) {
596   std::unique_ptr<ServerInterface> server;
597   Status s = Create(server_def, env, nullptr, &server);
598   if (!s.ok()) {
599     return s;
600   }
601   out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
602   return OkStatus();
603 }
604 
605 namespace {
606 
607 class GrpcServerFactory : public ServerFactory {
608  public:
AcceptsOptions(const ServerDef & server_def)609   bool AcceptsOptions(const ServerDef& server_def) override {
610     return server_def.protocol() == "grpc";
611   }
612 
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)613   Status NewServer(const ServerDef& server_def, const Options& options,
614                    std::unique_ptr<ServerInterface>* out_server) override {
615     return GrpcServer::Create(server_def, Env::Default(),
616                               options.local_device_mgr, out_server);
617   }
618 };
619 
620 // Registers a `ServerFactory` for `GrpcServer` instances.
621 class GrpcServerRegistrar {
622  public:
GrpcServerRegistrar()623   GrpcServerRegistrar() {
624     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
625   }
626 };
627 static GrpcServerRegistrar registrar;
628 
629 }  // namespace
630 }  // namespace tensorflow
631