xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/experimental/loss_scale_optimizer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16from tensorflow.python.distribute import distribution_strategy_context
17from tensorflow.python.framework import indexed_slices
18from tensorflow.python.framework import smart_cond
19from tensorflow.python.ops import control_flow_ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.training import optimizer
22from tensorflow.python.training.experimental import loss_scale as loss_scale_module
23from tensorflow.python.util import deprecation
24from tensorflow.python.util.tf_export import tf_export
25
26
27@deprecation.deprecated_endpoints(
28    'train.experimental.MixedPrecisionLossScaleOptimizer')
29@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer',
30               'train.experimental.MixedPrecisionLossScaleOptimizer'])
31class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
32  """An optimizer that applies loss scaling.
33
34  Loss scaling is a process that multiplies the loss by a multiplier called the
35  loss scale, and divides each gradient by the same multiplier. The pseudocode
36  for this process is:
37
38  ```
39  loss = ...
40  loss *= loss_scale
41  grads = gradients(loss, vars)
42  grads /= loss_scale
43  ```
44
45  Mathematically, loss scaling has no effect, but can help avoid numerical
46  underflow in intermediate gradients when float16 tensors are used for mixed
47  precision training. By multiplying the loss, each intermediate gradient will
48  have the same multiplier applied.
49
50  The loss scale can either be a fixed constant, chosen by the user, or be
51  dynamically determined. Dynamically determining the loss scale is convenient
52  as a loss scale does not have to be explicitly chosen. However it reduces
53  performance.
54
55  This optimizer wraps another optimizer and applies loss scaling to it via a
56  `LossScale`. Loss scaling is applied whenever gradients are
57  computed, such as through `minimize()`.
58  """
59
60  def __init__(self, opt, loss_scale):
61    if not isinstance(opt, optimizer.Optimizer):
62      raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
63                       type(opt))
64    self._optimizer = opt
65
66    use_locking = opt._use_locking  # pylint: disable=protected-access
67    name = opt.get_name()
68    super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
69
70    self._loss_scale = loss_scale_module.get(loss_scale)
71    if self._loss_scale is None:
72      raise ValueError('loss_scale cannot be None')
73    self._track_trackable(self._optimizer, 'base_optimizer')
74    self._track_trackable(self._loss_scale, 'loss_scale')
75
76  def _doing_dynamic_loss_scaling(self):
77    """Check if `_loss_scale` dynamically manages the loss scale."""
78    return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
79
80  def compute_gradients(self,
81                        loss,
82                        var_list=None,
83                        gate_gradients=optimizer.Optimizer.GATE_OP,
84                        aggregation_method=None,
85                        colocate_gradients_with_ops=False,
86                        grad_loss=None):
87    """Compute gradients of `loss` for the variables in `var_list`.
88
89    This adjusts the dynamic range of the gradient evaluation by scaling up
90    the `loss` value. The gradient values are then scaled back down by the
91    reciprocal of the loss scale. This is useful in reduced precision training
92    where small gradient values would otherwise underflow the representable
93    range.
94
95    Args:
96      loss: A Tensor containing the value to minimize or a callable taking no
97        arguments which returns the value to minimize. When eager execution is
98        enabled it must be a callable.
99      var_list: Optional list or tuple of `tf.Variable` to update to minimize
100        `loss`.  Defaults to the list of variables collected in the graph under
101        the key `GraphKeys.TRAINABLE_VARIABLES`.
102      gate_gradients: How to gate the computation of gradients.  Can be
103        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
104      aggregation_method: Specifies the method used to combine gradient terms.
105        Valid values are defined in the class `AggregationMethod`.
106      colocate_gradients_with_ops: If True, try colocating gradients with the
107        corresponding op.
108      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
109
110    Returns:
111      A list of (gradient, variable) pairs. Variable is always present, but
112      gradient can be `None`.
113    """
114    loss = self._scale_loss(loss)
115    grads_and_vars = self._optimizer.compute_gradients(
116        loss=loss,
117        var_list=var_list,
118        gate_gradients=gate_gradients,
119        aggregation_method=aggregation_method,
120        colocate_gradients_with_ops=colocate_gradients_with_ops,
121        grad_loss=grad_loss)
122
123    grads = [g for g, _ in grads_and_vars]
124    variables = [v for _, v in grads_and_vars]
125    unscaled_grads = self._unscale_grads(grads)
126    return list(zip(unscaled_grads, variables))
127
128  def _scale_loss(self, loss):
129    loss_scale = self._loss_scale()
130    if callable(loss):
131      def new_loss():
132        loss_val = loss()
133        return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
134      return new_loss
135    else:
136      return loss * math_ops.cast(loss_scale, loss.dtype)
137
138  def _unscale_grads(self, grads):
139    loss_scale = self._loss_scale()
140    loss_scale_reciprocal = 1 / loss_scale
141    return [
142        None if g is None else self._scale_grad(g, loss_scale_reciprocal)
143        for g in grads
144    ]
145
146  def _scale_grad(self, grad, loss_scale_reciprocal):
147    if isinstance(grad, indexed_slices.IndexedSlices):
148      grad_vals = grad.values * loss_scale_reciprocal
149      return indexed_slices.IndexedSlices(grad_vals, grad.indices,
150                                          grad.dense_shape)
151    return grad * loss_scale_reciprocal
152
153  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
154    """Apply gradients to variables.
155
156    This is the second part of `minimize()`. It returns an `Operation` that
157    conditionally applies gradients if all gradient values are finite.
158    Otherwise no update is performed (nor is `global_step` incremented).
159
160    Args:
161      grads_and_vars: List of (gradient, variable) pairs as returned by
162        `compute_gradients()`.
163      global_step: Optional `Variable` to increment by one after the variables
164        have been updated.
165      name: Optional name for the returned operation.  Default to the name
166        passed to the `Optimizer` constructor.
167
168    Returns:
169      An `Operation` that conditionally applies the specified gradients. If
170      `global_step` was not None, that operation also increments `global_step`.
171
172    Raises:
173      RuntimeError: If you should use `_distributed_apply()` instead.
174    """
175    if distribution_strategy_context.in_cross_replica_context():
176      raise ValueError('apply_gradients() must be called in a replica context.')
177
178    if not self._doing_dynamic_loss_scaling():
179      return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
180
181    replica_context = distribution_strategy_context.get_replica_context()
182    grads_and_vars = tuple(grads_and_vars)
183
184    # TODO(nluehr) cleanup GraphKeys.TRAIN_OP
185    return replica_context.merge_call(
186        self._distributed_apply, args=(grads_and_vars, global_step, name))
187
188  def _distributed_apply(self,
189                         distribution,
190                         grads_and_vars,
191                         global_step=None,
192                         name=None):
193    """A version of `apply_gradients` for cross replica context.
194
195    When users are in a cross replica strategy, they must call this rather than
196    `apply_gradients()`.
197
198    Args:
199      distribution: a `DistributionStrategy` object.
200      grads_and_vars: List of (gradient, variable) pairs as returned by
201        `compute_gradients()` and then aggregated across replicas.
202      global_step: Optional (mirrored) `Variable` to increment by one after the
203        variables have been updated.
204      name: Optional name for the returned operation. Default to the name passed
205        to the `Optimizer` constructor.
206
207    Returns:
208      An `Operation` that applies the specified gradients across all
209      replicas. If `global_step` was not None, that operation also
210      increments `global_step`
211    """
212    name = name if name is not None else self.get_name()
213    grads = [g for g, _ in grads_and_vars]
214    loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
215
216    def apply_fn():
217      return self._apply_gradients(distribution, grads_and_vars, global_step,
218                                   name + '-wrapped')
219
220    maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
221                                           control_flow_ops.no_op)
222    return control_flow_ops.group(
223        maybe_apply_op, loss_scale_update_op, name=name)
224
225  def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
226    """Unconditionally apply gradients in cross replica context."""
227    update_ops = distribution.extended.call_for_each_replica(
228        self._optimizer.apply_gradients,
229        args=(grads_and_vars, global_step, name))
230    return distribution.group(update_ops)
231
232  def _apply_sparse(self, grad, var):
233    """This function should never be called."""
234    raise RuntimeError('This function should never be called')
235
236  def _apply_dense(self, grad, var):
237    """This function should never be called."""
238    raise RuntimeError('This function should never be called')
239
240  def _resource_apply_sparse(self, grad, handle, indices):
241    """This function should never be called."""
242    raise RuntimeError('This function should never be called')
243
244  def _resource_apply_dense(self, grad, handle):
245    """This function should never be called."""
246    raise RuntimeError('This function should never be called')
247
248  def variables(self):
249    """Returns the variables of the Optimizer."""
250    return (self._optimizer.variables() +
251            list(self._loss_scale._weights.values()))  # pylint: disable=protected-access
252