1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains LossScale classes.""" 16import abc 17 18from tensorflow.python.distribute import distribution_strategy_context 19from tensorflow.python.distribute import reduce_util 20from tensorflow.python.eager import context 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import variable_scope 26from tensorflow.python.ops import variables 27from tensorflow.python.trackable import base as trackable 28from tensorflow.python.util import deprecation 29from tensorflow.python.util import nest 30from tensorflow.python.util.tf_export import tf_export 31 32 33@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale', 34 'train.experimental.LossScale') 35@tf_export( 36 v1=[ 37 'mixed_precision.LossScale', 38 'mixed_precision.experimental.LossScale', 39 'train.experimental.LossScale' 40 ]) 41class LossScale(trackable.Trackable, metaclass=abc.ABCMeta): 42 """Base class for all TF1 loss scales. 43 44 This is an abstract base class, so you cannot instantiate it directly. 45 Instead, use one of its concrete subclasses: 46 * `tf.compat.v1.mixed_precision.DynamicLossScale` 47 * `tf.compat.v1.mixed_precision.FixedLossScale` 48 49 Loss scaling is a process that multiplies the loss by a multiplier called the 50 loss scale, and divides each gradient by the same multiplier. The pseudocode 51 for this process is: 52 53 ``` 54 loss = ... 55 loss *= loss_scale 56 grads = gradients(loss, vars) 57 grads /= loss_scale 58 ``` 59 60 Mathematically, loss scaling has no effect, but can help avoid numerical 61 underflow in intermediate gradients when float16 tensors are used for mixed 62 precision training. By multiplying the loss, each intermediate gradient will 63 have the same multiplier applied. 64 65 Instances of this class represent a loss scale. Calling instances of this 66 class returns the loss scale as a scalar float32 tensor, while method 67 `update()` updates the loss scale depending on the values of the gradients. 68 Optimizers use instances of this class to scale loss and gradients. 69 70 In most functions that accept a LossScale, you can also pass an int (such as 71 8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic 72 loss scale. 73 """ 74 75 def __init__(self): 76 """Initializes the loss scale class.""" 77 self._weights = {} 78 79 @abc.abstractmethod 80 def __call__(self): 81 """Returns the current loss scale as a scalar `float32` tensor.""" 82 pass 83 84 @abc.abstractmethod 85 def update(self, grads): 86 """Updates the value of the loss scale. 87 88 The loss scale will be potentially updated, based on the value of `grads`. 89 The tensor returned by calling this class is only updated when this function 90 is evaluated. 91 92 In eager mode, this directly updates the loss scale, so that calling 93 `__call__` will return the newly updated loss scale. In graph mode, 94 this returns an op that, when evaluated, updates the loss scale. 95 96 This function also returns a `should_apply_gradients` bool. If False, 97 gradients should not be applied to the variables that step, as nonfinite 98 gradients were found, and the loss scale has been be updated to reduce the 99 chance of finding nonfinite gradients in the next step. Some loss scale 100 classes will always return True, as they cannot adjust themselves in 101 response to nonfinite gradients. 102 103 When a DistributionStrategy is used, this function may only be called in a 104 cross-replica context. 105 106 Args: 107 grads: A nested structure of unscaled gradients, each which is the 108 gradient of the loss with respect to a weight. The gradients should have 109 already been divided by the loss scale being before passed to this 110 function. 'None' gradients are accepted, and are ignored. 111 112 Returns: 113 update_op: In eager mode, None. In graph mode, an op to update the loss 114 scale. 115 should_apply_gradients: Either a bool or a scalar boolean tensor. If 116 False, the caller should skip applying `grads` to the variables this 117 step. 118 """ 119 pass 120 121 def _add_weight(self, name, initial_value, dtype=None): 122 """Adds a weight to this loss scale. 123 124 Args: 125 name: Variable name. 126 initial_value: The variable's initial value. 127 dtype: The type of the variable. 128 129 Returns: 130 A variable. 131 132 Raises: 133 RuntimeError: If a weight with `name` has already been added. 134 """ 135 variable = variable_scope.variable( 136 initial_value=initial_value, 137 name=name, 138 dtype=dtype, 139 trainable=False, 140 use_resource=True, 141 synchronization=variables.VariableSynchronization.AUTO, 142 # Set aggregation to NONE, as loss scaling variables should never be 143 # aggregated. 144 aggregation=variables.VariableAggregation.NONE) 145 if context.executing_eagerly(): 146 graph_key = None 147 else: 148 graph = ops.get_default_graph() 149 graph_key = graph._graph_key # pylint: disable=protected-access 150 151 key = (name, graph_key) 152 if self._weights.get(key, None) is not None: 153 raise RuntimeError('Duplicate variables detected. {}'.format(key)) 154 self._weights[key] = variable 155 self._handle_deferred_dependencies(name=name, trackable=variable) 156 return variable 157 158 def _trackable_children(self, 159 save_type=trackable.SaveType.CHECKPOINT, 160 **kwargs): 161 """From Trackable. Gather graph-specific weights to save.""" 162 if context.executing_eagerly(): 163 graph_key = None 164 else: 165 graph = ops.get_default_graph() 166 graph_key = graph._graph_key # pylint: disable=protected-access 167 weights = {} 168 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 169 if g == graph_key: 170 weights[name] = v 171 weights.update( 172 super(LossScale, self)._trackable_children(save_type, **kwargs)) 173 return weights 174 175 def _lookup_dependency(self, name): 176 """From Trackable. Find a weight in the current graph.""" 177 unconditional = super(LossScale, self)._lookup_dependency(name) 178 if unconditional is not None: 179 return unconditional 180 if context.executing_eagerly(): 181 graph_key = None 182 else: 183 graph = ops.get_default_graph() 184 graph_key = graph._graph_key # pylint: disable=protected-access 185 return self._weights.get((name, graph_key), None) 186 187 @abc.abstractmethod 188 def get_config(self): 189 """Returns the config of this loss scale.""" 190 pass 191 192 @classmethod 193 def from_config(cls, config): 194 """Creates the LossScale from its config.""" 195 return cls(**config) 196 197 198@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale', 199 'train.experimental.FixedLossScale') 200@tf_export( 201 v1=[ 202 'mixed_precision.FixedLossScale', 203 'mixed_precision.experimental.FixedLossScale', 204 'train.experimental.FixedLossScale' 205 ]) 206class FixedLossScale(LossScale): 207 """Loss scale with a fixed value. 208 209 The loss scale is not updated for the lifetime of instances of this class. 210 A given instance of this class always returns the same number when called. 211 """ 212 213 @deprecation.deprecated( 214 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 215 'LossScaleOptimizer now has all the functionality of ' 216 'FixedLossScale') 217 def __init__(self, loss_scale_value): 218 """Creates the fixed loss scale. 219 220 Args: 221 loss_scale_value: A Python float. Its ideal value varies depending on 222 models to run. Choosing a too small loss_scale might affect model 223 quality; a too big loss_scale might cause inf or nan. There is no single 224 right loss_scale to apply. There is no harm choosing a relatively big 225 number as long as no nan or inf is encountered in training. 226 227 Raises: 228 ValueError: If loss_scale_value is less than 1. 229 """ 230 super(FixedLossScale, self).__init__() 231 if not isinstance(loss_scale_value, (int, float)): 232 raise ValueError('loss_scale_value must be a Python int or float.') 233 if loss_scale_value < 1: 234 raise ValueError('loss_scale_value must be at least 1.') 235 # It's important we do not create tensors in the constructor, as such 236 # tensors might be on a different device or tf.function vs when the tensor 237 # is used. This would hurt performance. Therefore, we do not create a tensor 238 # from loss_scale_value, but instead leave it as a Python float. 239 # TODO(reedwm): Also do not create tensors in the DynamicLossScale 240 # constructor. 241 self._loss_scale_value = float(loss_scale_value) 242 243 def __call__(self): 244 return ops.convert_to_tensor(self._loss_scale_value) 245 246 def update(self, grads): 247 del grads 248 return control_flow_ops.no_op(), True 249 250 def __repr__(self): 251 return 'FixedLossScale(%s)' % self._loss_scale_value 252 253 def get_config(self): 254 return {'loss_scale_value': self._loss_scale_value} 255 256 257def _is_all_finite(grads): 258 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 259 is_finite_per_grad = [ 260 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 261 ] 262 return math_ops.reduce_all(is_finite_per_grad) 263 264 265def _op_in_graph_mode(tensor): 266 """Returns the tensor's op in graph mode, or the tensor in eager mode. 267 268 This is useful because sometimes an op is needed in graph mode instead of a 269 tensor. In eager mode, there are no ops. 270 271 Args: 272 tensor: A tensor. 273 274 Returns: 275 The tensor's op in graph mode. The tensor in eager mode. 276 """ 277 if context.executing_eagerly(): 278 return tensor 279 return tensor.op 280 281 282def _assign_if_finite(var, value): 283 """Assigns a value to a variable if the value is finite.""" 284 return control_flow_ops.cond( 285 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 286 control_flow_ops.no_op) 287 288 289@deprecation.deprecated_endpoints( 290 'mixed_precision.experimental.DynamicLossScale', 291 'train.experimental.DynamicLossScale') 292@tf_export( 293 v1=[ 294 'mixed_precision.DynamicLossScale', 295 'mixed_precision.experimental.DynamicLossScale', 296 'train.experimental.DynamicLossScale' 297 ]) 298class DynamicLossScale(LossScale): 299 """Loss scale that dynamically adjusts itself. 300 301 Dynamic loss scaling works by adjusting the loss scale as training progresses. 302 The goal is to keep the loss scale as high as possible without overflowing the 303 gradients. As long as the gradients do not overflow, raising the loss scale 304 never hurts. 305 306 The algorithm starts by setting the loss scale to an initial value. Every N 307 steps that the gradients are finite, the loss scale is increased by some 308 factor. However, if a NaN or Inf gradient is found, the gradients for that 309 step are not applied, and the loss scale is decreased by the factor. This 310 process tends to keep the loss scale as high as possible without gradients 311 overflowing. 312 """ 313 314 @deprecation.deprecated( 315 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 316 'LossScaleOptimizer now has all the functionality of ' 317 'DynamicLossScale') 318 def __init__(self, 319 initial_loss_scale=2 ** 15, # See docstring for why this is big. 320 increment_period=2000, 321 multiplier=2.): 322 """Creates the dynamic loss scale. 323 324 Args: 325 initial_loss_scale: A Python float. The loss scale to use at the 326 beginning. It's better to start this at a very high number, because a 327 loss scale that is too high gets lowered far more quickly than a loss 328 scale that is too low gets raised. The default is 2 ** 15, which is 329 approximately half the maximum float16 value. 330 increment_period: Increases loss scale every `increment_period` 331 consecutive steps that finite gradients are encountered. If a nonfinite 332 gradient is encountered, the count is reset back to zero. 333 multiplier: The multiplier to use when increasing or decreasing the loss 334 scale. 335 """ 336 super(DynamicLossScale, self).__init__() 337 self._initial_loss_scale = float(initial_loss_scale) 338 self._increment_period = int(increment_period) 339 self._multiplier = float(multiplier) 340 341 self._current_loss_scale = self._add_weight( 342 name='current_loss_scale', 343 dtype=dtypes.float32, 344 initial_value=self._initial_loss_scale) 345 # The number of consecutive steps with finite gradients since the last 346 # nonfinite gradient or change in loss scale. 347 self._num_good_steps = self._add_weight( 348 name='good_steps', dtype=dtypes.int64, initial_value=0) 349 350 @property 351 def initial_loss_scale(self): 352 return self._initial_loss_scale 353 354 @property 355 def increment_period(self): 356 return self._increment_period 357 358 @property 359 def multiplier(self): 360 return self._multiplier 361 362 def __call__(self): 363 return ops.convert_to_tensor(self._current_loss_scale) 364 365 def update(self, grads): 366 """Updates loss scale based on if gradients are finite in current step.""" 367 grads = nest.flatten(grads) 368 if distribution_strategy_context.has_strategy(): 369 distribution = distribution_strategy_context.get_cross_replica_context() 370 371 def get_is_finite(grads): 372 is_finite = _is_all_finite(grads) 373 # We cast to float, because we cannot reduce booleans with 374 # DistributionStrategy. 375 return math_ops.cast(is_finite, dtypes.float32) 376 377 is_finite_float = distribution.extended.call_for_each_replica( 378 get_is_finite, args=(grads,)) 379 reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, 380 is_finite_float, axis=None) 381 is_finite = math_ops.equal(reduced_is_finite_float, 382 distribution.num_replicas_in_sync) 383 else: 384 is_finite = _is_all_finite(grads) 385 386 def update_if_finite_grads(): 387 """Update assuming the gradients are finite.""" 388 389 def incr_loss_scale(): 390 new_loss_scale = self._current_loss_scale * self._multiplier 391 return control_flow_ops.group( 392 _assign_if_finite(self._current_loss_scale, new_loss_scale), 393 self._num_good_steps.assign(0)) 394 395 return control_flow_ops.cond( 396 self._num_good_steps + 1 >= self._increment_period, 397 incr_loss_scale, lambda: _op_in_graph_mode( 398 self._num_good_steps.assign_add(1))) 399 400 def update_if_not_finite_grads(): 401 """Update assuming the gradients are nonfinite.""" 402 403 new_loss_scale = math_ops.maximum( 404 self._current_loss_scale / self._multiplier, 1) 405 return control_flow_ops.group( 406 self._num_good_steps.assign(0), 407 self._current_loss_scale.assign(new_loss_scale)) 408 409 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 410 update_if_not_finite_grads) 411 should_apply_gradients = is_finite 412 return update_op, should_apply_gradients 413 414 def __repr__(self): 415 if context.executing_eagerly(): 416 return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, ' 417 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' % 418 (self._current_loss_scale.numpy(), self._num_good_steps.numpy(), 419 self.initial_loss_scale, self.increment_period, self.multiplier)) 420 else: 421 return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, ' 422 'multiplier=%s)' % 423 (self.initial_loss_scale, self.increment_period, self.multiplier)) 424 425 def get_config(self): 426 return { 427 'initial_loss_scale': self.initial_loss_scale, 428 'increment_period': self.increment_period, 429 'multiplier': self.multiplier, 430 } 431 432 433def get(identifier): 434 """Get a loss scale object.""" 435 if isinstance(identifier, (int, float)): 436 return FixedLossScale(identifier) 437 if identifier == 'dynamic': 438 return DynamicLossScale() 439 if isinstance(identifier, LossScale): 440 return identifier 441 elif identifier is None: 442 return None 443 else: 444 raise ValueError('Could not interpret loss scale identifier: %s' % 445 identifier) 446