xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/adam.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adam for TensorFlow."""
16from tensorflow.python.eager import context
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import control_flow_ops
19from tensorflow.python.ops import math_ops
20from tensorflow.python.ops import resource_variable_ops
21from tensorflow.python.ops import state_ops
22from tensorflow.python.training import optimizer
23from tensorflow.python.training import training_ops
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export(v1=["train.AdamOptimizer"])
28class AdamOptimizer(optimizer.Optimizer):
29  """Optimizer that implements the Adam algorithm.
30
31  References:
32    Adam - A Method for Stochastic Optimization:
33      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
34      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
35
36  @compatibility(TF2)
37  tf.compat.v1.train.AdamOptimizer is compatible with eager mode and
38  `tf.function`.
39  When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
40  `epsilon` can each be a callable that takes no arguments and returns the
41  actual value to use. This can be useful for changing these values across
42  different invocations of optimizer functions.
43
44  To switch to native TF2 style, use [`tf.keras.optimizers.Adam`]
45  (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam)
46  instead. Please notice that due to the implementation differences,
47  `tf.keras.optimizers.Adam` and
48  `tf.compat.v1.train.AdamOptimizer` may have slight differences in
49  floating point numerics even though the formula used for the variable
50  updates still matches.
51
52  #### Structural Mapping to Native TF2
53
54  Before:
55
56  ```python
57  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
58  ```
59
60  After:
61
62  ```python
63  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
64  ```
65
66  #### How to Map Arguments
67  |TF1 Arg Name          |TF2 Arg Name |Note                  |
68  |----------------------|-------------|----------------------|
69  |learning_rate         |learning_rate|Be careful of setting learning_rate as a
70  :                      :             : tensor value computed from the global
71  :                      :             : step. In TF1 this was usually meant to
72  :                      :             : imply a dynamic learning rate and would
73  :                      :             : recompute in each step. In TF2 (eager +
74  :                      :             : function) it will treat it as a scalar
75  :                      :             : value that only gets computed once
76  :                      :             : instead of a symbolic placeholder to be
77  :                      :             : computed each time.                   :
78  |beta1                 |beta_1        |                      |
79  |beta2                 |beta_2        |                      |
80  |epsilon               |epsilon      | Default value is 1e-08 in TF1, but
81  :                      :             : 1e-07 in TF2.                     :
82  |use_locking           |N/A          |Not applicable in TF2. |
83
84  #### Before & After Usage Example
85  Before:
86
87  ```python
88  x = tf.Variable([1,2,3], dtype=tf.float32)
89  grad = tf.constant([0.1, 0.2, 0.3])
90  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
91  optimizer.apply_gradients(zip([grad], [x]))
92  ```
93
94  After:
95
96  ```python
97  x = tf.Variable([1,2,3], dtype=tf.float32)
98  grad = tf.constant([0.1, 0.2, 0.3])
99  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
100  optimizer.apply_gradients(zip([grad], [x]))
101  ```
102
103  @end_compatibility
104  """
105
106  def __init__(self,
107               learning_rate=0.001,
108               beta1=0.9,
109               beta2=0.999,
110               epsilon=1e-8,
111               use_locking=False,
112               name="Adam"):
113    r"""Construct a new Adam optimizer.
114
115    Initialization:
116
117    $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
118    $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
119    $$t := 0 \text{(Initialize timestep)}$$
120
121    The update rule for `variable` with gradient `g` uses an optimization
122    described at the end of section 2 of the paper:
123
124    $$t := t + 1$$
125    $$\text{lr}_t := \mathrm{learning_rate} *
126      \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
127
128    $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
129    $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
130    $$\text{variable} := \text{variable} -
131      \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$
132
133    The default value of 1e-8 for epsilon might not be a good default in
134    general. For example, when training an Inception network on ImageNet a
135    current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
136    formulation just before Section 2.1 of the Kingma and Ba paper rather than
137    the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
138    hat" in the paper.
139
140    The sparse implementation of this algorithm (used when the gradient is an
141    IndexedSlices object, typically because of `tf.gather` or an embedding
142    lookup in the forward pass) does apply momentum to variable slices even if
143    they were not used in the forward pass (meaning they have a gradient equal
144    to zero). Momentum decay (beta1) is also applied to the entire momentum
145    accumulator. This means that the sparse behavior is equivalent to the dense
146    behavior (in contrast to some momentum implementations which ignore momentum
147    unless a variable slice was actually used).
148
149    Args:
150      learning_rate: A Tensor or a floating point value.  The learning rate.
151      beta1: A float value or a constant float tensor. The exponential decay
152        rate for the 1st moment estimates.
153      beta2: A float value or a constant float tensor. The exponential decay
154        rate for the 2nd moment estimates.
155      epsilon: A small constant for numerical stability. This epsilon is
156        "epsilon hat" in the Kingma and Ba paper (in the formula just before
157        Section 2.1), not the epsilon in Algorithm 1 of the paper.
158      use_locking: If True use locks for update operations.
159      name: Optional name for the operations created when applying gradients.
160        Defaults to "Adam".
161
162
163    """
164
165    super(AdamOptimizer, self).__init__(use_locking, name)
166    self._lr = learning_rate
167    self._beta1 = beta1
168    self._beta2 = beta2
169    self._epsilon = epsilon
170
171    # Tensor versions of the constructor arguments, created in _prepare().
172    self._lr_t = None
173    self._beta1_t = None
174    self._beta2_t = None
175    self._epsilon_t = None
176
177  def _get_beta_accumulators(self):
178    with ops.init_scope():
179      if context.executing_eagerly():
180        graph = None
181      else:
182        graph = ops.get_default_graph()
183      return (self._get_non_slot_variable("beta1_power", graph=graph),
184              self._get_non_slot_variable("beta2_power", graph=graph))
185
186  def _create_slots(self, var_list):
187    # Create the beta1 and beta2 accumulators on the same device as the first
188    # variable. Sort the var_list to make sure this device is consistent across
189    # workers (these need to go on the same PS, otherwise some updates are
190    # silently ignored).
191    first_var = min(var_list, key=lambda x: x.name)
192    self._create_non_slot_variable(
193        initial_value=self._beta1, name="beta1_power", colocate_with=first_var)
194    self._create_non_slot_variable(
195        initial_value=self._beta2, name="beta2_power", colocate_with=first_var)
196
197    # Create slots for the first and second moments.
198    for v in var_list:
199      self._zeros_slot(v, "m", self._name)
200      self._zeros_slot(v, "v", self._name)
201
202  def _prepare(self):
203    lr = self._call_if_callable(self._lr)
204    beta1 = self._call_if_callable(self._beta1)
205    beta2 = self._call_if_callable(self._beta2)
206    epsilon = self._call_if_callable(self._epsilon)
207
208    self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
209    self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
210    self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
211    self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
212
213  def _apply_dense(self, grad, var):
214    m = self.get_slot(var, "m")
215    v = self.get_slot(var, "v")
216    beta1_power, beta2_power = self._get_beta_accumulators()
217    return training_ops.apply_adam(
218        var,
219        m,
220        v,
221        math_ops.cast(beta1_power, var.dtype.base_dtype),
222        math_ops.cast(beta2_power, var.dtype.base_dtype),
223        math_ops.cast(self._lr_t, var.dtype.base_dtype),
224        math_ops.cast(self._beta1_t, var.dtype.base_dtype),
225        math_ops.cast(self._beta2_t, var.dtype.base_dtype),
226        math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
227        grad,
228        use_locking=self._use_locking).op
229
230  def _resource_apply_dense(self, grad, var):
231    m = self.get_slot(var, "m")
232    v = self.get_slot(var, "v")
233    beta1_power, beta2_power = self._get_beta_accumulators()
234    return training_ops.resource_apply_adam(
235        var.handle,
236        m.handle,
237        v.handle,
238        math_ops.cast(beta1_power, grad.dtype.base_dtype),
239        math_ops.cast(beta2_power, grad.dtype.base_dtype),
240        math_ops.cast(self._lr_t, grad.dtype.base_dtype),
241        math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
242        math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
243        math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
244        grad,
245        use_locking=self._use_locking)
246
247  def _apply_sparse_shared(self, grad, var, indices, scatter_add):
248    beta1_power, beta2_power = self._get_beta_accumulators()
249    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
250    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
251    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
252    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
253    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
254    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
255    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
256    # m_t = beta1 * m + (1 - beta1) * g_t
257    m = self.get_slot(var, "m")
258    m_scaled_g_values = grad * (1 - beta1_t)
259    m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
260    with ops.control_dependencies([m_t]):
261      m_t = scatter_add(m, indices, m_scaled_g_values)
262    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
263    v = self.get_slot(var, "v")
264    v_scaled_g_values = (grad * grad) * (1 - beta2_t)
265    v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
266    with ops.control_dependencies([v_t]):
267      v_t = scatter_add(v, indices, v_scaled_g_values)
268    v_sqrt = math_ops.sqrt(v_t)
269    var_update = state_ops.assign_sub(
270        var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
271    return control_flow_ops.group(*[var_update, m_t, v_t])
272
273  def _apply_sparse(self, grad, var):
274    return self._apply_sparse_shared(
275        grad.values,
276        var,
277        grad.indices,
278        lambda x, i, v: state_ops.scatter_add(  # pylint: disable=g-long-lambda
279            x,
280            i,
281            v,
282            use_locking=self._use_locking))
283
284  def _resource_scatter_add(self, x, i, v):
285    with ops.control_dependencies(
286        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
287      return x.value()
288
289  def _resource_apply_sparse(self, grad, var, indices):
290    return self._apply_sparse_shared(grad, var, indices,
291                                     self._resource_scatter_add)
292
293  def _finish(self, update_ops, name_scope):
294    # Update the power accumulators.
295    with ops.control_dependencies(update_ops):
296      beta1_power, beta2_power = self._get_beta_accumulators()
297      with ops.colocate_with(beta1_power):
298        update_beta1 = beta1_power.assign(
299            beta1_power * self._beta1_t, use_locking=self._use_locking)
300        update_beta2 = beta2_power.assign(
301            beta2_power * self._beta2_t, use_locking=self._use_locking)
302    return control_flow_ops.group(
303        *update_ops + [update_beta1, update_beta2], name=name_scope)
304