1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains functions to use mixed precision with the graph rewrite.""" 16 17from tensorflow.python.framework import config 18from tensorflow.python.platform import tf_logging 19from tensorflow.python.training import optimizer 20from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 21from tensorflow.python.training.experimental import mixed_precision_global_state 22from tensorflow.python.util import deprecation 23from tensorflow.python.util.tf_export import tf_export 24 25 26# A mapping between optimizers and (wrapper_fn, wrapper_cls) pairs. wrapper_cls 27# is a loss scale optimizer class, and wrapper_fn is a function that takes in 28# an optimizer and LossScale and returns a wrapper_cls instance. 29_REGISTERED_WRAPPER_OPTIMIZER_CLS = { 30 optimizer.Optimizer: 31 (loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,) * 2, 32} 33 34 35@tf_export('__internal__.mixed_precision.register_loss_scale_wrapper', v1=[]) 36def register_loss_scale_wrapper(optimizer_cls, wrapper_fn, wrapper_cls=None): 37 """Registers a loss scale optimizer wrapper. 38 39 `tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite` 40 automatically wraps an optimizer with an optimizer wrapper that performs loss 41 scaling. This function registers a 42 `(base_cls, wrapper_fn, wrapper_cls)` triple 43 that is used by `enable_mixed_precision_graph_rewrite`, where 44 `wrapper_fn` is called to create a `wrapper_cls` instance that wraps an 45 `optimizer_cls` instance. 46 47 Args: 48 optimizer_cls: A base optimizer class, e.g. `tf.keras.optimizers.Optimizer`. 49 wrapper_fn: A function that takes in arguments "optimizer" and 50 "loss_scale", and returns a loss scale optimizer of type "wrapper_cls" 51 that wraps "optimizer". 52 wrapper_cls: A loss scale optimizer class. Defaults to `wrapper_fn`, in 53 which case `wrapper_fn` should be a loss scale optimizer class whose 54 constructor takes in arguments "optimizer" and "loss_scale". 55 """ 56 _REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = ( 57 wrapper_fn, wrapper_cls or wrapper_fn) 58 59 60def _wrap_optimizer(opt, loss_scale): 61 """Wraps an optimizer with a LossScaleOptimizer.""" 62 63 for _, wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values(): 64 if isinstance(opt, wrapper_optimizer): 65 raise ValueError('"opt" must not already be an instance of a {cls}. ' 66 '`enable_mixed_precision_graph_rewrite` will ' 67 'automatically wrap the optimizer with a ' 68 '{cls}.' 69 .format(cls=wrapper_optimizer.__name__)) 70 71 for optimizer_cls, (wrapper_fn, _) in ( 72 _REGISTERED_WRAPPER_OPTIMIZER_CLS.items()): 73 if isinstance(opt, optimizer_cls): 74 return wrapper_fn(opt, loss_scale) 75 76 raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a ' 77 'tf.keras.optimizers.Optimizer, but got: %s' % opt) 78 79 80@deprecation.deprecated_endpoints( 81 'train.experimental.enable_mixed_precision_graph_rewrite') 82@tf_export(v1=['mixed_precision.enable_mixed_precision_graph_rewrite', 83 'train.experimental.enable_mixed_precision_graph_rewrite']) 84def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'): 85 """Enable mixed precision via a graph rewrite. 86 87 Mixed precision is the use of both float32 and float16 data types when 88 training a model to improve performance. This is achieved via a graph rewrite 89 operation and a loss-scale optimizer. 90 91 Performing arithmetic operations in float16 takes advantage of specialized 92 processing units, such as NVIDIA Tensor Cores, for much higher arithmetic 93 throughput. However, due to the smaller representable range, performing the 94 entire training with float16 can result in gradient underflow, that is, small 95 gradient values becoming zeroes. Instead, performing only select arithmetic 96 operations in float16 results in higher throughput and decreased training 97 time when using compatible hardware accelerators while also reducing memory 98 usage, typically without sacrificing model accuracy. 99 100 Note: While the mixed precision rewrite changes the datatype of various 101 layers throughout the model, the same accuracy reached in float32 is 102 expected. If a `NaN` gradient occurs with dynamic loss scaling, the model 103 update for that batch is skipped. In this case, the global step count is not 104 incremented, and the `LossScaleOptimizer` attempts to decrease the loss 105 scaling value to avoid `NaN` values in subsequent iterations. This approach 106 has been shown to achieve the same accuracy as float32 and, in most cases, 107 better training throughput. 108 109 Example: 110 111 ```python 112 model = tf.keras.models.Sequential([ 113 tf.keras.layers.Dense(64, activation='relu'), 114 tf.keras.layers.Dense(64, activation='softmax'), 115 ]) 116 117 opt = tf.keras.optimizers.SGD() 118 opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt) 119 model.compile(loss="mse", optimizer=opt) 120 121 x_train = np.random.random((1024, 64)) 122 y_train = np.random.random((1024, 64)) 123 model.fit(x_train, y_train) 124 ``` 125 126 Calling `enable_mixed_precision_graph_rewrite(opt)` enables the graph rewrite 127 operation before computing gradients. The function additionally returns an 128 `Optimizer` (`opt`) wrapped with a `LossScaleOptimizer`. This prevents 129 underflow in the float16 tensors during the backward pass. An optimizer of 130 type `tf.train.Optimizer` or `tf.keras.optimizers.Optimizer` must be passed 131 to this function, which will then be wrapped to use loss scaling. 132 133 The graph rewrite operation changes the `dtype` of certain operations in the 134 graph from float32 to float16. There are several categories of operations 135 that are either included or excluded by this rewrite operation. The following 136 categories of Ops are defined inside corresponding functions under the class 137 `AutoMixedPrecisionLists` in 138 <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/ 139 core/grappler/optimizers/auto_mixed_precision_lists.h"> 140 auto_mixed_precision_lists.h</a>: 141 142 * `ClearList`: Ops that do not have numerically significant adverse effects. 143 E.g. `ArgMax` and `Floor`. 144 * `AllowList`: Ops that are considered numerically safe for execution in 145 float16, and thus are always converted. E.g. `Conv2D`. 146 * `DenyList`: Ops that are numerically unsafe to execute in float16 and 147 can negatively affect downstream nodes. E.g. `Softmax`. 148 * `GrayList`: Ops that are considered numerically safe for execution in 149 float16 unless downstream from a DenyList Op. E.g. `Add` and `AvgPool`. 150 151 When this function is used, gradients should only be computed and applied 152 with the returned optimizer, either by calling `opt.minimize()` or 153 `opt.compute_gradients()` followed by `opt.apply_gradients()`. 154 Gradients should not be computed with `tf.gradients` or `tf.GradientTape`. 155 This is because the returned optimizer will apply loss scaling, and 156 `tf.gradients` or `tf.GradientTape` will not. If you do directly use 157 `tf.gradients` or `tf.GradientTape`, your model may not converge due to 158 float16 underflow problems. 159 160 When eager execution is enabled, the mixed precision graph rewrite is only 161 enabled within `tf.function`s, as outside `tf.function`s, there is no graph. 162 163 For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions 164 (such as batch size, input size, output size, and channel counts) 165 should be powers of two if under 256, or otherwise divisible by 8 if above 166 256. For more information, check out the 167 [NVIDIA Deep Learning Performance Guide]( 168 https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html). 169 170 Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with 171 Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The 172 parts of the graph on CPUs and TPUs are untouched by the graph rewrite. 173 174 Raises: 175 `ValueError`, if the `tf.keras.mixed_precision` API is also used by calling 176 `tf.keras.mixed_precision.set_global_policy`. Only one mixed precision 177 API can be used. 178 179 Args: 180 opt: An instance of a `tf.keras.optimizers.Optimizer` or a 181 `tf.train.Optimizer`. 182 loss_scale: Either an int/float, the string `"dynamic"`, or an instance of 183 a `tf.mixed_precision.experimental.LossScale`. The loss scale to use. It 184 is recommended to keep this as its default value of `"dynamic"`, which 185 will adjust the scaling automatically to prevent `Inf` or `NaN` values. 186 187 Returns: 188 A version of `opt` that will use loss scaling to prevent underflow. 189 """ 190 if mixed_precision_global_state.is_using_mixed_precision_policy(): 191 raise ValueError( 192 'The mixed precision graph rewrite cannot be enabled, because the ' 193 'global Keras dtype Policy has been set to a mixed precision policy. ' 194 'At most, one of the following can be called:\n\n' 195 ' 1. tf.keras.mixed_precision.set_global_policy() with a mixed ' 196 'precision policy (You called this first)\n\n' 197 ' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() ' 198 '(You called this second)\n' 199 'You called both functions, which is an error, because both functions ' 200 'enable you to use mixed precision. If in doubt which function to use, ' 201 'use the first, as it supports Eager execution and is more ' 202 'customizable.') 203 204 if mixed_precision_global_state.non_mixed_precision_session_created(): 205 # TODO(reedwm): Give the stacktrace of the existing Sessions. And if the 206 # Sessions have already been closed, do not raise this error message. 207 tf_logging.warn('You already have existing Sessions that do not use mixed ' 208 'precision. enable_mixed_precision_graph_rewrite() will ' 209 'not affect these Sessions.') 210 opt = _wrap_optimizer(opt, loss_scale) 211 config.set_optimizer_experimental_options({'auto_mixed_precision': True}) 212 mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(True) 213 return opt 214 215 216@deprecation.deprecated_endpoints( 217 'train.experimental.disable_mixed_precision_graph_rewrite') 218@tf_export(v1=['mixed_precision.disable_mixed_precision_graph_rewrite', 219 'train.experimental.disable_mixed_precision_graph_rewrite']) 220def disable_mixed_precision_graph_rewrite_v1(): 221 """Disables the mixed precision graph rewrite. 222 223 After this is called, the mixed precision graph rewrite will no longer run for 224 new Sessions, and so float32 operations will no longer be converted to float16 225 in such Sessions. However, any existing Sessions will continue to have the 226 graph rewrite enabled if they were created after 227 `enable_mixed_precision_graph_rewrite` was called but before 228 `disable_mixed_precision_graph_rewrite` was called. 229 230 This does not undo the effects of loss scaling. Any optimizers wrapped with a 231 LossScaleOptimizer will continue to do loss scaling, although this loss 232 scaling will no longer be useful if the optimizer is used in new Sessions, as 233 the graph rewrite no longer converts the graph to use float16. 234 235 This function is useful for unit testing. A unit tests can test using the 236 mixed precision graph rewrite, then disable it so future unit tests continue 237 using float32. If this is done, unit tests should not share a single session, 238 as `enable_mixed_precision_graph_rewrite` and 239 `disable_mixed_precision_graph_rewrite` have no effect on existing sessions. 240 """ 241 # We only have a separate V1 version of this function, because the V1 242 # docstring mentions sessions. 243 if (not 244 mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled()): 245 tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed ' 246 'precision is already disabled.') 247 config.set_optimizer_experimental_options({'auto_mixed_precision': False}) 248 mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(False) 249