1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing a multi-worker parameter server tf.distribute strategy.""" 16 17import copy 18 19 20from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 21from tensorflow.python.distribute import device_util 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import distribute_utils 24from tensorflow.python.distribute import input_lib 25from tensorflow.python.distribute import input_util 26from tensorflow.python.distribute import mirrored_run 27from tensorflow.python.distribute import multi_worker_util 28from tensorflow.python.distribute import numpy_dataset 29from tensorflow.python.distribute import ps_values 30from tensorflow.python.distribute import values 31from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 32from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 33from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 34from tensorflow.python.eager import context 35from tensorflow.python.framework import device as tf_device 36from tensorflow.python.framework import ops 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import resource_variable_ops 39from tensorflow.python.ops import variable_scope as vs 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.training import device_setter 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44 45_LOCAL_CPU = "/device:CPU:0" 46 47 48@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring 49class ParameterServerStrategyV1(distribute_lib.StrategyV1): 50 """An asynchronous multi-worker parameter server tf.distribute strategy. 51 52 This strategy requires two roles: workers and parameter servers. Variables and 53 updates to those variables will be assigned to parameter servers and other 54 operations are assigned to workers. 55 56 When each worker has more than one GPU, operations will be replicated on all 57 GPUs. Even though operations may be replicated, variables are not and each 58 worker shares a common view for which parameter server a variable is assigned 59 to. 60 61 By default it uses `TFConfigClusterResolver` to detect configurations for 62 multi-worker training. This requires a 'TF_CONFIG' environment variable and 63 the 'TF_CONFIG' must have a cluster spec. 64 65 This class assumes each worker is running the same code independently, but 66 parameter servers are running a standard server. This means that while each 67 worker will synchronously compute a single gradient update across all GPUs, 68 updates between workers proceed asynchronously. Operations that occur only on 69 the first replica (such as incrementing the global step), will occur on the 70 first replica *of every worker*. 71 72 It is expected to call `call_for_each_replica(fn, ...)` for any 73 operations which potentially can be replicated across replicas (i.e. multiple 74 GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra 75 caution needs to be taken: 76 77 1) It is generally not recommended to open a device scope under the strategy's 78 scope. A device scope (i.e. calling `tf.device`) will be merged with or 79 override the device for operations but will not change the device for 80 variables. 81 82 2) It is also not recommended to open a colocation scope (i.e. calling 83 `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating 84 variables, use `strategy.extended.colocate_vars_with` instead. Colocation of 85 ops will possibly create device assignment conflicts. 86 87 Note: This strategy only works with the Estimator API. Pass an instance of 88 this strategy to the `experimental_distribute` argument when you create the 89 `RunConfig`. This instance of `RunConfig` should then be passed to the 90 `Estimator` instance on which `train_and_evaluate` is called. 91 92 For Example: 93 ``` 94 strategy = tf.distribute.experimental.ParameterServerStrategy() 95 run_config = tf.estimator.RunConfig( 96 experimental_distribute.train_distribute=strategy) 97 estimator = tf.estimator.Estimator(config=run_config) 98 tf.estimator.train_and_evaluate(estimator,...) 99 ``` 100 """ 101 102 def __init__(self, cluster_resolver=None): 103 """Initializes this strategy with an optional `cluster_resolver`. 104 105 Args: 106 cluster_resolver: Optional 107 `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a 108 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 109 """ 110 if cluster_resolver is None: 111 cluster_resolver = TFConfigClusterResolver() 112 super(ParameterServerStrategyV1, self).__init__( 113 ParameterServerStrategyExtended( 114 self, cluster_resolver=cluster_resolver)) 115 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 116 "ParameterServerStrategy") 117 118 def experimental_distribute_dataset(self, dataset, options=None): 119 if (options and options.experimental_replication_mode == 120 distribute_lib.InputReplicationMode.PER_REPLICA): 121 raise NotImplementedError( 122 "InputReplicationMode.PER_REPLICA " 123 "is only supported in " 124 "`experimental_distribute_datasets_from_function`." 125 ) 126 self._raise_pss_error_if_eager() 127 super(ParameterServerStrategyV1, 128 self).experimental_distribute_dataset(dataset=dataset, 129 options=options) 130 131 def distribute_datasets_from_function(self, dataset_fn, options=None): 132 if (options and options.experimental_replication_mode == 133 distribute_lib.InputReplicationMode.PER_REPLICA): 134 raise NotImplementedError( 135 "InputReplicationMode.PER_REPLICA " 136 "is only supported in " 137 "`experimental_distribute_datasets_from_function` " 138 "of tf.distribute.MirroredStrategy") 139 self._raise_pss_error_if_eager() 140 super(ParameterServerStrategyV1, self).distribute_datasets_from_function( 141 dataset_fn=dataset_fn, options=options) 142 143 def run(self, fn, args=(), kwargs=None, options=None): 144 self._raise_pss_error_if_eager() 145 super(ParameterServerStrategyV1, self).run( 146 fn, args=args, kwargs=kwargs, options=options) 147 148 def scope(self): 149 self._raise_pss_error_if_eager() 150 return super(ParameterServerStrategyV1, self).scope() 151 152 def _raise_pss_error_if_eager(self): 153 if context.executing_eagerly(): 154 raise NotImplementedError( 155 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` " 156 "currently only works with the tf.Estimator API") 157 158 159# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 160class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): 161 """Implementation of ParameterServerStrategy and CentralStorageStrategy.""" 162 163 def __init__(self, 164 container_strategy, 165 cluster_resolver=None, 166 compute_devices=None, 167 parameter_device=None): 168 super(ParameterServerStrategyExtended, self).__init__(container_strategy) 169 self._initialize_strategy( 170 cluster_resolver=cluster_resolver, 171 compute_devices=compute_devices, 172 parameter_device=parameter_device) 173 174 # We typically don't need to do all-reduce in this strategy. 175 self._cross_device_ops = ( 176 cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) 177 178 def _initialize_strategy(self, 179 cluster_resolver=None, 180 compute_devices=None, 181 parameter_device=None): 182 if cluster_resolver and cluster_resolver.cluster_spec(): 183 self._initialize_multi_worker(cluster_resolver) 184 else: 185 self._initialize_local( 186 compute_devices, parameter_device, cluster_resolver=cluster_resolver) 187 188 def _initialize_multi_worker(self, cluster_resolver): 189 """Initialize devices for multiple workers. 190 191 It creates variable devices and compute devices. Variables and operations 192 will be assigned to them respectively. We have one compute device per 193 replica. The variable device is a device function or device string. The 194 default variable device assigns variables to parameter servers in a 195 round-robin fashion. 196 197 Args: 198 cluster_resolver: a descendant of `ClusterResolver` object. 199 200 Raises: 201 ValueError: if the cluster doesn't have ps jobs. 202 """ 203 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 204 # some cases. 205 if isinstance(cluster_resolver, TFConfigClusterResolver): 206 num_gpus = context.num_gpus() 207 else: 208 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 209 210 # Save the num_gpus_per_worker for configure method. 211 self._num_gpus_per_worker = num_gpus 212 213 cluster_spec = cluster_resolver.cluster_spec() 214 task_type = cluster_resolver.task_type 215 task_id = cluster_resolver.task_id 216 if not task_type or task_id is None: 217 raise ValueError("When `cluster_spec` is given, you must also specify " 218 "`task_type` and `task_id`") 219 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 220 assert cluster_spec.as_dict() 221 222 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 223 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 224 225 # Define compute devices which is a list of device strings and one for each 226 # replica. When there are GPUs, replicate operations on these GPUs. 227 # Otherwise, place operations on CPU. 228 if num_gpus > 0: 229 compute_devices = tuple( 230 "%s/device:GPU:%d" % (self._worker_device, i) 231 for i in range(num_gpus)) 232 else: 233 compute_devices = (self._worker_device,) 234 235 self._compute_devices = [ 236 device_util.canonicalize(d) for d in compute_devices] 237 238 # In distributed mode, place variables on ps jobs in a round-robin fashion. 239 # Note that devices returned from `replica_device_setter` are not 240 # canonical and therefore we don't canonicalize all variable devices to 241 # make them consistent. 242 # TODO(yuefengz): support passing a strategy object to control variable 243 # assignment. 244 # TODO(yuefengz): merge the logic of replica_device_setter into this 245 # class. 246 num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) 247 if num_ps_replicas == 0: 248 raise ValueError("The cluster spec needs to have `ps` jobs.") 249 self._variable_device = device_setter.replica_device_setter( 250 ps_tasks=num_ps_replicas, 251 worker_device=self._worker_device, 252 merge_devices=True, 253 cluster=cluster_spec) 254 255 # The `_parameter_devices` is needed for the `parameter_devices` property 256 # and is a list of all variable devices. Here parameter devices are all 257 # tasks of the "ps" job. 258 self._parameter_devices = tuple(map("/job:ps/task:{}".format, 259 range(num_ps_replicas))) 260 261 # Add a default device so that ops without specified devices will not end up 262 # on other workers. 263 self._default_device = self._worker_device 264 265 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 266 task_id) 267 self._cluster_spec = cluster_spec 268 self._task_type = task_type 269 self._task_id = task_id 270 271 logging.info( 272 "Multi-worker ParameterServerStrategy with " 273 "cluster_spec = %r, task_type = %r, task_id = %r, " 274 "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " 275 "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, 276 num_ps_replicas, self._is_chief, self._compute_devices, 277 self._variable_device) 278 279 # TODO(yuefengz): get rid of cluster_resolver argument when contrib's 280 # version no longer depends on this class. 281 def _initialize_local(self, 282 compute_devices, 283 parameter_device, 284 cluster_resolver=None): 285 """Initialize local devices for training.""" 286 self._worker_device = device_util.canonicalize("/device:CPU:0") 287 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 288 289 if compute_devices is None: 290 if not cluster_resolver: 291 num_gpus = context.num_gpus() 292 else: 293 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 294 # Save the num_gpus_per_worker for configure method which is used by the 295 # contrib version. 296 self._num_gpus_per_worker = num_gpus 297 298 compute_devices = device_util.local_devices_from_num_gpus(num_gpus) 299 300 compute_devices = [device_util.canonicalize(d) for d in compute_devices] 301 302 if parameter_device is None: 303 # If there is only one GPU, put everything on that GPU. Otherwise, place 304 # variables on CPU. 305 if len(compute_devices) == 1: 306 parameter_device = compute_devices[0] 307 else: 308 parameter_device = _LOCAL_CPU 309 310 self._variable_device = parameter_device 311 self._compute_devices = compute_devices 312 self._parameter_devices = (parameter_device,) 313 self._is_chief = True 314 self._cluster_spec = None 315 self._task_type = None 316 self._task_id = None 317 318 logging.info( 319 "ParameterServerStrategy (CentralStorageStrategy if you are using a " 320 "single machine) with compute_devices = %r, variable_device = %r", 321 compute_devices, self._variable_device) 322 323 def _input_workers_with_options(self, options=None): 324 if not options or options.experimental_fetch_to_device: 325 return input_lib.InputWorkers( 326 [(self._worker_device, self._compute_devices)]) 327 else: 328 return input_lib.InputWorkers( 329 [(self._worker_device, 330 (self._worker_device,) * len(self._compute_devices))]) 331 332 @property 333 def _input_workers(self): 334 return self._input_workers_with_options() 335 336 def _validate_colocate_with_variable(self, colocate_with_variable): 337 distribute_utils.validate_colocate(colocate_with_variable, self) 338 339 def _experimental_distribute_dataset(self, dataset, options): 340 return input_util.get_distributed_dataset( 341 dataset, 342 self._input_workers_with_options(options), 343 self._container_strategy(), 344 num_replicas_in_sync=self._num_replicas_in_sync, 345 options=options) 346 347 def _make_dataset_iterator(self, dataset): 348 return input_lib_v1.DatasetIterator( 349 dataset, 350 self._input_workers, 351 self._container_strategy(), 352 num_replicas_in_sync=self._num_replicas_in_sync) 353 354 def _make_input_fn_iterator( 355 self, 356 input_fn, 357 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 358 """Distributes the dataset to each local GPU.""" 359 if self._cluster_spec: 360 input_pipeline_id = multi_worker_util.id_in_cluster( 361 self._cluster_spec, self._task_type, self._task_id) 362 num_input_pipelines = multi_worker_util.worker_count( 363 self._cluster_spec, self._task_type) 364 else: 365 input_pipeline_id = 0 366 num_input_pipelines = 1 367 input_context = distribute_lib.InputContext( 368 num_input_pipelines=num_input_pipelines, 369 input_pipeline_id=input_pipeline_id, 370 num_replicas_in_sync=self._num_replicas_in_sync) 371 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 372 [input_context], 373 self._container_strategy()) 374 375 def _experimental_make_numpy_dataset(self, numpy_input, session): 376 return numpy_dataset.one_host_numpy_dataset( 377 numpy_input, self._input_host_device, session) 378 379 def _distribute_datasets_from_function(self, dataset_fn, options): 380 if self._cluster_spec: 381 input_pipeline_id = multi_worker_util.id_in_cluster( 382 self._cluster_spec, self._task_type, self._task_id) 383 num_input_pipelines = multi_worker_util.worker_count( 384 self._cluster_spec, self._task_type) 385 else: 386 input_pipeline_id = 0 387 num_input_pipelines = 1 388 389 input_context = distribute_lib.InputContext( 390 num_input_pipelines=num_input_pipelines, 391 input_pipeline_id=input_pipeline_id, 392 num_replicas_in_sync=self._num_replicas_in_sync) 393 394 return input_util.get_distributed_datasets_from_function( 395 dataset_fn, 396 self._input_workers_with_options(options), [input_context], 397 self._container_strategy(), 398 options=options) 399 400 def _experimental_distribute_values_from_function(self, value_fn): 401 per_replica_values = [] 402 for replica_id in range(self._num_replicas_in_sync): 403 per_replica_values.append( 404 value_fn(distribute_lib.ValueContext(replica_id, 405 self._num_replicas_in_sync))) 406 return distribute_utils.regroup(per_replica_values, always_wrap=True) 407 408 def _broadcast_to(self, tensor, destinations): 409 # This is both a fast path for Python constants, and a way to delay 410 # converting Python values to a tensor until we know what type it 411 # should be converted to. Otherwise we have trouble with: 412 # global_step.assign_add(1) 413 # since the `1` gets broadcast as an int32 but global_step is int64. 414 if isinstance(tensor, (float, int)): 415 return tensor 416 if not cross_device_ops_lib.check_destinations(destinations): 417 # TODO(josh11b): Use current logical device instead of 0 here. 418 destinations = self._compute_devices 419 return self._cross_device_ops.broadcast(tensor, destinations) 420 421 def _allow_variable_partition(self): 422 return not context.executing_eagerly() 423 424 def _create_var_creator(self, next_creator, **kwargs): 425 if self._num_replicas_in_sync > 1: 426 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 427 if aggregation not in ( 428 vs.VariableAggregation.NONE, 429 vs.VariableAggregation.SUM, 430 vs.VariableAggregation.MEAN, 431 vs.VariableAggregation.ONLY_FIRST_REPLICA 432 ): 433 raise ValueError("Invalid variable aggregation mode: " + aggregation + 434 " for variable: " + kwargs["name"]) 435 436 def var_creator(**kwargs): 437 """Create an AggregatingVariable and fix up collections.""" 438 # Record what collections this variable should be added to. 439 collections = kwargs.pop("collections", None) 440 if collections is None: 441 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 442 kwargs["collections"] = [] 443 444 # Create and wrap the variable. 445 v = next_creator(**kwargs) 446 wrapped = ps_values.AggregatingVariable(self._container_strategy(), v, 447 aggregation) 448 449 # Add the wrapped variable to the requested collections. 450 # The handling of eager mode and the global step matches 451 # ResourceVariable._init_from_args(). 452 if not context.executing_eagerly(): 453 g = ops.get_default_graph() 454 # If "trainable" is True, next_creator() will add the contained 455 # variable to the TRAINABLE_VARIABLES collection, so we manually 456 # remove it and replace with the wrapper. We can't set "trainable" 457 # to False for next_creator() since that causes functions like 458 # implicit_gradients to skip those variables. 459 if kwargs.get("trainable", True): 460 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 461 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 462 if v in l: 463 l.remove(v) 464 g.add_to_collections(collections, wrapped) 465 elif ops.GraphKeys.GLOBAL_STEP in collections: 466 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) 467 468 return wrapped 469 return var_creator 470 else: 471 return next_creator 472 473 # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through 474 # this creator, such as "MutableHashTable". 475 def _create_variable(self, next_creator, **kwargs): 476 var_creator = self._create_var_creator(next_creator, **kwargs) 477 478 if "colocate_with" in kwargs: 479 colocate_with = kwargs["colocate_with"] 480 if isinstance(colocate_with, numpy_dataset.SingleDevice): 481 with ops.device(colocate_with.device): 482 return var_creator(**kwargs) 483 with ops.device(None): 484 with ops.colocate_with(colocate_with): 485 return var_creator(**kwargs) 486 487 with ops.colocate_with(None, ignore_existing=True): 488 with ops.device(self._variable_device): 489 return var_creator(**kwargs) 490 491 def _call_for_each_replica(self, fn, args, kwargs): 492 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 493 args, kwargs) 494 495 def _verify_destinations_not_different_worker(self, destinations): 496 if not self._cluster_spec: 497 return 498 if destinations is None: 499 return 500 for d in cross_device_ops_lib.get_devices_from(destinations): 501 d_spec = tf_device.DeviceSpec.from_string(d) 502 if d_spec.job == self._task_type and d_spec.task != self._task_id: 503 raise ValueError( 504 "Cannot reduce to another worker: %r, current worker is %r" % 505 (d, self._worker_device)) 506 507 def _gather_to_implementation(self, value, destinations, axis, 508 options): 509 self._verify_destinations_not_different_worker(destinations) 510 if not isinstance(value, values.DistributedValues): 511 return value 512 return self._cross_device_ops._gather( # pylint: disable=protected-access 513 value, 514 destinations=destinations, 515 axis=axis, 516 options=options) 517 518 def _reduce_to(self, reduce_op, value, destinations, options): 519 self._verify_destinations_not_different_worker(destinations) 520 if not isinstance(value, values.DistributedValues): 521 # pylint: disable=protected-access 522 return cross_device_ops_lib.reduce_non_distributed_value( 523 reduce_op, value, destinations, self._num_replicas_in_sync) 524 return self._cross_device_ops.reduce( 525 reduce_op, value, destinations=destinations, options=options) 526 527 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 528 for _, destinations in value_destination_pairs: 529 self._verify_destinations_not_different_worker(destinations) 530 return self._cross_device_ops.batch_reduce(reduce_op, 531 value_destination_pairs, options) 532 533 def _select_single_value(self, structured): 534 """Select any single value in `structured`.""" 535 536 def _select_fn(x): # pylint: disable=g-missing-docstring 537 if isinstance(x, values.Mirrored) or isinstance(x, values.PerReplica): 538 return x._primary # pylint: disable=protected-access 539 else: 540 return x 541 542 return nest.map_structure(_select_fn, structured) 543 544 def _update(self, var, fn, args, kwargs, group): 545 if isinstance(var, ps_values.AggregatingVariable): 546 var = var.get() 547 if not resource_variable_ops.is_resource_variable(var): 548 raise ValueError( 549 "You can not update `var` %r. It must be a Variable." % var) 550 with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): 551 result = fn(var, *self._select_single_value(args), 552 **self._select_single_value(kwargs)) 553 if group: 554 return result 555 else: 556 return nest.map_structure(self._local_results, result) 557 558 # TODO(yuefengz): does it need to call _select_single_value? 559 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 560 with ops.device( 561 colocate_with.device), distribute_lib.UpdateContext(colocate_with): 562 result = fn(*args, **kwargs) 563 if group: 564 return result 565 else: 566 return nest.map_structure(self._local_results, result) 567 568 def value_container(self, val): 569 if (hasattr(val, "_aggregating_container") and 570 not isinstance(val, ps_values.AggregatingVariable)): 571 wrapper = val._aggregating_container() # pylint: disable=protected-access 572 if wrapper is not None: 573 return wrapper 574 return val 575 576 def read_var(self, var): 577 # No need to distinguish between normal variables and replica-local 578 # variables. 579 return array_ops.identity(var) 580 581 def _configure(self, 582 session_config=None, 583 cluster_spec=None, 584 task_type=None, 585 task_id=None): 586 """Configures the strategy class with `cluster_spec`. 587 588 The strategy object will be re-initialized if `cluster_spec` is passed to 589 `configure` but was not passed when instantiating the strategy. 590 591 Args: 592 session_config: Session config object. 593 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 594 cluster configurations. 595 task_type: the current task type. 596 task_id: the current task id. 597 598 Raises: 599 ValueError: if `cluster_spec` is given but `task_type` or `task_id` is 600 not. 601 """ 602 if cluster_spec: 603 # Use the num_gpus_per_worker recorded in constructor since _configure 604 # doesn't take num_gpus. 605 cluster_resolver = SimpleClusterResolver( 606 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 607 task_type=task_type, 608 task_id=task_id, 609 num_accelerators={"GPU": self._num_gpus_per_worker}) 610 self._initialize_multi_worker(cluster_resolver) 611 612 if session_config: 613 session_config.CopyFrom(self._update_config_proto(session_config)) 614 615 def _update_config_proto(self, config_proto): 616 updated_config = copy.deepcopy(config_proto) 617 if not self._cluster_spec: 618 updated_config.isolate_session_state = True 619 return updated_config 620 621 updated_config.isolate_session_state = False 622 623 assert self._task_type 624 assert self._task_id is not None 625 626 # The device filters prevent communication between workers. 627 del updated_config.device_filters[:] 628 if self._task_type in ["chief", "worker"]: 629 updated_config.device_filters.extend( 630 ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) 631 elif self._task_type == "evaluator": 632 updated_config.device_filters.append( 633 "/job:%s/task:%d" % (self._task_type, self._task_id)) 634 return updated_config 635 636 def _in_multi_worker_mode(self): 637 """Whether this strategy indicates working in multi-worker settings.""" 638 return self._cluster_spec is not None 639 640 @property 641 def _num_replicas_in_sync(self): 642 return len(self._compute_devices) 643 644 @property 645 def worker_devices(self): 646 return self._compute_devices 647 648 @property 649 def worker_devices_by_replica(self): 650 return [[d] for d in self._compute_devices] 651 652 @property 653 def parameter_devices(self): 654 return self._parameter_devices 655 656 def non_slot_devices(self, var_list): 657 return min(var_list, key=lambda x: x.name) 658 659 @property 660 def experimental_between_graph(self): 661 # TODO(yuefengz): Should this return False in the local case? 662 return True 663 664 @property 665 def experimental_should_init(self): 666 return self._is_chief 667 668 @property 669 def should_checkpoint(self): 670 return self._is_chief 671 672 @property 673 def should_save_summary(self): 674 return self._is_chief 675 676 # TODO(priyag): Delete this once all strategies use global batch size. 677 @property 678 def _global_batch_size(self): 679 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 680 681 `make_input_fn_iterator` assumes per-replica batching. 682 683 Returns: 684 Boolean. 685 """ 686 return True 687 688 def _get_local_replica_id(self, replica_id_in_sync_group): 689 return replica_id_in_sync_group 690 691 def _get_replica_id_in_sync_group(self, replica_id): 692 return replica_id 693