1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing utilities used by tf.distribute.Strategy.""" 16 17from collections import abc 18import contextlib 19import threading 20 21import contextlib 22import threading 23from tensorflow.python.distribute import tpu_values as tpu_values_lib 24from tensorflow.python.distribute import values as values_lib 25from tensorflow.python.eager import context 26from tensorflow.python.eager import tape 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import variable_scope as vs 32from tensorflow.python.util import nest 33 34 35def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): 36 """Makes a nest per-replica into a nest of PerReplica/Mirrored values. 37 38 Args: 39 values: Values to regroup 40 wrap_class: Class that `values` be wrapped in. 41 always_wrap: Always wrap the `values` in `wrap_class` even if the values 42 are the same except for DistributeVariable. 43 Returns: 44 Wrapped `values`. 45 """ 46 v0 = values[0] 47 48 if isinstance(v0, list): 49 for v in values[1:]: 50 assert isinstance(v, list) 51 assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % 52 (len(v), len(v0), v, v0)) 53 return [ 54 regroup(tuple(v[i] for v in values), wrap_class, always_wrap) 55 for i in range(len(v0)) 56 ] 57 58 if isinstance(v0, tuple): 59 for v in values[1:]: 60 assert isinstance(v, tuple) 61 assert len(v) == len(v0), ("Values to regroup had different lengths: " 62 f"len(v) == {len(v)}, len(v0) == {len(v0)}, " 63 f"v: {v}, v0: {v0}") 64 regrouped_tuple = tuple( 65 regroup(tuple(v[i] for v in values), wrap_class, always_wrap) 66 for i in range(len(v0))) 67 if hasattr(v0, "_fields"): 68 # This tuple is in fact a namedtuple! Create a new namedtuple instance 69 # and initialize it with the regrouped values: 70 assert hasattr(v0, "_make") 71 return v0._make(regrouped_tuple) 72 else: 73 return regrouped_tuple 74 75 if isinstance(v0, abc.Mapping): 76 v0keys = v0.keys() 77 for v in values[1:]: 78 assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v)) 79 assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % 80 (set(v0keys), set(v.keys()))) 81 # Use the actual type in case it is a class inherited from a dict. 82 return type(v0)({ 83 key: regroup(tuple(v[key] for v in values), 84 wrap_class, always_wrap) 85 for key in v0keys 86 }) 87 88 # If exactly the same object across all devices, return it unwrapped. 89 same_id = True 90 for v in values[1:]: 91 if v is not v0: 92 same_id = False 93 break 94 # Consider three cases where same_id is true: 95 # * If v0 is a DistributedVariable (a MirroredVariable or 96 # SyncOnReadVariable, and same_id means it is the same across all 97 # devices), we want to return it. We check DistributedVariable 98 # specifically since it can look like it has a 99 # _distributed_container member since its members do. 100 if same_id and isinstance(v0, values_lib.DistributedVariable): 101 return v0 102 # * If v0 is a member of a distributed variable, in which case 103 # hasattr(v0, "_distributed_container") is true, we want to 104 # return the DistributedVariable that contains it using the 105 # _distributed_container logic below. This case can trigger 106 # same_id when there is only one device. 107 # * In any other situation, same_id means we return v0 unless `always_wrap` is 108 # true. 109 if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): 110 return v0 111 112 # Detect the case where each device has a parallel component of the 113 # same MirroredVariable (or SyncOnReadVariable). In this case we 114 # want to return the containing MirroredVariable, after a bunch of 115 # sanity checking. In particular, each component should have the 116 # same container, and the devices of the variables should match the 117 # keys of the per-replica dictionary. 118 if hasattr(v0, "_distributed_container"): 119 # pylint: disable=protected-access 120 assert not isinstance(v0, values_lib.MirroredVariable), ( 121 "ids = %s, values = %s" % ([id(v) for v in values], values)) 122 distributed_container = v0._distributed_container() 123 assert distributed_container is not None 124 for v in values[1:]: 125 assert distributed_container is v._distributed_container() 126 return distributed_container 127 # pylint: enable=protected-access 128 129 return wrap_class(values) 130 131 132def select_replica(replica_id, structured): 133 """Specialize a nest of regular & per-replica values for one replica.""" 134 135 def _get(x): 136 # `DistributedValues` would be sliced according to replica unless it is a 137 # `DistributedVariable` because `DistributedVariable` can be handled 138 # directly in the replica context. 139 if (isinstance(x, values_lib.DistributedVariable) or 140 not isinstance(x, values_lib.DistributedValues)): 141 return x 142 else: 143 return x.values[replica_id] 144 145 return nest.map_structure(_get, structured) 146 147 148def select_replica_mirrored(replica_id, structured): 149 """Specialize a nest of regular & mirrored values for one replica.""" 150 assert_mirrored(structured) 151 return select_replica(replica_id, structured) 152 153 154def assert_mirrored(structured): 155 """Raises if the structured is not composed of mirrored or regular values.""" 156 157 def _assert_mirrored(x): 158 if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x): 159 raise TypeError( 160 "Expected value to be mirrored across replicas: %s in %s." % 161 (x, structured)) 162 163 nest.map_structure(_assert_mirrored, structured) 164 165 166def update_regroup(extended, updates, group): 167 """Regroup for an update, with dependencies to ensure all updates execute.""" 168 if not group: 169 regrouped = regroup(updates, values_lib.Mirrored) 170 return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access 171 172 def _make_grouped_mirrored(values): 173 """Convert per-replica list `values` into Mirrored type with grouping.""" 174 if len(values) == 1: 175 return values_lib.Mirrored(values) 176 177 # Make sure we run all updates. Without this, something like 178 # session.run(extended.update(...)) may only update one replica. 179 g = control_flow_ops.group(values) 180 181 # If values is just ops, the grouping is enough. Everything in values 182 # should have the same type, since we expect every replica to be performing 183 # the same computation. 184 if not all(tensor_util.is_tf_type(v) for v in values): 185 return g 186 187 # Otherwise we need tensors with the same values as `values`, but 188 # that have a dependency on `g`. 189 with_dep = [] 190 for v in values: 191 with ops.device(v.device), ops.control_dependencies([g]): 192 with_dep.append(array_ops.identity(v)) 193 194 return values_lib.Mirrored(with_dep) 195 196 return regroup(updates, _make_grouped_mirrored) 197 198 199def value_container(val): 200 """Returns the container that this per-replica `value` belongs to. 201 202 Args: 203 val: A value returned by `call_for_each_replica()` or a variable created in 204 `scope()`. 205 206 Returns: 207 A container that `value` belongs to. 208 If value does not belong to any container (including the case of 209 container having been destroyed), returns the value itself. 210 """ 211 if (hasattr(val, "_distributed_container") and 212 # DistributedVariable has _distributed_container defined 213 # but we don't want to return it. 214 not isinstance(val, values_lib.DistributedVariable)): 215 container = val._distributed_container() # pylint: disable=protected-access 216 if container is not None: 217 return container 218 return val 219 220 221def is_distributed_variable(v): 222 """Determine if a variable is ds variable or TPU mirrored variable.""" 223 return getattr(v, "is_distributed_variable", False) 224 225 226def is_distributed_table(v): 227 """Determine if an object is a DistributedTable.""" 228 return getattr(v, "is_distributed_table", False) 229 230 231def _validate_colocate_extended(v, extended): 232 variable_strategy = v._distribute_strategy # pylint: disable=protected-access 233 if variable_strategy.extended is not extended: 234 raise ValueError( 235 "`colocate_vars_with` must only be passed a variable created in this " 236 "tf.distribute.Strategy.scope(), not %s created in scope: %s" % 237 (v, variable_strategy)) 238 239 240def validate_colocate_distributed_variable(v, extended): 241 if not isinstance(v, values_lib.DistributedVariable): 242 raise ValueError( 243 "`colocate_vars_with` must only be passed a variable created in this " 244 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 245 _validate_colocate_extended(v, extended) 246 247 248def validate_colocate(v, extended): 249 if not hasattr(v, "_distribute_strategy"): 250 raise ValueError( 251 "`colocate_vars_with` must only be passed a variable created in this " 252 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 253 _validate_colocate_extended(v, extended) 254 255 256# Variable creation function for sync strategies. 257def _validate_synchronization(kwargs): 258 """Validate that given synchronization value is valid.""" 259 synchronization = kwargs.get("synchronization", 260 vs.VariableSynchronization.AUTO) 261 if synchronization == vs.VariableSynchronization.NONE: 262 raise ValueError( 263 "`NONE` variable synchronization mode is not supported with " 264 "tf.distribute strategy. Please change the `synchronization` for " 265 "variable: " + str(kwargs["name"])) 266 if synchronization not in (vs.VariableSynchronization.ON_READ, 267 vs.VariableSynchronization.ON_WRITE, 268 vs.VariableSynchronization.AUTO): 269 raise ValueError( 270 "Invalid variable synchronization mode: %s for variable: %s" % 271 (synchronization, kwargs["name"])) 272 if synchronization == vs.VariableSynchronization.AUTO: 273 return vs.VariableSynchronization.ON_WRITE 274 return synchronization 275 276 277def _validate_aggregation(kwargs): 278 aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE) 279 280 if aggregation not in (vs.VariableAggregation.NONE, 281 vs.VariableAggregation.SUM, 282 vs.VariableAggregation.MEAN, 283 vs.VariableAggregation.ONLY_FIRST_REPLICA): 284 raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % 285 (aggregation, kwargs["name"])) 286 return aggregation 287 288 289def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, 290 policy_mapping, **kwargs): 291 """Create distributed variables with given synchronization and aggregation.""" 292 # Figure out what collections this variable should be added to. 293 # We'll add the MirroredVariable to those collections instead. 294 var_collections = kwargs.pop("collections", None) 295 if var_collections is None: 296 var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] 297 kwargs["collections"] = [] 298 299 synchronization = _validate_synchronization(kwargs) 300 # Update synchronization in kwargs in case it's AUTO, which is converted to 301 # ON_WRITE. 302 kwargs["synchronization"] = synchronization 303 aggregation = _validate_aggregation(kwargs) 304 use_var_policy = getattr(strategy.extended, "_use_var_policy", False) 305 306 # Ignore user-specified caching device, not needed for mirrored variables. 307 kwargs.pop("caching_device", None) 308 309 # TODO(josh11b,apassos): It would be better if variable initialization 310 # was never recorded on the tape instead of having to do this manually 311 # here. 312 with tape.stop_recording(): 313 value_list = real_mirrored_creator(**kwargs) 314 # MirroredVariable is recreated during saved_model loading, and its 315 # component variables (value_list) will have None initializer. We 316 # set their initializers to no_op so that consumer like 317 # `global_variables_initializer` wouldn't complain, as it groups all 318 # variables' initializers thus all variables have to have initializers. 319 for v in value_list: 320 # pylint:disable=protected-access 321 if hasattr(v, "_initializer_op") and v._initializer_op is None: 322 v._initializer_op = control_flow_ops.no_op() 323 # pylint:enable=protected-access 324 if use_var_policy: 325 var_policy_cls = policy_mapping.get(synchronization) 326 var_policy = var_policy_cls(aggregation=aggregation) 327 var_cls = class_mapping.get("VariableClass") 328 result = var_cls(strategy, value_list, aggregation, var_policy=var_policy) 329 else: 330 var_cls = class_mapping.get(synchronization) 331 result = var_cls(strategy, value_list, aggregation) 332 333 # Add the wrapped variable to the requested collections. 334 # The handling of eager mode and the global step matches 335 # ResourceVariable._init_from_args(). 336 if not context.executing_eagerly(): 337 g = ops.get_default_graph() 338 # If "trainable" is True, next_creator() will add the member variables 339 # to the TRAINABLE_VARIABLES collection, so we manually remove 340 # them and replace with the MirroredVariable. We can't set 341 # "trainable" to False for next_creator() since that causes functions 342 # like implicit_gradients to skip those variables. 343 if kwargs.get("trainable", True): 344 var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 345 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 346 for value in value_list: 347 for i, trainable_variable in enumerate(l): 348 if value is trainable_variable: 349 del l[i] 350 break 351 352 g.add_to_collections(var_collections, result) 353 elif ops.GraphKeys.GLOBAL_STEP in var_collections: 354 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) 355 356 return result 357 358 359# Utility functions 360# Return True if the Value is Mirrored or the Variable is replicated and kept in 361# sync. 362def is_mirrored(val): 363 if isinstance(val, values_lib.DistributedVariable): 364 if val._policy: # pylint: disable=protected-access 365 return val._policy._is_mirrored() # pylint: disable=protected-access 366 return isinstance(val, values_lib.Mirrored) 367 368 369def is_sync_on_read(val): 370 if isinstance(val, values_lib.DistributedVariable): 371 if val._policy: # pylint: disable=protected-access 372 return not val._policy._is_mirrored() # pylint: disable=protected-access 373 return not isinstance(val, values_lib.Mirrored) 374 375 376class CachingScopeLocal(threading.local): 377 """Class for maintaining thread local state for caching scope.""" 378 379 def __init__(self): 380 super(CachingScopeLocal, self).__init__() 381 self.new_cache_scope_count = 0 382 self.cache_scope_exited_count = 0 383 384 def enter_scope(self): 385 self.new_cache_scope_count += 1 386 387 def exit_scope(self): 388 self.cache_scope_exited_count += 1 389 390 def in_caching_scope(self): 391 return self.new_cache_scope_count > self.cache_scope_exited_count 392 393 394caching_scope_local = CachingScopeLocal() 395 396 397@contextlib.contextmanager 398def cache_variable_reads(): 399 """Scope for caching variable reads for AggregatingVariable. 400 401 The variable reads for AggregatingVariable inside this scope are cached. i.e. 402 the first read of variable reads the value from possibly remote handle, but 403 subsequent reads are returned using local cached value. 404 405 For example: 406 strategy = ParameterServerStrategy... 407 with strategy.scope(): 408 # Variable v is of AggregatingVariable type with actual variable residing 409 # on PS. 410 v = tf.Variable(1.0) 411 412 with distribute_utils.cache_variable_reads(): 413 v.read_value() # Reads value 1.0 414 v.assign(constant_op.constant(5.0)) # v changes to 5.0 415 t1 = v.read_value() 416 t2 = v.read_value() # Both t1 & t2 return cached value 1.0 from local CPU. 417 418 Notes about cache_variable_reads scope: 419 1. Nesting of scope cache_variable_reads() is not supported 420 2. And when caching scope is enabled, the thread enabling the cache and 421 mirrored_run._MirroredReplicaThread threads spawned from it will have 422 caching enabled. 423 424 Yields: 425 A context for caching variables. 426 """ 427 428 try: 429 if caching_scope_local.in_caching_scope(): 430 # There is nested cache scope, which is not supported. 431 raise ValueError("cache_variable_reads scope cannot be nested") 432 caching_scope_local.enter_scope() 433 yield 434 finally: 435 caching_scope_local.exit_scope() 436 437 438# The following mapping indicates the policy that you must use for a given 439# variable `synchronization` and `aggregation` pair. 440# OnWritePolicy is used for: 441# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 442# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 443# OnReadPolicy is used for: 444# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 445VARIABLE_POLICY_MAPPING = { 446 vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy, 447 vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy, 448} 449 450VARIABLE_CLASS_MAPPING = { 451 "VariableClass": values_lib.DistributedVariable, 452 vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, 453 vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable, 454} 455 456TPU_VARIABLE_POLICY_MAPPING = { 457 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy, 458 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy, 459} 460 461TPU_VARIABLE_CLASS_MAPPING = { 462 "VariableClass": tpu_values_lib.TPUDistributedVariable, 463 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable, 464 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable, 465} 466