xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/training_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor_types.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 namespace functor {
26 
27 // Each training algorithm has a ApplyXYZ functor struct declared in
28 // this header file. They are specialized for different devices
29 // (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc).
30 
31 template <typename Device, typename T>
32 struct ApplyGradientDescent {
33   void operator()(const Device& d, typename TTypes<T>::Flat var,
34                   typename TTypes<T>::ConstScalar alpha,
35                   typename TTypes<T>::ConstFlat delta);
36 };
37 
38 template <typename Device, typename T>
39 struct ApplyAdadelta {
40   void operator()(const Device& d, typename TTypes<T>::Flat var,
41                   typename TTypes<T>::Flat accum,
42                   typename TTypes<T>::Flat accum_update,
43                   typename TTypes<T>::ConstScalar lr,
44                   typename TTypes<T>::ConstScalar rho,
45                   typename TTypes<T>::ConstScalar epsilon,
46                   typename TTypes<T>::ConstFlat grad);
47 };
48 
49 template <typename Device, typename T, typename Tindex>
50 struct SparseApplyAdadelta {
51   void operator()(const Device& d, typename TTypes<T>::Matrix var,
52                   typename TTypes<T>::Matrix accum,
53                   typename TTypes<T>::Matrix accum_update,
54                   typename TTypes<T>::ConstScalar lr,
55                   typename TTypes<T>::ConstScalar rho,
56                   typename TTypes<T>::ConstScalar epsilon,
57                   typename TTypes<T>::ConstMatrix grad,
58                   typename TTypes<Tindex>::ConstFlat indices);
59 };
60 
61 template <typename Device, typename T>
62 struct FobosElasticNet {
63   void operator()(const Device& d, typename TTypes<T>::Flat var,
64                   typename TTypes<T>::ConstScalar lr,
65                   typename TTypes<T>::ConstScalar l1,
66                   typename TTypes<T>::ConstScalar l2,
67                   typename TTypes<T>::ConstFlat grad);
68 };
69 
70 template <typename Device, typename T>
71 struct ApplyProximalGradientDescent {
72   void operator()(const Device& d, typename TTypes<T>::Flat var,
73                   typename TTypes<T>::ConstScalar lr,
74                   typename TTypes<T>::ConstScalar l1,
75                   typename TTypes<T>::ConstScalar l2,
76                   typename TTypes<T>::ConstFlat grad);
77 };
78 
79 template <typename Device, typename T>
80 struct ApplyAdagrad {
81   void operator()(const Device& d, typename TTypes<T>::Flat var,
82                   typename TTypes<T>::Flat accum,
83                   typename TTypes<T>::ConstScalar lr,
84                   typename TTypes<T>::ConstFlat grad, bool update_slots);
85 };
86 
87 template <typename Device, typename T>
88 struct ApplyAdagradV2 {
89   void operator()(const Device& d, typename TTypes<T>::Flat var,
90                   typename TTypes<T>::Flat accum,
91                   typename TTypes<T>::ConstScalar lr,
92                   typename TTypes<T>::ConstScalar epsilon,
93                   typename TTypes<T>::ConstFlat grad, bool update_slots);
94 };
95 
96 template <typename Device, typename T>
97 struct ApplyAdagradDA {
98   void operator()(const Device& d, typename TTypes<T>::Flat var,
99                   typename TTypes<T>::Flat gradient_accum,
100                   typename TTypes<T>::Flat gradient_squared_accum,
101                   typename TTypes<T>::ConstScalar lr, int64_t global_step,
102                   typename TTypes<T>::ConstScalar l1,
103                   typename TTypes<T>::ConstScalar l2,
104                   typename TTypes<T>::ConstFlat grad);
105 };
106 
107 template <typename Device, typename T, typename Tindex, bool has_epsilon>
108 struct SparseApplyAdagrad {
109   // Note that epsilon is ignored if has_epsilon is false.
110   Status operator()(const Device& d, typename TTypes<T>::Matrix var,
111                     typename TTypes<T>::Matrix accum,
112                     typename TTypes<T>::ConstScalar lr,
113                     typename TTypes<T>::ConstScalar epsilon,
114                     typename TTypes<T>::ConstMatrix grad,
115                     typename TTypes<Tindex>::ConstVec indices,
116                     int64_t inner_dim, bool update_slots);
117 };
118 
119 template <typename Device, typename T>
120 struct ApplyProximalAdagrad {
121   void operator()(const Device& d, typename TTypes<T>::Flat var,
122                   typename TTypes<T>::Flat accum,
123                   typename TTypes<T>::ConstScalar lr,
124                   typename TTypes<T>::ConstScalar l1,
125                   typename TTypes<T>::ConstScalar l2,
126                   typename TTypes<T>::ConstFlat grad);
127 };
128 
129 template <typename Device, typename T, typename Tindex>
130 struct SparseApplyProximalAdagrad {
131   Status operator()(const Device& d, typename TTypes<T>::Matrix var,
132                     typename TTypes<T>::Matrix accum,
133                     typename TTypes<T>::ConstScalar lr,
134                     typename TTypes<T>::ConstScalar l1,
135                     typename TTypes<T>::ConstScalar l2,
136                     typename TTypes<T>::ConstMatrix grad,
137                     typename TTypes<Tindex>::ConstVec indices,
138                     int64_t inner_dim);
139 };
140 
141 template <typename Device, typename T>
142 struct ApplyFtrl {
143   void operator()(const Device& d, typename TTypes<T>::Flat var,
144                   typename TTypes<T>::Flat accum,
145                   typename TTypes<T>::Flat linear,
146                   typename TTypes<T>::ConstFlat grad,
147                   typename TTypes<T>::ConstScalar lr,
148                   typename TTypes<T>::ConstScalar l1,
149                   typename TTypes<T>::ConstScalar l2,
150                   typename TTypes<T>::ConstScalar lr_power);
151 };
152 
153 template <typename Device, typename T>
154 struct ApplyFtrlMultiplyLinearByLr {
155   void operator()(const Device& d, typename TTypes<T>::Flat var,
156                   typename TTypes<T>::Flat accum,
157                   typename TTypes<T>::Flat linear,
158                   typename TTypes<T>::ConstFlat grad,
159                   typename TTypes<T>::ConstScalar lr,
160                   typename TTypes<T>::ConstScalar l1,
161                   typename TTypes<T>::ConstScalar l2,
162                   typename TTypes<T>::ConstScalar lr_power);
163 };
164 
165 template <typename Device, typename T>
166 struct ApplyFtrlV2 {
167   void operator()(const Device& d, typename TTypes<T>::Flat var,
168                   typename TTypes<T>::Flat accum,
169                   typename TTypes<T>::Flat linear,
170                   typename TTypes<T>::ConstFlat grad,
171                   typename TTypes<T>::ConstScalar lr,
172                   typename TTypes<T>::ConstScalar l1,
173                   typename TTypes<T>::ConstScalar l2,
174                   typename TTypes<T>::ConstScalar l2_shrinkage,
175                   typename TTypes<T>::ConstScalar lr_power);
176 };
177 
178 template <typename Device, typename T>
179 struct ApplyFtrlV2MultiplyLinearByLr {
180   void operator()(const Device& d, typename TTypes<T>::Flat var,
181                   typename TTypes<T>::Flat accum,
182                   typename TTypes<T>::Flat linear,
183                   typename TTypes<T>::ConstFlat grad,
184                   typename TTypes<T>::ConstScalar lr,
185                   typename TTypes<T>::ConstScalar l1,
186                   typename TTypes<T>::ConstScalar l2,
187                   typename TTypes<T>::ConstScalar l2_shrinkage,
188                   typename TTypes<T>::ConstScalar lr_power);
189 };
190 
191 template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
192 struct SparseApplyFtrl {
193   Status operator()(const Device& d, typename TTypes<T>::Matrix var_flat,
194                     typename TTypes<T>::Matrix accum_flat,
195                     typename TTypes<T>::Matrix linear_flat,
196                     typename TTypes<T>::ConstScalar lr,
197                     typename TTypes<T>::ConstScalar l1,
198                     typename TTypes<T>::ConstScalar l2,
199                     typename TTypes<T>::ConstScalar l2_shrinkage,
200                     typename TTypes<T>::ConstScalar lr_power,
201                     typename TTypes<T>::ConstMatrix grad_flat,
202                     typename TTypes<Tindex>::ConstVec indices_vec,
203                     int64_t inner_dim, bool multiply_linear_by_lr);
204 };
205 
206 template <typename Device, typename T>
207 struct ApplyMomentum {
208   void operator()(const Device& d, typename TTypes<T>::Flat var,
209                   typename TTypes<T>::Flat accum,
210                   typename TTypes<T>::ConstScalar lr,
211                   typename TTypes<T>::ConstFlat grad,
212                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
213 };
214 
215 template <typename Device, typename T>
216 struct ApplyKerasMomentum {
217   void operator()(const Device& d, typename TTypes<T>::Flat var,
218                   typename TTypes<T>::Flat accum,
219                   typename TTypes<T>::ConstScalar lr,
220                   typename TTypes<T>::ConstFlat grad,
221                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
222 };
223 
224 template <typename Device, typename T, typename Tindex>
225 struct SparseApplyKerasMomentum {
226   Tindex operator()(const Device& d, typename TTypes<T>::Matrix var,
227                     typename TTypes<T>::Matrix accum,
228                     typename TTypes<T>::ConstScalar lr,
229                     typename TTypes<T>::ConstMatrix grad,
230                     typename TTypes<Tindex>::ConstFlat indices,
231                     typename TTypes<T>::ConstScalar momentum,
232                     bool use_nesterov);
233 };
234 
235 template <typename Device, typename T>
236 struct ApplyAdam {
237   void operator()(const Device& d, typename TTypes<T>::Flat var,
238                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
239                   typename TTypes<T>::ConstScalar beta1_power,
240                   typename TTypes<T>::ConstScalar beta2_power,
241                   typename TTypes<T>::ConstScalar lr,
242                   typename TTypes<T>::ConstScalar beta1,
243                   typename TTypes<T>::ConstScalar beta2,
244                   typename TTypes<T>::ConstScalar epsilon,
245                   typename TTypes<T>::ConstFlat grad, bool use_nesterov);
246 };
247 
248 template <typename Device, typename T>
249 struct ApplyAdamWithAmsgrad {
250   void operator()(const Device& d, typename TTypes<T>::Flat var,
251                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
252                   typename TTypes<T>::Flat vhat,
253                   typename TTypes<T>::ConstScalar beta1_power,
254                   typename TTypes<T>::ConstScalar beta2_power,
255                   typename TTypes<T>::ConstScalar lr,
256                   typename TTypes<T>::ConstScalar beta1,
257                   typename TTypes<T>::ConstScalar beta2,
258                   typename TTypes<T>::ConstScalar epsilon,
259                   typename TTypes<T>::ConstFlat grad);
260 };
261 
262 template <typename Device, typename T>
263 struct ApplyAdaMax {
264   void operator()(const Device& d, typename TTypes<T>::Flat var,
265                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
266                   typename TTypes<T>::ConstScalar beta1_power,
267                   typename TTypes<T>::ConstScalar lr,
268                   typename TTypes<T>::ConstScalar beta1,
269                   typename TTypes<T>::ConstScalar beta2,
270                   typename TTypes<T>::ConstScalar epsilon,
271                   typename TTypes<T>::ConstFlat grad);
272 };
273 
274 template <typename Device, typename T>
275 struct ApplyRMSProp {
276   void operator()(const Device& d, typename TTypes<T>::Flat var,
277                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
278                   typename TTypes<T>::ConstScalar lr,
279                   typename TTypes<T>::ConstScalar rho,
280                   typename TTypes<T>::ConstScalar momentum,
281                   typename TTypes<T>::ConstScalar epsilon,
282                   typename TTypes<T>::ConstFlat grad);
283 };
284 
285 template <typename Device, typename T>
286 struct ApplyCenteredRMSProp {
287   void operator()(const Device& d, typename TTypes<T>::Flat var,
288                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
289                   typename TTypes<T>::Flat mom,
290                   typename TTypes<T>::ConstScalar lr,
291                   typename TTypes<T>::ConstScalar rho,
292                   typename TTypes<T>::ConstScalar momentum,
293                   typename TTypes<T>::ConstScalar epsilon,
294                   typename TTypes<T>::ConstFlat grad);
295 };
296 
297 template <typename Device, typename T>
298 struct ApplyAddSign {
299   void operator()(const Device& d, typename TTypes<T>::Flat var,
300                   typename TTypes<T>::Flat m,
301                   typename TTypes<T>::ConstScalar lr,
302                   typename TTypes<T>::ConstScalar alpha,
303                   typename TTypes<T>::ConstScalar sign_decay,
304                   typename TTypes<T>::ConstScalar beta,
305                   typename TTypes<T>::ConstFlat grad);
306 };
307 
308 template <typename Device, typename T>
309 struct ApplyPowerSign {
310   void operator()(const Device& d, typename TTypes<T>::Flat var,
311                   typename TTypes<T>::Flat m,
312                   typename TTypes<T>::ConstScalar lr,
313                   typename TTypes<T>::ConstScalar logbase,
314                   typename TTypes<T>::ConstScalar sign_decay,
315                   typename TTypes<T>::ConstScalar beta,
316                   typename TTypes<T>::ConstFlat grad);
317 };
318 
319 }  // end namespace functor
320 }  // end namespace tensorflow
321 
322 #endif  // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
323