1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A tf.distribute.Strategy for running on a single device.""" 16 17from tensorflow.python.distribute import device_util 18from tensorflow.python.distribute import distribute_lib 19from tensorflow.python.distribute import distribute_utils 20from tensorflow.python.distribute import input_lib 21from tensorflow.python.distribute import input_util 22from tensorflow.python.distribute import numpy_dataset 23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.util import nest 29from tensorflow.python.util.tf_export import tf_export 30 31 32# TODO(josh11b): Do we wrap values in types to generate errors if you are 33# doing something that won't work with other DistributionStrategy 34# implementations? 35 36 37@tf_export("distribute.OneDeviceStrategy", v1=[]) 38class OneDeviceStrategy(distribute_lib.Strategy): 39 """A distribution strategy for running on a single device. 40 41 Using this strategy will place any variables created in its scope on the 42 specified device. Input distributed through this strategy will be 43 prefetched to the specified device. Moreover, any functions called via 44 `strategy.run` will also be placed on the specified device 45 as well. 46 47 Typical usage of this strategy could be testing your code with the 48 tf.distribute.Strategy API before switching to other strategies which 49 actually distribute to multiple devices/machines. 50 51 For example: 52 ``` 53 strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") 54 55 with strategy.scope(): 56 v = tf.Variable(1.0) 57 print(v.device) # /job:localhost/replica:0/task:0/device:GPU:0 58 59 def step_fn(x): 60 return x * 2 61 62 result = 0 63 for i in range(10): 64 result += strategy.run(step_fn, args=(i,)) 65 print(result) # 90 66 ``` 67 """ 68 69 def __init__(self, device): 70 """Creates a `OneDeviceStrategy`. 71 72 Args: 73 device: Device string identifier for the device on which the variables 74 should be placed. See class docs for more details on how the device is 75 used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0" 76 """ 77 super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) 78 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 79 "OneDeviceStrategy") 80 81 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation 82 """Distributes a tf.data.Dataset instance provided via dataset. 83 84 In this case, there is only one device, so this is only a thin wrapper 85 around the input dataset. It will, however, prefetch the input data to the 86 specified device. The returned distributed dataset can be iterated over 87 similar to how regular datasets can. 88 89 NOTE: Currently, the user cannot add any more transformations to a 90 distributed dataset. 91 92 Example: 93 ``` 94 strategy = tf.distribute.OneDeviceStrategy() 95 dataset = tf.data.Dataset.range(10).batch(2) 96 dist_dataset = strategy.experimental_distribute_dataset(dataset) 97 for x in dist_dataset: 98 print(x) # [0, 1], [2, 3],... 99 ``` 100 Args: 101 dataset: `tf.data.Dataset` to be prefetched to device. 102 options: `tf.distribute.InputOptions` used to control options on how this 103 dataset is distributed. 104 Returns: 105 A "distributed `Dataset`" that the caller can iterate over. 106 """ 107 return super(OneDeviceStrategy, self).experimental_distribute_dataset( 108 dataset, options) 109 110 def distribute_datasets_from_function( 111 self, 112 dataset_fn, # pylint: disable=useless-super-delegation 113 options=None): 114 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 115 116 `dataset_fn` will be called once for each worker in the strategy. In this 117 case, we only have one worker and one device so `dataset_fn` is called 118 once. 119 120 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 121 information about batching and input replication can be accessed: 122 123 ``` 124 def dataset_fn(input_context): 125 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 126 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 127 return d.shard( 128 input_context.num_input_pipelines, input_context.input_pipeline_id) 129 130 inputs = strategy.distribute_datasets_from_function(dataset_fn) 131 132 for batch in inputs: 133 replica_results = strategy.run(replica_fn, args=(batch,)) 134 ``` 135 136 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 137 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 138 the global batch size. This may be computed using 139 `input_context.get_per_replica_batch_size`. 140 141 Args: 142 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 143 returning a `tf.data.Dataset`. 144 options: `tf.distribute.InputOptions` used to control options on how this 145 dataset is distributed. 146 147 Returns: 148 A "distributed `Dataset`", which the caller can iterate over like regular 149 datasets. 150 """ 151 return super(OneDeviceStrategy, 152 self).distribute_datasets_from_function(dataset_fn, options) 153 154 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation 155 """Returns the list of all local per-replica values contained in `value`. 156 157 In `OneDeviceStrategy`, the `value` is always expected to be a single 158 value, so the result is just the value in a tuple. 159 160 Args: 161 value: A value returned by `experimental_run()`, `run()`, 162 `extended.call_for_each_replica()`, or a variable created in `scope`. 163 164 Returns: 165 A tuple of values contained in `value`. If `value` represents a single 166 value, this returns `(value,).` 167 """ 168 return super(OneDeviceStrategy, self).experimental_local_results(value) 169 170 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation 171 """Run `fn` on each replica, with the given arguments. 172 173 In `OneDeviceStrategy`, `fn` is simply called within a device scope for the 174 given device, with the provided arguments. 175 176 Args: 177 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 178 args: (Optional) Positional arguments to `fn`. 179 kwargs: (Optional) Keyword arguments to `fn`. 180 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 181 the options to run `fn`. 182 183 Returns: 184 Return value from running `fn`. 185 """ 186 return super(OneDeviceStrategy, self).run(fn, args, kwargs, options) 187 188 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation 189 """Reduce `value` across replicas. 190 191 In `OneDeviceStrategy`, there is only one replica, so if axis=None, value 192 is simply returned. If axis is specified as something other than None, 193 such as axis=0, value is reduced along that axis and returned. 194 195 Example: 196 ``` 197 t = tf.range(10) 198 199 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy() 200 # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 201 202 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy() 203 # result: 45 204 ``` 205 206 Args: 207 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 208 be combined. 209 value: A "per replica" value, e.g. returned by `run` to 210 be combined into a single tensor. 211 axis: Specifies the dimension to reduce along within each 212 replica's tensor. Should typically be set to the batch dimension, or 213 `None` to only reduce across replicas (e.g. if the tensor has no batch 214 dimension). 215 216 Returns: 217 A `Tensor`. 218 """ 219 return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis) 220 221 def scope(self): # pylint: disable=useless-super-delegation 222 """Returns a context manager selecting this Strategy as current. 223 224 Inside a `with strategy.scope():` code block, this thread 225 will use a variable creator set by `strategy`, and will 226 enter its "cross-replica context". 227 228 In `OneDeviceStrategy`, all variables created inside `strategy.scope()` 229 will be on `device` specified at strategy construction time. 230 See example in the docs for this class. 231 232 Returns: 233 A context manager to use for creating variables with this strategy. 234 """ 235 return super(OneDeviceStrategy, self).scope() 236 237 238@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring 239class OneDeviceStrategyV1(distribute_lib.StrategyV1): 240 241 __doc__ = OneDeviceStrategy.__doc__.replace( 242 "For example:\n ```", 243 "For example:\n ```\n tf.enable_eager_execution()") 244 245 def __init__(self, device): 246 super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device)) 247 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 248 "OneDeviceStrategy") 249 __init__.__doc__ = OneDeviceStrategy.__init__.__doc__ 250 251 252# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs. 253class OneDeviceExtended(distribute_lib.StrategyExtendedV1): 254 """Implementation of OneDeviceStrategy.""" 255 256 def __init__(self, container_strategy, device): 257 super(OneDeviceExtended, self).__init__(container_strategy) 258 self._device = device_util.resolve(device) 259 self._input_device = device_util.get_host_for_device(self._device) 260 261 def _input_workers_with_options(self, options=None): 262 if not options or options.experimental_fetch_to_device: 263 return input_lib.InputWorkers([(self._input_device, (self._device,))]) 264 else: 265 return input_lib.InputWorkers([(self._input_device, 266 (self._input_device,))]) 267 268 @property 269 def _input_workers(self): 270 return self._input_workers_with_options() 271 272 def _create_variable(self, next_creator, **kwargs): 273 colocate_with = kwargs.pop("colocate_with", None) 274 if colocate_with is None: 275 with ops.device(self._device): 276 return next_creator(**kwargs) 277 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 278 with ops.device(colocate_with.device): 279 return next_creator(**kwargs) 280 else: 281 with ops.colocate_with(colocate_with): 282 return next_creator(**kwargs) 283 284 def _validate_colocate_with_variable(self, colocate_with_variable): 285 distribute_utils.validate_colocate(colocate_with_variable, self) 286 287 def _make_dataset_iterator(self, dataset): 288 """Make iterator from dataset without splitting the batch.""" 289 # Note that split_batch_by argument is not passed because it is always 1 in 290 # this strategy, and adding it adds unnecessary overhead to the dataset. 291 return input_lib_v1.DatasetIterator(dataset, self._input_workers, 292 self._container_strategy()) 293 294 def _make_input_fn_iterator( 295 self, 296 input_fn, 297 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 298 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 299 [distribute_lib.InputContext()], 300 self._container_strategy()) 301 302 def _experimental_make_numpy_dataset(self, numpy_input, session): 303 return numpy_dataset.one_host_numpy_dataset( 304 numpy_input, numpy_dataset.SingleDevice(self._input_device), session) 305 306 def _broadcast_to(self, tensor, destinations): 307 del destinations 308 return tensor 309 310 def _experimental_distribute_dataset(self, dataset, options): 311 # Note that split_batch_by argument is not passed because it is always 1 in 312 # this strategy, and adding it adds unnecessary overhead to the dataset. 313 if (options and options.experimental_replication_mode == 314 distribute_lib.InputReplicationMode.PER_REPLICA): 315 raise NotImplementedError( 316 "InputReplicationMode.PER_REPLICA " 317 "is only supported in " 318 "`experimental_distribute_datasets_from_function`." 319 ) 320 return input_util.get_distributed_dataset( 321 dataset, 322 self._input_workers_with_options(options), 323 self._container_strategy(), 324 options=options) 325 326 def _distribute_datasets_from_function(self, dataset_fn, options): 327 if (options and options.experimental_replication_mode == 328 distribute_lib.InputReplicationMode.PER_REPLICA): 329 raise NotImplementedError( 330 "InputReplicationMode.PER_REPLICA " 331 "is only supported in " 332 "`experimental_distribute_datasets_from_function` " 333 "of tf.distribute.MirroredStrategy") 334 return input_util.get_distributed_datasets_from_function( 335 dataset_fn, 336 self._input_workers_with_options(options), 337 [distribute_lib.InputContext()], 338 self._container_strategy(), 339 options=options) 340 341 def _experimental_distribute_values_from_function(self, value_fn): 342 # TODO(b/137795644): This should return a PerReplica value but other 343 # methods like run in OneDeviceStrategy need to be modified 344 # to do the same. 345 return value_fn(distribute_lib.ValueContext()) 346 347 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 348 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 349 initial_loop_values=None): 350 if initial_loop_values is None: 351 initial_loop_values = {} 352 initial_loop_values = nest.flatten(initial_loop_values) 353 354 ctx = input_lib.MultiStepContext() 355 def body(i, *args): 356 """A wrapper around `fn` to create the while loop body.""" 357 del args 358 fn_result = fn(ctx, iterator.get_next()) 359 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 360 with ops.control_dependencies([fn_result]): 361 return [i + 1] + flat_last_step_outputs 362 363 # We capture the control_flow_context at this point, before we run `fn` 364 # inside a while_loop. This is useful in cases where we might need to exit 365 # these contexts and get back to the outer context to do some things, for 366 # e.g. create an op which should be evaluated only once at the end of the 367 # loop on the host. One such usage is in creating metrics' value op. 368 self._outer_control_flow_context = ( 369 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 370 371 # TODO(priyag): Use max_iterations instead of an explicit counter. 372 cond = lambda i, *args: i < iterations 373 i = constant_op.constant(0) 374 loop_result = control_flow_ops.while_loop( 375 cond, body, [i] + initial_loop_values, name="", 376 parallel_iterations=1, back_prop=False, swap_memory=False, 377 return_same_structure=True) 378 del self._outer_control_flow_context 379 380 ctx.run_op = control_flow_ops.group(loop_result) 381 382 # Convert the last_step_outputs from a list to the original dict structure 383 # of last_step_outputs. 384 last_step_tensor_outputs = loop_result[1:] 385 last_step_tensor_outputs_dict = nest.pack_sequence_as( 386 ctx.last_step_outputs, last_step_tensor_outputs) 387 388 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 389 return ctx 390 391 def _call_for_each_replica(self, fn, args, kwargs): 392 strategy = self._container_strategy() 393 with ops.device(self._device), _OneDeviceReplicaContext(strategy): 394 return fn(*args, **kwargs) 395 396 def _reduce_to(self, reduce_op, value, destinations, options): 397 del reduce_op, destinations, options 398 return value 399 400 def _gather_to_implementation(self, value, destinations, axis, options): 401 del destinations, axis, options 402 return value 403 404 def _update(self, var, fn, args, kwargs, group): 405 # The implementations of _update() and _update_non_slot() are identical 406 # except _update() passes `var` as the first argument to `fn()`. 407 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 408 409 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 410 del colocate_with 411 with ops.device(self._device), distribute_lib.UpdateContext(self._device): 412 result = fn(*args, **kwargs) 413 if group: 414 return result 415 else: 416 return nest.map_structure(self._local_results, result) 417 418 def read_var(self, replica_local_var): 419 """Read the aggregate value of a replica-local variable.""" 420 return array_ops.identity(replica_local_var) 421 422 def _local_results(self, value): 423 return (value,) 424 425 def value_container(self, value): 426 return value 427 428 def _in_multi_worker_mode(self): 429 """Whether this strategy indicates working in multi-worker settings.""" 430 return False 431 432 @property 433 def _num_replicas_in_sync(self): 434 return 1 435 436 @property 437 def worker_devices(self): 438 return (self._device,) 439 440 @property 441 def parameter_devices(self): 442 return (self._device,) 443 444 def non_slot_devices(self, var_list): 445 del var_list 446 return (self._device,) 447 448 @property 449 def experimental_should_init(self): 450 return True 451 452 @property 453 def experimental_between_graph(self): 454 return False 455 456 @property 457 def should_checkpoint(self): 458 return True 459 460 @property 461 def should_save_summary(self): 462 return True 463 464 # TODO(priyag): Delete this once all strategies use global batch size. 465 @property 466 def _global_batch_size(self): 467 """Global and per-replica batching are equivalent for OneDeviceStrategy.""" 468 return True 469 470 @property 471 def _support_per_replica_values(self): 472 return False 473 474 def _get_local_replica_id(self, replica_id_in_sync_group): 475 return replica_id_in_sync_group 476 477 478class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): 479 """ReplicaContext for OneDeviceStrategy.""" 480 481 def __init__(self, strategy): 482 distribute_lib.ReplicaContext.__init__( 483 self, strategy, replica_id_in_sync_group=0) 484 485 @property 486 def devices(self): 487 return self._strategy.extended.worker_devices 488