1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A component for running distributed TensorFlow.""" 16 17import copy 18import json 19import os 20import threading 21import time 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.distribute import distribute_coordinator_context 26from tensorflow.python.distribute import multi_worker_util 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.training import coordinator 29from tensorflow.python.training import monitored_session 30from tensorflow.python.training import server_lib 31 32 33_thread_local = threading.local() 34 35 36class _TaskType(object): 37 PS = "ps" 38 WORKER = "worker" 39 CHIEF = "chief" 40 EVALUATOR = "evaluator" 41 CLIENT = "client" 42 43 44# TODO(yuefengz): support another mode where the client colocates with one 45# worker. 46class CoordinatorMode(object): 47 """Specify how distribute coordinator runs.""" 48 # The default mode where distribute coordinator will run as a standalone 49 # client and connects to remote servers for training. Each remote server can 50 # use the distribute coordinator binary with task_type set correctly which 51 # will then turn into standard servers. 52 STANDALONE_CLIENT = "standalone_client" 53 54 # The distribute coordinator runs on each worker. It will run a standard 55 # server on each worker and optionally run the `worker_fn` that is configured 56 # to talk to its standard server. 57 INDEPENDENT_WORKER = "independent_worker" 58 59 60class _Barrier(object): 61 """A reusable barrier class for worker synchronization.""" 62 63 def __init__(self, num_participants): 64 """Initializes the barrier object. 65 66 Args: 67 num_participants: an integer which is the expected number of calls of 68 `wait` pass to through this barrier. 69 """ 70 self._num_participants = num_participants 71 self._counter = 0 72 self._flag = False 73 self._local_sense = threading.local() 74 self._lock = threading.Lock() 75 self._condition = threading.Condition() 76 77 def wait(self): 78 """Waits until all other callers reach the same wait call.""" 79 self._local_sense.value = not self._flag 80 with self._lock: 81 self._counter += 1 82 if self._counter == self._num_participants: 83 self._counter = 0 84 self._flag = self._local_sense.value 85 with self._condition: 86 while self._flag != self._local_sense.value: 87 self._condition.wait() 88 self._condition.notify_all() 89 90 91def _get_num_workers(cluster_spec): 92 """Gets number of workers including chief.""" 93 if not cluster_spec: 94 return 0 95 return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len( 96 cluster_spec.as_dict().get(_TaskType.CHIEF, [])) 97 98 99class _WorkerContext(object): 100 """The worker context class. 101 102 This context object provides configuration information for each task. One 103 context manager with a worker context object will be created per 104 invocation to the `worker_fn` where `get_current_worker_context` can be called 105 to access the worker context object. 106 """ 107 108 def __init__(self, 109 strategy, 110 cluster_spec, 111 task_type, 112 task_id, 113 session_config=None, 114 rpc_layer="grpc", 115 worker_barrier=None): 116 """Initialize the worker context object. 117 118 Args: 119 strategy: a `DistributionStrategy` object. 120 cluster_spec: a ClusterSpec object. It can be empty or None in the local 121 training case. 122 task_type: a string indicating the role of the corresponding task, such as 123 "worker" or "ps". It can be None if it is local training or in-graph 124 replicated training. 125 task_id: an integer indicating id of the corresponding task. It can be 126 None if it is local training or in-graph replicated training. 127 session_config: an optional `tf.compat.v1.ConfigProto` object. 128 rpc_layer: optional string specifying the RPC protocol for communication 129 with worker masters. If None or empty, hosts in the `cluster_spec` will 130 be used directly. 131 worker_barrier: optional, the barrier object for worker synchronization. 132 """ 133 self._strategy = strategy 134 self._cluster_spec = cluster_spec 135 self._task_type = task_type 136 self._task_id = task_id 137 self._session_config = session_config 138 self._worker_barrier = worker_barrier 139 self._rpc_layer = rpc_layer 140 self._master_target = self._get_master_target() 141 self._num_workers = _get_num_workers(cluster_spec) 142 self._is_chief_node = self._is_chief() 143 144 def _debug_message(self): 145 if self._cluster_spec: 146 return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( 147 self._cluster_spec, self.task_type, self.task_id) 148 else: 149 return "[local]" 150 151 def __enter__(self): 152 old_context = distribute_coordinator_context.get_current_worker_context() 153 if old_context: 154 raise ValueError( 155 "You cannot run distribute coordinator in a `worker_fn`.\t" + 156 self._debug_message()) 157 # pylint: disable=protected-access 158 distribute_coordinator_context._worker_context.current = self 159 160 def __exit__(self, unused_exception_type, unused_exception_value, 161 unused_traceback): 162 # pylint: disable=protected-access 163 distribute_coordinator_context._worker_context.current = None 164 165 def _get_master_target(self): 166 """Return the master target for a task.""" 167 # If cluster_spec is None or empty, we use local master. 168 if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR: 169 return "" 170 171 # If task_type is None, then it is in-graph replicated training. In this 172 # case we use the chief or first worker's master target. 173 if not self._task_type: 174 if _TaskType.CHIEF in self._cluster_spec.jobs: 175 task_type = _TaskType.CHIEF 176 task_id = 0 177 else: 178 assert _TaskType.WORKER in self._cluster_spec.jobs 179 task_type = _TaskType.WORKER 180 task_id = 0 181 else: 182 task_type = self._task_type 183 task_id = self._task_id 184 185 prefix = "" 186 if self._rpc_layer: 187 prefix = self._rpc_layer + "://" 188 return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0] 189 190 def _is_chief(self): 191 """Return whether the task is the chief worker.""" 192 if (not self._cluster_spec or 193 self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]): 194 return True 195 196 # If not local and chief not in the cluster_spec, use the first worker as 197 # chief. 198 if (_TaskType.CHIEF not in self._cluster_spec.jobs and 199 self._task_type == _TaskType.WORKER and self._task_id == 0): 200 return True 201 return False 202 203 def wait_for_other_workers(self): 204 """Waits for other workers to reach the same call to this method. 205 206 Raises: 207 ValueError: if `worker_barrier` is not passed to the __init__ method. 208 """ 209 if not self._worker_barrier: 210 # TODO(yuefengz): we should throw an error in independent worker mode. 211 return 212 self._worker_barrier.wait() 213 214 def session_creator(self, 215 scaffold=None, 216 config=None, 217 checkpoint_dir=None, 218 checkpoint_filename_with_path=None, 219 max_wait_secs=7200): 220 """Returns a session creator. 221 222 The returned session creator will be configured with the correct master 223 target and session configs. It will also run either init ops or ready ops 224 by querying the `strategy` object when `create_session` is called on it. 225 226 Args: 227 scaffold: A `Scaffold` used for gathering or building supportive ops. If 228 not specified a default one is created. It's used to finalize the graph. 229 config: `ConfigProto` proto used to configure the session. 230 checkpoint_dir: A string. Optional path to a directory where to restore 231 variables. 232 checkpoint_filename_with_path: Full file name path to the checkpoint file. 233 Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be 234 specified. 235 max_wait_secs: Maximum time to wait for the session to become available. 236 237 Returns: 238 a descendant of SessionCreator. 239 """ 240 if config: 241 session_config = copy.deepcopy(config) 242 session_config.MergeFrom(self._session_config) 243 else: 244 session_config = self._session_config 245 246 if not self._strategy or self._strategy.extended.experimental_should_init: 247 logging.info("Creating chief session creator with config: %r", config) 248 return monitored_session.ChiefSessionCreator( 249 scaffold, 250 master=self.master_target, 251 config=session_config, 252 checkpoint_dir=checkpoint_dir, 253 checkpoint_filename_with_path=checkpoint_filename_with_path) 254 else: 255 logging.info("Creating worker session creator with config: %r", config) 256 return monitored_session.WorkerSessionCreator( 257 scaffold, 258 master=self.master_target, 259 config=session_config, 260 max_wait_secs=max_wait_secs) 261 262 @property 263 def session_config(self): 264 return copy.deepcopy(self._session_config) 265 266 @property 267 def has_barrier(self): 268 """Whether the barrier is set or not.""" 269 return self._worker_barrier is not None 270 271 @property 272 def distributed_mode(self): 273 """Whether it is distributed training or not.""" 274 return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR 275 276 @property 277 def cluster_spec(self): 278 """Returns a copy of the cluster_spec object.""" 279 return copy.deepcopy(self._cluster_spec) 280 281 @property 282 def task_type(self): 283 """Returns the role of the corresponding task.""" 284 return self._task_type 285 286 @property 287 def task_id(self): 288 """Returns the id or index of the corresponding task.""" 289 return self._task_id 290 291 @property 292 def master_target(self): 293 """Returns the session master for the corresponding task to connect to.""" 294 return self._master_target 295 296 @property 297 def is_chief(self): 298 """Returns whether the task is a chief node.""" 299 return self._is_chief_node 300 301 @property 302 def num_workers(self): 303 """Returns number of workers in the cluster, including chief.""" 304 return self._num_workers 305 306 @property 307 def experimental_should_init(self): 308 """Whether to run init ops.""" 309 return self._strategy.extended.experimental_should_init 310 311 @property 312 def should_checkpoint(self): 313 """Whether to save checkpoint.""" 314 return self._strategy.extended.should_checkpoint 315 316 @property 317 def should_save_summary(self): 318 """Whether to save summaries.""" 319 return self._strategy.extended.should_save_summary 320 321 322def _run_single_worker(worker_fn, 323 strategy, 324 cluster_spec, 325 task_type, 326 task_id, 327 session_config, 328 rpc_layer="", 329 worker_barrier=None, 330 coord=None): 331 """Runs a single worker by calling `worker_fn` under context.""" 332 session_config = copy.deepcopy(session_config) 333 strategy = copy.deepcopy(strategy) 334 # If there is an EVALUATOR task, we run single-machine eval on that task. 335 if task_type == _TaskType.EVALUATOR: 336 # It is possible to not have a strategy object for EVALUATOR task. 337 if strategy: 338 strategy.configure(session_config) 339 else: 340 assert strategy 341 strategy.configure(session_config, cluster_spec, task_type, task_id) 342 343 context = _WorkerContext( 344 strategy, 345 cluster_spec, 346 task_type, 347 task_id, 348 session_config=session_config, 349 rpc_layer=rpc_layer, 350 worker_barrier=worker_barrier) 351 with context: 352 if coord: 353 with coord.stop_on_exception(): 354 return worker_fn(strategy) 355 else: 356 return worker_fn(strategy) 357 358 359def _split_cluster_for_evaluator(cluster_spec, task_type): 360 """Split the cluster for evaluator since it needn't talk to other tasks.""" 361 # Splitting the cluster is important to prevent the evaluator from talking to 362 # other tasks in the cluster. Since we allow evaluator not to use 363 # distribution strategies and as a result ops in the evaluator task may have 364 # unspecified devices. Those ops may end up on other tasks if we don't split 365 # the cluster. 366 # Note: if you bypass distribute coordinator and bring the cluster yourself, 367 # you can equivalently set device filters to split clusters. This is already 368 # done by distribution strategy's `update_config_proto` method. 369 new_cluster_spec = multi_worker_util.normalize_cluster_spec( 370 cluster_spec).as_dict() 371 if task_type == _TaskType.EVALUATOR: 372 assert _TaskType.EVALUATOR in new_cluster_spec 373 new_cluster_spec = { 374 _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR] 375 } 376 else: 377 new_cluster_spec.pop(_TaskType.EVALUATOR, None) 378 return multi_worker_util.normalize_cluster_spec(new_cluster_spec) 379 380 381def _run_std_server(cluster_spec=None, 382 task_type=None, 383 task_id=None, 384 session_config=None, 385 rpc_layer=None, 386 environment=None): 387 """Runs a standard server.""" 388 # Check if the Server is already running. If so, assert that no configuration 389 # options have changed, and return the existing Server. This allows us to 390 # call `run_distribute_coordinator` multiple times. 391 if getattr(_thread_local, "server", None) is not None: 392 assert _thread_local.cluster_spec == cluster_spec 393 assert _thread_local.task_type == task_type 394 assert _thread_local.task_id == task_id 395 assert _thread_local.session_config_str == repr(session_config) 396 assert _thread_local.rpc_layer == rpc_layer 397 assert _thread_local.environment == environment 398 return _thread_local.server 399 else: 400 # This method is not thread-safe. 401 _thread_local.server_started = True 402 _thread_local.cluster_spec = cluster_spec 403 _thread_local.task_type = task_type 404 _thread_local.task_id = task_id 405 _thread_local.session_config_str = repr(session_config) 406 _thread_local.rpc_layer = rpc_layer 407 _thread_local.environment = environment 408 409 assert cluster_spec 410 target = cluster_spec.task_address(task_type, task_id) 411 if rpc_layer: 412 target = rpc_layer + "://" + target 413 414 class _FakeServer(object): 415 """A fake server that runs a master session.""" 416 417 def start(self): 418 # A tensorflow server starts when a remote session is created. 419 logging.info( 420 "Creating a remote session to start a TensorFlow server, " 421 "target = %r, session_config=%r", target, session_config) 422 session.Session(target=target, config=session_config) 423 424 def join(self): 425 while True: 426 time.sleep(5) 427 428 if environment == "google": 429 server = _FakeServer() 430 else: 431 if session_config: 432 logging.info( 433 "Starting standard TensorFlow server, target = %r, session_config= " 434 "%r", target, session_config) 435 else: 436 logging.info("Starting standard TensorFlow server, target = %r", target) 437 cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type) 438 server = server_lib.Server( 439 cluster_spec, 440 job_name=task_type, 441 task_index=task_id, 442 config=session_config, 443 protocol=rpc_layer) 444 445 server.start() 446 _thread_local.server = server 447 return server 448 449 450def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 451 cluster_spec, session_config, rpc_layer): 452 """Runs a standalone client for between-graph replication.""" 453 coord = coordinator.Coordinator() 454 eval_thread = None 455 if _TaskType.EVALUATOR in cluster_spec.jobs: 456 eval_thread = threading.Thread( 457 target=_run_single_worker, 458 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 459 session_config), 460 kwargs={ 461 "rpc_layer": rpc_layer, 462 "coord": coord, 463 }) 464 eval_thread.start() 465 466 threads = [] 467 worker_barrier = _Barrier(_get_num_workers(cluster_spec)) 468 for task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 469 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 470 t = threading.Thread( 471 target=_run_single_worker, 472 args=(worker_fn, strategy, cluster_spec, task_type, task_id, 473 session_config), 474 kwargs={ 475 "rpc_layer": rpc_layer, 476 "worker_barrier": worker_barrier, 477 "coord": coord, 478 }) 479 t.start() 480 threads.append(t) 481 482 if eval_thread: 483 # TODO(yuefengz): is it necessary to join eval thread? 484 threads_to_join = threads + [eval_thread] 485 else: 486 threads_to_join = threads 487 coord.join(threads_to_join) 488 489 # TODO(yuefengz): we probably want to return results from all workers? 490 return None 491 492 493def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 494 cluster_spec, session_config, rpc_layer): 495 """Runs a standalone client for in-graph replication.""" 496 coord = coordinator.Coordinator() 497 eval_thread = None 498 if _TaskType.EVALUATOR in cluster_spec.jobs: 499 eval_thread = threading.Thread( 500 target=_run_single_worker, 501 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 502 session_config), 503 kwargs={ 504 "rpc_layer": rpc_layer, 505 "coord": coord, 506 }) 507 eval_thread.start() 508 509 worker_result = _run_single_worker( 510 worker_fn, 511 strategy, 512 cluster_spec, 513 None, 514 None, 515 session_config, 516 rpc_layer=rpc_layer, 517 coord=coord) 518 519 if eval_thread: 520 coord.join([eval_thread]) 521 522 return worker_result 523 524 525def _configure_session_config_for_std_servers( 526 strategy, eval_strategy, session_config, cluster_spec, task_type, task_id): 527 # pylint: disable=g-doc-args 528 """Call strategy's `configure` to mutate the session_config. 529 530 The session_config is currently needed as default config for a TensorFlow 531 server. In the future, we should be able to remove this method and only pass 532 the session config to a client session. 533 """ 534 if task_type == _TaskType.EVALUATOR: 535 if eval_strategy: 536 eval_strategy.configure(session_config=session_config) 537 else: 538 # The strategy may be shared in standalone client mode. 539 strategy = copy.deepcopy(strategy) 540 strategy.configure( 541 session_config=session_config, 542 cluster_spec=cluster_spec, 543 task_type=task_type, 544 task_id=task_id) 545 # Remove the device filters specific to the strategy, so that the 546 # TensorFlow server brought up with one strategy can be used by other 547 # strategies. The device filters can be set in the client side as well. 548 del session_config.device_filters[:] 549 550 551def run_standard_tensorflow_server(session_config=None): 552 """Starts a standard TensorFlow server. 553 554 This method parses configurations from "TF_CONFIG" environment variable and 555 starts a TensorFlow server. The "TF_CONFIG" is typically a json string and 556 must have information of the cluster and the role of the server in the 557 cluster. One example is: 558 559 TF_CONFIG='{ 560 "cluster": { 561 "worker": ["host1:2222", "host2:2222", "host3:2222"], 562 "ps": ["host4:2222", "host5:2222"] 563 }, 564 "task": {"type": "worker", "index": 1} 565 }' 566 567 This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster 568 and the current role is worker 1. 569 570 Valid task types are "chief", "worker", "ps" and "evaluator" and you can have 571 at most one "chief" and at most one "evaluator". 572 573 An optional key-value can be specified is "rpc_layer". The default value is 574 "grpc". 575 576 Args: 577 session_config: an optional `tf.compat.v1.ConfigProto` object. Users can 578 pass in the session config object to configure server-local devices. 579 580 Returns: 581 a `tf.distribute.Server` object which has already been started. 582 583 Raises: 584 ValueError: if the "TF_CONFIG" environment is not complete. 585 """ 586 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 587 if "cluster" not in tf_config: 588 raise ValueError("\"cluster\" is not found in TF_CONFIG.") 589 cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"]) 590 if "task" not in tf_config: 591 raise ValueError("\"task\" is not found in TF_CONFIG.") 592 task_env = tf_config["task"] 593 if "type" not in task_env: 594 raise ValueError( 595 "\"task_type\" is not found in the `task` part of TF_CONFIG.") 596 task_type = task_env["type"] 597 task_id = int(task_env.get("index", 0)) 598 599 rpc_layer = tf_config.get("rpc_layer", "grpc") 600 601 session_config = session_config or config_pb2.ConfigProto() 602 # Set the collective group leader for collective ops to initialize collective 603 # ops when server starts. 604 if "chief" in cluster_spec.jobs: 605 session_config.experimental.collective_group_leader = ( 606 "/job:chief/replica:0/task:0") 607 else: 608 if "worker" not in cluster_spec.jobs: 609 raise ValueError( 610 "You must have `chief` or `worker` jobs in the `cluster_spec`.") 611 session_config.experimental.collective_group_leader = ( 612 "/job:worker/replica:0/task:0") 613 614 server = _run_std_server( 615 cluster_spec=cluster_spec, 616 task_type=task_type, 617 task_id=task_id, 618 session_config=session_config, 619 rpc_layer=rpc_layer) 620 server.start() 621 return server 622 623 624# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. 625# TODO(yuefengz): we may need a smart way to figure out whether the current task 626# is the special task when we support cluster_spec propagation. 627def run_distribute_coordinator(worker_fn, 628 strategy, 629 eval_fn=None, 630 eval_strategy=None, 631 mode=CoordinatorMode.STANDALONE_CLIENT, 632 cluster_spec=None, 633 task_type=None, 634 task_id=None, 635 session_config=None, 636 rpc_layer="grpc"): 637 """Runs the coordinator for distributed TensorFlow. 638 639 This function runs a split coordinator for distributed TensorFlow in its 640 default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec` 641 specifying server addresses and their roles in a cluster, this coordinator 642 will figure out how to set them up, give the underlying function the right 643 targets for master sessions via a scope object and coordinate their training. 644 The cluster consisting of standard servers needs to be brought up either with 645 the standard server binary or with a binary running distribute coordinator 646 with `task_type` set to non-client type which will then turn into standard 647 servers. 648 649 In addition to be the distribute coordinator, this is also the source of 650 configurations for each job in the distributed training. As there are multiple 651 ways to configure a distributed TensorFlow cluster, its context object 652 provides these configurations so that users or higher-level APIs don't have to 653 figure out the configuration for each job by themselves. 654 655 In the between-graph replicated training, this coordinator will create 656 multiple threads and each calls the `worker_fn` which is supposed to create 657 its own graph and connect to one worker master given by its context object. In 658 the in-graph replicated training, it has only one thread calling this 659 `worker_fn`. 660 661 Another mode is the INDEPENDENT_WORKER mode where each server runs a 662 distribute coordinator which will start a standard server and optionally runs 663 `worker_fn` depending whether it is between-graph training or in-graph 664 replicated training. 665 666 The `strategy` object is expected to be a DistributionStrategy object which 667 has implemented methods needed by distributed coordinator such as 668 `configure(session_config, cluster_spec, task_type, task_id)` which configures 669 the strategy object for a specific task and `experimental_should_init` 670 property which instructs the distribute coordinator whether to run init ops 671 for a task. The distribute coordinator will make a copy of the `strategy` 672 object, call its `configure` method and pass it to `worker_fn` as an argument. 673 674 The `worker_fn` defines the training logic and is called under its own 675 worker context which can be accessed to via `get_current_worker_context`. A 676 worker context provides access to configurations for each task, e.g. the 677 task_type, task_id, master target and so on. Since `worker_fn` will be called 678 in a thread and possibly multiple times, caller should be careful when it 679 accesses global data. For example, it is unsafe to define flags in a 680 `worker_fn` or to define different environment variables for different 681 `worker_fn`s. 682 683 The `worker_fn` for the between-graph replication is defined as if there is 684 only one worker corresponding to the `worker_fn` and possibly ps jobs. For 685 example, when training with parameter servers, it assigns variables to 686 parameter servers and all other operations to that worker. In the in-graph 687 replication case, the `worker_fn` has to define operations for all worker 688 jobs. Using a distribution strategy can simplify the `worker_fn` by not having 689 to worry about the replication and device assignment of variables and 690 operations. 691 692 This method is intended to be invoked by high-level APIs so that users don't 693 have to explicitly call it to run this coordinator. For those who don't use 694 high-level APIs, to change a program to use this coordinator, wrap everything 695 in a the program after global data definitions such as commandline flag 696 definition into the `worker_fn` and get task-specific configurations from 697 the worker context. 698 699 The `cluster_spec` can be either passed by the argument or parsed from the 700 "TF_CONFIG" environment variable. Example of a TF_CONFIG: 701 ``` 702 cluster = {'chief': ['host0:2222'], 703 'ps': ['host1:2222', 'host2:2222'], 704 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 705 os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster}) 706 ``` 707 708 If `cluster_spec` is not given in any format, it becomes local training and 709 this coordinator will connect to a local session. 710 711 For evaluation, if "evaluator" exists in the cluster_spec, a separate thread 712 will be created to call `eval_fn` with its `task_type` set to "evaluator". If 713 `eval_fn` is not defined, fall back to `worker_fn`. This implies that 714 evaluation will be done on a single machine if there is an "evaluator" task. 715 If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the 716 `worker_fn` for how to do evaluation. 717 718 Args: 719 worker_fn: the function to be called. The function should accept a 720 `strategy` object and will be given access to a context object via a 721 context manager scope. 722 strategy: a DistributionStrategy object specifying whether it should 723 run between-graph replicated training or not, whether to run init ops, 724 etc. This object will also be configured given `session_config`, 725 `cluster_spec`, `task_type` and `task_id`. 726 eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed 727 in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn` 728 will be used for this task. 729 eval_strategy: optional DistributionStrategy object for "evaluator" task. 730 mode: in which mode this distribute coordinator runs. 731 cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles 732 in a cluster. If not set or empty, fall back to local training. 733 task_type: the current task type, optional if this is a client. 734 task_id: the current task id, optional if this is a client. 735 session_config: an optional `tf.compat.v1.ConfigProto` object which will be 736 passed to `strategy`'s `configure` method and used to create a session. 737 rpc_layer: optional string, the protocol for RPC, e.g. "grpc". 738 739 Raises: 740 ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or 741 a ClusterSpec. 742 743 Returns: 744 In the client job, return the value returned by `worker_fn` if 745 it is in-graph replication or INDEPENDENT_WORKER mode; return None 746 otherwise. 747 """ 748 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 749 rpc_layer = tf_config.get("rpc_layer", rpc_layer) 750 environment = tf_config.get("environment", None) 751 752 if not cluster_spec: 753 cluster_spec = tf_config.get("cluster", {}) 754 task_env = tf_config.get("task", {}) 755 if task_env: 756 task_type = task_env.get("type", task_type) 757 task_id = int(task_env.get("index", task_id)) 758 759 if cluster_spec: 760 # TODO(yuefengz): validate cluster_spec. 761 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 762 elif hasattr(strategy.extended, "_cluster_resolver"): 763 cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access 764 task_type = cluster_resolver.task_type 765 task_id = cluster_resolver.task_id 766 rpc_layer = cluster_resolver.rpc_layer or rpc_layer 767 environment = cluster_resolver.environment 768 cluster_spec = cluster_resolver.cluster_spec() 769 770 # Setting the session config is necessary for some strategies such as 771 # CollectiveAllReduceStrategy. 772 session_config = session_config or config_pb2.ConfigProto( 773 allow_soft_placement=True) 774 775 if cluster_spec: 776 logging.info( 777 "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " 778 "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode, 779 cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer) 780 781 if not cluster_spec: 782 # `mode` is ignored in the local case. 783 logging.info("Running local Distribute Coordinator.") 784 _run_single_worker(worker_fn, strategy, None, None, None, session_config, 785 rpc_layer) 786 if eval_fn: 787 _run_single_worker(eval_fn, eval_strategy, None, None, None, 788 session_config, rpc_layer) 789 else: 790 logging.warning("Skipped evaluation since `eval_fn` is not passed in.") 791 elif mode == CoordinatorMode.STANDALONE_CLIENT: 792 if not eval_fn: 793 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 794 "used if an \"evaluator\" task exists in the cluster.") 795 eval_fn = eval_fn or worker_fn 796 if not eval_strategy: 797 logging.warning("`eval_strategy` is not passed in. No distribution " 798 "strategy will be used for evaluation.") 799 800 # The client must know the cluster but servers in the cluster don't have to 801 # know the client. 802 if task_type in [_TaskType.CLIENT, None]: 803 if strategy.extended.experimental_between_graph: 804 return _run_between_graph_client(worker_fn, strategy, eval_fn, 805 eval_strategy, cluster_spec, 806 session_config, rpc_layer) 807 else: 808 return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 809 cluster_spec, session_config, rpc_layer) 810 else: 811 # If not a client job, run the standard server. 812 _configure_session_config_for_std_servers(strategy, eval_strategy, 813 session_config, cluster_spec, 814 task_type, task_id) 815 server = _run_std_server( 816 cluster_spec=cluster_spec, 817 task_type=task_type, 818 task_id=task_id, 819 session_config=session_config, 820 rpc_layer=rpc_layer, 821 environment=environment) 822 server.join() 823 else: 824 if mode != CoordinatorMode.INDEPENDENT_WORKER: 825 raise ValueError("Unexpected coordinator mode: %r" % mode) 826 827 if not eval_fn: 828 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 829 "used if an \"evaluator\" task exists in the cluster.") 830 eval_fn = eval_fn or worker_fn 831 if not eval_strategy: 832 logging.warning("`eval_strategy` is not passed in. No distribution " 833 "strategy will be used for evaluation.") 834 835 # Every one starts a standard server, get session config from `configure` 836 # method. 837 _configure_session_config_for_std_servers(strategy, eval_strategy, 838 session_config, cluster_spec, 839 task_type, task_id) 840 841 if (task_type != _TaskType.EVALUATOR and 842 not getattr(strategy.extended, "_std_server_started", False)): 843 # Right now, with eager mode, context is configured with a std server at 844 # the very beginning while with graph mode the std server is started when 845 # distribute coordinator is called. We should consolidate these two paths. 846 server = _run_std_server( 847 cluster_spec=cluster_spec, 848 task_type=task_type, 849 task_id=task_id, 850 session_config=session_config, 851 rpc_layer=rpc_layer, 852 environment=environment) 853 if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 854 if strategy.extended.experimental_between_graph: 855 # All jobs run `worker_fn` if between-graph. 856 return _run_single_worker(worker_fn, strategy, cluster_spec, task_type, 857 task_id, session_config, rpc_layer) 858 else: 859 # Only one node runs `worker_fn` if in-graph. 860 context = _WorkerContext(strategy, cluster_spec, task_type, task_id) 861 if context.is_chief: 862 return _run_single_worker(worker_fn, strategy, cluster_spec, None, 863 None, session_config, rpc_layer) 864 else: 865 server.join() 866 elif task_type == _TaskType.EVALUATOR: 867 return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type, 868 task_id, session_config, rpc_layer) 869 else: 870 if task_type != _TaskType.PS: 871 raise ValueError("Unexpected task_type: %r" % task_type) 872 server.join() 873