1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "grpcpp/grpcpp.h"
26 #include "grpcpp/security/credentials.h"
27 #include "grpcpp/server_builder.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
32 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
33 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
34 #include "tensorflow/core/distributed_runtime/local_master.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/master_env.h"
37 #include "tensorflow/core/distributed_runtime/master_session.h"
38 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
39 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
45 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
46 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
47 #include "tensorflow/core/distributed_runtime/server_lib.h"
48 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
49 #include "tensorflow/core/distributed_runtime/worker_env.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/nccl/collective_communicator.h"
54 #include "tensorflow/core/platform/cpu_info.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/mem.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/threadpool.h"
60 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
61 #include "tensorflow/core/public/session_options.h"
62 #include "tensorflow/core/util/env_var.h"
63
64 namespace tensorflow {
65
66 namespace {
67
68 // Define an option subclass in order to disable SO_REUSEPORT for the
69 // server socket.
70 class NoReusePortOption : public ::grpc::ServerBuilderOption {
71 public:
UpdateArguments(::grpc::ChannelArguments * args)72 void UpdateArguments(::grpc::ChannelArguments* args) override {
73 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
74 }
75
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)76 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
77 plugins) override {}
78 };
79
80 // Define an option subclass in order to enable SO_REUSEPORT for the
81 // server socket.
82 class ReusePortOption : public ::grpc::ServerBuilderOption {
83 public:
UpdateArguments(::grpc::ChannelArguments * args)84 void UpdateArguments(::grpc::ChannelArguments* args) override {
85 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
86 }
87
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)88 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
89 plugins) override {}
90 };
91
92 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)93 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
94 return new RpcRendezvousMgr(env);
95 }
96
97 } // namespace
98
GrpcServer(const ServerDef & server_def,Env * env)99 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
100 : env_(env), state_(NEW), server_def_(server_def) {}
101
~GrpcServer()102 GrpcServer::~GrpcServer() {
103 TF_CHECK_OK(Stop());
104 TF_CHECK_OK(Join());
105
106 delete master_service_;
107 delete worker_service_;
108 delete eager_service_;
109
110 for (auto& kv : extra_services_) {
111 AsyncServiceInterface* service = kv.second;
112 delete service;
113 }
114
115 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
116 // to destroy them.
117
118 // Shut down all outstanding rendezvous.
119 delete worker_env_.rendezvous_mgr;
120
121 // We must delete graph_mgr before device_mgr, due to shared
122 // ownership of OpKernels in the executors. (The graph_mgr will
123 // free all stateless OpKernels, and pass over borrowed stateful
124 // OpKernels, which are also held in their respective devices'
125 // OpSegments.)
126 if (worker_env_.session_mgr != nullptr) {
127 delete worker_env_.session_mgr; // Deletes graph_mgr's.
128 }
129
130 // Do not delete (as these are not owned by the server):
131 // - master_env_.env
132 // - worker_env_.env
133 // - worker_env_.compute_pool
134 }
135
136 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const137 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
138 string* host_name, int* port) const {
139 *port = -1;
140 *host_name = "localhost";
141 for (const auto& job : server_def.cluster().job()) {
142 if (job.name() == server_def.job_name()) {
143 auto iter = job.tasks().find(server_def.task_index());
144 if (iter == job.tasks().end()) {
145 return errors::Internal("Task ", server_def.task_index(),
146 " was not defined in job \"",
147 server_def.job_name(), "\"");
148 }
149
150 if (server_def.port() != 0) {
151 *port = server_def.port();
152 } else {
153 auto colon_index = iter->second.find_last_of(':');
154 if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
155 port)) {
156 return errors::InvalidArgument(
157 "Could not parse port for local server from \"", iter->second,
158 "\".");
159 }
160
161 if (colon_index != string::npos &&
162 !iter->second.substr(0, colon_index).empty()) {
163 *host_name = iter->second.substr(0, colon_index);
164 }
165 }
166 break;
167 }
168 }
169 if (*port == -1) {
170 return errors::Internal("Job \"", server_def.job_name(),
171 "\" was not defined in cluster");
172 }
173
174 return OkStatus();
175 }
176
Init(const GrpcServerOptions & opts)177 Status GrpcServer::Init(const GrpcServerOptions& opts) {
178 mutex_lock l(mu_);
179 CHECK_EQ(state_, NEW);
180 master_env_.env = env_;
181 worker_env_.env = env_;
182
183 // Check parameters before DeviceFactory::AddDevices,
184 // otherwise if 'task_index=-1' the program will abort.
185
186 int requested_port;
187 TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
188
189 SessionOptions sess_opts;
190 VLOG(3) << "Grpc Server Init Definition: " << server_def_.DebugString();
191 ConfigProto config = server_def_.default_session_config();
192 sess_opts.config = config;
193
194 // Configure shared devices between master and worker.
195 string name_prefix =
196 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
197 "/task:", server_def_.task_index());
198 if (opts.local_device_mgr == nullptr) {
199 std::vector<std::unique_ptr<Device>> devices;
200 TF_RETURN_IF_ERROR(
201 DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
202 worker_env_.device_mgr = new DynamicDeviceMgr(std::move(devices));
203 owned_device_manager_.reset(worker_env_.device_mgr);
204 } else {
205 worker_env_.device_mgr = opts.local_device_mgr;
206 owned_device_manager_.reset(nullptr);
207 }
208 worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
209 master_env_.local_devices = worker_env_.device_mgr->ListDevices();
210 worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
211 ? new RpcRendezvousMgr(&worker_env_)
212 : opts.rendezvous_mgr_func(&worker_env_);
213 string unused;
214 string default_worker_name;
215 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
216 &default_worker_name, &unused)) {
217 return errors::Internal("Could not parse worker name.");
218 }
219
220 // N.B. The order of initialization here is intricate, because we
221 // wish to allow `requested_port == 0` (for choosing any port,
222 // mostly for testing). Therefore, the construction of the channel
223 // and worker caches depends on `bound_port_`, which is not set
224 // until we call `builder.BuildAndStart()`. We must create the
225 // service objects before calling `builder.BuildAndStart()`, but
226 // `master_env_` and `worker_env_` are only partially
227 // configured. However, this is not dangerous, because we do not
228 // start serving requests until `this->Start()` is called, which
229 // happens after this method returns.
230 //
231 // TODO(mrry): Provide a general mechanism for dynamically setting
232 // the identities of tasks in the worker pool after the service is
233 // running.
234 ::grpc::ServerBuilder builder;
235 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
236 GetServerCredentials(server_def_), &bound_port_);
237 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
238
239 bool reuse_port = false;
240 const Status status =
241 ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
242 if (!status.ok()) {
243 LOG(ERROR) << status.error_message();
244 }
245 auto server_build_option =
246 reuse_port
247 ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
248 : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
249 builder.SetOption(std::move(server_build_option));
250
251 // Allow subclasses to specify more args to pass to the gRPC server.
252 MaybeMutateBuilder(&builder, requested_port);
253 master_impl_ = CreateMaster(&master_env_);
254 master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
255 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
256 : NewGrpcWorker(&worker_env_, config);
257 worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
258 opts.worker_service_options)
259 .release();
260 eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
261 thread::ThreadPool* compute_pool = ComputePool(sess_opts);
262 coordination_service_ =
263 new GrpcCoordinationServiceImpl(compute_pool, &builder);
264
265 profiler_service_ = profiler::CreateProfilerService();
266 builder.RegisterService(profiler_service_.get());
267
268 // Add any extra services to be started.
269 extra_services_ = ExtraServices(&builder);
270
271 // extra service:
272 if (opts.service_func != nullptr) {
273 opts.service_func(&worker_env_, &builder);
274 }
275 server_ = builder.BuildAndStart();
276
277 if (!server_) {
278 return errors::Unknown("Could not start gRPC server");
279 }
280 // Create the execution environment for the GRPC workers cache.
281 grpc_worker_env_.reset(CreateGrpcWorkerEnv());
282
283 WorkerCacheInterface* worker_cache;
284 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
285 TF_RETURN_IF_ERROR(
286 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
287 CHECK_NE(nullptr, worker_cache);
288
289 if (opts.collective_mgr_func) {
290 worker_env_.collective_executor_mgr.reset(
291 opts.collective_mgr_func(config, &worker_env_, worker_cache));
292 if (worker_env_.collective_executor_mgr == nullptr) {
293 return errors::Internal(
294 "collective_mgr_func did not return CollectiveExecutorMgr");
295 }
296 } else {
297 worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
298 config, worker_env_.device_mgr, MaybeCreateNcclCommunicator(config),
299 worker_cache, default_worker_name);
300 }
301
302 // Set up worker environment.
303 worker_env_.session_mgr = new SessionMgr(
304 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
305 std::unique_ptr<WorkerCacheInterface>(worker_cache),
306 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
307 WorkerCacheFactoryOptions options(server_def);
308 return WorkerCacheFactory(options, worker_cache);
309 });
310 worker_env_.compute_pool = compute_pool;
311
312 // Finish setting up master environment.
313 master_env_.ops = OpRegistry::Global();
314 master_env_.worker_cache = worker_cache;
315 master_env_.collective_executor_mgr =
316 worker_env_.collective_executor_mgr.get();
317 StatsPublisherFactory stats_factory = opts.stats_factory;
318 master_env_.master_session_factory =
319 [config, stats_factory](
320 SessionOptions options, const MasterEnv* env,
321 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
322 std::unique_ptr<WorkerCacheInterface> worker_cache,
323 std::unique_ptr<DeviceSet> device_set,
324 std::vector<string> filtered_worker_list) {
325 options.config.MergeFrom(config);
326 return new MasterSession(options, env, std::move(remote_devs),
327 std::move(worker_cache), std::move(device_set),
328 std::move(filtered_worker_list),
329 stats_factory);
330 };
331 master_env_.worker_cache_factory =
332 [this](const WorkerCacheFactoryOptions& options,
333 WorkerCacheInterface** worker_cache) {
334 return WorkerCacheFactory(options, worker_cache);
335 };
336
337 // Provide direct access to the master from in-process clients.
338 LocalMaster::Register(target(), master_impl_.get(),
339 config.operation_timeout_in_ms());
340
341 return OkStatus();
342 }
343
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)344 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
345 GrpcChannelSpec* channel_spec) {
346 for (const auto& job : options.cluster_def->job()) {
347 std::map<int, string> host_ports;
348 for (const auto& task : job.tasks()) {
349 string& host_port = host_ports[task.first];
350 if (!host_port.empty()) {
351 return errors::InvalidArgument("JobDef for job \"", job.name(),
352 "\" specified two addresses for task \"",
353 task.first, "\": ", host_port, " and ",
354 task.second);
355 }
356 if (job.name() == *options.job_name && task.first == options.task_index) {
357 host_port = strings::StrCat(host_name_, ":", bound_port_);
358 } else {
359 host_port = task.second;
360 }
361 }
362 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
363 }
364 return OkStatus();
365 }
366
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)367 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
368 WorkerCacheInterface** worker_cache) {
369 if (options.job_name == nullptr || options.job_name->empty()) {
370 Status s = errors::InvalidArgument(
371 "The master (current machine) is not included in the provided "
372 "cluster_def. ",
373 options.cluster_def->DebugString());
374 LOG(WARNING) << s;
375 return s;
376 }
377
378 GrpcChannelSpec channel_spec;
379 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
380
381 if (options.rpc_options == nullptr) {
382 return errors::InvalidArgument(
383 "rpc_options not set in WorkerCacheFactoryOptions");
384 }
385 std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
386 channel_spec, GetChannelCreationFunction(), *options.rpc_options));
387
388 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
389 "/task:", options.task_index);
390
391 const string host_port = channel_cache->TranslateTask(name_prefix);
392 int requested_port;
393
394 auto colon_index = host_port.find_last_of(':');
395 if (!strings::safe_strto32(host_port.substr(colon_index + 1),
396 &requested_port)) {
397 return errors::Internal("Could not parse port for local server from \"",
398 host_port, "\".");
399 }
400 if (requested_port != bound_port_) {
401 return errors::InvalidArgument("Requested port ", requested_port,
402 " differs from expected port ", bound_port_);
403 }
404 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
405 channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
406 return OkStatus();
407 }
408
Start()409 Status GrpcServer::Start() {
410 mutex_lock l(mu_);
411 switch (state_) {
412 case NEW: {
413 master_thread_.reset(
414 env_->StartThread(ThreadOptions(), "TF_master_service",
415 [this] { master_service_->HandleRPCsLoop(); }));
416 worker_thread_.reset(
417 env_->StartThread(ThreadOptions(), "TF_worker_service",
418 [this] { worker_service_->HandleRPCsLoop(); }));
419 eager_thread_.reset(
420 env_->StartThread(ThreadOptions(), "TF_eager_service",
421 [this] { eager_service_->HandleRPCsLoop(); }));
422 coordination_thread_.reset(env_->StartThread(
423 ThreadOptions(), "TF_coordination_service",
424 [this] { coordination_service_->HandleRPCsLoop(); }));
425
426 for (const auto& kv : extra_services_) {
427 const std::string& service_name = kv.first;
428 AsyncServiceInterface* service = kv.second;
429 std::unique_ptr<Thread> extra_service_thread;
430 extra_service_thread.reset(env_->StartThread(
431 ThreadOptions(), service_name,
432 [service = service] { service->HandleRPCsLoop(); }));
433 extra_service_threads_.push_back(std::move(extra_service_thread));
434 VLOG(3) << "Started extra service: " << service_name;
435 }
436
437 state_ = STARTED;
438 LOG(INFO) << "Started server with target: " << target();
439 return OkStatus();
440 }
441 case STARTED:
442 LOG(INFO) << "Server already started (target: " << target() << ")";
443 return OkStatus();
444 case STOPPED:
445 return errors::FailedPrecondition("Server has stopped.");
446 default:
447 LOG(FATAL);
448 }
449 }
450
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)451 Status GrpcServer::AddMasterEagerContextToEagerService(
452 const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
453 auto* eager_service =
454 static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
455 return eager_service->CreateMasterContext(context_id, context);
456 }
457
UpdateServerDef(const ServerDef & server_def)458 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
459 mutex_lock l(mu_);
460 server_def_ = server_def;
461 WorkerCacheInterface* worker_cache;
462 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
463 TF_RETURN_IF_ERROR(
464 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
465 if (worker_cache == nullptr) {
466 return errors::InvalidArgument(
467 "Failed to build worker cache with the provided server def.");
468 }
469 // Transfer ownership of worker_cache to worker_env_.session_mgr.
470 worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
471
472 string default_worker_name;
473 string unused;
474 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
475 &default_worker_name, &unused)) {
476 return errors::Internal("Could not parse worker name.");
477 }
478 worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
479 server_def_.default_session_config(), worker_env_.device_mgr,
480 MaybeCreateNcclCommunicator(server_def_.default_session_config()),
481 worker_cache, default_worker_name);
482
483 master_env_.worker_cache = worker_cache;
484 master_env_.collective_executor_mgr =
485 worker_env_.collective_executor_mgr.get();
486 return OkStatus();
487 }
488
489 // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set
490 // field inside the RPC coordination service handler.
SetCoordinationServiceAgentInstance(CoordinationServiceAgent * agent)491 Status GrpcServer::SetCoordinationServiceAgentInstance(
492 CoordinationServiceAgent* agent) {
493 auto* coord_service =
494 static_cast<GrpcCoordinationServiceImpl*>(coordination_service_);
495 coord_service->SetCoordinationServiceAgentInstance(agent);
496 return OkStatus();
497 }
498
StopCoordinationService()499 Status GrpcServer::StopCoordinationService() {
500 // Note: the sequence of events is important here.
501 // 1. Agent must be torn down before the service as it needs to notify the
502 // service.
503 // 2. Remove RPC handlers' access to agent/service first before destructing
504 // them within the session manager to prevent data races.
505 TF_RETURN_IF_ERROR(SetCoordinationServiceAgentInstance(nullptr));
506 worker_env()->session_mgr->TeardownCoordinationServiceAgent();
507 coordination_service_->Shutdown();
508 worker_env()->session_mgr->TeardownCoordinationService();
509 return OkStatus();
510 }
511
Stop()512 Status GrpcServer::Stop() {
513 mutex_lock l(mu_);
514 switch (state_) {
515 case NEW:
516 state_ = STOPPED;
517 return OkStatus();
518 case STARTED:
519 return errors::Unimplemented(
520 "Clean shutdown is not currently implemented");
521 case STOPPED:
522 LOG(INFO) << "Server already stopped (target: " << target() << ")";
523 return OkStatus();
524 default:
525 LOG(FATAL);
526 }
527 }
528
Join()529 Status GrpcServer::Join() {
530 mutex_lock l(mu_);
531 switch (state_) {
532 case NEW:
533 // Prevent the server from being started subsequently.
534 state_ = STOPPED;
535 return OkStatus();
536 case STARTED:
537 case STOPPED:
538 master_thread_.reset();
539 worker_thread_.reset();
540 eager_thread_.reset();
541 for (auto& thread : extra_service_threads_) {
542 thread.reset();
543 }
544 return OkStatus();
545 default:
546 LOG(FATAL);
547 }
548 }
549
target() const550 const string GrpcServer::target() const {
551 return strings::StrCat("grpc://", host_name_, ":", bound_port_);
552 }
553
GetServerCredentials(const ServerDef & server_def) const554 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
555 const ServerDef& server_def) const {
556 return ::grpc::InsecureServerCredentials();
557 }
558
GetChannelCreationFunction() const559 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
560 // We can do this because SparseGrpcChannelCache is robust to nullptr being
561 // returned by the channel creation function
562 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
563 }
564
CreateMaster(MasterEnv * master_env)565 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
566 return std::unique_ptr<Master>(new Master(master_env, 0.0));
567 }
568
569 /* static */
Create(const ServerDef & server_def,Env * env,DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)570 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
571 DeviceMgr* local_device_mgr,
572 std::unique_ptr<ServerInterface>* out_server) {
573 std::unique_ptr<GrpcServer> ret(
574 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
575 GrpcServerOptions options;
576 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
577 options.local_device_mgr = local_device_mgr;
578 Status s = ret->Init(options);
579 if (!s.ok()) {
580 LOG(ERROR) << s;
581 return s;
582 }
583 *out_server = std::move(ret);
584 return OkStatus();
585 }
586
587 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)588 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
589 std::unique_ptr<ServerInterface>* out_server) {
590 return Create(server_def, env, nullptr, out_server);
591 }
592
593 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)594 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
595 std::unique_ptr<GrpcServer>* out_server) {
596 std::unique_ptr<ServerInterface> server;
597 Status s = Create(server_def, env, nullptr, &server);
598 if (!s.ok()) {
599 return s;
600 }
601 out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
602 return OkStatus();
603 }
604
605 namespace {
606
607 class GrpcServerFactory : public ServerFactory {
608 public:
AcceptsOptions(const ServerDef & server_def)609 bool AcceptsOptions(const ServerDef& server_def) override {
610 return server_def.protocol() == "grpc";
611 }
612
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)613 Status NewServer(const ServerDef& server_def, const Options& options,
614 std::unique_ptr<ServerInterface>* out_server) override {
615 return GrpcServer::Create(server_def, Env::Default(),
616 options.local_device_mgr, out_server);
617 }
618 };
619
620 // Registers a `ServerFactory` for `GrpcServer` instances.
621 class GrpcServerRegistrar {
622 public:
GrpcServerRegistrar()623 GrpcServerRegistrar() {
624 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
625 }
626 };
627 static GrpcServerRegistrar registrar;
628
629 } // namespace
630 } // namespace tensorflow
631