1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains LossScale classes.""" 16from tensorflow.python.distribute import distribution_strategy_context 17from tensorflow.python.framework import indexed_slices 18from tensorflow.python.framework import smart_cond 19from tensorflow.python.ops import control_flow_ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.training import optimizer 22from tensorflow.python.training.experimental import loss_scale as loss_scale_module 23from tensorflow.python.util import deprecation 24from tensorflow.python.util.tf_export import tf_export 25 26 27@deprecation.deprecated_endpoints( 28 'train.experimental.MixedPrecisionLossScaleOptimizer') 29@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer', 30 'train.experimental.MixedPrecisionLossScaleOptimizer']) 31class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): 32 """An optimizer that applies loss scaling. 33 34 Loss scaling is a process that multiplies the loss by a multiplier called the 35 loss scale, and divides each gradient by the same multiplier. The pseudocode 36 for this process is: 37 38 ``` 39 loss = ... 40 loss *= loss_scale 41 grads = gradients(loss, vars) 42 grads /= loss_scale 43 ``` 44 45 Mathematically, loss scaling has no effect, but can help avoid numerical 46 underflow in intermediate gradients when float16 tensors are used for mixed 47 precision training. By multiplying the loss, each intermediate gradient will 48 have the same multiplier applied. 49 50 The loss scale can either be a fixed constant, chosen by the user, or be 51 dynamically determined. Dynamically determining the loss scale is convenient 52 as a loss scale does not have to be explicitly chosen. However it reduces 53 performance. 54 55 This optimizer wraps another optimizer and applies loss scaling to it via a 56 `LossScale`. Loss scaling is applied whenever gradients are 57 computed, such as through `minimize()`. 58 """ 59 60 def __init__(self, opt, loss_scale): 61 if not isinstance(opt, optimizer.Optimizer): 62 raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % 63 type(opt)) 64 self._optimizer = opt 65 66 use_locking = opt._use_locking # pylint: disable=protected-access 67 name = opt.get_name() 68 super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name) 69 70 self._loss_scale = loss_scale_module.get(loss_scale) 71 if self._loss_scale is None: 72 raise ValueError('loss_scale cannot be None') 73 self._track_trackable(self._optimizer, 'base_optimizer') 74 self._track_trackable(self._loss_scale, 'loss_scale') 75 76 def _doing_dynamic_loss_scaling(self): 77 """Check if `_loss_scale` dynamically manages the loss scale.""" 78 return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale) 79 80 def compute_gradients(self, 81 loss, 82 var_list=None, 83 gate_gradients=optimizer.Optimizer.GATE_OP, 84 aggregation_method=None, 85 colocate_gradients_with_ops=False, 86 grad_loss=None): 87 """Compute gradients of `loss` for the variables in `var_list`. 88 89 This adjusts the dynamic range of the gradient evaluation by scaling up 90 the `loss` value. The gradient values are then scaled back down by the 91 reciprocal of the loss scale. This is useful in reduced precision training 92 where small gradient values would otherwise underflow the representable 93 range. 94 95 Args: 96 loss: A Tensor containing the value to minimize or a callable taking no 97 arguments which returns the value to minimize. When eager execution is 98 enabled it must be a callable. 99 var_list: Optional list or tuple of `tf.Variable` to update to minimize 100 `loss`. Defaults to the list of variables collected in the graph under 101 the key `GraphKeys.TRAINABLE_VARIABLES`. 102 gate_gradients: How to gate the computation of gradients. Can be 103 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 104 aggregation_method: Specifies the method used to combine gradient terms. 105 Valid values are defined in the class `AggregationMethod`. 106 colocate_gradients_with_ops: If True, try colocating gradients with the 107 corresponding op. 108 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 109 110 Returns: 111 A list of (gradient, variable) pairs. Variable is always present, but 112 gradient can be `None`. 113 """ 114 loss = self._scale_loss(loss) 115 grads_and_vars = self._optimizer.compute_gradients( 116 loss=loss, 117 var_list=var_list, 118 gate_gradients=gate_gradients, 119 aggregation_method=aggregation_method, 120 colocate_gradients_with_ops=colocate_gradients_with_ops, 121 grad_loss=grad_loss) 122 123 grads = [g for g, _ in grads_and_vars] 124 variables = [v for _, v in grads_and_vars] 125 unscaled_grads = self._unscale_grads(grads) 126 return list(zip(unscaled_grads, variables)) 127 128 def _scale_loss(self, loss): 129 loss_scale = self._loss_scale() 130 if callable(loss): 131 def new_loss(): 132 loss_val = loss() 133 return loss_val * math_ops.cast(loss_scale, loss_val.dtype) 134 return new_loss 135 else: 136 return loss * math_ops.cast(loss_scale, loss.dtype) 137 138 def _unscale_grads(self, grads): 139 loss_scale = self._loss_scale() 140 loss_scale_reciprocal = 1 / loss_scale 141 return [ 142 None if g is None else self._scale_grad(g, loss_scale_reciprocal) 143 for g in grads 144 ] 145 146 def _scale_grad(self, grad, loss_scale_reciprocal): 147 if isinstance(grad, indexed_slices.IndexedSlices): 148 grad_vals = grad.values * loss_scale_reciprocal 149 return indexed_slices.IndexedSlices(grad_vals, grad.indices, 150 grad.dense_shape) 151 return grad * loss_scale_reciprocal 152 153 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 154 """Apply gradients to variables. 155 156 This is the second part of `minimize()`. It returns an `Operation` that 157 conditionally applies gradients if all gradient values are finite. 158 Otherwise no update is performed (nor is `global_step` incremented). 159 160 Args: 161 grads_and_vars: List of (gradient, variable) pairs as returned by 162 `compute_gradients()`. 163 global_step: Optional `Variable` to increment by one after the variables 164 have been updated. 165 name: Optional name for the returned operation. Default to the name 166 passed to the `Optimizer` constructor. 167 168 Returns: 169 An `Operation` that conditionally applies the specified gradients. If 170 `global_step` was not None, that operation also increments `global_step`. 171 172 Raises: 173 RuntimeError: If you should use `_distributed_apply()` instead. 174 """ 175 if distribution_strategy_context.in_cross_replica_context(): 176 raise ValueError('apply_gradients() must be called in a replica context.') 177 178 if not self._doing_dynamic_loss_scaling(): 179 return self._optimizer.apply_gradients(grads_and_vars, global_step, name) 180 181 replica_context = distribution_strategy_context.get_replica_context() 182 grads_and_vars = tuple(grads_and_vars) 183 184 # TODO(nluehr) cleanup GraphKeys.TRAIN_OP 185 return replica_context.merge_call( 186 self._distributed_apply, args=(grads_and_vars, global_step, name)) 187 188 def _distributed_apply(self, 189 distribution, 190 grads_and_vars, 191 global_step=None, 192 name=None): 193 """A version of `apply_gradients` for cross replica context. 194 195 When users are in a cross replica strategy, they must call this rather than 196 `apply_gradients()`. 197 198 Args: 199 distribution: a `DistributionStrategy` object. 200 grads_and_vars: List of (gradient, variable) pairs as returned by 201 `compute_gradients()` and then aggregated across replicas. 202 global_step: Optional (mirrored) `Variable` to increment by one after the 203 variables have been updated. 204 name: Optional name for the returned operation. Default to the name passed 205 to the `Optimizer` constructor. 206 207 Returns: 208 An `Operation` that applies the specified gradients across all 209 replicas. If `global_step` was not None, that operation also 210 increments `global_step` 211 """ 212 name = name if name is not None else self.get_name() 213 grads = [g for g, _ in grads_and_vars] 214 loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads)) 215 216 def apply_fn(): 217 return self._apply_gradients(distribution, grads_and_vars, global_step, 218 name + '-wrapped') 219 220 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 221 control_flow_ops.no_op) 222 return control_flow_ops.group( 223 maybe_apply_op, loss_scale_update_op, name=name) 224 225 def _apply_gradients(self, distribution, grads_and_vars, global_step, name): 226 """Unconditionally apply gradients in cross replica context.""" 227 update_ops = distribution.extended.call_for_each_replica( 228 self._optimizer.apply_gradients, 229 args=(grads_and_vars, global_step, name)) 230 return distribution.group(update_ops) 231 232 def _apply_sparse(self, grad, var): 233 """This function should never be called.""" 234 raise RuntimeError('This function should never be called') 235 236 def _apply_dense(self, grad, var): 237 """This function should never be called.""" 238 raise RuntimeError('This function should never be called') 239 240 def _resource_apply_sparse(self, grad, handle, indices): 241 """This function should never be called.""" 242 raise RuntimeError('This function should never be called') 243 244 def _resource_apply_dense(self, grad, handle): 245 """This function should never be called.""" 246 raise RuntimeError('This function should never be called') 247 248 def variables(self): 249 """Returns the variables of the Optimizer.""" 250 return (self._optimizer.variables() + 251 list(self._loss_scale._weights.values())) # pylint: disable=protected-access 252