xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distribute_coordinator.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A component for running distributed TensorFlow."""
16
17import copy
18import json
19import os
20import threading
21import time
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.distribute import distribute_coordinator_context
26from tensorflow.python.distribute import multi_worker_util
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.training import coordinator
29from tensorflow.python.training import monitored_session
30from tensorflow.python.training import server_lib
31
32
33_thread_local = threading.local()
34
35
36class _TaskType(object):
37  PS = "ps"
38  WORKER = "worker"
39  CHIEF = "chief"
40  EVALUATOR = "evaluator"
41  CLIENT = "client"
42
43
44# TODO(yuefengz): support another mode where the client colocates with one
45# worker.
46class CoordinatorMode(object):
47  """Specify how distribute coordinator runs."""
48  # The default mode where distribute coordinator will run as a standalone
49  # client and connects to remote servers for training.  Each remote server can
50  # use the distribute coordinator binary with task_type set correctly which
51  # will then turn into standard servers.
52  STANDALONE_CLIENT = "standalone_client"
53
54  # The distribute coordinator runs on each worker. It will run a standard
55  # server on each worker and optionally run the `worker_fn` that is configured
56  # to talk to its standard server.
57  INDEPENDENT_WORKER = "independent_worker"
58
59
60class _Barrier(object):
61  """A reusable barrier class for worker synchronization."""
62
63  def __init__(self, num_participants):
64    """Initializes the barrier object.
65
66    Args:
67      num_participants: an integer which is the expected number of calls of
68        `wait` pass to through this barrier.
69    """
70    self._num_participants = num_participants
71    self._counter = 0
72    self._flag = False
73    self._local_sense = threading.local()
74    self._lock = threading.Lock()
75    self._condition = threading.Condition()
76
77  def wait(self):
78    """Waits until all other callers reach the same wait call."""
79    self._local_sense.value = not self._flag
80    with self._lock:
81      self._counter += 1
82      if self._counter == self._num_participants:
83        self._counter = 0
84        self._flag = self._local_sense.value
85    with self._condition:
86      while self._flag != self._local_sense.value:
87        self._condition.wait()
88      self._condition.notify_all()
89
90
91def _get_num_workers(cluster_spec):
92  """Gets number of workers including chief."""
93  if not cluster_spec:
94    return 0
95  return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
96      cluster_spec.as_dict().get(_TaskType.CHIEF, []))
97
98
99class _WorkerContext(object):
100  """The worker context class.
101
102  This context object provides configuration information for each task. One
103  context manager with a worker context object will be created per
104  invocation to the `worker_fn` where `get_current_worker_context` can be called
105  to access the worker context object.
106  """
107
108  def __init__(self,
109               strategy,
110               cluster_spec,
111               task_type,
112               task_id,
113               session_config=None,
114               rpc_layer="grpc",
115               worker_barrier=None):
116    """Initialize the worker context object.
117
118    Args:
119      strategy: a `DistributionStrategy` object.
120      cluster_spec: a ClusterSpec object. It can be empty or None in the local
121        training case.
122      task_type: a string indicating the role of the corresponding task, such as
123        "worker" or "ps". It can be None if it is local training or in-graph
124        replicated training.
125      task_id: an integer indicating id of the corresponding task. It can be
126        None if it is local training or in-graph replicated training.
127      session_config: an optional `tf.compat.v1.ConfigProto` object.
128      rpc_layer: optional string specifying the RPC protocol for communication
129        with worker masters. If None or empty, hosts in the `cluster_spec` will
130        be used directly.
131      worker_barrier: optional, the barrier object for worker synchronization.
132    """
133    self._strategy = strategy
134    self._cluster_spec = cluster_spec
135    self._task_type = task_type
136    self._task_id = task_id
137    self._session_config = session_config
138    self._worker_barrier = worker_barrier
139    self._rpc_layer = rpc_layer
140    self._master_target = self._get_master_target()
141    self._num_workers = _get_num_workers(cluster_spec)
142    self._is_chief_node = self._is_chief()
143
144  def _debug_message(self):
145    if self._cluster_spec:
146      return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
147          self._cluster_spec, self.task_type, self.task_id)
148    else:
149      return "[local]"
150
151  def __enter__(self):
152    old_context = distribute_coordinator_context.get_current_worker_context()
153    if old_context:
154      raise ValueError(
155          "You cannot run distribute coordinator in a `worker_fn`.\t" +
156          self._debug_message())
157    # pylint: disable=protected-access
158    distribute_coordinator_context._worker_context.current = self
159
160  def __exit__(self, unused_exception_type, unused_exception_value,
161               unused_traceback):
162    # pylint: disable=protected-access
163    distribute_coordinator_context._worker_context.current = None
164
165  def _get_master_target(self):
166    """Return the master target for a task."""
167    # If cluster_spec is None or empty, we use local master.
168    if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
169      return ""
170
171    # If task_type is None, then it is in-graph replicated training. In this
172    # case we use the chief or first worker's master target.
173    if not self._task_type:
174      if _TaskType.CHIEF in self._cluster_spec.jobs:
175        task_type = _TaskType.CHIEF
176        task_id = 0
177      else:
178        assert _TaskType.WORKER in self._cluster_spec.jobs
179        task_type = _TaskType.WORKER
180        task_id = 0
181    else:
182      task_type = self._task_type
183      task_id = self._task_id
184
185    prefix = ""
186    if self._rpc_layer:
187      prefix = self._rpc_layer + "://"
188    return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
189
190  def _is_chief(self):
191    """Return whether the task is the chief worker."""
192    if (not self._cluster_spec or
193        self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]):
194      return True
195
196    # If not local and chief not in the cluster_spec, use the first worker as
197    # chief.
198    if (_TaskType.CHIEF not in self._cluster_spec.jobs and
199        self._task_type == _TaskType.WORKER and self._task_id == 0):
200      return True
201    return False
202
203  def wait_for_other_workers(self):
204    """Waits for other workers to reach the same call to this method.
205
206    Raises:
207      ValueError: if `worker_barrier` is not passed to the __init__ method.
208    """
209    if not self._worker_barrier:
210      # TODO(yuefengz): we should throw an error in independent worker mode.
211      return
212    self._worker_barrier.wait()
213
214  def session_creator(self,
215                      scaffold=None,
216                      config=None,
217                      checkpoint_dir=None,
218                      checkpoint_filename_with_path=None,
219                      max_wait_secs=7200):
220    """Returns a session creator.
221
222    The returned session creator will be configured with the correct master
223    target and session configs. It will also run either init ops or ready ops
224    by querying the `strategy` object when `create_session` is called on it.
225
226    Args:
227      scaffold: A `Scaffold` used for gathering or building supportive ops. If
228        not specified a default one is created. It's used to finalize the graph.
229      config: `ConfigProto` proto used to configure the session.
230      checkpoint_dir: A string. Optional path to a directory where to restore
231        variables.
232      checkpoint_filename_with_path: Full file name path to the checkpoint file.
233        Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
234        specified.
235      max_wait_secs: Maximum time to wait for the session to become available.
236
237    Returns:
238      a descendant of SessionCreator.
239    """
240    if config:
241      session_config = copy.deepcopy(config)
242      session_config.MergeFrom(self._session_config)
243    else:
244      session_config = self._session_config
245
246    if not self._strategy or self._strategy.extended.experimental_should_init:
247      logging.info("Creating chief session creator with config: %r", config)
248      return monitored_session.ChiefSessionCreator(
249          scaffold,
250          master=self.master_target,
251          config=session_config,
252          checkpoint_dir=checkpoint_dir,
253          checkpoint_filename_with_path=checkpoint_filename_with_path)
254    else:
255      logging.info("Creating worker session creator with config: %r", config)
256      return monitored_session.WorkerSessionCreator(
257          scaffold,
258          master=self.master_target,
259          config=session_config,
260          max_wait_secs=max_wait_secs)
261
262  @property
263  def session_config(self):
264    return copy.deepcopy(self._session_config)
265
266  @property
267  def has_barrier(self):
268    """Whether the barrier is set or not."""
269    return self._worker_barrier is not None
270
271  @property
272  def distributed_mode(self):
273    """Whether it is distributed training or not."""
274    return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
275
276  @property
277  def cluster_spec(self):
278    """Returns a copy of the cluster_spec object."""
279    return copy.deepcopy(self._cluster_spec)
280
281  @property
282  def task_type(self):
283    """Returns the role of the corresponding task."""
284    return self._task_type
285
286  @property
287  def task_id(self):
288    """Returns the id or index of the corresponding task."""
289    return self._task_id
290
291  @property
292  def master_target(self):
293    """Returns the session master for the corresponding task to connect to."""
294    return self._master_target
295
296  @property
297  def is_chief(self):
298    """Returns whether the task is a chief node."""
299    return self._is_chief_node
300
301  @property
302  def num_workers(self):
303    """Returns number of workers in the cluster, including chief."""
304    return self._num_workers
305
306  @property
307  def experimental_should_init(self):
308    """Whether to run init ops."""
309    return self._strategy.extended.experimental_should_init
310
311  @property
312  def should_checkpoint(self):
313    """Whether to save checkpoint."""
314    return self._strategy.extended.should_checkpoint
315
316  @property
317  def should_save_summary(self):
318    """Whether to save summaries."""
319    return self._strategy.extended.should_save_summary
320
321
322def _run_single_worker(worker_fn,
323                       strategy,
324                       cluster_spec,
325                       task_type,
326                       task_id,
327                       session_config,
328                       rpc_layer="",
329                       worker_barrier=None,
330                       coord=None):
331  """Runs a single worker by calling `worker_fn` under context."""
332  session_config = copy.deepcopy(session_config)
333  strategy = copy.deepcopy(strategy)
334  # If there is an EVALUATOR task, we run single-machine eval on that task.
335  if task_type == _TaskType.EVALUATOR:
336    # It is possible to not have a strategy object for EVALUATOR task.
337    if strategy:
338      strategy.configure(session_config)
339  else:
340    assert strategy
341    strategy.configure(session_config, cluster_spec, task_type, task_id)
342
343  context = _WorkerContext(
344      strategy,
345      cluster_spec,
346      task_type,
347      task_id,
348      session_config=session_config,
349      rpc_layer=rpc_layer,
350      worker_barrier=worker_barrier)
351  with context:
352    if coord:
353      with coord.stop_on_exception():
354        return worker_fn(strategy)
355    else:
356      return worker_fn(strategy)
357
358
359def _split_cluster_for_evaluator(cluster_spec, task_type):
360  """Split the cluster for evaluator since it needn't talk to other tasks."""
361  # Splitting the cluster is important to prevent the evaluator from talking to
362  # other tasks in the cluster. Since we allow evaluator not to use
363  # distribution strategies and as a result ops in the evaluator task may have
364  # unspecified devices. Those ops may end up on other tasks if we don't split
365  # the cluster.
366  # Note: if you bypass distribute coordinator and bring the cluster yourself,
367  # you can equivalently set device filters to split clusters. This is already
368  # done by distribution strategy's `update_config_proto` method.
369  new_cluster_spec = multi_worker_util.normalize_cluster_spec(
370      cluster_spec).as_dict()
371  if task_type == _TaskType.EVALUATOR:
372    assert _TaskType.EVALUATOR in new_cluster_spec
373    new_cluster_spec = {
374        _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
375    }
376  else:
377    new_cluster_spec.pop(_TaskType.EVALUATOR, None)
378  return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
379
380
381def _run_std_server(cluster_spec=None,
382                    task_type=None,
383                    task_id=None,
384                    session_config=None,
385                    rpc_layer=None,
386                    environment=None):
387  """Runs a standard server."""
388  # Check if the Server is already running. If so, assert that no configuration
389  # options have changed, and return the existing Server. This allows us to
390  # call `run_distribute_coordinator` multiple times.
391  if getattr(_thread_local, "server", None) is not None:
392    assert _thread_local.cluster_spec == cluster_spec
393    assert _thread_local.task_type == task_type
394    assert _thread_local.task_id == task_id
395    assert _thread_local.session_config_str == repr(session_config)
396    assert _thread_local.rpc_layer == rpc_layer
397    assert _thread_local.environment == environment
398    return _thread_local.server
399  else:
400    # This method is not thread-safe.
401    _thread_local.server_started = True
402    _thread_local.cluster_spec = cluster_spec
403    _thread_local.task_type = task_type
404    _thread_local.task_id = task_id
405    _thread_local.session_config_str = repr(session_config)
406    _thread_local.rpc_layer = rpc_layer
407    _thread_local.environment = environment
408
409  assert cluster_spec
410  target = cluster_spec.task_address(task_type, task_id)
411  if rpc_layer:
412    target = rpc_layer + "://" + target
413
414  class _FakeServer(object):
415    """A fake server that runs a master session."""
416
417    def start(self):
418      # A tensorflow server starts when a remote session is created.
419      logging.info(
420          "Creating a remote session to start a TensorFlow server, "
421          "target = %r, session_config=%r", target, session_config)
422      session.Session(target=target, config=session_config)
423
424    def join(self):
425      while True:
426        time.sleep(5)
427
428  if environment == "google":
429    server = _FakeServer()
430  else:
431    if session_config:
432      logging.info(
433          "Starting standard TensorFlow server, target = %r, session_config= "
434          "%r", target, session_config)
435    else:
436      logging.info("Starting standard TensorFlow server, target = %r", target)
437    cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
438    server = server_lib.Server(
439        cluster_spec,
440        job_name=task_type,
441        task_index=task_id,
442        config=session_config,
443        protocol=rpc_layer)
444
445  server.start()
446  _thread_local.server = server
447  return server
448
449
450def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
451                              cluster_spec, session_config, rpc_layer):
452  """Runs a standalone client for between-graph replication."""
453  coord = coordinator.Coordinator()
454  eval_thread = None
455  if _TaskType.EVALUATOR in cluster_spec.jobs:
456    eval_thread = threading.Thread(
457        target=_run_single_worker,
458        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
459              session_config),
460        kwargs={
461            "rpc_layer": rpc_layer,
462            "coord": coord,
463        })
464    eval_thread.start()
465
466  threads = []
467  worker_barrier = _Barrier(_get_num_workers(cluster_spec))
468  for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
469    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
470      t = threading.Thread(
471          target=_run_single_worker,
472          args=(worker_fn, strategy, cluster_spec, task_type, task_id,
473                session_config),
474          kwargs={
475              "rpc_layer": rpc_layer,
476              "worker_barrier": worker_barrier,
477              "coord": coord,
478          })
479      t.start()
480      threads.append(t)
481
482  if eval_thread:
483    # TODO(yuefengz): is it necessary to join eval thread?
484    threads_to_join = threads + [eval_thread]
485  else:
486    threads_to_join = threads
487  coord.join(threads_to_join)
488
489  # TODO(yuefengz): we probably want to return results from all workers?
490  return None
491
492
493def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
494                         cluster_spec, session_config, rpc_layer):
495  """Runs a standalone client for in-graph replication."""
496  coord = coordinator.Coordinator()
497  eval_thread = None
498  if _TaskType.EVALUATOR in cluster_spec.jobs:
499    eval_thread = threading.Thread(
500        target=_run_single_worker,
501        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
502              session_config),
503        kwargs={
504            "rpc_layer": rpc_layer,
505            "coord": coord,
506        })
507    eval_thread.start()
508
509  worker_result = _run_single_worker(
510      worker_fn,
511      strategy,
512      cluster_spec,
513      None,
514      None,
515      session_config,
516      rpc_layer=rpc_layer,
517      coord=coord)
518
519  if eval_thread:
520    coord.join([eval_thread])
521
522  return worker_result
523
524
525def _configure_session_config_for_std_servers(
526    strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
527  # pylint: disable=g-doc-args
528  """Call strategy's `configure` to mutate the session_config.
529
530  The session_config is currently needed as default config for a TensorFlow
531  server. In the future, we should be able to remove this method and only pass
532  the session config to a client session.
533  """
534  if task_type == _TaskType.EVALUATOR:
535    if eval_strategy:
536      eval_strategy.configure(session_config=session_config)
537  else:
538    # The strategy may be shared in standalone client mode.
539    strategy = copy.deepcopy(strategy)
540    strategy.configure(
541        session_config=session_config,
542        cluster_spec=cluster_spec,
543        task_type=task_type,
544        task_id=task_id)
545  # Remove the device filters specific to the strategy, so that the
546  # TensorFlow server brought up with one strategy can be used by other
547  # strategies. The device filters can be set in the client side as well.
548  del session_config.device_filters[:]
549
550
551def run_standard_tensorflow_server(session_config=None):
552  """Starts a standard TensorFlow server.
553
554  This method parses configurations from "TF_CONFIG" environment variable and
555  starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
556  must have information of the cluster and the role of the server in the
557  cluster. One example is:
558
559  TF_CONFIG='{
560      "cluster": {
561          "worker": ["host1:2222", "host2:2222", "host3:2222"],
562          "ps": ["host4:2222", "host5:2222"]
563      },
564      "task": {"type": "worker", "index": 1}
565  }'
566
567  This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
568  and the current role is worker 1.
569
570  Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
571  at most one "chief" and at most one "evaluator".
572
573  An optional key-value can be specified is "rpc_layer". The default value is
574  "grpc".
575
576  Args:
577    session_config: an optional `tf.compat.v1.ConfigProto` object. Users can
578      pass in the session config object to configure server-local devices.
579
580  Returns:
581    a `tf.distribute.Server` object which has already been started.
582
583  Raises:
584    ValueError: if the "TF_CONFIG" environment is not complete.
585  """
586  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
587  if "cluster" not in tf_config:
588    raise ValueError("\"cluster\" is not found in TF_CONFIG.")
589  cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
590  if "task" not in tf_config:
591    raise ValueError("\"task\" is not found in TF_CONFIG.")
592  task_env = tf_config["task"]
593  if "type" not in task_env:
594    raise ValueError(
595        "\"task_type\" is not found in the `task` part of TF_CONFIG.")
596  task_type = task_env["type"]
597  task_id = int(task_env.get("index", 0))
598
599  rpc_layer = tf_config.get("rpc_layer", "grpc")
600
601  session_config = session_config or config_pb2.ConfigProto()
602  # Set the collective group leader for collective ops to initialize collective
603  # ops when server starts.
604  if "chief" in cluster_spec.jobs:
605    session_config.experimental.collective_group_leader = (
606        "/job:chief/replica:0/task:0")
607  else:
608    if "worker" not in cluster_spec.jobs:
609      raise ValueError(
610          "You must have `chief` or `worker` jobs in the `cluster_spec`.")
611    session_config.experimental.collective_group_leader = (
612        "/job:worker/replica:0/task:0")
613
614  server = _run_std_server(
615      cluster_spec=cluster_spec,
616      task_type=task_type,
617      task_id=task_id,
618      session_config=session_config,
619      rpc_layer=rpc_layer)
620  server.start()
621  return server
622
623
624# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
625# TODO(yuefengz): we may need a smart way to figure out whether the current task
626# is the special task when we support cluster_spec propagation.
627def run_distribute_coordinator(worker_fn,
628                               strategy,
629                               eval_fn=None,
630                               eval_strategy=None,
631                               mode=CoordinatorMode.STANDALONE_CLIENT,
632                               cluster_spec=None,
633                               task_type=None,
634                               task_id=None,
635                               session_config=None,
636                               rpc_layer="grpc"):
637  """Runs the coordinator for distributed TensorFlow.
638
639  This function runs a split coordinator for distributed TensorFlow in its
640  default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
641  specifying server addresses and their roles in a cluster, this coordinator
642  will figure out how to set them up, give the underlying function the right
643  targets for master sessions via a scope object and coordinate their training.
644  The cluster consisting of standard servers needs to be brought up either with
645  the standard server binary or with a binary running distribute coordinator
646  with `task_type` set to non-client type which will then turn into standard
647  servers.
648
649  In addition to be the distribute coordinator, this is also the source of
650  configurations for each job in the distributed training. As there are multiple
651  ways to configure a distributed TensorFlow cluster, its context object
652  provides these configurations so that users or higher-level APIs don't have to
653  figure out the configuration for each job by themselves.
654
655  In the between-graph replicated training, this coordinator will create
656  multiple threads and each calls the `worker_fn` which is supposed to create
657  its own graph and connect to one worker master given by its context object. In
658  the in-graph replicated training, it has only one thread calling this
659  `worker_fn`.
660
661  Another mode is the INDEPENDENT_WORKER mode where each server runs a
662  distribute coordinator which will start a standard server and optionally runs
663  `worker_fn` depending whether it is between-graph training or in-graph
664  replicated training.
665
666  The `strategy` object is expected to be a DistributionStrategy object which
667  has implemented methods needed by distributed coordinator such as
668  `configure(session_config, cluster_spec, task_type, task_id)` which configures
669  the strategy object for a specific task and `experimental_should_init`
670  property which instructs the distribute coordinator whether to run init ops
671  for a task. The distribute coordinator will make a copy of the `strategy`
672  object, call its `configure` method and pass it to `worker_fn` as an argument.
673
674  The `worker_fn` defines the training logic and is called under its own
675  worker context which can be accessed to via `get_current_worker_context`. A
676  worker context provides access to configurations for each task, e.g. the
677  task_type, task_id, master target and so on. Since `worker_fn` will be called
678  in a thread and possibly multiple times, caller should be careful when it
679  accesses global data. For example, it is unsafe to define flags in a
680  `worker_fn` or to define different environment variables for different
681  `worker_fn`s.
682
683  The `worker_fn` for the between-graph replication is defined as if there is
684  only one worker corresponding to the `worker_fn` and possibly ps jobs. For
685  example, when training with parameter servers, it assigns variables to
686  parameter servers and all other operations to that worker. In the in-graph
687  replication case, the `worker_fn` has to define operations for all worker
688  jobs. Using a distribution strategy can simplify the `worker_fn` by not having
689  to worry about the replication and device assignment of variables and
690  operations.
691
692  This method is intended to be invoked by high-level APIs so that users don't
693  have to explicitly call it to run this coordinator. For those who don't use
694  high-level APIs, to change a program to use this coordinator, wrap everything
695  in a the program after global data definitions such as commandline flag
696  definition into the `worker_fn` and get task-specific configurations from
697  the worker context.
698
699  The `cluster_spec` can be either passed by the argument or parsed from the
700  "TF_CONFIG" environment variable. Example of a TF_CONFIG:
701  ```
702    cluster = {'chief': ['host0:2222'],
703               'ps': ['host1:2222', 'host2:2222'],
704               'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
705    os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
706  ```
707
708  If `cluster_spec` is not given in any format, it becomes local training and
709  this coordinator will connect to a local session.
710
711  For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
712  will be created to call `eval_fn` with its `task_type` set to "evaluator". If
713  `eval_fn` is not defined, fall back to `worker_fn`. This implies that
714  evaluation will be done on a single machine if there is an "evaluator" task.
715  If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
716  `worker_fn` for how to do evaluation.
717
718  Args:
719    worker_fn: the function to be called. The function should accept a
720      `strategy` object and will be given access to a context object via a
721      context manager scope.
722    strategy: a DistributionStrategy object specifying whether it should
723      run between-graph replicated training or not, whether to run init ops,
724      etc. This object will also be configured given `session_config`,
725      `cluster_spec`, `task_type` and `task_id`.
726    eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
727      in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
728      will be used for this task.
729    eval_strategy: optional DistributionStrategy object for "evaluator" task.
730    mode: in which mode this distribute coordinator runs.
731    cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
732      in a cluster. If not set or empty, fall back to local training.
733    task_type: the current task type, optional if this is a client.
734    task_id: the current task id, optional if this is a client.
735    session_config: an optional `tf.compat.v1.ConfigProto` object which will be
736      passed to `strategy`'s `configure` method and used to create a session.
737    rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
738
739  Raises:
740    ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
741      a ClusterSpec.
742
743  Returns:
744    In the client job, return the value returned by `worker_fn` if
745    it is in-graph replication or INDEPENDENT_WORKER mode; return None
746    otherwise.
747  """
748  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
749  rpc_layer = tf_config.get("rpc_layer", rpc_layer)
750  environment = tf_config.get("environment", None)
751
752  if not cluster_spec:
753    cluster_spec = tf_config.get("cluster", {})
754    task_env = tf_config.get("task", {})
755    if task_env:
756      task_type = task_env.get("type", task_type)
757      task_id = int(task_env.get("index", task_id))
758
759  if cluster_spec:
760    # TODO(yuefengz): validate cluster_spec.
761    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
762  elif hasattr(strategy.extended, "_cluster_resolver"):
763    cluster_resolver = strategy.extended._cluster_resolver  # pylint: disable=protected-access
764    task_type = cluster_resolver.task_type
765    task_id = cluster_resolver.task_id
766    rpc_layer = cluster_resolver.rpc_layer or rpc_layer
767    environment = cluster_resolver.environment
768    cluster_spec = cluster_resolver.cluster_spec()
769
770  # Setting the session config is necessary for some strategies such as
771  # CollectiveAllReduceStrategy.
772  session_config = session_config or config_pb2.ConfigProto(
773      allow_soft_placement=True)
774
775  if cluster_spec:
776    logging.info(
777        "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
778        "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
779        cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
780
781  if not cluster_spec:
782    # `mode` is ignored in the local case.
783    logging.info("Running local Distribute Coordinator.")
784    _run_single_worker(worker_fn, strategy, None, None, None, session_config,
785                       rpc_layer)
786    if eval_fn:
787      _run_single_worker(eval_fn, eval_strategy, None, None, None,
788                         session_config, rpc_layer)
789    else:
790      logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
791  elif mode == CoordinatorMode.STANDALONE_CLIENT:
792    if not eval_fn:
793      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
794                      "used if an \"evaluator\" task exists in the cluster.")
795    eval_fn = eval_fn or worker_fn
796    if not eval_strategy:
797      logging.warning("`eval_strategy` is not passed in. No distribution "
798                      "strategy will be used for evaluation.")
799
800    # The client must know the cluster but servers in the cluster don't have to
801    # know the client.
802    if task_type in [_TaskType.CLIENT, None]:
803      if strategy.extended.experimental_between_graph:
804        return _run_between_graph_client(worker_fn, strategy, eval_fn,
805                                         eval_strategy, cluster_spec,
806                                         session_config, rpc_layer)
807      else:
808        return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
809                                    cluster_spec, session_config, rpc_layer)
810    else:
811      # If not a client job, run the standard server.
812      _configure_session_config_for_std_servers(strategy, eval_strategy,
813                                                session_config, cluster_spec,
814                                                task_type, task_id)
815      server = _run_std_server(
816          cluster_spec=cluster_spec,
817          task_type=task_type,
818          task_id=task_id,
819          session_config=session_config,
820          rpc_layer=rpc_layer,
821          environment=environment)
822      server.join()
823  else:
824    if mode != CoordinatorMode.INDEPENDENT_WORKER:
825      raise ValueError("Unexpected coordinator mode: %r" % mode)
826
827    if not eval_fn:
828      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
829                      "used if an \"evaluator\" task exists in the cluster.")
830    eval_fn = eval_fn or worker_fn
831    if not eval_strategy:
832      logging.warning("`eval_strategy` is not passed in. No distribution "
833                      "strategy will be used for evaluation.")
834
835    # Every one starts a standard server, get session config from `configure`
836    # method.
837    _configure_session_config_for_std_servers(strategy, eval_strategy,
838                                              session_config, cluster_spec,
839                                              task_type, task_id)
840
841    if (task_type != _TaskType.EVALUATOR and
842        not getattr(strategy.extended, "_std_server_started", False)):
843      # Right now, with eager mode, context is configured with a std server at
844      # the very beginning while with graph mode the std server is started when
845      # distribute coordinator is called. We should consolidate these two paths.
846      server = _run_std_server(
847          cluster_spec=cluster_spec,
848          task_type=task_type,
849          task_id=task_id,
850          session_config=session_config,
851          rpc_layer=rpc_layer,
852          environment=environment)
853    if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
854      if strategy.extended.experimental_between_graph:
855        # All jobs run `worker_fn` if between-graph.
856        return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
857                                  task_id, session_config, rpc_layer)
858      else:
859        # Only one node runs `worker_fn` if in-graph.
860        context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
861        if context.is_chief:
862          return _run_single_worker(worker_fn, strategy, cluster_spec, None,
863                                    None, session_config, rpc_layer)
864        else:
865          server.join()
866    elif task_type == _TaskType.EVALUATOR:
867      return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
868                                task_id, session_config, rpc_layer)
869    else:
870      if task_type != _TaskType.PS:
871        raise ValueError("Unexpected task_type: %r" % task_type)
872      server.join()
873