xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/training_ops_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/training_ops.h"
22 #include "tensorflow/core/util/gpu_kernel_helper.h"
23 
24 namespace tensorflow {
25 
26 typedef Eigen::GpuDevice GPUDevice;
27 
28 namespace functor {
29 
30 template <typename T>
impl_sign(T x)31 __device__ T impl_sign(T x) {
32   return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
33 }
34 
35 template <typename T, typename Tindex, bool has_epsilon>
SparseApplyAdagradKernel(T * var,T * accum,const T * lr,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool update_slots)36 __global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
37     T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
38     const Tindex* indices, Tindex param_rows, Tindex updates_size,
39     Tindex indices_size, bool update_slots) {
40   Tindex col_size = updates_size / indices_size;
41   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
42     Tindex indices_row = grad_index / col_size;
43     Tindex param_row = indices[indices_row];
44     if (param_row < 0 || param_row >= param_rows) {
45       // Ignore indices that are out of range.
46       continue;
47     }
48 
49     // Compute the index of var and accum.
50     Tindex param_index = param_row * col_size + (grad_index % col_size);
51 
52     // Read variables.
53     T var_i = var[param_index];
54     T accum_i = accum[param_index];
55     T grad_i = grad[grad_index];
56     const T lr_t = *lr;
57     const T epsilon_t = *epsilon;
58 
59     if (update_slots) {
60       accum_i += grad_i * grad_i;
61     }
62     if (has_epsilon) {
63       var_i -= lr_t * grad_i / (Eigen::numext::sqrt(accum_i) + epsilon_t);
64     } else {
65       var_i -= lr_t * grad_i * Eigen::numext::rsqrt(accum_i);
66     }
67 
68     // Write update back to variables.
69     var[param_index] = var_i;
70     accum[param_index] = accum_i;
71   }
72 }
73 
74 template <typename T, typename Tindex>
SparseApplyProximalAdagradKernel(T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)75 __global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
76     T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
77     const Tindex* indices, Tindex param_rows, Tindex updates_size,
78     Tindex indices_size) {
79   Tindex col_size = updates_size / indices_size;
80   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
81     Tindex indices_row = grad_index / col_size;
82     Tindex param_row = indices[indices_row];
83     if (param_row < 0 || param_row >= param_rows) {
84       // Ignore indices that are out of range.
85       continue;
86     }
87 
88     // Compute the index of var and accum.
89     Tindex param_index = param_row * col_size + (grad_index % col_size);
90 
91     // Read variables.
92     T var_i = var[param_index];
93     T accum_i = accum[param_index];
94     T grad_i = grad[grad_index];
95     const T lr_t = *lr;
96     const T l1_t = *l1;
97     const T l2_t = *l2;
98 
99     accum_i += grad_i * grad_i;
100     T learning_rate = lr_t * Eigen::numext::rsqrt(accum_i);
101     // compute v = w - lr * grad.
102     T prox_var_i = var_i - grad_i * learning_rate;
103     // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
104     var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
105             max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
106             (T(1.) + l2_t * learning_rate);
107 
108     // Write update back to variables.
109     var[param_index] = var_i;
110     accum[param_index] = accum_i;
111   }
112 }
113 
114 template <typename T, typename Tindex, bool has_l2_shrinkage>
SparseApplyFtrlKernel(T * var,T * accum,T * linear,const T * lr,const T * l1,const T * l2,const T * l2_shrinkage,const T * lr_power,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool multiply_linear_by_lr)115 __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
116                                       const T* l1, const T* l2,
117                                       const T* l2_shrinkage, const T* lr_power,
118                                       const T* grad, const Tindex* indices,
119                                       Tindex param_rows, Tindex updates_size,
120                                       Tindex indices_size,
121                                       bool multiply_linear_by_lr) {
122   const Tindex col_size = updates_size / indices_size;
123   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
124     const Tindex indices_row = grad_index / col_size;
125     const Tindex param_row = indices[indices_row];
126     if (param_row < 0 || param_row >= param_rows) {
127       // Ignore indices that are out of range.
128       continue;
129     }
130 
131     // Compute the index of var and accum.
132     const Tindex param_index = param_row * col_size + (grad_index % col_size);
133 
134     // Read variables.
135     T var_i = var[param_index];
136     T accum_i = accum[param_index];
137     T linear_i = linear[param_index];
138     const T grad_i = grad[grad_index];
139     const T lr_t = *lr;
140     const T l1_t = *l1;
141     const T l2_t = *l2;
142     const T lr_power_t = *lr_power;
143 
144     const T grad_shr_i =
145         has_l2_shrinkage ? grad_i + static_cast<T>(2) * (*l2_shrinkage) * var_i
146                          : grad_i;
147     const T new_accum_i = accum_i + grad_i * grad_i;
148     const bool lr_power_is_neg_half = lr_power_t == static_cast<T>(-0.5);
149     const T pow_new_accum = lr_power_is_neg_half
150                                 ? Eigen::numext::sqrt(new_accum_i)
151                                 : pow(new_accum_i, -lr_power_t);
152     const T pow_accum = lr_power_is_neg_half ? Eigen::numext::sqrt(accum_i)
153                                              : pow(accum_i, -lr_power_t);
154     T linear_change = grad_shr_i * lr_t - (pow_new_accum - pow_accum) * var_i;
155     if (!multiply_linear_by_lr) {
156       linear_change /= lr_t;
157     }
158     linear_i += linear_change;
159 
160     T l1_mult = l1_t;
161     if (multiply_linear_by_lr) {
162       l1_mult *= lr_t;
163     }
164     const T l1_reg_adjust = max(min(linear_i, l1_mult), -l1_mult);
165     const T x = l1_reg_adjust - linear_i;
166     T y = pow_new_accum + static_cast<T>(2) * l2_t * lr_t;
167     if (!multiply_linear_by_lr) {
168       y /= lr_t;
169     }
170     var_i = x / y;
171     accum_i = new_accum_i;
172 
173     // Write update back to variables.
174     var[param_index] = var_i;
175     accum[param_index] = accum_i;
176     linear[param_index] = linear_i;
177   }
178 }
179 
180 template <typename T>
ApplyAdamKernel(int32 data_dim,T * var,T * m,T * v,const T * const beta1_power_,const T * const beta2_power_,const T * const lr_,const T * const beta1_,const T * const beta2_,const T * const epsilon_,const T * grad,bool use_nesterov)181 __global__ __launch_bounds__(1024) void ApplyAdamKernel(
182     int32 data_dim, T* var, T* m, T* v, const T* const beta1_power_,
183     const T* const beta2_power_, const T* const lr_, const T* const beta1_,
184     const T* const beta2_, const T* const epsilon_, const T* grad,
185     bool use_nesterov) {
186   eigen_assert(blockDim.y == 1);
187   eigen_assert(blockDim.z == 1);
188   eigen_assert(gridDim.y == 1);
189   eigen_assert(gridDim.z == 1);
190 
191   const T mul_factor =
192       (*lr_) * Eigen::numext::sqrt(static_cast<T>(1.0) - (*beta2_power_)) /
193       (static_cast<T>(1.0) - (*beta1_power_));
194   const T epsilon = (*epsilon_);
195   const T beta1 = (*beta1_);
196   const T one_minus_beta1 = static_cast<T>(1.0) - (beta1);
197   const T one_minus_beta2 = static_cast<T>(1.0) - (*beta2_);
198   const int32 stripe = gridDim.x * blockDim.x;
199 
200   for (int32 i = blockIdx.x * blockDim.x + threadIdx.x; i < data_dim;
201        i += stripe) {
202     auto m_i = m[i];
203     auto g_i = grad[i];
204     auto v_i = v[i];
205 
206     // Avoid += and -= due to std::complex<T> issues on device for MSVC.
207     m_i = m_i + one_minus_beta1 * (g_i - m_i);
208     v_i = v_i + one_minus_beta2 * (g_i * g_i - v_i);
209     if (use_nesterov) {
210       var[i] = var[i] - mul_factor * (m_i * beta1 + one_minus_beta1 * g_i) /
211                             (epsilon + Eigen::numext::sqrt(v_i));
212     } else {
213       var[i] = var[i] - mul_factor * m_i / (epsilon + Eigen::numext::sqrt(v_i));
214     }
215 
216     m[i] = m_i;
217     v[i] = v_i;
218   }
219 }
220 
221 template <typename T, typename Tindex>
SparseApplyKerasMomentumKernel(T * var,T * accum,const T * lr,const T * grad,const Tindex * indices,const T * momentum,bool use_nesterov,Tindex param_rows,Tindex updates_size,Tindex indices_size)222 __global__ __launch_bounds__(1024) void SparseApplyKerasMomentumKernel(
223     T* var, T* accum, const T* lr, const T* grad, const Tindex* indices,
224     const T* momentum, bool use_nesterov, Tindex param_rows,
225     Tindex updates_size, Tindex indices_size) {
226   Tindex col_size = updates_size / indices_size;
227   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
228     Tindex indices_row = grad_index / col_size;
229     Tindex param_row = indices[indices_row];
230     if (param_row < 0 || param_row >= param_rows) {
231       // Ignore indices that are out of range.
232       continue;
233     }
234 
235     // Compute the index of var and accum.
236     Tindex param_index = param_row * col_size + (grad_index % col_size);
237 
238     // Read variables.
239     T var_i = var[param_index];
240     T accum_i = accum[param_index];
241     T grad_i = grad[grad_index];
242     const T momentum_t = *momentum;
243     const T lr_t = *lr;
244 
245     // Variable update computation.
246     accum_i = momentum_t * accum_i - lr_t * grad_i;
247     // static branching in cuda does not impact performance.
248     // Avoid += due to std::complex<T> issues on device for MSVC.
249     if (use_nesterov) {
250       var_i = var_i + (momentum_t * accum_i - lr_t * grad_i);
251     } else {
252       var_i = var_i + accum_i;
253     }
254 
255     // Write update back to variables.
256     var[param_index] = var_i;
257     accum[param_index] = accum_i;
258   }
259 }
260 
261 template <typename T>
262 struct ApplyGradientDescent<GPUDevice, T> {
operator ()tensorflow::functor::ApplyGradientDescent263   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
264                   typename TTypes<T>::ConstScalar lr,
265                   typename TTypes<T>::ConstFlat grad) {
266     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
267     bcast[0] = grad.dimension(0);
268     Eigen::Sizes<1> single;
269     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad;
270   }
271 };
272 
273 template <typename T>
ApplyAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * grad,bool update_slots)274 __global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
275                                                            T* var, T* accum,
276                                                            const T* lr,
277                                                            const T* grad,
278                                                            bool update_slots) {
279   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
280     if (update_slots) accum[i] += grad[i] * grad[i];
281     var[i] -= lr[0] * grad[i] * Eigen::numext::rsqrt(accum[i]);
282   }
283 }
284 
285 template <typename T>
ApplyAdagradV2Kernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * epsilon,const T * grad,bool update_slots)286 __global__ __launch_bounds__(1024) void ApplyAdagradV2Kernel(
287     GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* epsilon,
288     const T* grad, bool update_slots) {
289   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
290     if (update_slots) accum[i] += grad[i] * grad[i];
291     T update = grad[i] / (Eigen::numext::sqrt(accum[i]) + epsilon[0]);
292     var[i] -= lr[0] * update;
293   }
294 }
295 
296 template <typename T>
ApplyProximalAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad)297 __global__ __launch_bounds__(1024) void ApplyProximalAdagradKernel(
298     GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* l1,
299     const T* l2, const T* grad) {
300   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
301     accum[i] += grad[i] * grad[i];
302     T lr_scaled = lr[0] * Eigen::numext::rsqrt(accum[i]);
303     T prox_var = var[i] - grad[i] * lr_scaled;
304     var[i] = impl_sign(prox_var) *
305              max(Eigen::numext::abs(prox_var) - lr_scaled * max(l1[0], T(0.f)),
306                  T(0.f)) /
307              (T(1.f) + l2[0] * lr_scaled);
308   }
309 }
310 
311 template <typename T>
ApplyAdadeltaKernel(GpuLaunchConfig cfg,T * var,T * accum,T * accum_update,const T * plr,const T * prho,const T * peps,const T * grad)312 __global__ __launch_bounds__(1024) void ApplyAdadeltaKernel(
313     GpuLaunchConfig cfg, T* var, T* accum, T* accum_update, const T* plr,
314     const T* prho, const T* peps, const T* grad) {
315   T rho = prho[0];
316   T eps = peps[0];
317   T lr = plr[0];
318   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
319     accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
320     T update = Eigen::numext::sqrt(accum_update[i] + eps) * grad[i] *
321                Eigen::numext::rsqrt(accum[i] + eps);
322     var[i] -= update * lr;
323     accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
324   }
325 }
326 
327 template <typename T>
ApplyRMSPropKernel(GpuLaunchConfig cfg,T * var,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)328 __global__ __launch_bounds__(1024) void ApplyRMSPropKernel(
329     GpuLaunchConfig cfg, T* var, T* ms, T* mom, const T* plr, const T* prho,
330     const T* pmomentum, const T* peps, const T* grad) {
331   T rho = prho[0];
332   T eps = peps[0];
333   T lr = plr[0];
334   T momentum = pmomentum[0];
335   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
336     ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
337     mom[i] =
338         mom[i] * momentum + lr * grad[i] * Eigen::numext::rsqrt(eps + ms[i]);
339     var[i] -= mom[i];
340   }
341 }
342 
343 template <typename T>
ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg,T * var,T * mg,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)344 __global__ __launch_bounds__(1024) void ApplyCenteredRMSPropKernel(
345     GpuLaunchConfig cfg, T* var, T* mg, T* ms, T* mom, const T* plr,
346     const T* prho, const T* pmomentum, const T* peps, const T* grad) {
347   T rho = prho[0];
348   T eps = peps[0];
349   T lr = plr[0];
350   T momentum = pmomentum[0];
351   T one_minus_rho = T(1.0) - rho;
352   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
353     ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
354     mg[i] += one_minus_rho * (grad[i] - mg[i]);
355     T denom = (ms[i] - mg[i] * mg[i]) + eps;
356     mom[i] = mom[i] * momentum + lr * grad[i] * Eigen::numext::rsqrt(denom);
357     var[i] -= mom[i];
358   }
359 }
360 
361 namespace kernel_forward {
to_pointers(bool x)362 bool to_pointers(bool x) { return x; }
363 template <class T>
to_pointers(T & x)364 typename T::PointerType to_pointers(T& x) {
365   return x.data();
366 }
367 template <class T>
to_pointers(const T & x)368 typename T::ConstPointerType to_pointers(const T& x) {
369   return x.data();
370 }
371 
372 template <typename T, typename... CallerArgs, typename... KernelArgs>
wrap_kernel_call(void (* func)(KernelArgs...),const GPUDevice & d,T var,CallerArgs...args)373 void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
374                       CallerArgs... args) {
375   int32 data_dim = var.dimension(0);
376   auto config = GetGpuLaunchConfig(data_dim, d);
377   TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
378                               0, d.stream(), config, var.data(),
379                               to_pointers(args)...));
380 }
381 };  // namespace kernel_forward
382 
383 using kernel_forward::wrap_kernel_call;
384 
385 template <typename T>
386 struct ApplyAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagrad387   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
388                   typename TTypes<T>::Flat accum,
389                   typename TTypes<T>::ConstScalar lr,
390                   typename TTypes<T>::ConstFlat grad, bool update_slots) {
391 #if TENSORFLOW_USE_ROCM
392     wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
393                      update_slots);
394 #else
395     if (update_slots) {
396       accum.device(d) += grad.square();
397     }
398     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
399     bcast[0] = grad.dimension(0);
400     Eigen::Sizes<1> single;
401     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
402 #endif
403   }
404 };
405 
406 template <typename T>
407 struct ApplyAdagradV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagradV2408   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
409                   typename TTypes<T>::Flat accum,
410                   typename TTypes<T>::ConstScalar lr,
411                   typename TTypes<T>::ConstScalar epsilon,
412                   typename TTypes<T>::ConstFlat grad, bool update_slots) {
413 #if TENSORFLOW_USE_ROCM
414     wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
415                      update_slots);
416 #else
417     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
418     bcast[0] = grad.dimension(0);
419     Eigen::Sizes<1> single;
420     if (update_slots) {
421       accum.device(d) += grad.square();
422     }
423     const auto update =
424         grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
425     var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
426 #endif
427   }
428 };
429 
430 template <typename T, typename Tindex, bool has_epsilon>
431 struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
operator ()tensorflow::functor::SparseApplyAdagrad432   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
433                     typename TTypes<T>::Matrix accum,
434                     typename TTypes<T>::ConstScalar lr,
435                     typename TTypes<T>::ConstScalar epsilon,
436                     typename TTypes<T>::ConstMatrix grad,
437                     typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
438                     bool update_slots) {
439     const Tindex first_dim_size = var.dimension(0);
440     const Tindex grad_size = grad.size();
441     const Tindex indices_size = indices.size();
442     if (grad_size == 0) {
443       return Status::OK();
444     }
445     GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
446     return GpuLaunchKernel(
447         SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
448         config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
449         lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
450         grad_size, indices_size, update_slots);
451   }
452 };
453 
454 template <typename T>
455 struct ApplyProximalAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyProximalAdagrad456   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
457                   typename TTypes<T>::Flat accum,
458                   typename TTypes<T>::ConstScalar lr,
459                   typename TTypes<T>::ConstScalar l1,
460                   typename TTypes<T>::ConstScalar l2,
461                   typename TTypes<T>::ConstFlat grad) {
462 #if TENSORFLOW_USE_ROCM
463     wrap_kernel_call(ApplyProximalAdagradKernel<T>, d, var, accum, lr, l1, l2,
464                      grad);
465 #else
466     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
467     bcast[0] = grad.dimension(0);
468     Eigen::Sizes<1> single;
469     // Fobos update per paper with Adagrad learning rate.
470     accum.device(d) += grad.square();
471     // Adagrad learning rate.
472     // The following is the GPU equivalent of the CPU version:
473     // auto learning_rate = accum.constant(lr()) * accum.rsqrt();
474     auto lr_bcast = lr.reshape(single).broadcast(bcast);
475     auto l1_bcast = l1.reshape(single).broadcast(bcast);
476     auto l2_bcast = l2.reshape(single).broadcast(bcast);
477     auto learning_rate = lr_bcast * accum.rsqrt();
478     auto prox_var = var;
479     // compute v = w - lr * grad.
480     prox_var.device(d) -= grad * learning_rate;
481     // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
482     var.device(d) = prox_var.sign() *
483                     (prox_var.abs() - learning_rate * l1_bcast.cwiseMax(T(0.f)))
484                         .cwiseMax(T(0.f)) /
485                     (var.constant(T(1.f)) + l2_bcast * learning_rate);
486 #endif
487   }
488 };
489 
490 template <typename T, typename Tindex>
491 struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyProximalAdagrad492   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
493                     typename TTypes<T>::Matrix accum,
494                     typename TTypes<T>::ConstScalar lr,
495                     typename TTypes<T>::ConstScalar l1,
496                     typename TTypes<T>::ConstScalar l2,
497                     typename TTypes<T>::ConstMatrix grad,
498                     typename TTypes<Tindex>::ConstVec indices,
499                     int64 inner_dim) {
500     const Tindex first_dim_size = var.dimension(0);
501     const Tindex grad_size = grad.size();
502     const Tindex indices_size = indices.size();
503     if (grad_size == 0) {
504       return Status::OK();
505     }
506     GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
507     return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
508                            config.block_count, config.thread_per_block, 0,
509                            d.stream(), var.data(), accum.data(), lr.data(),
510                            l1.data(), l2.data(), grad.data(), indices.data(),
511                            first_dim_size, grad_size, indices_size);
512   }
513 };
514 
515 template <typename T>
516 struct ApplyAdadelta<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdadelta517   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
518                   typename TTypes<T>::Flat accum,
519                   typename TTypes<T>::Flat accum_update,
520                   typename TTypes<T>::ConstScalar lr,
521                   typename TTypes<T>::ConstScalar rho,
522                   typename TTypes<T>::ConstScalar epsilon,
523                   typename TTypes<T>::ConstFlat grad) {
524 #if TENSORFLOW_USE_ROCM
525     wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
526                      rho, epsilon, grad);
527 #else
528     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
529     bcast[0] = grad.dimension(0);
530     Eigen::Sizes<1> single;
531 
532     accum.device(d) = accum * rho.reshape(single).broadcast(bcast) +
533                       grad.square() * (grad.constant(T(1)) -
534                                        rho.reshape(single).broadcast(bcast));
535     const auto update =
536         (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
537         (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
538     var.device(d) -= update * lr.reshape(single).broadcast(bcast);
539     accum_update.device(d) =
540         accum_update * rho.reshape(single).broadcast(bcast) +
541         update.square() *
542             (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
543 #endif
544   }
545 };
546 
547 template <typename T>
548 struct ApplyFtrl<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrl549   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
550                   typename TTypes<T>::Flat accum,
551                   typename TTypes<T>::Flat linear,
552                   typename TTypes<T>::ConstFlat grad,
553                   typename TTypes<T>::ConstScalar lr,
554                   typename TTypes<T>::ConstScalar l1,
555                   typename TTypes<T>::ConstScalar l2,
556                   typename TTypes<T>::ConstScalar lr_power) {
557     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
558     bcast[0] = grad.dimension(0);
559     Eigen::Sizes<1> single;
560 
561     auto l1_bcast = l1.reshape(single).broadcast(bcast);
562     auto l2_bcast = l2.reshape(single).broadcast(bcast);
563     auto lr_bcast = lr.reshape(single).broadcast(bcast);
564     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
565     const auto two = static_cast<T>(2.0);
566 
567     auto new_accum = accum + grad.square();
568     auto accum_power = accum.binaryExpr(lr_power_bcast,
569                                         Eigen::internal::scalar_pow_op<T, T>());
570     auto new_accum_power = new_accum.binaryExpr(
571         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
572     linear.device(d) += grad - (new_accum_power - accum_power) * var / lr_bcast;
573     auto x = (l1_bcast * linear.sign() - linear);
574     auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
575     auto pre_shrink = x / y;
576     var.device(d) = (linear.abs() > l1_bcast)
577                         .select(pre_shrink, var.constant(static_cast<T>(0)));
578     accum.device(d) += grad.square();
579   }
580 };
581 
582 template <typename T>
583 struct ApplyFtrlMultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlMultiplyLinearByLr584   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
585                   typename TTypes<T>::Flat accum,
586                   typename TTypes<T>::Flat linear,
587                   typename TTypes<T>::ConstFlat grad,
588                   typename TTypes<T>::ConstScalar lr,
589                   typename TTypes<T>::ConstScalar l1,
590                   typename TTypes<T>::ConstScalar l2,
591                   typename TTypes<T>::ConstScalar lr_power) {
592     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
593     bcast[0] = grad.dimension(0);
594     Eigen::Sizes<1> single;
595 
596     auto lr_bcast = lr.reshape(single).broadcast(bcast);
597     auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
598     auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
599     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
600     const auto two = static_cast<T>(2.0);
601 
602     auto new_accum = accum + grad.square();
603     auto accum_power = accum.binaryExpr(lr_power_bcast,
604                                         Eigen::internal::scalar_pow_op<T, T>());
605     auto new_accum_power = new_accum.binaryExpr(
606         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
607     linear.device(d) += grad * lr_bcast - (new_accum_power - accum_power) * var;
608     auto x = (l1_lr_bcast * linear.sign() - linear);
609     auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
610     auto pre_shrink = x / y;
611     var.device(d) = (linear.abs() > l1_lr_bcast)
612                         .select(pre_shrink, var.constant(static_cast<T>(0)));
613     accum.device(d) += grad.square();
614   }
615 };
616 
617 template <typename T>
618 struct ApplyFtrlV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2619   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
620                   typename TTypes<T>::Flat accum,
621                   typename TTypes<T>::Flat linear,
622                   typename TTypes<T>::ConstFlat grad,
623                   typename TTypes<T>::ConstScalar lr,
624                   typename TTypes<T>::ConstScalar l1,
625                   typename TTypes<T>::ConstScalar l2,
626                   typename TTypes<T>::ConstScalar l2_shrinkage,
627                   typename TTypes<T>::ConstScalar lr_power) {
628     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
629     bcast[0] = grad.dimension(0);
630     Eigen::Sizes<1> single;
631 
632     auto l1_bcast = l1.reshape(single).broadcast(bcast);
633     auto l2_bcast = l2.reshape(single).broadcast(bcast);
634     auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
635     auto lr_bcast = lr.reshape(single).broadcast(bcast);
636     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
637     const auto two = static_cast<T>(2.0);
638 
639     auto new_accum = accum + grad.square();
640     auto accum_power = accum.binaryExpr(lr_power_bcast,
641                                         Eigen::internal::scalar_pow_op<T, T>());
642     auto new_accum_power = new_accum.binaryExpr(
643         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
644     auto grad_with_shrinkage =
645         grad + (var.constant(two) * l2_shrinkage_bcast * var);
646     linear.device(d) +=
647         grad_with_shrinkage - (new_accum_power - accum_power) * var / lr_bcast;
648     auto x = (l1_bcast * linear.sign() - linear);
649     auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
650     auto pre_shrink = x / y;
651     var.device(d) = (linear.abs() > l1_bcast)
652                         .select(pre_shrink, var.constant(static_cast<T>(0)));
653     accum.device(d) += grad.square();
654   }
655 };
656 
657 template <typename T>
658 struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2MultiplyLinearByLr659   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
660                   typename TTypes<T>::Flat accum,
661                   typename TTypes<T>::Flat linear,
662                   typename TTypes<T>::ConstFlat grad,
663                   typename TTypes<T>::ConstScalar lr,
664                   typename TTypes<T>::ConstScalar l1,
665                   typename TTypes<T>::ConstScalar l2,
666                   typename TTypes<T>::ConstScalar l2_shrinkage,
667                   typename TTypes<T>::ConstScalar lr_power) {
668     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
669     bcast[0] = grad.dimension(0);
670     Eigen::Sizes<1> single;
671 
672     auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
673     auto lr_bcast = lr.reshape(single).broadcast(bcast);
674     auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
675     auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
676     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
677     const auto two = static_cast<T>(2.0);
678 
679     auto new_accum = accum + grad.square();
680     auto accum_power = accum.binaryExpr(lr_power_bcast,
681                                         Eigen::internal::scalar_pow_op<T, T>());
682     auto new_accum_power = new_accum.binaryExpr(
683         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
684     auto grad_with_shrinkage =
685         grad + (var.constant(two) * l2_shrinkage_bcast * var);
686     linear.device(d) +=
687         grad_with_shrinkage * lr_bcast - (new_accum_power - accum_power) * var;
688     auto x = (l1_lr_bcast * linear.sign() - linear);
689     auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
690     auto pre_shrink = x / y;
691     var.device(d) = (linear.abs() > l1_lr_bcast)
692                         .select(pre_shrink, var.constant(static_cast<T>(0)));
693     accum.device(d) += grad.square();
694   }
695 };
696 
697 template <typename T, typename Tindex, bool has_l2_shrinkage>
698 struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
operator ()tensorflow::functor::SparseApplyFtrl699   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
700                     typename TTypes<T>::Matrix accum,
701                     typename TTypes<T>::Matrix linear,
702                     typename TTypes<T>::ConstScalar lr,
703                     typename TTypes<T>::ConstScalar l1,
704                     typename TTypes<T>::ConstScalar l2,
705                     typename TTypes<T>::ConstScalar l2_shrinkage,
706                     typename TTypes<T>::ConstScalar lr_power,
707                     typename TTypes<T>::ConstMatrix grad,
708                     typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
709                     bool multiply_linear_by_lr) {
710     const Tindex first_dim_size = var.dimension(0);
711     const Tindex grad_size = grad.size();
712     const Tindex indices_size = indices.size();
713     if (grad_size == 0) {
714       return Status::OK();
715     }
716     // The simpler overload of GetGpuLaunchConfig() would result in a "too many
717     // resources requested for launch" error.
718     auto* device_func = SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>;
719     GpuLaunchConfig config =
720         GetGpuLaunchConfig(grad_size, d, device_func, 0, 0);
721     return GpuLaunchKernel(
722         device_func, config.block_count, config.thread_per_block, 0, d.stream(),
723         /*var=*/var.data(),
724         /*accum=*/accum.data(),
725         /*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
726         /*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
727         /*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
728         /*indices=*/indices.data(), /*param_rows=*/first_dim_size,
729         /*updates_size=*/grad_size,
730         /*indices_size=*/indices_size,
731         /*multiply_linear_by_lr=*/multiply_linear_by_lr);
732   }
733 };
734 
735 template <typename T>
736 struct ApplyMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyMomentum737   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
738                   typename TTypes<T>::Flat accum,
739                   typename TTypes<T>::ConstScalar lr,
740                   typename TTypes<T>::ConstFlat grad,
741                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
742     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
743     bcast[0] = grad.dimension(0);
744     Eigen::Sizes<1> single;
745     accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
746     if (use_nesterov) {
747       var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
748                        accum * momentum.reshape(single).broadcast(bcast) *
749                            lr.reshape(single).broadcast(bcast);
750     } else {
751       var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
752     }
753   }
754 };
755 
756 template <typename T>
757 struct ApplyKerasMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyKerasMomentum758   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
759                   typename TTypes<T>::Flat accum,
760                   typename TTypes<T>::ConstScalar lr,
761                   typename TTypes<T>::ConstFlat grad,
762                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
763     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
764     bcast[0] = grad.dimension(0);
765     Eigen::Sizes<1> single;
766     accum.device(d) = (accum * momentum.reshape(single).broadcast(bcast) -
767                        grad * lr.reshape(single).broadcast(bcast));
768     if (use_nesterov) {
769       var.device(d) += (accum * momentum.reshape(single).broadcast(bcast) -
770                         grad * lr.reshape(single).broadcast(bcast));
771     } else {
772       var.device(d) += accum;
773     }
774   }
775 };
776 
777 template <typename T, typename Tindex>
778 struct SparseApplyKerasMomentum<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyKerasMomentum779   Tindex operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
780                     typename TTypes<T>::Matrix accum,
781                     typename TTypes<T>::ConstScalar lr,
782                     typename TTypes<T>::ConstMatrix grad,
783                     typename TTypes<Tindex>::ConstVec indices,
784                     typename TTypes<T>::ConstScalar momentum,
785                     bool use_nesterov) {
786     const Tindex first_dim_size = var.dimension(0);
787     const Tindex grad_size = grad.size();
788     const Tindex indices_size = indices.size();
789     if (grad_size != 0) {
790       GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
791       TF_CHECK_OK(GpuLaunchKernel(
792           SparseApplyKerasMomentumKernel<T, Tindex>, config.block_count,
793           config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
794           lr.data(), grad.data(), indices.data(), momentum.data(), use_nesterov,
795           first_dim_size, grad_size, indices_size));
796     }
797     return static_cast<Tindex>(-1);
798   }
799 };
800 
801 template <typename T, typename Tindex>
SparseApplyAdadeltaKernel(T * var,T * accum,T * accum_update,const T * lr,const T * rho,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)802 __global__ __launch_bounds__(1024) void SparseApplyAdadeltaKernel(
803     T* var, T* accum, T* accum_update, const T* lr, const T* rho,
804     const T* epsilon, const T* grad, const Tindex* indices, Tindex param_rows,
805     Tindex updates_size, Tindex indices_size) {
806   Tindex col_size = updates_size / indices_size;
807   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
808     Tindex indices_row = grad_index / col_size;
809     Tindex param_row = indices[indices_row];
810     if (param_row < 0 || param_row >= param_rows) {
811       // Ignore indices that are out of range.
812       continue;
813     }
814 
815     // Compute the index of var and accum.
816     Tindex param_index = param_row * col_size + (grad_index % col_size);
817 
818     // Read variables.
819     T var_i = var[param_index];
820     T accum_i = accum[param_index];
821     T accum_update_i = accum_update[param_index];
822     T grad_i = grad[grad_index];
823     const T lr_t = *lr;
824     const T rho_t = *rho;
825     const T epsilon_t = *epsilon;
826 
827     // Variable update computation.
828     accum_i = accum_i * rho_t + grad_i * grad_i * (T(1.0) - rho_t);
829     T update = Eigen::numext::sqrt(accum_update_i + epsilon_t) * grad_i /
830                Eigen::numext::sqrt(accum_i + epsilon_t);
831     var_i = var_i - update * lr_t;
832     accum_update_i =
833         accum_update_i * rho_t + update * update * (T(1.0) - rho_t);
834 
835     // Write update back to variables.
836     var[param_index] = var_i;
837     accum[param_index] = accum_i;
838     accum_update[param_index] = accum_update_i;
839   }
840 }
841 
842 template <typename T, typename Tindex>
843 struct SparseApplyAdadelta<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyAdadelta844   void operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
845                   typename TTypes<T>::Matrix accum,
846                   typename TTypes<T>::Matrix accum_update,
847                   typename TTypes<T>::ConstScalar lr,
848                   typename TTypes<T>::ConstScalar rho,
849                   typename TTypes<T>::ConstScalar epsilon,
850                   typename TTypes<T>::ConstMatrix grad,
851                   typename TTypes<Tindex>::ConstFlat indices) {
852     const Tindex first_dim_size = var.dimension(0);
853     const Tindex grad_size = grad.size();
854     const Tindex indices_size = indices.size();
855     GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
856     TF_CHECK_OK(GpuLaunchKernel(
857         SparseApplyAdadeltaKernel<T, Tindex>, config.block_count,
858         config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
859         accum_update.data(), lr.data(), rho.data(), epsilon.data(), grad.data(),
860         indices.data(), first_dim_size, grad_size, indices_size));
861   }
862 };
863 
864 template <typename T>
865 struct ApplyAdam<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdam866   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
867                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
868                   typename TTypes<T>::ConstScalar beta1_power,
869                   typename TTypes<T>::ConstScalar beta2_power,
870                   typename TTypes<T>::ConstScalar lr,
871                   typename TTypes<T>::ConstScalar beta1,
872                   typename TTypes<T>::ConstScalar beta2,
873                   typename TTypes<T>::ConstScalar epsilon,
874                   typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
875     int32 data_dim = grad.dimension(0);
876     if (data_dim == 0) {
877       return;
878     }  // No work load.
879     GpuLaunchConfig config = GetGpuLaunchConfig(data_dim, d);
880     eigen_assert(static_cast<int64_t>(grad.dimension(0)) +
881                      static_cast<int64_t>(config.block_count) *
882                          static_cast<int64_t>(config.thread_per_block) <
883                  std::numeric_limits<int32>::max());
884 
885     TF_CHECK_OK(GpuLaunchKernel(
886         ApplyAdamKernel<T>, config.block_count, config.thread_per_block, 0,
887         d.stream(), data_dim, var.data(), m.data(), v.data(),
888         beta1_power.data(), beta2_power.data(), lr.data(), beta1.data(),
889         beta2.data(), epsilon.data(), grad.data(), use_nesterov));
890   }
891 };
892 
893 template <typename T>
894 struct ApplyAdamWithAmsgrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdamWithAmsgrad895   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
896                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
897                   typename TTypes<T>::Flat vhat,
898                   typename TTypes<T>::ConstScalar beta1_power,
899                   typename TTypes<T>::ConstScalar beta2_power,
900                   typename TTypes<T>::ConstScalar lr,
901                   typename TTypes<T>::ConstScalar beta1,
902                   typename TTypes<T>::ConstScalar beta2,
903                   typename TTypes<T>::ConstScalar epsilon,
904                   typename TTypes<T>::ConstFlat grad) {
905     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
906     bcast[0] = grad.dimension(0);
907     Eigen::Sizes<1> single;
908     const auto one = static_cast<T>(1.0);
909     m.device(d) =
910         m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
911                 (grad - m);
912     v.device(d) =
913         v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
914                 (grad.square() - v);
915     vhat.device(d) = vhat.cwiseMax(v);
916 
917     var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
918                       (beta1_power.constant(one) - beta1_power))
919                          .reshape(single)
920                          .broadcast(bcast) *
921                      m /
922                      (epsilon.reshape(single).broadcast(bcast) + vhat.sqrt());
923   }
924 };
925 
926 template <typename T>
927 struct ApplyAdaMax<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdaMax928   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
929                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
930                   typename TTypes<T>::ConstScalar beta1_power,
931                   typename TTypes<T>::ConstScalar lr,
932                   typename TTypes<T>::ConstScalar beta1,
933                   typename TTypes<T>::ConstScalar beta2,
934                   typename TTypes<T>::ConstScalar epsilon,
935                   typename TTypes<T>::ConstFlat grad) {
936     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
937     bcast[0] = grad.dimension(0);
938     Eigen::Sizes<1> single;
939     const auto one = static_cast<T>(1.0);
940     m.device(d) +=
941         (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
942         (grad - m);
943     v.device(d) =
944         (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
945     var.device(d) -= lr.reshape(single).broadcast(bcast) /
946                      (beta1_power.constant(one) - beta1_power)
947                          .reshape(single)
948                          .broadcast(bcast) *
949                      (m / (v + epsilon.reshape(single).broadcast(bcast)));
950   }
951 };
952 
953 template <typename T>
954 struct ApplyRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyRMSProp955   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
956                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
957                   typename TTypes<T>::ConstScalar lr,
958                   typename TTypes<T>::ConstScalar rho,
959                   typename TTypes<T>::ConstScalar momentum,
960                   typename TTypes<T>::ConstScalar epsilon,
961                   typename TTypes<T>::ConstFlat grad) {
962 #if TENSORFLOW_USE_ROCM
963     wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
964                      epsilon, grad);
965 #else
966     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
967     bcast[0] = grad.dimension(0);
968     Eigen::Sizes<1> single;
969     const auto one = static_cast<T>(1.0);
970     ms.device(d) =
971         ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
972                  (grad.square() - ms);
973     mom.device(d) =
974         mom * momentum.reshape(single).broadcast(bcast) +
975         lr.reshape(single).broadcast(bcast) * grad /
976             ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
977     var.device(d) -= mom;
978 #endif
979   }
980 };
981 
982 template <typename T>
983 struct ApplyCenteredRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyCenteredRMSProp984   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
985                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
986                   typename TTypes<T>::Flat mom,
987                   typename TTypes<T>::ConstScalar lr,
988                   typename TTypes<T>::ConstScalar rho,
989                   typename TTypes<T>::ConstScalar momentum,
990                   typename TTypes<T>::ConstScalar epsilon,
991                   typename TTypes<T>::ConstFlat grad) {
992 #if TENSORFLOW_USE_ROCM
993     wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
994                      rho, momentum, epsilon, grad);
995 #else
996     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
997     bcast[0] = grad.dimension(0);
998     Eigen::Sizes<1> single;
999     const auto one = static_cast<T>(1.0);
1000     const auto one_minus_rho =
1001         (rho.constant(one) - rho).reshape(single).broadcast(bcast);
1002     ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
1003     mg.device(d) = mg + one_minus_rho * (grad - mg);
1004     auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast);
1005     mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
1006                     lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
1007     var.device(d) -= mom;
1008 #endif
1009   }
1010 };
1011 
1012 template <typename T>
1013 struct ApplyAddSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAddSign1014   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1015                   typename TTypes<T>::Flat m,
1016                   typename TTypes<T>::ConstScalar lr,
1017                   typename TTypes<T>::ConstScalar alpha,
1018                   typename TTypes<T>::ConstScalar sign_decay,
1019                   typename TTypes<T>::ConstScalar beta,
1020                   typename TTypes<T>::ConstFlat grad) {
1021     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1022     bcast[0] = grad.dimension(0);
1023     Eigen::Sizes<1> single;
1024 
1025     // The following is the GPU equivalent of the CPU version:
1026     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1027     const auto one = static_cast<T>(1.0);
1028     auto beta_bcast = beta.reshape(single).broadcast(bcast);
1029     auto one_minus_beta =
1030         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1031     m.device(d) = m * beta_bcast + grad * one_minus_beta;
1032 
1033     // The following is the GPU equivalent of the CPU version:
1034     // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
1035     auto sign_gm = grad.sign() * m.sign();
1036     auto lr_bcast = lr.reshape(single).broadcast(bcast);
1037     auto alpha_bcast = alpha.reshape(single).broadcast(bcast);
1038     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1039     var.device(d) -=
1040         lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad;
1041   }
1042 };
1043 
1044 template <typename T>
1045 struct ApplyPowerSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyPowerSign1046   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1047                   typename TTypes<T>::Flat m,
1048                   typename TTypes<T>::ConstScalar lr,
1049                   typename TTypes<T>::ConstScalar logbase,
1050                   typename TTypes<T>::ConstScalar sign_decay,
1051                   typename TTypes<T>::ConstScalar beta,
1052                   typename TTypes<T>::ConstFlat grad) {
1053     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1054     bcast[0] = grad.dimension(0);
1055     Eigen::Sizes<1> single;
1056 
1057     // The following is the GPU equivalent of the CPU version:
1058     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1059     const auto one = static_cast<T>(1.0);
1060     auto beta_bcast = beta.reshape(single).broadcast(bcast);
1061     auto one_minus_beta =
1062         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1063     m.device(d) = m * beta_bcast + grad * one_minus_beta;
1064 
1065     // The following is the GPU equivalent of the CPU version:
1066     // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
1067     // var.device(d) -= lr() * grad_scale * grad;
1068     auto sign_gm = grad.sign() * m.sign();
1069     auto lr_bcast = lr.reshape(single).broadcast(bcast);
1070     auto logbase_bcast = logbase.reshape(single).broadcast(bcast);
1071     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1072     auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp();
1073     var.device(d) -= lr_bcast * grad_scale * grad;
1074   }
1075 };
1076 
1077 }  // namespace functor
1078 
1079 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
1080 template struct functor::ApplyGradientDescent<GPUDevice, float>;
1081 template struct functor::ApplyGradientDescent<GPUDevice, double>;
1082 template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
1083 template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
1084 
1085 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
1086 template struct functor::ApplyAdagrad<GPUDevice, float>;
1087 template struct functor::ApplyAdagrad<GPUDevice, double>;
1088 template struct functor::ApplyAdagrad<GPUDevice, complex64>;
1089 template struct functor::ApplyAdagrad<GPUDevice, complex128>;
1090 
1091 template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
1092 template struct functor::ApplyAdagradV2<GPUDevice, float>;
1093 template struct functor::ApplyAdagradV2<GPUDevice, double>;
1094 template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
1095 template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
1096 
1097 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                             \
1098   template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
1099                                               /*has_epsilon=*/false>; \
1100   template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
1101                                               /*has_epsilon=*/false>; \
1102   template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
1103                                               /*has_epsilon=*/true>;  \
1104   template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
1105                                               /*has_epsilon=*/true>
1106 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1107 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1108 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1109 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1110 
1111 template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
1112 template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
1113 template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
1114 
1115 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1116                                                     int32>;
1117 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1118                                                     int64>;
1119 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
1120 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
1121 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
1122 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
1123 
1124 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
1125 template struct functor::ApplyAdadelta<GPUDevice, float>;
1126 template struct functor::ApplyAdadelta<GPUDevice, double>;
1127 template struct functor::ApplyAdadelta<GPUDevice, complex64>;
1128 template struct functor::ApplyAdadelta<GPUDevice, complex128>;
1129 
1130 template struct functor::ApplyFtrl<GPUDevice, Eigen::half>;
1131 template struct functor::ApplyFtrl<GPUDevice, float>;
1132 template struct functor::ApplyFtrl<GPUDevice, double>;
1133 
1134 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, Eigen::half>;
1135 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, float>;
1136 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, double>;
1137 
1138 template struct functor::ApplyFtrlV2<GPUDevice, Eigen::half>;
1139 template struct functor::ApplyFtrlV2<GPUDevice, float>;
1140 template struct functor::ApplyFtrlV2<GPUDevice, double>;
1141 
1142 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, Eigen::half>;
1143 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, float>;
1144 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, double>;
1145 
1146 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                               \
1147   template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
1148                                            /*has_l2_shrinkage=*/false>; \
1149   template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
1150                                            /*has_l2_shrinkage=*/false>; \
1151   template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
1152                                            /*has_l2_shrinkage=*/true>;  \
1153   template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
1154                                            /*has_l2_shrinkage=*/true>
1155 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1156 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1157 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1158 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1159 
1160 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
1161 template struct functor::ApplyMomentum<GPUDevice, float>;
1162 template struct functor::ApplyMomentum<GPUDevice, double>;
1163 template struct functor::ApplyMomentum<GPUDevice, complex64>;
1164 template struct functor::ApplyMomentum<GPUDevice, complex128>;
1165 
1166 template struct functor::ApplyKerasMomentum<GPUDevice, Eigen::half>;
1167 template struct functor::ApplyKerasMomentum<GPUDevice, float>;
1168 template struct functor::ApplyKerasMomentum<GPUDevice, double>;
1169 template struct functor::ApplyKerasMomentum<GPUDevice, complex64>;
1170 template struct functor::ApplyKerasMomentum<GPUDevice, complex128>;
1171 
1172 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1173                                                   int32>;
1174 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1175                                                   int64>;
1176 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int32>;
1177 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int64>;
1178 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int32>;
1179 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int64>;
1180 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int32>;
1181 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int64>;
1182 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int32>;
1183 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int64>;
1184 
1185 template struct functor::SparseApplyAdadelta<GPUDevice, Eigen::half, int32>;
1186 template struct functor::SparseApplyAdadelta<GPUDevice, Eigen::half, int64>;
1187 template struct functor::SparseApplyAdadelta<GPUDevice, float, int32>;
1188 template struct functor::SparseApplyAdadelta<GPUDevice, float, int64>;
1189 template struct functor::SparseApplyAdadelta<GPUDevice, double, int32>;
1190 template struct functor::SparseApplyAdadelta<GPUDevice, double, int64>;
1191 template struct functor::SparseApplyAdadelta<GPUDevice, complex64, int32>;
1192 template struct functor::SparseApplyAdadelta<GPUDevice, complex64, int64>;
1193 template struct functor::SparseApplyAdadelta<GPUDevice, complex128, int32>;
1194 template struct functor::SparseApplyAdadelta<GPUDevice, complex128, int64>;
1195 
1196 template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
1197 template struct functor::ApplyAdam<GPUDevice, float>;
1198 template struct functor::ApplyAdam<GPUDevice, double>;
1199 template struct functor::ApplyAdam<GPUDevice, complex64>;
1200 template struct functor::ApplyAdam<GPUDevice, complex128>;
1201 
1202 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, Eigen::half>;
1203 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, float>;
1204 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, double>;
1205 
1206 template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>;
1207 template struct functor::ApplyAdaMax<GPUDevice, float>;
1208 template struct functor::ApplyAdaMax<GPUDevice, double>;
1209 
1210 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
1211 template struct functor::ApplyRMSProp<GPUDevice, float>;
1212 template struct functor::ApplyRMSProp<GPUDevice, double>;
1213 template struct functor::ApplyRMSProp<GPUDevice, complex64>;
1214 template struct functor::ApplyRMSProp<GPUDevice, complex128>;
1215 
1216 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
1217 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
1218 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
1219 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
1220 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
1221 
1222 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>;
1223 template struct functor::ApplyAddSign<GPUDevice, float>;
1224 template struct functor::ApplyAddSign<GPUDevice, double>;
1225 
1226 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>;
1227 template struct functor::ApplyPowerSign<GPUDevice, float>;
1228 template struct functor::ApplyPowerSign<GPUDevice, double>;
1229 
1230 }  // end namespace tensorflow
1231 
1232 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1233