1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/training_ops.h"
22 #include "tensorflow/core/util/gpu_kernel_helper.h"
23
24 namespace tensorflow {
25
26 typedef Eigen::GpuDevice GPUDevice;
27
28 namespace functor {
29
30 template <typename T>
impl_sign(T x)31 __device__ T impl_sign(T x) {
32 return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
33 }
34
35 template <typename T, typename Tindex, bool has_epsilon>
SparseApplyAdagradKernel(T * var,T * accum,const T * lr,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool update_slots)36 __global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
37 T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
38 const Tindex* indices, Tindex param_rows, Tindex updates_size,
39 Tindex indices_size, bool update_slots) {
40 Tindex col_size = updates_size / indices_size;
41 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
42 Tindex indices_row = grad_index / col_size;
43 Tindex param_row = indices[indices_row];
44 if (param_row < 0 || param_row >= param_rows) {
45 // Ignore indices that are out of range.
46 continue;
47 }
48
49 // Compute the index of var and accum.
50 Tindex param_index = param_row * col_size + (grad_index % col_size);
51
52 // Read variables.
53 T var_i = var[param_index];
54 T accum_i = accum[param_index];
55 T grad_i = grad[grad_index];
56 const T lr_t = *lr;
57 const T epsilon_t = *epsilon;
58
59 if (update_slots) {
60 accum_i += grad_i * grad_i;
61 }
62 if (has_epsilon) {
63 var_i -= lr_t * grad_i / (Eigen::numext::sqrt(accum_i) + epsilon_t);
64 } else {
65 var_i -= lr_t * grad_i * Eigen::numext::rsqrt(accum_i);
66 }
67
68 // Write update back to variables.
69 var[param_index] = var_i;
70 accum[param_index] = accum_i;
71 }
72 }
73
74 template <typename T, typename Tindex>
SparseApplyProximalAdagradKernel(T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)75 __global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
76 T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
77 const Tindex* indices, Tindex param_rows, Tindex updates_size,
78 Tindex indices_size) {
79 Tindex col_size = updates_size / indices_size;
80 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
81 Tindex indices_row = grad_index / col_size;
82 Tindex param_row = indices[indices_row];
83 if (param_row < 0 || param_row >= param_rows) {
84 // Ignore indices that are out of range.
85 continue;
86 }
87
88 // Compute the index of var and accum.
89 Tindex param_index = param_row * col_size + (grad_index % col_size);
90
91 // Read variables.
92 T var_i = var[param_index];
93 T accum_i = accum[param_index];
94 T grad_i = grad[grad_index];
95 const T lr_t = *lr;
96 const T l1_t = *l1;
97 const T l2_t = *l2;
98
99 accum_i += grad_i * grad_i;
100 T learning_rate = lr_t * Eigen::numext::rsqrt(accum_i);
101 // compute v = w - lr * grad.
102 T prox_var_i = var_i - grad_i * learning_rate;
103 // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
104 var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
105 max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
106 (T(1.) + l2_t * learning_rate);
107
108 // Write update back to variables.
109 var[param_index] = var_i;
110 accum[param_index] = accum_i;
111 }
112 }
113
114 template <typename T, typename Tindex, bool has_l2_shrinkage>
SparseApplyFtrlKernel(T * var,T * accum,T * linear,const T * lr,const T * l1,const T * l2,const T * l2_shrinkage,const T * lr_power,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool multiply_linear_by_lr)115 __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
116 const T* l1, const T* l2,
117 const T* l2_shrinkage, const T* lr_power,
118 const T* grad, const Tindex* indices,
119 Tindex param_rows, Tindex updates_size,
120 Tindex indices_size,
121 bool multiply_linear_by_lr) {
122 const Tindex col_size = updates_size / indices_size;
123 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
124 const Tindex indices_row = grad_index / col_size;
125 const Tindex param_row = indices[indices_row];
126 if (param_row < 0 || param_row >= param_rows) {
127 // Ignore indices that are out of range.
128 continue;
129 }
130
131 // Compute the index of var and accum.
132 const Tindex param_index = param_row * col_size + (grad_index % col_size);
133
134 // Read variables.
135 T var_i = var[param_index];
136 T accum_i = accum[param_index];
137 T linear_i = linear[param_index];
138 const T grad_i = grad[grad_index];
139 const T lr_t = *lr;
140 const T l1_t = *l1;
141 const T l2_t = *l2;
142 const T lr_power_t = *lr_power;
143
144 const T grad_shr_i =
145 has_l2_shrinkage ? grad_i + static_cast<T>(2) * (*l2_shrinkage) * var_i
146 : grad_i;
147 const T new_accum_i = accum_i + grad_i * grad_i;
148 const bool lr_power_is_neg_half = lr_power_t == static_cast<T>(-0.5);
149 const T pow_new_accum = lr_power_is_neg_half
150 ? Eigen::numext::sqrt(new_accum_i)
151 : pow(new_accum_i, -lr_power_t);
152 const T pow_accum = lr_power_is_neg_half ? Eigen::numext::sqrt(accum_i)
153 : pow(accum_i, -lr_power_t);
154 T linear_change = grad_shr_i * lr_t - (pow_new_accum - pow_accum) * var_i;
155 if (!multiply_linear_by_lr) {
156 linear_change /= lr_t;
157 }
158 linear_i += linear_change;
159
160 T l1_mult = l1_t;
161 if (multiply_linear_by_lr) {
162 l1_mult *= lr_t;
163 }
164 const T l1_reg_adjust = max(min(linear_i, l1_mult), -l1_mult);
165 const T x = l1_reg_adjust - linear_i;
166 T y = pow_new_accum + static_cast<T>(2) * l2_t * lr_t;
167 if (!multiply_linear_by_lr) {
168 y /= lr_t;
169 }
170 var_i = x / y;
171 accum_i = new_accum_i;
172
173 // Write update back to variables.
174 var[param_index] = var_i;
175 accum[param_index] = accum_i;
176 linear[param_index] = linear_i;
177 }
178 }
179
180 template <typename T>
ApplyAdamKernel(int32 data_dim,T * var,T * m,T * v,const T * const beta1_power_,const T * const beta2_power_,const T * const lr_,const T * const beta1_,const T * const beta2_,const T * const epsilon_,const T * grad,bool use_nesterov)181 __global__ __launch_bounds__(1024) void ApplyAdamKernel(
182 int32 data_dim, T* var, T* m, T* v, const T* const beta1_power_,
183 const T* const beta2_power_, const T* const lr_, const T* const beta1_,
184 const T* const beta2_, const T* const epsilon_, const T* grad,
185 bool use_nesterov) {
186 eigen_assert(blockDim.y == 1);
187 eigen_assert(blockDim.z == 1);
188 eigen_assert(gridDim.y == 1);
189 eigen_assert(gridDim.z == 1);
190
191 const T mul_factor =
192 (*lr_) * Eigen::numext::sqrt(static_cast<T>(1.0) - (*beta2_power_)) /
193 (static_cast<T>(1.0) - (*beta1_power_));
194 const T epsilon = (*epsilon_);
195 const T beta1 = (*beta1_);
196 const T one_minus_beta1 = static_cast<T>(1.0) - (beta1);
197 const T one_minus_beta2 = static_cast<T>(1.0) - (*beta2_);
198 const int32 stripe = gridDim.x * blockDim.x;
199
200 for (int32 i = blockIdx.x * blockDim.x + threadIdx.x; i < data_dim;
201 i += stripe) {
202 auto m_i = m[i];
203 auto g_i = grad[i];
204 auto v_i = v[i];
205
206 // Avoid += and -= due to std::complex<T> issues on device for MSVC.
207 m_i = m_i + one_minus_beta1 * (g_i - m_i);
208 v_i = v_i + one_minus_beta2 * (g_i * g_i - v_i);
209 if (use_nesterov) {
210 var[i] = var[i] - mul_factor * (m_i * beta1 + one_minus_beta1 * g_i) /
211 (epsilon + Eigen::numext::sqrt(v_i));
212 } else {
213 var[i] = var[i] - mul_factor * m_i / (epsilon + Eigen::numext::sqrt(v_i));
214 }
215
216 m[i] = m_i;
217 v[i] = v_i;
218 }
219 }
220
221 template <typename T, typename Tindex>
SparseApplyKerasMomentumKernel(T * var,T * accum,const T * lr,const T * grad,const Tindex * indices,const T * momentum,bool use_nesterov,Tindex param_rows,Tindex updates_size,Tindex indices_size)222 __global__ __launch_bounds__(1024) void SparseApplyKerasMomentumKernel(
223 T* var, T* accum, const T* lr, const T* grad, const Tindex* indices,
224 const T* momentum, bool use_nesterov, Tindex param_rows,
225 Tindex updates_size, Tindex indices_size) {
226 Tindex col_size = updates_size / indices_size;
227 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
228 Tindex indices_row = grad_index / col_size;
229 Tindex param_row = indices[indices_row];
230 if (param_row < 0 || param_row >= param_rows) {
231 // Ignore indices that are out of range.
232 continue;
233 }
234
235 // Compute the index of var and accum.
236 Tindex param_index = param_row * col_size + (grad_index % col_size);
237
238 // Read variables.
239 T var_i = var[param_index];
240 T accum_i = accum[param_index];
241 T grad_i = grad[grad_index];
242 const T momentum_t = *momentum;
243 const T lr_t = *lr;
244
245 // Variable update computation.
246 accum_i = momentum_t * accum_i - lr_t * grad_i;
247 // static branching in cuda does not impact performance.
248 // Avoid += due to std::complex<T> issues on device for MSVC.
249 if (use_nesterov) {
250 var_i = var_i + (momentum_t * accum_i - lr_t * grad_i);
251 } else {
252 var_i = var_i + accum_i;
253 }
254
255 // Write update back to variables.
256 var[param_index] = var_i;
257 accum[param_index] = accum_i;
258 }
259 }
260
261 template <typename T>
262 struct ApplyGradientDescent<GPUDevice, T> {
operator ()tensorflow::functor::ApplyGradientDescent263 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
264 typename TTypes<T>::ConstScalar lr,
265 typename TTypes<T>::ConstFlat grad) {
266 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
267 bcast[0] = grad.dimension(0);
268 Eigen::Sizes<1> single;
269 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad;
270 }
271 };
272
273 template <typename T>
ApplyAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * grad,bool update_slots)274 __global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
275 T* var, T* accum,
276 const T* lr,
277 const T* grad,
278 bool update_slots) {
279 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
280 if (update_slots) accum[i] += grad[i] * grad[i];
281 var[i] -= lr[0] * grad[i] * Eigen::numext::rsqrt(accum[i]);
282 }
283 }
284
285 template <typename T>
ApplyAdagradV2Kernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * epsilon,const T * grad,bool update_slots)286 __global__ __launch_bounds__(1024) void ApplyAdagradV2Kernel(
287 GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* epsilon,
288 const T* grad, bool update_slots) {
289 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
290 if (update_slots) accum[i] += grad[i] * grad[i];
291 T update = grad[i] / (Eigen::numext::sqrt(accum[i]) + epsilon[0]);
292 var[i] -= lr[0] * update;
293 }
294 }
295
296 template <typename T>
ApplyProximalAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad)297 __global__ __launch_bounds__(1024) void ApplyProximalAdagradKernel(
298 GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* l1,
299 const T* l2, const T* grad) {
300 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
301 accum[i] += grad[i] * grad[i];
302 T lr_scaled = lr[0] * Eigen::numext::rsqrt(accum[i]);
303 T prox_var = var[i] - grad[i] * lr_scaled;
304 var[i] = impl_sign(prox_var) *
305 max(Eigen::numext::abs(prox_var) - lr_scaled * max(l1[0], T(0.f)),
306 T(0.f)) /
307 (T(1.f) + l2[0] * lr_scaled);
308 }
309 }
310
311 template <typename T>
ApplyAdadeltaKernel(GpuLaunchConfig cfg,T * var,T * accum,T * accum_update,const T * plr,const T * prho,const T * peps,const T * grad)312 __global__ __launch_bounds__(1024) void ApplyAdadeltaKernel(
313 GpuLaunchConfig cfg, T* var, T* accum, T* accum_update, const T* plr,
314 const T* prho, const T* peps, const T* grad) {
315 T rho = prho[0];
316 T eps = peps[0];
317 T lr = plr[0];
318 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
319 accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
320 T update = Eigen::numext::sqrt(accum_update[i] + eps) * grad[i] *
321 Eigen::numext::rsqrt(accum[i] + eps);
322 var[i] -= update * lr;
323 accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
324 }
325 }
326
327 template <typename T>
ApplyRMSPropKernel(GpuLaunchConfig cfg,T * var,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)328 __global__ __launch_bounds__(1024) void ApplyRMSPropKernel(
329 GpuLaunchConfig cfg, T* var, T* ms, T* mom, const T* plr, const T* prho,
330 const T* pmomentum, const T* peps, const T* grad) {
331 T rho = prho[0];
332 T eps = peps[0];
333 T lr = plr[0];
334 T momentum = pmomentum[0];
335 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
336 ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
337 mom[i] =
338 mom[i] * momentum + lr * grad[i] * Eigen::numext::rsqrt(eps + ms[i]);
339 var[i] -= mom[i];
340 }
341 }
342
343 template <typename T>
ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg,T * var,T * mg,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)344 __global__ __launch_bounds__(1024) void ApplyCenteredRMSPropKernel(
345 GpuLaunchConfig cfg, T* var, T* mg, T* ms, T* mom, const T* plr,
346 const T* prho, const T* pmomentum, const T* peps, const T* grad) {
347 T rho = prho[0];
348 T eps = peps[0];
349 T lr = plr[0];
350 T momentum = pmomentum[0];
351 T one_minus_rho = T(1.0) - rho;
352 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
353 ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
354 mg[i] += one_minus_rho * (grad[i] - mg[i]);
355 T denom = (ms[i] - mg[i] * mg[i]) + eps;
356 mom[i] = mom[i] * momentum + lr * grad[i] * Eigen::numext::rsqrt(denom);
357 var[i] -= mom[i];
358 }
359 }
360
361 namespace kernel_forward {
to_pointers(bool x)362 bool to_pointers(bool x) { return x; }
363 template <class T>
to_pointers(T & x)364 typename T::PointerType to_pointers(T& x) {
365 return x.data();
366 }
367 template <class T>
to_pointers(const T & x)368 typename T::ConstPointerType to_pointers(const T& x) {
369 return x.data();
370 }
371
372 template <typename T, typename... CallerArgs, typename... KernelArgs>
wrap_kernel_call(void (* func)(KernelArgs...),const GPUDevice & d,T var,CallerArgs...args)373 void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
374 CallerArgs... args) {
375 int32 data_dim = var.dimension(0);
376 auto config = GetGpuLaunchConfig(data_dim, d);
377 TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
378 0, d.stream(), config, var.data(),
379 to_pointers(args)...));
380 }
381 }; // namespace kernel_forward
382
383 using kernel_forward::wrap_kernel_call;
384
385 template <typename T>
386 struct ApplyAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagrad387 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
388 typename TTypes<T>::Flat accum,
389 typename TTypes<T>::ConstScalar lr,
390 typename TTypes<T>::ConstFlat grad, bool update_slots) {
391 #if TENSORFLOW_USE_ROCM
392 wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
393 update_slots);
394 #else
395 if (update_slots) {
396 accum.device(d) += grad.square();
397 }
398 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
399 bcast[0] = grad.dimension(0);
400 Eigen::Sizes<1> single;
401 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
402 #endif
403 }
404 };
405
406 template <typename T>
407 struct ApplyAdagradV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagradV2408 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
409 typename TTypes<T>::Flat accum,
410 typename TTypes<T>::ConstScalar lr,
411 typename TTypes<T>::ConstScalar epsilon,
412 typename TTypes<T>::ConstFlat grad, bool update_slots) {
413 #if TENSORFLOW_USE_ROCM
414 wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
415 update_slots);
416 #else
417 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
418 bcast[0] = grad.dimension(0);
419 Eigen::Sizes<1> single;
420 if (update_slots) {
421 accum.device(d) += grad.square();
422 }
423 const auto update =
424 grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
425 var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
426 #endif
427 }
428 };
429
430 template <typename T, typename Tindex, bool has_epsilon>
431 struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
operator ()tensorflow::functor::SparseApplyAdagrad432 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
433 typename TTypes<T>::Matrix accum,
434 typename TTypes<T>::ConstScalar lr,
435 typename TTypes<T>::ConstScalar epsilon,
436 typename TTypes<T>::ConstMatrix grad,
437 typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
438 bool update_slots) {
439 const Tindex first_dim_size = var.dimension(0);
440 const Tindex grad_size = grad.size();
441 const Tindex indices_size = indices.size();
442 if (grad_size == 0) {
443 return Status::OK();
444 }
445 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
446 return GpuLaunchKernel(
447 SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
448 config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
449 lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
450 grad_size, indices_size, update_slots);
451 }
452 };
453
454 template <typename T>
455 struct ApplyProximalAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyProximalAdagrad456 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
457 typename TTypes<T>::Flat accum,
458 typename TTypes<T>::ConstScalar lr,
459 typename TTypes<T>::ConstScalar l1,
460 typename TTypes<T>::ConstScalar l2,
461 typename TTypes<T>::ConstFlat grad) {
462 #if TENSORFLOW_USE_ROCM
463 wrap_kernel_call(ApplyProximalAdagradKernel<T>, d, var, accum, lr, l1, l2,
464 grad);
465 #else
466 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
467 bcast[0] = grad.dimension(0);
468 Eigen::Sizes<1> single;
469 // Fobos update per paper with Adagrad learning rate.
470 accum.device(d) += grad.square();
471 // Adagrad learning rate.
472 // The following is the GPU equivalent of the CPU version:
473 // auto learning_rate = accum.constant(lr()) * accum.rsqrt();
474 auto lr_bcast = lr.reshape(single).broadcast(bcast);
475 auto l1_bcast = l1.reshape(single).broadcast(bcast);
476 auto l2_bcast = l2.reshape(single).broadcast(bcast);
477 auto learning_rate = lr_bcast * accum.rsqrt();
478 auto prox_var = var;
479 // compute v = w - lr * grad.
480 prox_var.device(d) -= grad * learning_rate;
481 // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
482 var.device(d) = prox_var.sign() *
483 (prox_var.abs() - learning_rate * l1_bcast.cwiseMax(T(0.f)))
484 .cwiseMax(T(0.f)) /
485 (var.constant(T(1.f)) + l2_bcast * learning_rate);
486 #endif
487 }
488 };
489
490 template <typename T, typename Tindex>
491 struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyProximalAdagrad492 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
493 typename TTypes<T>::Matrix accum,
494 typename TTypes<T>::ConstScalar lr,
495 typename TTypes<T>::ConstScalar l1,
496 typename TTypes<T>::ConstScalar l2,
497 typename TTypes<T>::ConstMatrix grad,
498 typename TTypes<Tindex>::ConstVec indices,
499 int64 inner_dim) {
500 const Tindex first_dim_size = var.dimension(0);
501 const Tindex grad_size = grad.size();
502 const Tindex indices_size = indices.size();
503 if (grad_size == 0) {
504 return Status::OK();
505 }
506 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
507 return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
508 config.block_count, config.thread_per_block, 0,
509 d.stream(), var.data(), accum.data(), lr.data(),
510 l1.data(), l2.data(), grad.data(), indices.data(),
511 first_dim_size, grad_size, indices_size);
512 }
513 };
514
515 template <typename T>
516 struct ApplyAdadelta<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdadelta517 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
518 typename TTypes<T>::Flat accum,
519 typename TTypes<T>::Flat accum_update,
520 typename TTypes<T>::ConstScalar lr,
521 typename TTypes<T>::ConstScalar rho,
522 typename TTypes<T>::ConstScalar epsilon,
523 typename TTypes<T>::ConstFlat grad) {
524 #if TENSORFLOW_USE_ROCM
525 wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
526 rho, epsilon, grad);
527 #else
528 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
529 bcast[0] = grad.dimension(0);
530 Eigen::Sizes<1> single;
531
532 accum.device(d) = accum * rho.reshape(single).broadcast(bcast) +
533 grad.square() * (grad.constant(T(1)) -
534 rho.reshape(single).broadcast(bcast));
535 const auto update =
536 (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
537 (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
538 var.device(d) -= update * lr.reshape(single).broadcast(bcast);
539 accum_update.device(d) =
540 accum_update * rho.reshape(single).broadcast(bcast) +
541 update.square() *
542 (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
543 #endif
544 }
545 };
546
547 template <typename T>
548 struct ApplyFtrl<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrl549 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
550 typename TTypes<T>::Flat accum,
551 typename TTypes<T>::Flat linear,
552 typename TTypes<T>::ConstFlat grad,
553 typename TTypes<T>::ConstScalar lr,
554 typename TTypes<T>::ConstScalar l1,
555 typename TTypes<T>::ConstScalar l2,
556 typename TTypes<T>::ConstScalar lr_power) {
557 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
558 bcast[0] = grad.dimension(0);
559 Eigen::Sizes<1> single;
560
561 auto l1_bcast = l1.reshape(single).broadcast(bcast);
562 auto l2_bcast = l2.reshape(single).broadcast(bcast);
563 auto lr_bcast = lr.reshape(single).broadcast(bcast);
564 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
565 const auto two = static_cast<T>(2.0);
566
567 auto new_accum = accum + grad.square();
568 auto accum_power = accum.binaryExpr(lr_power_bcast,
569 Eigen::internal::scalar_pow_op<T, T>());
570 auto new_accum_power = new_accum.binaryExpr(
571 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
572 linear.device(d) += grad - (new_accum_power - accum_power) * var / lr_bcast;
573 auto x = (l1_bcast * linear.sign() - linear);
574 auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
575 auto pre_shrink = x / y;
576 var.device(d) = (linear.abs() > l1_bcast)
577 .select(pre_shrink, var.constant(static_cast<T>(0)));
578 accum.device(d) += grad.square();
579 }
580 };
581
582 template <typename T>
583 struct ApplyFtrlMultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlMultiplyLinearByLr584 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
585 typename TTypes<T>::Flat accum,
586 typename TTypes<T>::Flat linear,
587 typename TTypes<T>::ConstFlat grad,
588 typename TTypes<T>::ConstScalar lr,
589 typename TTypes<T>::ConstScalar l1,
590 typename TTypes<T>::ConstScalar l2,
591 typename TTypes<T>::ConstScalar lr_power) {
592 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
593 bcast[0] = grad.dimension(0);
594 Eigen::Sizes<1> single;
595
596 auto lr_bcast = lr.reshape(single).broadcast(bcast);
597 auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
598 auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
599 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
600 const auto two = static_cast<T>(2.0);
601
602 auto new_accum = accum + grad.square();
603 auto accum_power = accum.binaryExpr(lr_power_bcast,
604 Eigen::internal::scalar_pow_op<T, T>());
605 auto new_accum_power = new_accum.binaryExpr(
606 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
607 linear.device(d) += grad * lr_bcast - (new_accum_power - accum_power) * var;
608 auto x = (l1_lr_bcast * linear.sign() - linear);
609 auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
610 auto pre_shrink = x / y;
611 var.device(d) = (linear.abs() > l1_lr_bcast)
612 .select(pre_shrink, var.constant(static_cast<T>(0)));
613 accum.device(d) += grad.square();
614 }
615 };
616
617 template <typename T>
618 struct ApplyFtrlV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2619 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
620 typename TTypes<T>::Flat accum,
621 typename TTypes<T>::Flat linear,
622 typename TTypes<T>::ConstFlat grad,
623 typename TTypes<T>::ConstScalar lr,
624 typename TTypes<T>::ConstScalar l1,
625 typename TTypes<T>::ConstScalar l2,
626 typename TTypes<T>::ConstScalar l2_shrinkage,
627 typename TTypes<T>::ConstScalar lr_power) {
628 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
629 bcast[0] = grad.dimension(0);
630 Eigen::Sizes<1> single;
631
632 auto l1_bcast = l1.reshape(single).broadcast(bcast);
633 auto l2_bcast = l2.reshape(single).broadcast(bcast);
634 auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
635 auto lr_bcast = lr.reshape(single).broadcast(bcast);
636 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
637 const auto two = static_cast<T>(2.0);
638
639 auto new_accum = accum + grad.square();
640 auto accum_power = accum.binaryExpr(lr_power_bcast,
641 Eigen::internal::scalar_pow_op<T, T>());
642 auto new_accum_power = new_accum.binaryExpr(
643 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
644 auto grad_with_shrinkage =
645 grad + (var.constant(two) * l2_shrinkage_bcast * var);
646 linear.device(d) +=
647 grad_with_shrinkage - (new_accum_power - accum_power) * var / lr_bcast;
648 auto x = (l1_bcast * linear.sign() - linear);
649 auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
650 auto pre_shrink = x / y;
651 var.device(d) = (linear.abs() > l1_bcast)
652 .select(pre_shrink, var.constant(static_cast<T>(0)));
653 accum.device(d) += grad.square();
654 }
655 };
656
657 template <typename T>
658 struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2MultiplyLinearByLr659 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
660 typename TTypes<T>::Flat accum,
661 typename TTypes<T>::Flat linear,
662 typename TTypes<T>::ConstFlat grad,
663 typename TTypes<T>::ConstScalar lr,
664 typename TTypes<T>::ConstScalar l1,
665 typename TTypes<T>::ConstScalar l2,
666 typename TTypes<T>::ConstScalar l2_shrinkage,
667 typename TTypes<T>::ConstScalar lr_power) {
668 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
669 bcast[0] = grad.dimension(0);
670 Eigen::Sizes<1> single;
671
672 auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
673 auto lr_bcast = lr.reshape(single).broadcast(bcast);
674 auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
675 auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
676 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
677 const auto two = static_cast<T>(2.0);
678
679 auto new_accum = accum + grad.square();
680 auto accum_power = accum.binaryExpr(lr_power_bcast,
681 Eigen::internal::scalar_pow_op<T, T>());
682 auto new_accum_power = new_accum.binaryExpr(
683 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
684 auto grad_with_shrinkage =
685 grad + (var.constant(two) * l2_shrinkage_bcast * var);
686 linear.device(d) +=
687 grad_with_shrinkage * lr_bcast - (new_accum_power - accum_power) * var;
688 auto x = (l1_lr_bcast * linear.sign() - linear);
689 auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
690 auto pre_shrink = x / y;
691 var.device(d) = (linear.abs() > l1_lr_bcast)
692 .select(pre_shrink, var.constant(static_cast<T>(0)));
693 accum.device(d) += grad.square();
694 }
695 };
696
697 template <typename T, typename Tindex, bool has_l2_shrinkage>
698 struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
operator ()tensorflow::functor::SparseApplyFtrl699 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
700 typename TTypes<T>::Matrix accum,
701 typename TTypes<T>::Matrix linear,
702 typename TTypes<T>::ConstScalar lr,
703 typename TTypes<T>::ConstScalar l1,
704 typename TTypes<T>::ConstScalar l2,
705 typename TTypes<T>::ConstScalar l2_shrinkage,
706 typename TTypes<T>::ConstScalar lr_power,
707 typename TTypes<T>::ConstMatrix grad,
708 typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
709 bool multiply_linear_by_lr) {
710 const Tindex first_dim_size = var.dimension(0);
711 const Tindex grad_size = grad.size();
712 const Tindex indices_size = indices.size();
713 if (grad_size == 0) {
714 return Status::OK();
715 }
716 // The simpler overload of GetGpuLaunchConfig() would result in a "too many
717 // resources requested for launch" error.
718 auto* device_func = SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>;
719 GpuLaunchConfig config =
720 GetGpuLaunchConfig(grad_size, d, device_func, 0, 0);
721 return GpuLaunchKernel(
722 device_func, config.block_count, config.thread_per_block, 0, d.stream(),
723 /*var=*/var.data(),
724 /*accum=*/accum.data(),
725 /*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
726 /*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
727 /*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
728 /*indices=*/indices.data(), /*param_rows=*/first_dim_size,
729 /*updates_size=*/grad_size,
730 /*indices_size=*/indices_size,
731 /*multiply_linear_by_lr=*/multiply_linear_by_lr);
732 }
733 };
734
735 template <typename T>
736 struct ApplyMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyMomentum737 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
738 typename TTypes<T>::Flat accum,
739 typename TTypes<T>::ConstScalar lr,
740 typename TTypes<T>::ConstFlat grad,
741 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
742 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
743 bcast[0] = grad.dimension(0);
744 Eigen::Sizes<1> single;
745 accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
746 if (use_nesterov) {
747 var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
748 accum * momentum.reshape(single).broadcast(bcast) *
749 lr.reshape(single).broadcast(bcast);
750 } else {
751 var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
752 }
753 }
754 };
755
756 template <typename T>
757 struct ApplyKerasMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyKerasMomentum758 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
759 typename TTypes<T>::Flat accum,
760 typename TTypes<T>::ConstScalar lr,
761 typename TTypes<T>::ConstFlat grad,
762 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
763 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
764 bcast[0] = grad.dimension(0);
765 Eigen::Sizes<1> single;
766 accum.device(d) = (accum * momentum.reshape(single).broadcast(bcast) -
767 grad * lr.reshape(single).broadcast(bcast));
768 if (use_nesterov) {
769 var.device(d) += (accum * momentum.reshape(single).broadcast(bcast) -
770 grad * lr.reshape(single).broadcast(bcast));
771 } else {
772 var.device(d) += accum;
773 }
774 }
775 };
776
777 template <typename T, typename Tindex>
778 struct SparseApplyKerasMomentum<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyKerasMomentum779 Tindex operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
780 typename TTypes<T>::Matrix accum,
781 typename TTypes<T>::ConstScalar lr,
782 typename TTypes<T>::ConstMatrix grad,
783 typename TTypes<Tindex>::ConstVec indices,
784 typename TTypes<T>::ConstScalar momentum,
785 bool use_nesterov) {
786 const Tindex first_dim_size = var.dimension(0);
787 const Tindex grad_size = grad.size();
788 const Tindex indices_size = indices.size();
789 if (grad_size != 0) {
790 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
791 TF_CHECK_OK(GpuLaunchKernel(
792 SparseApplyKerasMomentumKernel<T, Tindex>, config.block_count,
793 config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
794 lr.data(), grad.data(), indices.data(), momentum.data(), use_nesterov,
795 first_dim_size, grad_size, indices_size));
796 }
797 return static_cast<Tindex>(-1);
798 }
799 };
800
801 template <typename T, typename Tindex>
SparseApplyAdadeltaKernel(T * var,T * accum,T * accum_update,const T * lr,const T * rho,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)802 __global__ __launch_bounds__(1024) void SparseApplyAdadeltaKernel(
803 T* var, T* accum, T* accum_update, const T* lr, const T* rho,
804 const T* epsilon, const T* grad, const Tindex* indices, Tindex param_rows,
805 Tindex updates_size, Tindex indices_size) {
806 Tindex col_size = updates_size / indices_size;
807 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
808 Tindex indices_row = grad_index / col_size;
809 Tindex param_row = indices[indices_row];
810 if (param_row < 0 || param_row >= param_rows) {
811 // Ignore indices that are out of range.
812 continue;
813 }
814
815 // Compute the index of var and accum.
816 Tindex param_index = param_row * col_size + (grad_index % col_size);
817
818 // Read variables.
819 T var_i = var[param_index];
820 T accum_i = accum[param_index];
821 T accum_update_i = accum_update[param_index];
822 T grad_i = grad[grad_index];
823 const T lr_t = *lr;
824 const T rho_t = *rho;
825 const T epsilon_t = *epsilon;
826
827 // Variable update computation.
828 accum_i = accum_i * rho_t + grad_i * grad_i * (T(1.0) - rho_t);
829 T update = Eigen::numext::sqrt(accum_update_i + epsilon_t) * grad_i /
830 Eigen::numext::sqrt(accum_i + epsilon_t);
831 var_i = var_i - update * lr_t;
832 accum_update_i =
833 accum_update_i * rho_t + update * update * (T(1.0) - rho_t);
834
835 // Write update back to variables.
836 var[param_index] = var_i;
837 accum[param_index] = accum_i;
838 accum_update[param_index] = accum_update_i;
839 }
840 }
841
842 template <typename T, typename Tindex>
843 struct SparseApplyAdadelta<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyAdadelta844 void operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
845 typename TTypes<T>::Matrix accum,
846 typename TTypes<T>::Matrix accum_update,
847 typename TTypes<T>::ConstScalar lr,
848 typename TTypes<T>::ConstScalar rho,
849 typename TTypes<T>::ConstScalar epsilon,
850 typename TTypes<T>::ConstMatrix grad,
851 typename TTypes<Tindex>::ConstFlat indices) {
852 const Tindex first_dim_size = var.dimension(0);
853 const Tindex grad_size = grad.size();
854 const Tindex indices_size = indices.size();
855 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
856 TF_CHECK_OK(GpuLaunchKernel(
857 SparseApplyAdadeltaKernel<T, Tindex>, config.block_count,
858 config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
859 accum_update.data(), lr.data(), rho.data(), epsilon.data(), grad.data(),
860 indices.data(), first_dim_size, grad_size, indices_size));
861 }
862 };
863
864 template <typename T>
865 struct ApplyAdam<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdam866 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
867 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
868 typename TTypes<T>::ConstScalar beta1_power,
869 typename TTypes<T>::ConstScalar beta2_power,
870 typename TTypes<T>::ConstScalar lr,
871 typename TTypes<T>::ConstScalar beta1,
872 typename TTypes<T>::ConstScalar beta2,
873 typename TTypes<T>::ConstScalar epsilon,
874 typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
875 int32 data_dim = grad.dimension(0);
876 if (data_dim == 0) {
877 return;
878 } // No work load.
879 GpuLaunchConfig config = GetGpuLaunchConfig(data_dim, d);
880 eigen_assert(static_cast<int64_t>(grad.dimension(0)) +
881 static_cast<int64_t>(config.block_count) *
882 static_cast<int64_t>(config.thread_per_block) <
883 std::numeric_limits<int32>::max());
884
885 TF_CHECK_OK(GpuLaunchKernel(
886 ApplyAdamKernel<T>, config.block_count, config.thread_per_block, 0,
887 d.stream(), data_dim, var.data(), m.data(), v.data(),
888 beta1_power.data(), beta2_power.data(), lr.data(), beta1.data(),
889 beta2.data(), epsilon.data(), grad.data(), use_nesterov));
890 }
891 };
892
893 template <typename T>
894 struct ApplyAdamWithAmsgrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdamWithAmsgrad895 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
896 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
897 typename TTypes<T>::Flat vhat,
898 typename TTypes<T>::ConstScalar beta1_power,
899 typename TTypes<T>::ConstScalar beta2_power,
900 typename TTypes<T>::ConstScalar lr,
901 typename TTypes<T>::ConstScalar beta1,
902 typename TTypes<T>::ConstScalar beta2,
903 typename TTypes<T>::ConstScalar epsilon,
904 typename TTypes<T>::ConstFlat grad) {
905 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
906 bcast[0] = grad.dimension(0);
907 Eigen::Sizes<1> single;
908 const auto one = static_cast<T>(1.0);
909 m.device(d) =
910 m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
911 (grad - m);
912 v.device(d) =
913 v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
914 (grad.square() - v);
915 vhat.device(d) = vhat.cwiseMax(v);
916
917 var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
918 (beta1_power.constant(one) - beta1_power))
919 .reshape(single)
920 .broadcast(bcast) *
921 m /
922 (epsilon.reshape(single).broadcast(bcast) + vhat.sqrt());
923 }
924 };
925
926 template <typename T>
927 struct ApplyAdaMax<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdaMax928 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
929 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
930 typename TTypes<T>::ConstScalar beta1_power,
931 typename TTypes<T>::ConstScalar lr,
932 typename TTypes<T>::ConstScalar beta1,
933 typename TTypes<T>::ConstScalar beta2,
934 typename TTypes<T>::ConstScalar epsilon,
935 typename TTypes<T>::ConstFlat grad) {
936 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
937 bcast[0] = grad.dimension(0);
938 Eigen::Sizes<1> single;
939 const auto one = static_cast<T>(1.0);
940 m.device(d) +=
941 (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
942 (grad - m);
943 v.device(d) =
944 (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
945 var.device(d) -= lr.reshape(single).broadcast(bcast) /
946 (beta1_power.constant(one) - beta1_power)
947 .reshape(single)
948 .broadcast(bcast) *
949 (m / (v + epsilon.reshape(single).broadcast(bcast)));
950 }
951 };
952
953 template <typename T>
954 struct ApplyRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyRMSProp955 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
956 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
957 typename TTypes<T>::ConstScalar lr,
958 typename TTypes<T>::ConstScalar rho,
959 typename TTypes<T>::ConstScalar momentum,
960 typename TTypes<T>::ConstScalar epsilon,
961 typename TTypes<T>::ConstFlat grad) {
962 #if TENSORFLOW_USE_ROCM
963 wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
964 epsilon, grad);
965 #else
966 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
967 bcast[0] = grad.dimension(0);
968 Eigen::Sizes<1> single;
969 const auto one = static_cast<T>(1.0);
970 ms.device(d) =
971 ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
972 (grad.square() - ms);
973 mom.device(d) =
974 mom * momentum.reshape(single).broadcast(bcast) +
975 lr.reshape(single).broadcast(bcast) * grad /
976 ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
977 var.device(d) -= mom;
978 #endif
979 }
980 };
981
982 template <typename T>
983 struct ApplyCenteredRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyCenteredRMSProp984 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
985 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
986 typename TTypes<T>::Flat mom,
987 typename TTypes<T>::ConstScalar lr,
988 typename TTypes<T>::ConstScalar rho,
989 typename TTypes<T>::ConstScalar momentum,
990 typename TTypes<T>::ConstScalar epsilon,
991 typename TTypes<T>::ConstFlat grad) {
992 #if TENSORFLOW_USE_ROCM
993 wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
994 rho, momentum, epsilon, grad);
995 #else
996 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
997 bcast[0] = grad.dimension(0);
998 Eigen::Sizes<1> single;
999 const auto one = static_cast<T>(1.0);
1000 const auto one_minus_rho =
1001 (rho.constant(one) - rho).reshape(single).broadcast(bcast);
1002 ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
1003 mg.device(d) = mg + one_minus_rho * (grad - mg);
1004 auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast);
1005 mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
1006 lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
1007 var.device(d) -= mom;
1008 #endif
1009 }
1010 };
1011
1012 template <typename T>
1013 struct ApplyAddSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAddSign1014 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1015 typename TTypes<T>::Flat m,
1016 typename TTypes<T>::ConstScalar lr,
1017 typename TTypes<T>::ConstScalar alpha,
1018 typename TTypes<T>::ConstScalar sign_decay,
1019 typename TTypes<T>::ConstScalar beta,
1020 typename TTypes<T>::ConstFlat grad) {
1021 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1022 bcast[0] = grad.dimension(0);
1023 Eigen::Sizes<1> single;
1024
1025 // The following is the GPU equivalent of the CPU version:
1026 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1027 const auto one = static_cast<T>(1.0);
1028 auto beta_bcast = beta.reshape(single).broadcast(bcast);
1029 auto one_minus_beta =
1030 (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1031 m.device(d) = m * beta_bcast + grad * one_minus_beta;
1032
1033 // The following is the GPU equivalent of the CPU version:
1034 // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
1035 auto sign_gm = grad.sign() * m.sign();
1036 auto lr_bcast = lr.reshape(single).broadcast(bcast);
1037 auto alpha_bcast = alpha.reshape(single).broadcast(bcast);
1038 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1039 var.device(d) -=
1040 lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad;
1041 }
1042 };
1043
1044 template <typename T>
1045 struct ApplyPowerSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyPowerSign1046 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1047 typename TTypes<T>::Flat m,
1048 typename TTypes<T>::ConstScalar lr,
1049 typename TTypes<T>::ConstScalar logbase,
1050 typename TTypes<T>::ConstScalar sign_decay,
1051 typename TTypes<T>::ConstScalar beta,
1052 typename TTypes<T>::ConstFlat grad) {
1053 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1054 bcast[0] = grad.dimension(0);
1055 Eigen::Sizes<1> single;
1056
1057 // The following is the GPU equivalent of the CPU version:
1058 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1059 const auto one = static_cast<T>(1.0);
1060 auto beta_bcast = beta.reshape(single).broadcast(bcast);
1061 auto one_minus_beta =
1062 (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1063 m.device(d) = m * beta_bcast + grad * one_minus_beta;
1064
1065 // The following is the GPU equivalent of the CPU version:
1066 // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
1067 // var.device(d) -= lr() * grad_scale * grad;
1068 auto sign_gm = grad.sign() * m.sign();
1069 auto lr_bcast = lr.reshape(single).broadcast(bcast);
1070 auto logbase_bcast = logbase.reshape(single).broadcast(bcast);
1071 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1072 auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp();
1073 var.device(d) -= lr_bcast * grad_scale * grad;
1074 }
1075 };
1076
1077 } // namespace functor
1078
1079 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
1080 template struct functor::ApplyGradientDescent<GPUDevice, float>;
1081 template struct functor::ApplyGradientDescent<GPUDevice, double>;
1082 template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
1083 template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
1084
1085 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
1086 template struct functor::ApplyAdagrad<GPUDevice, float>;
1087 template struct functor::ApplyAdagrad<GPUDevice, double>;
1088 template struct functor::ApplyAdagrad<GPUDevice, complex64>;
1089 template struct functor::ApplyAdagrad<GPUDevice, complex128>;
1090
1091 template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
1092 template struct functor::ApplyAdagradV2<GPUDevice, float>;
1093 template struct functor::ApplyAdagradV2<GPUDevice, double>;
1094 template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
1095 template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
1096
1097 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
1098 template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
1099 /*has_epsilon=*/false>; \
1100 template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
1101 /*has_epsilon=*/false>; \
1102 template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
1103 /*has_epsilon=*/true>; \
1104 template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
1105 /*has_epsilon=*/true>
1106 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1107 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1108 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1109 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1110
1111 template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
1112 template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
1113 template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
1114
1115 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1116 int32>;
1117 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1118 int64>;
1119 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
1120 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
1121 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
1122 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
1123
1124 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
1125 template struct functor::ApplyAdadelta<GPUDevice, float>;
1126 template struct functor::ApplyAdadelta<GPUDevice, double>;
1127 template struct functor::ApplyAdadelta<GPUDevice, complex64>;
1128 template struct functor::ApplyAdadelta<GPUDevice, complex128>;
1129
1130 template struct functor::ApplyFtrl<GPUDevice, Eigen::half>;
1131 template struct functor::ApplyFtrl<GPUDevice, float>;
1132 template struct functor::ApplyFtrl<GPUDevice, double>;
1133
1134 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, Eigen::half>;
1135 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, float>;
1136 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, double>;
1137
1138 template struct functor::ApplyFtrlV2<GPUDevice, Eigen::half>;
1139 template struct functor::ApplyFtrlV2<GPUDevice, float>;
1140 template struct functor::ApplyFtrlV2<GPUDevice, double>;
1141
1142 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, Eigen::half>;
1143 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, float>;
1144 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, double>;
1145
1146 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
1147 template struct functor::SparseApplyFtrl<GPUDevice, T, int32, \
1148 /*has_l2_shrinkage=*/false>; \
1149 template struct functor::SparseApplyFtrl<GPUDevice, T, int64, \
1150 /*has_l2_shrinkage=*/false>; \
1151 template struct functor::SparseApplyFtrl<GPUDevice, T, int32, \
1152 /*has_l2_shrinkage=*/true>; \
1153 template struct functor::SparseApplyFtrl<GPUDevice, T, int64, \
1154 /*has_l2_shrinkage=*/true>
1155 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1156 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1157 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1158 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1159
1160 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
1161 template struct functor::ApplyMomentum<GPUDevice, float>;
1162 template struct functor::ApplyMomentum<GPUDevice, double>;
1163 template struct functor::ApplyMomentum<GPUDevice, complex64>;
1164 template struct functor::ApplyMomentum<GPUDevice, complex128>;
1165
1166 template struct functor::ApplyKerasMomentum<GPUDevice, Eigen::half>;
1167 template struct functor::ApplyKerasMomentum<GPUDevice, float>;
1168 template struct functor::ApplyKerasMomentum<GPUDevice, double>;
1169 template struct functor::ApplyKerasMomentum<GPUDevice, complex64>;
1170 template struct functor::ApplyKerasMomentum<GPUDevice, complex128>;
1171
1172 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1173 int32>;
1174 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1175 int64>;
1176 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int32>;
1177 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int64>;
1178 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int32>;
1179 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int64>;
1180 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int32>;
1181 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int64>;
1182 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int32>;
1183 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int64>;
1184
1185 template struct functor::SparseApplyAdadelta<GPUDevice, Eigen::half, int32>;
1186 template struct functor::SparseApplyAdadelta<GPUDevice, Eigen::half, int64>;
1187 template struct functor::SparseApplyAdadelta<GPUDevice, float, int32>;
1188 template struct functor::SparseApplyAdadelta<GPUDevice, float, int64>;
1189 template struct functor::SparseApplyAdadelta<GPUDevice, double, int32>;
1190 template struct functor::SparseApplyAdadelta<GPUDevice, double, int64>;
1191 template struct functor::SparseApplyAdadelta<GPUDevice, complex64, int32>;
1192 template struct functor::SparseApplyAdadelta<GPUDevice, complex64, int64>;
1193 template struct functor::SparseApplyAdadelta<GPUDevice, complex128, int32>;
1194 template struct functor::SparseApplyAdadelta<GPUDevice, complex128, int64>;
1195
1196 template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
1197 template struct functor::ApplyAdam<GPUDevice, float>;
1198 template struct functor::ApplyAdam<GPUDevice, double>;
1199 template struct functor::ApplyAdam<GPUDevice, complex64>;
1200 template struct functor::ApplyAdam<GPUDevice, complex128>;
1201
1202 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, Eigen::half>;
1203 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, float>;
1204 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, double>;
1205
1206 template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>;
1207 template struct functor::ApplyAdaMax<GPUDevice, float>;
1208 template struct functor::ApplyAdaMax<GPUDevice, double>;
1209
1210 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
1211 template struct functor::ApplyRMSProp<GPUDevice, float>;
1212 template struct functor::ApplyRMSProp<GPUDevice, double>;
1213 template struct functor::ApplyRMSProp<GPUDevice, complex64>;
1214 template struct functor::ApplyRMSProp<GPUDevice, complex128>;
1215
1216 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
1217 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
1218 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
1219 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
1220 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
1221
1222 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>;
1223 template struct functor::ApplyAddSign<GPUDevice, float>;
1224 template struct functor::ApplyAddSign<GPUDevice, double>;
1225
1226 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>;
1227 template struct functor::ApplyPowerSign<GPUDevice, float>;
1228 template struct functor::ApplyPowerSign<GPUDevice, double>;
1229
1230 } // end namespace tensorflow
1231
1232 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1233