1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions used by values.py and ps_values.py.""" 16 17from tensorflow.python.distribute import distribute_lib 18from tensorflow.python.distribute import distribution_strategy_context as ds_context 19from tensorflow.python.distribute import reduce_util 20from tensorflow.python.eager import context 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import variable_scope as vs 26from tensorflow.python.saved_model import save_context 27from tensorflow.python.saved_model import save_options 28from tensorflow.python.training.saving import saveable_object 29 30 31def write_object_proto(var, proto, options): 32 """Update a SavedObject proto for the caller. 33 34 If a DistributedVariable object supports this method, it will be called when 35 saving with a pre-built `SavedObject` proto representing the object, plus an 36 instance of `SaveOptions`. This method is then free to modify that proto 37 instance. 38 39 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 40 write out information about their components to the 41 `experimental_distributed_variable_components` field of a 42 `SavedVariable` (depending on the `SaveOptions` variable policy). 43 44 Args: 45 var: The DistributedVariable object. 46 proto: A pre-built `SavedObject` proto for this object. It is assumed this 47 will be a `SavedVariable` instance. 48 options: A `SaveOptions` instance. 49 """ 50 if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access 51 ): 52 for var in var.values: 53 var_proto = ( 54 proto.variable.experimental_distributed_variable_components.add()) 55 var_proto.name = var.name.split(":")[0] 56 var_proto.device = var.device 57 58 59def get_on_write_saveable(var, primary_var, name): 60 """Return saveable spec for AUTO and ON_WRITE variables.""" 61 # We use a callable so that we don't have to evaluate this expression 62 # in the case where we are trying to restore instead of save. 63 def tensor(): 64 if context.executing_eagerly() and not primary_var.is_initialized(): 65 # A SaveSpec tensor value of `None` indicates that the variable is 66 # uninitialized. 67 return None 68 strategy = var.distribute_strategy 69 return strategy.extended.read_var(var) 70 71 spec = saveable_object.SaveSpec( 72 tensor=tensor, 73 slice_spec="", 74 name=name, 75 dtype=var.dtype, 76 device=primary_var.device) 77 78 return tensor, [spec] 79 80 81def get_on_write_restore_ops(var, tensor): 82 """Return restore ops for AUTO and ON_WRITE variables.""" 83 packed_var = var._packed_variable # pylint: disable=protected-access 84 if packed_var is not None: 85 return control_flow_ops.group( 86 tuple( 87 assign_on_device(d, packed_var, tensor) 88 for d in packed_var.devices)) 89 return control_flow_ops.group( 90 tuple( 91 assign_on_device(v.device, v, tensor) 92 for v in var.values)) 93 94 95def get_on_read_saveable(var, primary_var, name): 96 """Return saveables for ON_READ variable.""" 97 98 # We use a callable so that we don't have to evaluate this expression 99 # in the case where we are trying to restore instead of save. 100 def tensor(): 101 return var._get_cross_replica() # pylint: disable=protected-access 102 103 spec = saveable_object.SaveSpec( 104 tensor=tensor, 105 slice_spec="", 106 name=name, 107 dtype=var.dtype, 108 device=primary_var.device) 109 110 return tensor, [spec] 111 112 113def get_on_read_restore_ops(var, tensor, aggregation): 114 """Return restore ops for ON_READ variables.""" 115 # To preserve the sum across save and restore, we have to divide the 116 # total across all devices when restoring a variable that was summed 117 # when saving. 118 if aggregation == vs.VariableAggregation.SUM: 119 strategy = var.distribute_strategy 120 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 121 var.dtype) 122 return control_flow_ops.group( 123 tuple( 124 assign_on_device(v.device, v, tensor) 125 for v in var.values)) 126 127 128# Utility function that indicates if you are in an UpdateContext when running 129# in a replica fn. 130def in_replica_update_context(): 131 return distribute_lib.get_update_replica_id() is not None 132 133 134def on_write_assign(var, value, use_locking=False, name=None, read_value=True): 135 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 136 return var._update( # pylint: disable=protected-access 137 update_fn=assign_fn, 138 value=value, 139 use_locking=use_locking, 140 name=name, 141 read_value=read_value) 142 143 144def on_write_assign_add(var, value, use_locking=False, name=None, 145 read_value=True): 146 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 147 return var._update( # pylint: disable=protected-access 148 update_fn=assign_add_fn, 149 value=value, 150 use_locking=use_locking, 151 name=name, 152 read_value=read_value) 153 154 155def on_write_assign_sub(var, value, use_locking=False, name=None, 156 read_value=True): 157 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 158 return var._update( # pylint: disable=protected-access 159 update_fn=assign_sub_fn, 160 value=value, 161 use_locking=use_locking, 162 name=name, 163 read_value=read_value) 164 165 166def assign_on_each_device(var, assign_func, value, read_value): 167 """Update the variable on each replica with the given assign_func and value.""" 168 if var._packed_variable is not None: # pylint: disable=protected-access 169 update = control_flow_ops.group( 170 tuple( 171 assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access 172 else: 173 update = control_flow_ops.group( 174 tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access 175 if not read_value: 176 return update 177 with ops.control_dependencies([update] if update else []): 178 return var.read_value() 179 180 181def on_read_assign_sub_cross_replica(var, value, read_value=True): 182 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 183 if ds_context.in_cross_replica_context(): 184 if var.aggregation == vs.VariableAggregation.SUM: 185 raise ValueError( 186 "SyncOnReadVariable does not support `assign_sub` in " 187 "cross-replica context when aggregation is set to " 188 "`tf.VariableAggregation.SUM`.") 189 return assign_on_each_device(var, assign_sub_on_device, 190 value, read_value) 191 192 193def on_read_assign_add_cross_replica(var, value, read_value=True): 194 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 195 if ds_context.in_cross_replica_context(): 196 if var.aggregation == vs.VariableAggregation.SUM: 197 raise ValueError( 198 "SyncOnReadVariable does not support `assign_add` in " 199 "cross-replica context when aggregation is set to " 200 "`tf.VariableAggregation.SUM`.") 201 return assign_on_each_device(var, assign_add_on_device, 202 value, read_value) 203 204 205def on_read_assign_cross_replica(var, value, read_value=True): 206 """Return the value of the variable in cross replica context.""" 207 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 208 if ds_context.in_cross_replica_context(): 209 # To preserve the sum across save and restore, we have to divide the 210 # total across all devices when restoring a variable that was summed 211 # when saving. 212 tensor = value 213 if var.aggregation == vs.VariableAggregation.SUM: 214 strategy = var._distribute_strategy # pylint: disable=protected-access 215 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 216 var.dtype) 217 return assign_on_each_device(var, assign_on_device, tensor, 218 read_value) 219 220 221def scatter_sub(var, sparse_delta, use_locking=False, name=None): 222 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 223 return var._update( # pylint: disable=protected-access 224 update_fn=scatter_sub_fn, 225 value=sparse_delta, 226 use_locking=use_locking, 227 name=name) 228 229 230def scatter_add(var, sparse_delta, use_locking=False, name=None): 231 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 232 return var._update( # pylint: disable=protected-access 233 update_fn=scatter_add_fn, 234 value=sparse_delta, 235 use_locking=use_locking, 236 name=name) 237 238 239def scatter_mul(var, sparse_delta, use_locking=False, name=None): 240 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 241 return var._update( # pylint: disable=protected-access 242 update_fn=scatter_mul_fn, 243 value=sparse_delta, 244 use_locking=use_locking, 245 name=name) 246 247 248def scatter_div(var, sparse_delta, use_locking=False, name=None): 249 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 250 return var._update( # pylint: disable=protected-access 251 update_fn=scatter_div_fn, 252 value=sparse_delta, 253 use_locking=use_locking, 254 name=name) 255 256 257def scatter_min(var, sparse_delta, use_locking=False, name=None): 258 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 259 return var._update( # pylint: disable=protected-access 260 update_fn=scatter_min_fn, 261 value=sparse_delta, 262 use_locking=use_locking, 263 name=name) 264 265 266def scatter_max(var, sparse_delta, use_locking=False, name=None): 267 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 268 return var._update( # pylint: disable=protected-access 269 update_fn=scatter_max_fn, 270 value=sparse_delta, 271 use_locking=use_locking, 272 name=name) 273 274 275def scatter_update(var, sparse_delta, use_locking=False, name=None): 276 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 277 return var._update( # pylint: disable=protected-access 278 update_fn=scatter_update_fn, 279 value=sparse_delta, 280 use_locking=use_locking, 281 name=name) 282 283 284def get_current_replica_id_as_int(): 285 """Returns the current replica ID as an integer, or `None`.""" 286 replica_context = ds_context.get_replica_context() 287 if replica_context: 288 replica_id = replica_context._replica_id # pylint: disable=protected-access 289 if not isinstance(replica_id, int): 290 replica_id = tensor_util.constant_value(replica_id) 291 else: 292 replica_id = distribute_lib.get_update_replica_id() 293 return replica_id 294 295 296def assign_on_device(device, variable, tensor): 297 with ops.device(device): 298 return variable.assign(tensor) 299 300 301def assign_add_on_device(device, variable, tensor): 302 with ops.device(device): 303 return variable.assign_add(tensor) 304 305 306def assign_sub_on_device(device, variable, tensor): 307 with ops.device(device): 308 return variable.assign_sub(tensor) 309 310 311def assert_replica_context(strategy): 312 replica_context = ds_context.get_replica_context() 313 if not replica_context: 314 raise RuntimeError( 315 "Replica-local variables may only be assigned in a replica context.") 316 if replica_context.strategy is not strategy: 317 raise RuntimeError( 318 "Replica-local variables may only be assigned in a replica context.") 319 320 321def apply_aggregation(strategy, value, aggregation, destinations): 322 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 323 return strategy.extended.broadcast_to( 324 strategy.experimental_local_results(value)[0], 325 destinations=destinations) 326 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 327 return strategy.extended.reduce_to(reduce_op, value, destinations) 328 329 330aggregation_error_msg = ( 331 "You must specify an aggregation method to update a " 332 "{variable_type} in Replica Context. You can do so by passing " 333 "an explicit value for argument `aggregation` to tf.Variable(..)." 334 "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" 335 "`tf.VariableAggregation` lists the possible aggregation methods." 336 "This is required because {variable_type} should always be " 337 "kept in sync. When updating them or assigning to them in a " 338 "replica context, we automatically try to aggregate the values " 339 "before updating the variable. For this aggregation, we need to " 340 "know the aggregation method. " 341 "Another alternative is to not try to update such " 342 "{variable_type} in replica context, but in cross replica " 343 "context. You can enter cross replica context by calling " 344 "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." 345 "Inside `merge_fn`, you can then update the {variable_type} " 346 "using `tf.distribute.StrategyExtended.update()`.") 347 348 349scatter_error_msg = ("{op_name} is only supported for mirrored " 350 "variable (variable created within certain " 351 "`tf.distribute.Strategy` scope) with NONE or " 352 "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.") 353 354 355def is_saving_non_distributed(): 356 """Returns whether we're saving a non-distributed version of the model. 357 358 It returns True iff we are in saving context and are saving a non-distributed 359 version of the model. That is, SaveOptions.experimental_variable_policy is 360 NONE. 361 362 Returns: 363 A boolean. 364 """ 365 if not save_context.in_save_context(): 366 return False 367 options = save_context.get_save_options() 368 return (options.experimental_variable_policy != 369 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES) 370 371 372def mark_as_unsaveable(): 373 """Marks the function as unsaveable if not inside save context.""" 374 if ops.inside_function() and not save_context.in_save_context(): 375 ops.get_default_graph().mark_as_unsaveable(""" 376ConcreteFunction that uses distributed variables in certain way cannot be saved. 377If you're saving with 378 379tf.saved_model.save(..., signatures=f.get_concrete_function()) 380 381do 382 383@tf.function(input_signature=...) 384def f_with_input_signature(): 385 ... 386 387tf.saved_model.save(..., signatures=f_with_input_signature)` 388 389instead.""") 390