xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/parameter_server_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a multi-worker parameter server tf.distribute strategy."""
16
17import copy
18
19
20from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribute_utils
24from tensorflow.python.distribute import input_lib
25from tensorflow.python.distribute import input_util
26from tensorflow.python.distribute import mirrored_run
27from tensorflow.python.distribute import multi_worker_util
28from tensorflow.python.distribute import numpy_dataset
29from tensorflow.python.distribute import ps_values
30from tensorflow.python.distribute import values
31from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
32from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
33from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
34from tensorflow.python.eager import context
35from tensorflow.python.framework import device as tf_device
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.training import device_setter
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44
45_LOCAL_CPU = "/device:CPU:0"
46
47
48@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  # pylint: disable=missing-docstring
49class ParameterServerStrategyV1(distribute_lib.StrategyV1):
50  """An asynchronous multi-worker parameter server tf.distribute strategy.
51
52  This strategy requires two roles: workers and parameter servers. Variables and
53  updates to those variables will be assigned to parameter servers and other
54  operations are assigned to workers.
55
56  When each worker has more than one GPU, operations will be replicated on all
57  GPUs. Even though operations may be replicated, variables are not and each
58  worker shares a common view for which parameter server a variable is assigned
59  to.
60
61  By default it uses `TFConfigClusterResolver` to detect configurations for
62  multi-worker training. This requires a 'TF_CONFIG' environment variable and
63  the 'TF_CONFIG' must have a cluster spec.
64
65  This class assumes each worker is running the same code independently, but
66  parameter servers are running a standard server. This means that while each
67  worker will synchronously compute a single gradient update across all GPUs,
68  updates between workers proceed asynchronously. Operations that occur only on
69  the first replica (such as incrementing the global step), will occur on the
70  first replica *of every worker*.
71
72  It is expected to call `call_for_each_replica(fn, ...)` for any
73  operations which potentially can be replicated across replicas (i.e. multiple
74  GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
75  caution needs to be taken:
76
77  1) It is generally not recommended to open a device scope under the strategy's
78  scope. A device scope (i.e. calling `tf.device`) will be merged with or
79  override the device for operations but will not change the device for
80  variables.
81
82  2) It is also not recommended to open a colocation scope (i.e. calling
83  `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating
84  variables, use `strategy.extended.colocate_vars_with` instead. Colocation of
85  ops will possibly create device assignment conflicts.
86
87  Note: This strategy only works with the Estimator API. Pass an instance of
88  this strategy to the `experimental_distribute` argument when you create the
89  `RunConfig`. This instance of `RunConfig` should then be passed to the
90  `Estimator` instance on which `train_and_evaluate` is called.
91
92  For Example:
93  ```
94  strategy = tf.distribute.experimental.ParameterServerStrategy()
95  run_config = tf.estimator.RunConfig(
96      experimental_distribute.train_distribute=strategy)
97  estimator = tf.estimator.Estimator(config=run_config)
98  tf.estimator.train_and_evaluate(estimator,...)
99  ```
100  """
101
102  def __init__(self, cluster_resolver=None):
103    """Initializes this strategy with an optional `cluster_resolver`.
104
105    Args:
106      cluster_resolver: Optional
107        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
108        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
109    """
110    if cluster_resolver is None:
111      cluster_resolver = TFConfigClusterResolver()
112    super(ParameterServerStrategyV1, self).__init__(
113        ParameterServerStrategyExtended(
114            self, cluster_resolver=cluster_resolver))
115    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
116        "ParameterServerStrategy")
117
118  def experimental_distribute_dataset(self, dataset, options=None):
119    if (options and options.experimental_replication_mode ==
120        distribute_lib.InputReplicationMode.PER_REPLICA):
121      raise NotImplementedError(
122          "InputReplicationMode.PER_REPLICA "
123          "is only supported in "
124          "`experimental_distribute_datasets_from_function`."
125      )
126    self._raise_pss_error_if_eager()
127    super(ParameterServerStrategyV1,
128          self).experimental_distribute_dataset(dataset=dataset,
129                                                options=options)
130
131  def distribute_datasets_from_function(self, dataset_fn, options=None):
132    if (options and options.experimental_replication_mode ==
133        distribute_lib.InputReplicationMode.PER_REPLICA):
134      raise NotImplementedError(
135          "InputReplicationMode.PER_REPLICA "
136          "is only supported in "
137          "`experimental_distribute_datasets_from_function` "
138          "of tf.distribute.MirroredStrategy")
139    self._raise_pss_error_if_eager()
140    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
141        dataset_fn=dataset_fn, options=options)
142
143  def run(self, fn, args=(), kwargs=None, options=None):
144    self._raise_pss_error_if_eager()
145    super(ParameterServerStrategyV1, self).run(
146        fn, args=args, kwargs=kwargs, options=options)
147
148  def scope(self):
149    self._raise_pss_error_if_eager()
150    return super(ParameterServerStrategyV1, self).scope()
151
152  def _raise_pss_error_if_eager(self):
153    if context.executing_eagerly():
154      raise NotImplementedError(
155          "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
156          "currently only works with the tf.Estimator API")
157
158
159# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
160class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
161  """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
162
163  def __init__(self,
164               container_strategy,
165               cluster_resolver=None,
166               compute_devices=None,
167               parameter_device=None):
168    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
169    self._initialize_strategy(
170        cluster_resolver=cluster_resolver,
171        compute_devices=compute_devices,
172        parameter_device=parameter_device)
173
174    # We typically don't need to do all-reduce in this strategy.
175    self._cross_device_ops = (
176        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
177
178  def _initialize_strategy(self,
179                           cluster_resolver=None,
180                           compute_devices=None,
181                           parameter_device=None):
182    if cluster_resolver and cluster_resolver.cluster_spec():
183      self._initialize_multi_worker(cluster_resolver)
184    else:
185      self._initialize_local(
186          compute_devices, parameter_device, cluster_resolver=cluster_resolver)
187
188  def _initialize_multi_worker(self, cluster_resolver):
189    """Initialize devices for multiple workers.
190
191    It creates variable devices and compute devices. Variables and operations
192    will be assigned to them respectively. We have one compute device per
193    replica. The variable device is a device function or device string. The
194    default variable device assigns variables to parameter servers in a
195    round-robin fashion.
196
197    Args:
198      cluster_resolver: a descendant of `ClusterResolver` object.
199
200    Raises:
201      ValueError: if the cluster doesn't have ps jobs.
202    """
203    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
204    # some cases.
205    if isinstance(cluster_resolver, TFConfigClusterResolver):
206      num_gpus = context.num_gpus()
207    else:
208      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
209
210    # Save the num_gpus_per_worker for configure method.
211    self._num_gpus_per_worker = num_gpus
212
213    cluster_spec = cluster_resolver.cluster_spec()
214    task_type = cluster_resolver.task_type
215    task_id = cluster_resolver.task_id
216    if not task_type or task_id is None:
217      raise ValueError("When `cluster_spec` is given, you must also specify "
218                       "`task_type` and `task_id`")
219    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
220    assert cluster_spec.as_dict()
221
222    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
223    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
224
225    # Define compute devices which is a list of device strings and one for each
226    # replica. When there are GPUs, replicate operations on these GPUs.
227    # Otherwise, place operations on CPU.
228    if num_gpus > 0:
229      compute_devices = tuple(
230          "%s/device:GPU:%d" % (self._worker_device, i)
231          for i in range(num_gpus))
232    else:
233      compute_devices = (self._worker_device,)
234
235    self._compute_devices = [
236        device_util.canonicalize(d) for d in compute_devices]
237
238    # In distributed mode, place variables on ps jobs in a round-robin fashion.
239    # Note that devices returned from `replica_device_setter` are not
240    # canonical and therefore we don't canonicalize all variable devices to
241    # make them consistent.
242    # TODO(yuefengz): support passing a strategy object to control variable
243    # assignment.
244    # TODO(yuefengz): merge the logic of replica_device_setter into this
245    # class.
246    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
247    if num_ps_replicas == 0:
248      raise ValueError("The cluster spec needs to have `ps` jobs.")
249    self._variable_device = device_setter.replica_device_setter(
250        ps_tasks=num_ps_replicas,
251        worker_device=self._worker_device,
252        merge_devices=True,
253        cluster=cluster_spec)
254
255    # The `_parameter_devices` is needed for the `parameter_devices` property
256    # and is a list of all variable devices. Here parameter devices are all
257    # tasks of the "ps" job.
258    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
259                                        range(num_ps_replicas)))
260
261    # Add a default device so that ops without specified devices will not end up
262    # on other workers.
263    self._default_device = self._worker_device
264
265    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
266                                                task_id)
267    self._cluster_spec = cluster_spec
268    self._task_type = task_type
269    self._task_id = task_id
270
271    logging.info(
272        "Multi-worker ParameterServerStrategy with "
273        "cluster_spec = %r, task_type = %r, task_id = %r, "
274        "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
275        "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
276        num_ps_replicas, self._is_chief, self._compute_devices,
277        self._variable_device)
278
279  # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
280  # version no longer depends on this class.
281  def _initialize_local(self,
282                        compute_devices,
283                        parameter_device,
284                        cluster_resolver=None):
285    """Initialize local devices for training."""
286    self._worker_device = device_util.canonicalize("/device:CPU:0")
287    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
288
289    if compute_devices is None:
290      if not cluster_resolver:
291        num_gpus = context.num_gpus()
292      else:
293        num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
294      # Save the num_gpus_per_worker for configure method which is used by the
295      # contrib version.
296      self._num_gpus_per_worker = num_gpus
297
298      compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
299
300    compute_devices = [device_util.canonicalize(d) for d in compute_devices]
301
302    if parameter_device is None:
303      # If there is only one GPU, put everything on that GPU. Otherwise, place
304      # variables on CPU.
305      if len(compute_devices) == 1:
306        parameter_device = compute_devices[0]
307      else:
308        parameter_device = _LOCAL_CPU
309
310    self._variable_device = parameter_device
311    self._compute_devices = compute_devices
312    self._parameter_devices = (parameter_device,)
313    self._is_chief = True
314    self._cluster_spec = None
315    self._task_type = None
316    self._task_id = None
317
318    logging.info(
319        "ParameterServerStrategy (CentralStorageStrategy if you are using a "
320        "single machine) with compute_devices = %r, variable_device = %r",
321        compute_devices, self._variable_device)
322
323  def _input_workers_with_options(self, options=None):
324    if not options or options.experimental_fetch_to_device:
325      return input_lib.InputWorkers(
326          [(self._worker_device, self._compute_devices)])
327    else:
328      return input_lib.InputWorkers(
329          [(self._worker_device,
330            (self._worker_device,) * len(self._compute_devices))])
331
332  @property
333  def _input_workers(self):
334    return self._input_workers_with_options()
335
336  def _validate_colocate_with_variable(self, colocate_with_variable):
337    distribute_utils.validate_colocate(colocate_with_variable, self)
338
339  def _experimental_distribute_dataset(self, dataset, options):
340    return input_util.get_distributed_dataset(
341        dataset,
342        self._input_workers_with_options(options),
343        self._container_strategy(),
344        num_replicas_in_sync=self._num_replicas_in_sync,
345        options=options)
346
347  def _make_dataset_iterator(self, dataset):
348    return input_lib_v1.DatasetIterator(
349        dataset,
350        self._input_workers,
351        self._container_strategy(),
352        num_replicas_in_sync=self._num_replicas_in_sync)
353
354  def _make_input_fn_iterator(
355      self,
356      input_fn,
357      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
358    """Distributes the dataset to each local GPU."""
359    if self._cluster_spec:
360      input_pipeline_id = multi_worker_util.id_in_cluster(
361          self._cluster_spec, self._task_type, self._task_id)
362      num_input_pipelines = multi_worker_util.worker_count(
363          self._cluster_spec, self._task_type)
364    else:
365      input_pipeline_id = 0
366      num_input_pipelines = 1
367    input_context = distribute_lib.InputContext(
368        num_input_pipelines=num_input_pipelines,
369        input_pipeline_id=input_pipeline_id,
370        num_replicas_in_sync=self._num_replicas_in_sync)
371    return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
372                                              [input_context],
373                                              self._container_strategy())
374
375  def _experimental_make_numpy_dataset(self, numpy_input, session):
376    return numpy_dataset.one_host_numpy_dataset(
377        numpy_input, self._input_host_device, session)
378
379  def _distribute_datasets_from_function(self, dataset_fn, options):
380    if self._cluster_spec:
381      input_pipeline_id = multi_worker_util.id_in_cluster(
382          self._cluster_spec, self._task_type, self._task_id)
383      num_input_pipelines = multi_worker_util.worker_count(
384          self._cluster_spec, self._task_type)
385    else:
386      input_pipeline_id = 0
387      num_input_pipelines = 1
388
389    input_context = distribute_lib.InputContext(
390        num_input_pipelines=num_input_pipelines,
391        input_pipeline_id=input_pipeline_id,
392        num_replicas_in_sync=self._num_replicas_in_sync)
393
394    return input_util.get_distributed_datasets_from_function(
395        dataset_fn,
396        self._input_workers_with_options(options), [input_context],
397        self._container_strategy(),
398        options=options)
399
400  def _experimental_distribute_values_from_function(self, value_fn):
401    per_replica_values = []
402    for replica_id in range(self._num_replicas_in_sync):
403      per_replica_values.append(
404          value_fn(distribute_lib.ValueContext(replica_id,
405                                               self._num_replicas_in_sync)))
406    return distribute_utils.regroup(per_replica_values, always_wrap=True)
407
408  def _broadcast_to(self, tensor, destinations):
409    # This is both a fast path for Python constants, and a way to delay
410    # converting Python values to a tensor until we know what type it
411    # should be converted to. Otherwise we have trouble with:
412    #   global_step.assign_add(1)
413    # since the `1` gets broadcast as an int32 but global_step is int64.
414    if isinstance(tensor, (float, int)):
415      return tensor
416    if not cross_device_ops_lib.check_destinations(destinations):
417      # TODO(josh11b): Use current logical device instead of 0 here.
418      destinations = self._compute_devices
419    return self._cross_device_ops.broadcast(tensor, destinations)
420
421  def _allow_variable_partition(self):
422    return not context.executing_eagerly()
423
424  def _create_var_creator(self, next_creator, **kwargs):
425    if self._num_replicas_in_sync > 1:
426      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
427      if aggregation not in (
428          vs.VariableAggregation.NONE,
429          vs.VariableAggregation.SUM,
430          vs.VariableAggregation.MEAN,
431          vs.VariableAggregation.ONLY_FIRST_REPLICA
432      ):
433        raise ValueError("Invalid variable aggregation mode: " + aggregation +
434                         " for variable: " + kwargs["name"])
435
436      def var_creator(**kwargs):
437        """Create an AggregatingVariable and fix up collections."""
438        # Record what collections this variable should be added to.
439        collections = kwargs.pop("collections", None)
440        if collections is None:
441          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
442        kwargs["collections"] = []
443
444        # Create and wrap the variable.
445        v = next_creator(**kwargs)
446        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
447                                                aggregation)
448
449        # Add the wrapped variable to the requested collections.
450        # The handling of eager mode and the global step matches
451        # ResourceVariable._init_from_args().
452        if not context.executing_eagerly():
453          g = ops.get_default_graph()
454          # If "trainable" is True, next_creator() will add the contained
455          # variable to the TRAINABLE_VARIABLES collection, so we manually
456          # remove it and replace with the wrapper. We can't set "trainable"
457          # to False for next_creator() since that causes functions like
458          # implicit_gradients to skip those variables.
459          if kwargs.get("trainable", True):
460            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
461            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
462            if v in l:
463              l.remove(v)
464          g.add_to_collections(collections, wrapped)
465        elif ops.GraphKeys.GLOBAL_STEP in collections:
466          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
467
468        return wrapped
469      return var_creator
470    else:
471      return next_creator
472
473  # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
474  # this creator, such as "MutableHashTable".
475  def _create_variable(self, next_creator, **kwargs):
476    var_creator = self._create_var_creator(next_creator, **kwargs)
477
478    if "colocate_with" in kwargs:
479      colocate_with = kwargs["colocate_with"]
480      if isinstance(colocate_with, numpy_dataset.SingleDevice):
481        with ops.device(colocate_with.device):
482          return var_creator(**kwargs)
483      with ops.device(None):
484        with ops.colocate_with(colocate_with):
485          return var_creator(**kwargs)
486
487    with ops.colocate_with(None, ignore_existing=True):
488      with ops.device(self._variable_device):
489        return var_creator(**kwargs)
490
491  def _call_for_each_replica(self, fn, args, kwargs):
492    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
493                                              args, kwargs)
494
495  def _verify_destinations_not_different_worker(self, destinations):
496    if not self._cluster_spec:
497      return
498    if destinations is None:
499      return
500    for d in cross_device_ops_lib.get_devices_from(destinations):
501      d_spec = tf_device.DeviceSpec.from_string(d)
502      if d_spec.job == self._task_type and d_spec.task != self._task_id:
503        raise ValueError(
504            "Cannot reduce to another worker: %r, current worker is %r" %
505            (d, self._worker_device))
506
507  def _gather_to_implementation(self, value, destinations, axis,
508                                options):
509    self._verify_destinations_not_different_worker(destinations)
510    if not isinstance(value, values.DistributedValues):
511      return value
512    return self._cross_device_ops._gather(  # pylint: disable=protected-access
513        value,
514        destinations=destinations,
515        axis=axis,
516        options=options)
517
518  def _reduce_to(self, reduce_op, value, destinations, options):
519    self._verify_destinations_not_different_worker(destinations)
520    if not isinstance(value, values.DistributedValues):
521      # pylint: disable=protected-access
522      return cross_device_ops_lib.reduce_non_distributed_value(
523          reduce_op, value, destinations, self._num_replicas_in_sync)
524    return self._cross_device_ops.reduce(
525        reduce_op, value, destinations=destinations, options=options)
526
527  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
528    for _, destinations in value_destination_pairs:
529      self._verify_destinations_not_different_worker(destinations)
530    return self._cross_device_ops.batch_reduce(reduce_op,
531                                               value_destination_pairs, options)
532
533  def _select_single_value(self, structured):
534    """Select any single value in `structured`."""
535
536    def _select_fn(x):  # pylint: disable=g-missing-docstring
537      if isinstance(x, values.Mirrored) or isinstance(x, values.PerReplica):
538        return x._primary  # pylint: disable=protected-access
539      else:
540        return x
541
542    return nest.map_structure(_select_fn, structured)
543
544  def _update(self, var, fn, args, kwargs, group):
545    if isinstance(var, ps_values.AggregatingVariable):
546      var = var.get()
547    if not resource_variable_ops.is_resource_variable(var):
548      raise ValueError(
549          "You can not update `var` %r. It must be a Variable." % var)
550    with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
551      result = fn(var, *self._select_single_value(args),
552                  **self._select_single_value(kwargs))
553      if group:
554        return result
555      else:
556        return nest.map_structure(self._local_results, result)
557
558  # TODO(yuefengz): does it need to call _select_single_value?
559  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
560    with ops.device(
561        colocate_with.device), distribute_lib.UpdateContext(colocate_with):
562      result = fn(*args, **kwargs)
563      if group:
564        return result
565      else:
566        return nest.map_structure(self._local_results, result)
567
568  def value_container(self, val):
569    if (hasattr(val, "_aggregating_container") and
570        not isinstance(val, ps_values.AggregatingVariable)):
571      wrapper = val._aggregating_container()  # pylint: disable=protected-access
572      if wrapper is not None:
573        return wrapper
574    return val
575
576  def read_var(self, var):
577    # No need to distinguish between normal variables and replica-local
578    # variables.
579    return array_ops.identity(var)
580
581  def _configure(self,
582                 session_config=None,
583                 cluster_spec=None,
584                 task_type=None,
585                 task_id=None):
586    """Configures the strategy class with `cluster_spec`.
587
588    The strategy object will be re-initialized if `cluster_spec` is passed to
589    `configure` but was not passed when instantiating the strategy.
590
591    Args:
592      session_config: Session config object.
593      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
594        cluster configurations.
595      task_type: the current task type.
596      task_id: the current task id.
597
598    Raises:
599      ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
600        not.
601    """
602    if cluster_spec:
603      # Use the num_gpus_per_worker recorded in constructor since _configure
604      # doesn't take num_gpus.
605      cluster_resolver = SimpleClusterResolver(
606          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
607          task_type=task_type,
608          task_id=task_id,
609          num_accelerators={"GPU": self._num_gpus_per_worker})
610      self._initialize_multi_worker(cluster_resolver)
611
612    if session_config:
613      session_config.CopyFrom(self._update_config_proto(session_config))
614
615  def _update_config_proto(self, config_proto):
616    updated_config = copy.deepcopy(config_proto)
617    if not self._cluster_spec:
618      updated_config.isolate_session_state = True
619      return updated_config
620
621    updated_config.isolate_session_state = False
622
623    assert self._task_type
624    assert self._task_id is not None
625
626    # The device filters prevent communication between workers.
627    del updated_config.device_filters[:]
628    if self._task_type in ["chief", "worker"]:
629      updated_config.device_filters.extend(
630          ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
631    elif self._task_type == "evaluator":
632      updated_config.device_filters.append(
633          "/job:%s/task:%d" % (self._task_type, self._task_id))
634    return updated_config
635
636  def _in_multi_worker_mode(self):
637    """Whether this strategy indicates working in multi-worker settings."""
638    return self._cluster_spec is not None
639
640  @property
641  def _num_replicas_in_sync(self):
642    return len(self._compute_devices)
643
644  @property
645  def worker_devices(self):
646    return self._compute_devices
647
648  @property
649  def worker_devices_by_replica(self):
650    return [[d] for d in self._compute_devices]
651
652  @property
653  def parameter_devices(self):
654    return self._parameter_devices
655
656  def non_slot_devices(self, var_list):
657    return min(var_list, key=lambda x: x.name)
658
659  @property
660  def experimental_between_graph(self):
661    # TODO(yuefengz): Should this return False in the local case?
662    return True
663
664  @property
665  def experimental_should_init(self):
666    return self._is_chief
667
668  @property
669  def should_checkpoint(self):
670    return self._is_chief
671
672  @property
673  def should_save_summary(self):
674    return self._is_chief
675
676  # TODO(priyag): Delete this once all strategies use global batch size.
677  @property
678  def _global_batch_size(self):
679    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
680
681    `make_input_fn_iterator` assumes per-replica batching.
682
683    Returns:
684      Boolean.
685    """
686    return True
687
688  def _get_local_replica_id(self, replica_id_in_sync_group):
689    return replica_id_in_sync_group
690
691  def _get_replica_id_in_sync_group(self, replica_id):
692    return replica_id
693