xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/momentum.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Momentum for TensorFlow."""
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import math_ops
19from tensorflow.python.training import optimizer
20from tensorflow.python.training import training_ops
21from tensorflow.python.util.tf_export import tf_export
22
23
24@tf_export(v1=["train.MomentumOptimizer"])
25class MomentumOptimizer(optimizer.Optimizer):
26  """Optimizer that implements the Momentum algorithm.
27
28  Computes (if `use_nesterov = False`):
29
30  ```
31  accumulation = momentum * accumulation + gradient
32  variable -= learning_rate * accumulation
33  ```
34
35  Note that in the dense version of this algorithm, `accumulation` is updated
36  and applied regardless of a gradient's value, whereas the sparse version (when
37  the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
38  embedding) only updates variable slices and corresponding `accumulation` terms
39  when that part of the variable was used in the forward pass.
40
41  @compatibility(TF2)
42  tf.compat.v1.train.MomentumOptimizer is compatible with eager mode and
43  `tf.function`.
44  When eager execution is enabled, `learning_rate`,`momentum`, can each be a
45  callable that takes no arguments and returns the actual value to use. This
46  can be useful for changing these values across different invocations of
47  optimizer functions.
48
49  To switch to native TF2 style, please directly use
50  [`tf.keras.optimizers.SGD`]
51  (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD)
52  with the `momentum` argument.
53
54  #### Structural mapping to native TF2
55
56  Before:
57
58  ```python
59  optimizer = tf.compat.v1.train.MomentumOptimizer(
60    learning_rate=learning_rate,
61    momentum=momentum,
62    use_nesterov=use_nesterov)
63  ```
64
65  After:
66
67  ```python
68  optimizer = tf.keras.optimizers.SGD(
69    learning_rate=learning_rate,
70    momentum=momentum,
71    nesterov=use_nesterov)
72  ```
73
74  #### How to map arguments
75  | TF1 Arg Name       | TF2 Arg Name   | Note                             |
76  | ------------------ | -------------  | -------------------------------  |
77  | `learning_rate`    | `learning_rate`| Be careful of setting           |
78  : : : learning_rate tensor value computed from the global step.          :
79  : : : In TF1 this was usually meant to imply a dynamic learning rate and :
80  : : : would recompute in each step. In TF2 (eager + function) it will    :
81  : : : treat it as a scalar value that only gets computed once instead of :
82  : : : a symbolic placeholder to be computed each time.                   :
83  | `momentum`         | `momentum`     | -                                |
84  | `use_locking`      | -              | Not applicable in TF2.           |
85  | `use_nesterov`     | `nesterov`     | -                                |
86
87  #### Before & after usage example
88  Before:
89
90  ```python
91  x = tf.Variable([1,2,3], dtype=tf.float32)
92  grad = tf.constant([0.1, 0.2, 0.3])
93  optimizer = tf.compat.v1.train.MomentumOptimizer(
94    learning_rate=0.001,
95    momentum=0.9,
96    use_nesterov=False)
97  optimizer.apply_gradients(zip([grad], [x]))
98  ```
99
100  After:
101
102  ```python
103  x = tf.Variable([1,2,3], dtype=tf.float32)
104  grad = tf.constant([0.1, 0.2, 0.3])
105  optimizer = tf.keras.optimizers.SGD(
106    learning_rate=0.001,
107    momentum=0.9,
108    nesterov=False)
109  optimizer.apply_gradients(zip([grad], [x]))
110  ```
111
112  @end_compatibility
113
114  """
115
116  def __init__(self, learning_rate, momentum,
117               use_locking=False, name="Momentum", use_nesterov=False):
118    """Construct a new Momentum optimizer.
119
120    Args:
121      learning_rate: A `Tensor` or a floating point value.  The learning rate.
122      momentum: A `Tensor` or a floating point value.  The momentum.
123      use_locking: If `True` use locks for update operations.
124      name: Optional name prefix for the operations created when applying
125        gradients.  Defaults to "Momentum".
126      use_nesterov: If `True` use Nesterov Momentum.
127        See (Sutskever et al., 2013).
128        This implementation always computes gradients at the value of the
129        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
130        variable(s) track the values called `theta_t + mu*v_t` in the paper.
131        This implementation is an approximation of the original formula, valid
132        for high values of momentum. It will compute the "adjusted gradient"
133        in NAG by assuming that the new gradient will be estimated by the
134        current average gradient plus the product of momentum and the change
135        in the average gradient.
136
137    References:
138      On the importance of initialization and momentum in deep learning:
139        [Sutskever et al., 2013]
140        (http://proceedings.mlr.press/v28/sutskever13.html)
141        ([pdf](http://proceedings.mlr.press/v28/sutskever13.pdf))
142
143
144    """
145    super(MomentumOptimizer, self).__init__(use_locking, name)
146    self._learning_rate = learning_rate
147    self._momentum = momentum
148    self._use_nesterov = use_nesterov
149
150  def _create_slots(self, var_list):
151    for v in var_list:
152      self._zeros_slot(v, "momentum", self._name)
153
154  def _prepare(self):
155    learning_rate = self._learning_rate
156    if callable(learning_rate):
157      learning_rate = learning_rate()
158    self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
159                                                       name="learning_rate")
160    momentum = self._momentum
161    if callable(momentum):
162      momentum = momentum()
163    self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
164
165  def _apply_dense(self, grad, var):
166    mom = self.get_slot(var, "momentum")
167    return training_ops.apply_momentum(
168        var, mom,
169        math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
170        grad,
171        math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
172        use_locking=self._use_locking,
173        use_nesterov=self._use_nesterov).op
174
175  def _resource_apply_dense(self, grad, var):
176    mom = self.get_slot(var, "momentum")
177    return training_ops.resource_apply_momentum(
178        var.handle, mom.handle,
179        math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
180        grad,
181        math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
182        use_locking=self._use_locking,
183        use_nesterov=self._use_nesterov)
184
185  def _apply_sparse(self, grad, var):
186    mom = self.get_slot(var, "momentum")
187    return training_ops.sparse_apply_momentum(
188        var, mom,
189        math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
190        grad.values, grad.indices,
191        math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
192        use_locking=self._use_locking,
193        use_nesterov=self._use_nesterov).op
194
195  def _resource_apply_sparse(self, grad, var, indices):
196    mom = self.get_slot(var, "momentum")
197    return training_ops.resource_sparse_apply_momentum(
198        var.handle, mom.handle,
199        math_ops.cast(self._learning_rate_tensor, grad.dtype),
200        grad, indices,
201        math_ops.cast(self._momentum_tensor, grad.dtype),
202        use_locking=self._use_locking,
203        use_nesterov=self._use_nesterov)
204