xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/experimental/loss_scale.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16import abc
17
18from tensorflow.python.distribute import distribution_strategy_context
19from tensorflow.python.distribute import reduce_util
20from tensorflow.python.eager import context
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import variable_scope
26from tensorflow.python.ops import variables
27from tensorflow.python.trackable import base as trackable
28from tensorflow.python.util import deprecation
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
31
32
33@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale',
34                                  'train.experimental.LossScale')
35@tf_export(
36    v1=[
37        'mixed_precision.LossScale',
38        'mixed_precision.experimental.LossScale',
39        'train.experimental.LossScale'
40    ])
41class LossScale(trackable.Trackable, metaclass=abc.ABCMeta):
42  """Base class for all TF1 loss scales.
43
44  This is an abstract base class, so you cannot instantiate it directly.
45  Instead, use one of its concrete subclasses:
46    * `tf.compat.v1.mixed_precision.DynamicLossScale`
47    * `tf.compat.v1.mixed_precision.FixedLossScale`
48
49  Loss scaling is a process that multiplies the loss by a multiplier called the
50  loss scale, and divides each gradient by the same multiplier. The pseudocode
51  for this process is:
52
53  ```
54  loss = ...
55  loss *= loss_scale
56  grads = gradients(loss, vars)
57  grads /= loss_scale
58  ```
59
60  Mathematically, loss scaling has no effect, but can help avoid numerical
61  underflow in intermediate gradients when float16 tensors are used for mixed
62  precision training. By multiplying the loss, each intermediate gradient will
63  have the same multiplier applied.
64
65  Instances of this class represent a loss scale. Calling instances of this
66  class returns the loss scale as a scalar float32 tensor, while method
67  `update()` updates the loss scale depending on the values of the gradients.
68  Optimizers use instances of this class to scale loss and gradients.
69
70  In most functions that accept a LossScale, you can also pass an int (such as
71  8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic
72  loss scale.
73  """
74
75  def __init__(self):
76    """Initializes the loss scale class."""
77    self._weights = {}
78
79  @abc.abstractmethod
80  def __call__(self):
81    """Returns the current loss scale as a scalar `float32` tensor."""
82    pass
83
84  @abc.abstractmethod
85  def update(self, grads):
86    """Updates the value of the loss scale.
87
88    The loss scale will be potentially updated, based on the value of `grads`.
89    The tensor returned by calling this class is only updated when this function
90    is evaluated.
91
92    In eager mode, this directly updates the loss scale, so that calling
93    `__call__` will return the newly updated loss scale. In graph mode,
94    this returns an op that, when evaluated, updates the loss scale.
95
96    This function also returns a `should_apply_gradients` bool. If False,
97    gradients should not be applied to the variables that step, as nonfinite
98    gradients were found, and the loss scale has been be updated to reduce the
99    chance of finding nonfinite gradients in the next step. Some loss scale
100    classes will always return True, as they cannot adjust themselves in
101    response to nonfinite gradients.
102
103    When a DistributionStrategy is used, this function may only be called in a
104    cross-replica context.
105
106    Args:
107      grads: A nested structure of unscaled gradients, each which is the
108        gradient of the loss with respect to a weight. The gradients should have
109        already been divided by the loss scale being before passed to this
110        function. 'None' gradients are accepted, and are ignored.
111
112    Returns:
113      update_op: In eager mode, None. In graph mode, an op to update the loss
114        scale.
115      should_apply_gradients: Either a bool or a scalar boolean tensor. If
116        False, the caller should skip applying `grads` to the variables this
117        step.
118    """
119    pass
120
121  def _add_weight(self, name, initial_value, dtype=None):
122    """Adds a weight to this loss scale.
123
124    Args:
125      name: Variable name.
126      initial_value: The variable's initial value.
127      dtype: The type of the variable.
128
129    Returns:
130      A variable.
131
132    Raises:
133      RuntimeError: If a weight with `name` has already been added.
134    """
135    variable = variable_scope.variable(
136        initial_value=initial_value,
137        name=name,
138        dtype=dtype,
139        trainable=False,
140        use_resource=True,
141        synchronization=variables.VariableSynchronization.AUTO,
142        # Set aggregation to NONE, as loss scaling variables should never be
143        # aggregated.
144        aggregation=variables.VariableAggregation.NONE)
145    if context.executing_eagerly():
146      graph_key = None
147    else:
148      graph = ops.get_default_graph()
149      graph_key = graph._graph_key  # pylint: disable=protected-access
150
151    key = (name, graph_key)
152    if self._weights.get(key, None) is not None:
153      raise RuntimeError('Duplicate variables detected. {}'.format(key))
154    self._weights[key] = variable
155    self._handle_deferred_dependencies(name=name, trackable=variable)
156    return variable
157
158  def _trackable_children(self,
159                          save_type=trackable.SaveType.CHECKPOINT,
160                          **kwargs):
161    """From Trackable. Gather graph-specific weights to save."""
162    if context.executing_eagerly():
163      graph_key = None
164    else:
165      graph = ops.get_default_graph()
166      graph_key = graph._graph_key  # pylint: disable=protected-access
167    weights = {}
168    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
169      if g == graph_key:
170        weights[name] = v
171    weights.update(
172        super(LossScale, self)._trackable_children(save_type, **kwargs))
173    return weights
174
175  def _lookup_dependency(self, name):
176    """From Trackable. Find a weight in the current graph."""
177    unconditional = super(LossScale, self)._lookup_dependency(name)
178    if unconditional is not None:
179      return unconditional
180    if context.executing_eagerly():
181      graph_key = None
182    else:
183      graph = ops.get_default_graph()
184      graph_key = graph._graph_key  # pylint: disable=protected-access
185    return self._weights.get((name, graph_key), None)
186
187  @abc.abstractmethod
188  def get_config(self):
189    """Returns the config of this loss scale."""
190    pass
191
192  @classmethod
193  def from_config(cls, config):
194    """Creates the LossScale from its config."""
195    return cls(**config)
196
197
198@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale',
199                                  'train.experimental.FixedLossScale')
200@tf_export(
201    v1=[
202        'mixed_precision.FixedLossScale',
203        'mixed_precision.experimental.FixedLossScale',
204        'train.experimental.FixedLossScale'
205    ])
206class FixedLossScale(LossScale):
207  """Loss scale with a fixed value.
208
209  The loss scale is not updated for the lifetime of instances of this class.
210  A given instance of this class always returns the same number when called.
211  """
212
213  @deprecation.deprecated(
214      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
215            'LossScaleOptimizer now has all the functionality of '
216            'FixedLossScale')
217  def __init__(self, loss_scale_value):
218    """Creates the fixed loss scale.
219
220    Args:
221      loss_scale_value: A Python float. Its ideal value varies depending on
222        models to run. Choosing a too small loss_scale might affect model
223        quality; a too big loss_scale might cause inf or nan. There is no single
224        right loss_scale to apply. There is no harm choosing a relatively big
225        number as long as no nan or inf is encountered in training.
226
227    Raises:
228      ValueError: If loss_scale_value is less than 1.
229    """
230    super(FixedLossScale, self).__init__()
231    if not isinstance(loss_scale_value, (int, float)):
232      raise ValueError('loss_scale_value must be a Python int or float.')
233    if loss_scale_value < 1:
234      raise ValueError('loss_scale_value must be at least 1.')
235    # It's important we do not create tensors in the constructor, as such
236    # tensors might be on a different device or tf.function vs when the tensor
237    # is used. This would hurt performance. Therefore, we do not create a tensor
238    # from loss_scale_value, but instead leave it as a Python float.
239    # TODO(reedwm): Also do not create tensors in the DynamicLossScale
240    # constructor.
241    self._loss_scale_value = float(loss_scale_value)
242
243  def __call__(self):
244    return ops.convert_to_tensor(self._loss_scale_value)
245
246  def update(self, grads):
247    del grads
248    return control_flow_ops.no_op(), True
249
250  def __repr__(self):
251    return 'FixedLossScale(%s)' % self._loss_scale_value
252
253  def get_config(self):
254    return {'loss_scale_value': self._loss_scale_value}
255
256
257def _is_all_finite(grads):
258  """Returns a scalar boolean tensor indicating if all gradients are finite."""
259  is_finite_per_grad = [
260      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
261  ]
262  return math_ops.reduce_all(is_finite_per_grad)
263
264
265def _op_in_graph_mode(tensor):
266  """Returns the tensor's op in graph mode, or the tensor in eager mode.
267
268  This is useful because sometimes an op is needed in graph mode instead of a
269  tensor. In eager mode, there are no ops.
270
271  Args:
272    tensor: A tensor.
273
274  Returns:
275    The tensor's op in graph mode. The tensor in eager mode.
276  """
277  if context.executing_eagerly():
278    return tensor
279  return tensor.op
280
281
282def _assign_if_finite(var, value):
283  """Assigns a value to a variable if the value is finite."""
284  return control_flow_ops.cond(
285      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
286      control_flow_ops.no_op)
287
288
289@deprecation.deprecated_endpoints(
290    'mixed_precision.experimental.DynamicLossScale',
291    'train.experimental.DynamicLossScale')
292@tf_export(
293    v1=[
294        'mixed_precision.DynamicLossScale',
295        'mixed_precision.experimental.DynamicLossScale',
296        'train.experimental.DynamicLossScale'
297    ])
298class DynamicLossScale(LossScale):
299  """Loss scale that dynamically adjusts itself.
300
301  Dynamic loss scaling works by adjusting the loss scale as training progresses.
302  The goal is to keep the loss scale as high as possible without overflowing the
303  gradients. As long as the gradients do not overflow, raising the loss scale
304  never hurts.
305
306  The algorithm starts by setting the loss scale to an initial value. Every N
307  steps that the gradients are finite, the loss scale is increased by some
308  factor. However, if a NaN or Inf gradient is found, the gradients for that
309  step are not applied, and the loss scale is decreased by the factor. This
310  process tends to keep the loss scale as high as possible without gradients
311  overflowing.
312  """
313
314  @deprecation.deprecated(
315      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
316            'LossScaleOptimizer now has all the functionality of '
317            'DynamicLossScale')
318  def __init__(self,
319               initial_loss_scale=2 ** 15,  # See docstring for why this is big.
320               increment_period=2000,
321               multiplier=2.):
322    """Creates the dynamic loss scale.
323
324    Args:
325      initial_loss_scale: A Python float.  The loss scale to use at the
326        beginning. It's better to start this at a very high number, because a
327        loss scale that is too high gets lowered far more quickly than a loss
328        scale that is too low gets raised. The default is 2 ** 15, which is
329        approximately half the maximum float16 value.
330      increment_period: Increases loss scale every `increment_period`
331        consecutive steps that finite gradients are encountered. If a nonfinite
332        gradient is encountered, the count is reset back to zero.
333      multiplier: The multiplier to use when increasing or decreasing the loss
334        scale.
335    """
336    super(DynamicLossScale, self).__init__()
337    self._initial_loss_scale = float(initial_loss_scale)
338    self._increment_period = int(increment_period)
339    self._multiplier = float(multiplier)
340
341    self._current_loss_scale = self._add_weight(
342        name='current_loss_scale',
343        dtype=dtypes.float32,
344        initial_value=self._initial_loss_scale)
345    # The number of consecutive steps with finite gradients since the last
346    # nonfinite gradient or change in loss scale.
347    self._num_good_steps = self._add_weight(
348        name='good_steps', dtype=dtypes.int64, initial_value=0)
349
350  @property
351  def initial_loss_scale(self):
352    return self._initial_loss_scale
353
354  @property
355  def increment_period(self):
356    return self._increment_period
357
358  @property
359  def multiplier(self):
360    return self._multiplier
361
362  def __call__(self):
363    return ops.convert_to_tensor(self._current_loss_scale)
364
365  def update(self, grads):
366    """Updates loss scale based on if gradients are finite in current step."""
367    grads = nest.flatten(grads)
368    if distribution_strategy_context.has_strategy():
369      distribution = distribution_strategy_context.get_cross_replica_context()
370
371      def get_is_finite(grads):
372        is_finite = _is_all_finite(grads)
373        # We cast to float, because we cannot reduce booleans with
374        # DistributionStrategy.
375        return math_ops.cast(is_finite, dtypes.float32)
376
377      is_finite_float = distribution.extended.call_for_each_replica(
378          get_is_finite, args=(grads,))
379      reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
380                                                    is_finite_float, axis=None)
381      is_finite = math_ops.equal(reduced_is_finite_float,
382                                 distribution.num_replicas_in_sync)
383    else:
384      is_finite = _is_all_finite(grads)
385
386    def update_if_finite_grads():
387      """Update assuming the gradients are finite."""
388
389      def incr_loss_scale():
390        new_loss_scale = self._current_loss_scale * self._multiplier
391        return control_flow_ops.group(
392            _assign_if_finite(self._current_loss_scale, new_loss_scale),
393            self._num_good_steps.assign(0))
394
395      return control_flow_ops.cond(
396          self._num_good_steps + 1 >= self._increment_period,
397          incr_loss_scale, lambda: _op_in_graph_mode(
398              self._num_good_steps.assign_add(1)))
399
400    def update_if_not_finite_grads():
401      """Update assuming the gradients are nonfinite."""
402
403      new_loss_scale = math_ops.maximum(
404          self._current_loss_scale / self._multiplier, 1)
405      return control_flow_ops.group(
406          self._num_good_steps.assign(0),
407          self._current_loss_scale.assign(new_loss_scale))
408
409    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
410                                      update_if_not_finite_grads)
411    should_apply_gradients = is_finite
412    return update_op, should_apply_gradients
413
414  def __repr__(self):
415    if context.executing_eagerly():
416      return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
417              'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
418              (self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
419               self.initial_loss_scale, self.increment_period, self.multiplier))
420    else:
421      return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
422              'multiplier=%s)' %
423              (self.initial_loss_scale, self.increment_period, self.multiplier))
424
425  def get_config(self):
426    return {
427        'initial_loss_scale': self.initial_loss_scale,
428        'increment_period': self.increment_period,
429        'multiplier': self.multiplier,
430    }
431
432
433def get(identifier):
434  """Get a loss scale object."""
435  if isinstance(identifier, (int, float)):
436    return FixedLossScale(identifier)
437  if identifier == 'dynamic':
438    return DynamicLossScale()
439  if isinstance(identifier, LossScale):
440    return identifier
441  elif identifier is None:
442    return None
443  else:
444    raise ValueError('Could not interpret loss scale identifier: %s' %
445                     identifier)
446