xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/experimental/mixed_precision.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functions to use mixed precision with the graph rewrite."""
16
17from tensorflow.python.framework import config
18from tensorflow.python.platform import tf_logging
19from tensorflow.python.training import optimizer
20from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1
21from tensorflow.python.training.experimental import mixed_precision_global_state
22from tensorflow.python.util import deprecation
23from tensorflow.python.util.tf_export import tf_export
24
25
26# A mapping between optimizers and (wrapper_fn, wrapper_cls) pairs. wrapper_cls
27# is a loss scale optimizer class, and wrapper_fn is a function that takes in
28# an optimizer and LossScale and returns a wrapper_cls instance.
29_REGISTERED_WRAPPER_OPTIMIZER_CLS = {
30    optimizer.Optimizer:
31        (loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,) * 2,
32}
33
34
35@tf_export('__internal__.mixed_precision.register_loss_scale_wrapper', v1=[])
36def register_loss_scale_wrapper(optimizer_cls, wrapper_fn, wrapper_cls=None):
37  """Registers a loss scale optimizer wrapper.
38
39  `tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite`
40  automatically wraps an optimizer with an optimizer wrapper that performs loss
41  scaling. This function registers a
42  `(base_cls, wrapper_fn, wrapper_cls)` triple
43  that is used by `enable_mixed_precision_graph_rewrite`, where
44  `wrapper_fn` is called to create a `wrapper_cls` instance that wraps an
45  `optimizer_cls` instance.
46
47  Args:
48    optimizer_cls: A base optimizer class, e.g. `tf.keras.optimizers.Optimizer`.
49    wrapper_fn: A function that takes in arguments "optimizer" and
50      "loss_scale", and returns a loss scale optimizer of type "wrapper_cls"
51      that wraps "optimizer".
52    wrapper_cls: A loss scale optimizer class. Defaults to `wrapper_fn`, in
53      which case `wrapper_fn` should be a loss scale optimizer class whose
54      constructor takes in arguments "optimizer" and "loss_scale".
55  """
56  _REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = (
57      wrapper_fn, wrapper_cls or wrapper_fn)
58
59
60def _wrap_optimizer(opt, loss_scale):
61  """Wraps an optimizer with a LossScaleOptimizer."""
62
63  for _, wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values():
64    if isinstance(opt, wrapper_optimizer):
65      raise ValueError('"opt" must not already be an instance of a {cls}. '
66                       '`enable_mixed_precision_graph_rewrite` will '
67                       'automatically wrap the optimizer with a '
68                       '{cls}.'
69                       .format(cls=wrapper_optimizer.__name__))
70
71  for optimizer_cls, (wrapper_fn, _) in (
72      _REGISTERED_WRAPPER_OPTIMIZER_CLS.items()):
73    if isinstance(opt, optimizer_cls):
74      return wrapper_fn(opt, loss_scale)
75
76  raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
77                   'tf.keras.optimizers.Optimizer, but got: %s' % opt)
78
79
80@deprecation.deprecated_endpoints(
81    'train.experimental.enable_mixed_precision_graph_rewrite')
82@tf_export(v1=['mixed_precision.enable_mixed_precision_graph_rewrite',
83               'train.experimental.enable_mixed_precision_graph_rewrite'])
84def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
85  """Enable mixed precision via a graph rewrite.
86
87  Mixed precision is the use of both float32 and float16 data types when
88  training a model to improve performance. This is achieved via a graph rewrite
89  operation and a loss-scale optimizer.
90
91  Performing arithmetic operations in float16 takes advantage of specialized
92  processing units, such as NVIDIA Tensor Cores, for much higher arithmetic
93  throughput. However, due to the smaller representable range, performing the
94  entire training with float16 can result in gradient underflow, that is, small
95  gradient values becoming zeroes. Instead, performing only select arithmetic
96  operations in float16 results in higher throughput and decreased training
97  time when using compatible hardware accelerators while also reducing memory
98  usage, typically without sacrificing model accuracy.
99
100  Note: While the mixed precision rewrite changes the datatype of various
101  layers throughout the model, the same accuracy reached in float32 is
102  expected. If a `NaN` gradient occurs with dynamic loss scaling, the model
103  update for that batch is skipped. In this case, the global step count is not
104  incremented, and the `LossScaleOptimizer` attempts to decrease the loss
105  scaling value to avoid `NaN` values in subsequent iterations. This approach
106  has been shown to achieve the same accuracy as float32 and, in most cases,
107  better training throughput.
108
109  Example:
110
111  ```python
112  model = tf.keras.models.Sequential([
113      tf.keras.layers.Dense(64, activation='relu'),
114      tf.keras.layers.Dense(64, activation='softmax'),
115  ])
116
117  opt = tf.keras.optimizers.SGD()
118  opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
119  model.compile(loss="mse", optimizer=opt)
120
121  x_train = np.random.random((1024, 64))
122  y_train = np.random.random((1024, 64))
123  model.fit(x_train, y_train)
124  ```
125
126  Calling `enable_mixed_precision_graph_rewrite(opt)` enables the graph rewrite
127  operation before computing gradients. The function additionally returns an
128  `Optimizer` (`opt`) wrapped with a `LossScaleOptimizer`. This prevents
129  underflow in the float16 tensors during the backward pass. An optimizer of
130  type `tf.train.Optimizer` or `tf.keras.optimizers.Optimizer` must be passed
131  to this function, which will then be wrapped to use loss scaling.
132
133  The graph rewrite operation changes the `dtype` of certain operations in the
134  graph from float32 to float16. There are several categories of operations
135  that are either included or excluded by this rewrite operation. The following
136  categories of Ops are defined inside corresponding functions under the class
137  `AutoMixedPrecisionLists` in
138  <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/
139  core/grappler/optimizers/auto_mixed_precision_lists.h">
140  auto_mixed_precision_lists.h</a>:
141
142  * `ClearList`: Ops that do not have numerically significant adverse effects.
143  E.g. `ArgMax` and `Floor`.
144  * `AllowList`: Ops that are considered numerically safe for execution in
145  float16, and thus are always converted. E.g. `Conv2D`.
146  * `DenyList`: Ops that are numerically unsafe to execute in float16 and
147  can negatively affect downstream nodes. E.g. `Softmax`.
148  * `GrayList`: Ops that are considered numerically safe for execution in
149  float16 unless downstream from a DenyList Op. E.g. `Add` and `AvgPool`.
150
151  When this function is used, gradients should only be computed and applied
152  with the returned optimizer, either by calling `opt.minimize()` or
153  `opt.compute_gradients()` followed by `opt.apply_gradients()`.
154  Gradients should not be computed with `tf.gradients` or `tf.GradientTape`.
155  This is because the returned optimizer will apply loss scaling, and
156  `tf.gradients` or `tf.GradientTape` will not. If you do directly use
157  `tf.gradients` or `tf.GradientTape`, your model may not converge due to
158  float16 underflow problems.
159
160  When eager execution is enabled, the mixed precision graph rewrite is only
161  enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
162
163  For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions
164  (such as batch size, input size, output size, and channel counts)
165  should be powers of two if under 256, or  otherwise divisible by 8 if above
166  256. For more information, check out the
167  [NVIDIA Deep Learning Performance Guide](
168  https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).
169
170  Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with
171  Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The
172  parts of the graph on CPUs and TPUs are untouched by the graph rewrite.
173
174  Raises:
175    `ValueError`, if the `tf.keras.mixed_precision` API is also used by calling
176    `tf.keras.mixed_precision.set_global_policy`. Only one mixed precision
177    API can be used.
178
179  Args:
180    opt: An instance of a `tf.keras.optimizers.Optimizer` or a
181      `tf.train.Optimizer`.
182    loss_scale: Either an int/float, the string `"dynamic"`, or an instance of
183      a `tf.mixed_precision.experimental.LossScale`. The loss scale to use. It
184      is recommended to keep this as its default value of `"dynamic"`, which
185      will adjust the scaling automatically to prevent `Inf` or `NaN` values.
186
187  Returns:
188    A version of `opt` that will use loss scaling to prevent underflow.
189  """
190  if mixed_precision_global_state.is_using_mixed_precision_policy():
191    raise ValueError(
192        'The mixed precision graph rewrite cannot be enabled, because the '
193        'global Keras dtype Policy has been set to a mixed precision policy. '
194        'At most, one of the following can be called:\n\n'
195        '  1. tf.keras.mixed_precision.set_global_policy() with a mixed '
196        'precision policy (You called this first)\n\n'
197        '  2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
198        '(You called this second)\n'
199        'You called both functions, which is an error, because both functions '
200        'enable you to use mixed precision. If in doubt which function to use, '
201        'use the first, as it supports Eager execution and is more '
202        'customizable.')
203
204  if mixed_precision_global_state.non_mixed_precision_session_created():
205    # TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
206    # Sessions have already been closed, do not raise this error message.
207    tf_logging.warn('You already have existing Sessions that do not use mixed '
208                    'precision. enable_mixed_precision_graph_rewrite() will '
209                    'not affect these Sessions.')
210  opt = _wrap_optimizer(opt, loss_scale)
211  config.set_optimizer_experimental_options({'auto_mixed_precision': True})
212  mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(True)
213  return opt
214
215
216@deprecation.deprecated_endpoints(
217    'train.experimental.disable_mixed_precision_graph_rewrite')
218@tf_export(v1=['mixed_precision.disable_mixed_precision_graph_rewrite',
219               'train.experimental.disable_mixed_precision_graph_rewrite'])
220def disable_mixed_precision_graph_rewrite_v1():
221  """Disables the mixed precision graph rewrite.
222
223  After this is called, the mixed precision graph rewrite will no longer run for
224  new Sessions, and so float32 operations will no longer be converted to float16
225  in such Sessions. However, any existing Sessions will continue to have the
226  graph rewrite enabled if they were created after
227  `enable_mixed_precision_graph_rewrite` was called but before
228  `disable_mixed_precision_graph_rewrite` was called.
229
230  This does not undo the effects of loss scaling. Any optimizers wrapped with a
231  LossScaleOptimizer will continue to do loss scaling, although this loss
232  scaling will no longer be useful if the optimizer is used in new Sessions, as
233  the graph rewrite no longer converts the graph to use float16.
234
235  This function is useful for unit testing. A unit tests can test using the
236  mixed precision graph rewrite, then disable it so future unit tests continue
237  using float32. If this is done, unit tests should not share a single session,
238  as `enable_mixed_precision_graph_rewrite` and
239  `disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
240  """
241  # We only have a separate V1 version of this function, because the V1
242  # docstring mentions sessions.
243  if (not
244      mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled()):
245    tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
246                    'precision is already disabled.')
247  config.set_optimizer_experimental_options({'auto_mixed_precision': False})
248  mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(False)
249