xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/FusedAdamKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedAdam.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11 
12 namespace{
13 
14 template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
15 typename std::enable_if<
16     std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17     void>::
adam_math(scalar_t * param_ptr,scalar_t * exp_avg_ptr,scalar_t * exp_avg_sq_ptr,scalar_t * grad_ptr,scalar_t * max_exp_avg_sq_ptr,double lr,double bias_correction1,double bias_correction2,double exp_avg_grad_coefficient,double exp_avg_sq_grad_coefficient,double bias_correction2_sqrt,double eps,double weight_decay,double beta2,bool amsgrad,bool maximize,const float * grad_scale_ptr,int64_t size)18     type inline adam_math(
19   scalar_t* param_ptr,
20   scalar_t* exp_avg_ptr,
21   scalar_t* exp_avg_sq_ptr,
22   scalar_t* grad_ptr,
23   scalar_t* max_exp_avg_sq_ptr,
24   double lr,
25   double bias_correction1,
26   double bias_correction2,
27   double exp_avg_grad_coefficient,
28   double exp_avg_sq_grad_coefficient,
29   double bias_correction2_sqrt,
30   double eps,
31   double weight_decay,
32   double beta2,
33   bool amsgrad,
34   bool maximize,
35   const float* grad_scale_ptr,
36   int64_t size
37 ){
38   double step_size = lr / bias_correction1;
39   using lpVec = at::vec::Vectorized<scalar_t>;
40   using fVec = at::vec::Vectorized<opmath_t>;
41   int64_t d = 0;
42   for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
43     lpVec param_lpvec = lpVec::loadu(param_ptr + d);
44     auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
45     lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
46     auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
47     if (grad_scale_ptr) {
48       grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
49       grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
50       lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
51       grad_vec_to_store.store(grad_ptr + d);
52     }
53     if (maximize){
54       grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
55       grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
56     }
57     if (weight_decay != 0.f){
58       if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
59         grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay));
60         grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay));
61        } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
62         param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay));
63         param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay));
64       }
65     }
66 
67     lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d);
68     auto [exp_avg_vec1, exp_avg_vec2] = vec::convert_to_float<scalar_t>(exp_avg_lpvec);
69 
70     // exp_avg.lerp_(grad, 1 - beta1)
71     const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient));
72     auto mask = lerp_weight.abs() < fVec(0.5);
73     auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask);
74 
75     auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask);
76     exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1);
77 
78     auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask);
79     exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2);
80 
81     lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d);
82     auto [exp_avg_sq_vec1, exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(exp_avg_sq_lpvec);
83     exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) +
84         fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1;
85     exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) +
86         fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2;
87 
88     vec::convert_from_float<scalar_t>(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d);
89     vec::convert_from_float<scalar_t>(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d);
90 
91     fVec denom_vec1, denom_vec2;
92     if (amsgrad) {
93       lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d);
94       auto [max_exp_avg_sq_vec1, max_exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(max_exp_avg_sq_lpvec);
95       max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1);
96       max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2);
97       vec::convert_from_float<scalar_t>(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d);
98       denom_vec1 =
99           (max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
100       denom_vec2 =
101           (max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
102     } else {
103       denom_vec1 =
104           (exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
105       denom_vec2 =
106           (exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
107     }
108     param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1;
109     param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2;
110     vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
111   }
112   for (; d < size; d++) {
113     opmath_t grad_val = grad_ptr[d];
114     opmath_t param_val = param_ptr[d];
115     if (grad_scale_ptr) {
116       grad_val = grad_ptr[d] / float(*grad_scale_ptr);
117       grad_ptr[d] = grad_val;
118     }
119     if (maximize) grad_val = -grad_val;
120     if (weight_decay != 0.f){
121       if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
122         grad_val += param_val * opmath_t(weight_decay);
123       } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
124         param_val = param_val * opmath_t(1 - lr * weight_decay);
125       }
126     }
127     // exp_avg.lerp_(grad, 1 - beta1)
128     opmath_t exp_avg_var = exp_avg_ptr[d];
129     auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5);
130     if (is_lerp_weight_small) {
131       exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var);
132     } else {
133       exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient));
134     }
135     exp_avg_ptr[d] = scalar_t(exp_avg_var);
136     opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d];
137     exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2);
138     exp_avg_sq_var = exp_avg_sq_var +
139         opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
140     exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var);
141     opmath_t demon_val;
142     if (amsgrad) {
143       opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d];
144       max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var);
145       max_exp_avg_sq_ptr[d] =
146           scalar_t(max_exp_avg_sq_var);
147       demon_val =
148           std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
149     } else {
150       demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
151     }
152     param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val;
153   }
154 }
155 
156 
157 template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
158 typename std::enable_if<
159     std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
160     void>::
adam_math(scalar_t * param_ptr,scalar_t * exp_avg_ptr,scalar_t * exp_avg_sq_ptr,scalar_t * grad_ptr,scalar_t * max_exp_avg_sq_ptr,double lr,double bias_correction1,double bias_correction2,double exp_avg_grad_coefficient,double exp_avg_sq_grad_coefficient,double bias_correction2_sqrt,double eps,double weight_decay,double beta2,bool amsgrad,bool maximize,const float * grad_scale_ptr,int64_t size)161     type inline adam_math(
162   scalar_t* param_ptr,
163   scalar_t* exp_avg_ptr,
164   scalar_t* exp_avg_sq_ptr,
165   scalar_t* grad_ptr,
166   scalar_t* max_exp_avg_sq_ptr,
167   double lr,
168   double bias_correction1,
169   double bias_correction2,
170   double exp_avg_grad_coefficient,
171   double exp_avg_sq_grad_coefficient,
172   double bias_correction2_sqrt,
173   double eps,
174   double weight_decay,
175   double beta2,
176   bool amsgrad,
177   bool maximize,
178   const float* grad_scale_ptr,
179   int64_t size
180 ){
181   double step_size = lr / bias_correction1;
182   using Vec = at::vec::Vectorized<scalar_t>;
183   int64_t d = 0;
184   for (; d < size - (size % Vec::size()); d += Vec::size()) {
185     Vec param_vec = Vec::loadu(param_ptr + d);
186     Vec grad_vec = Vec::loadu(grad_ptr + d);
187     if (grad_scale_ptr) {
188       grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
189       Vec grad_vec_to_store = grad_vec;
190       grad_vec_to_store.store(grad_ptr + d);
191     }
192     if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
193     if (weight_decay != 0.f){
194       if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
195         grad_vec += param_vec * Vec(scalar_t(weight_decay));
196       } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
197         param_vec = param_vec * Vec(scalar_t(1 - lr * weight_decay));
198       }
199     }
200     Vec exp_avg_vec = Vec::loadu(exp_avg_ptr + d);
201     // exp_avg.lerp_(grad, 1 - beta1)
202     const Vec lerp_weight = Vec(scalar_t(exp_avg_grad_coefficient));
203     auto mask = lerp_weight.abs() < Vec(0.5);
204     auto coeff = Vec::blendv(lerp_weight - Vec(1), lerp_weight, mask);
205     auto base = Vec::blendv(grad_vec, exp_avg_vec, mask);
206     exp_avg_vec = vec::fmadd(coeff, grad_vec - exp_avg_vec, base);
207 
208     Vec exp_avg_sq_vec = Vec::loadu(exp_avg_sq_ptr + d) * Vec(scalar_t(beta2)) +
209         Vec(scalar_t(exp_avg_sq_grad_coefficient)) * grad_vec * grad_vec;
210     exp_avg_vec.store(exp_avg_ptr + d);
211     exp_avg_sq_vec.store(exp_avg_sq_ptr + d);
212 
213     Vec denom_vec;
214     if (amsgrad) {
215       Vec max_exp_avg_sq_vec =
216           maximum(Vec::loadu(max_exp_avg_sq_ptr + d), exp_avg_sq_vec);
217       max_exp_avg_sq_vec.store(max_exp_avg_sq_ptr + d);
218       denom_vec =
219           (max_exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
220     } else {
221       denom_vec =
222           (exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
223     }
224     param_vec = param_vec + Vec(scalar_t(-step_size)) * exp_avg_vec / denom_vec;
225     param_vec.store(param_ptr + d);
226   }
227   for (; d < size; d++) {
228     scalar_t grad_val = grad_ptr[d];
229     if (grad_scale_ptr) {
230       grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
231       grad_ptr[d] = grad_val;
232     }
233     if (maximize) grad_val = -grad_val;
234     if (weight_decay != 0.f){
235       if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
236         grad_val += param_ptr[d] * scalar_t(weight_decay);
237       } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
238         param_ptr[d] = param_ptr[d] * scalar_t(1 - lr * weight_decay);
239       }
240     }
241     // exp_avg.lerp_(grad, 1 - beta1)
242     auto is_lerp_weight_small = std::abs(scalar_t(exp_avg_grad_coefficient)) < scalar_t(0.5);
243     if (is_lerp_weight_small) {
244       exp_avg_ptr[d] = exp_avg_ptr[d] + scalar_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_ptr[d]);
245     } else {
246       exp_avg_ptr[d] = grad_val - (grad_val - exp_avg_ptr[d]) * (scalar_t(1) - scalar_t(exp_avg_grad_coefficient));
247     }
248     exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * scalar_t(beta2);
249     exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
250         scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
251     scalar_t demon_val;
252     if (amsgrad) {
253       max_exp_avg_sq_ptr[d] =
254           std::max(max_exp_avg_sq_ptr[d], exp_avg_sq_ptr[d]);
255       demon_val =
256           std::sqrt(max_exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
257     } else {
258       demon_val = std::sqrt(exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
259     }
260     param_ptr[d] = param_ptr[d] - scalar_t(step_size) * exp_avg_ptr[d] / demon_val;
261   }
262 }
263 
264 
265 template <typename scalar_t, ADAM_MODE adam_mode>
adam_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & exp_avg,const at::Tensor & exp_avg_sq,const at::Tensor & max_exp_avg_sq,const at::Tensor & state_step,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const float * grad_scale_ptr)266 void adam_fused_step_impl(
267     const at::Tensor& param,
268     const at::Tensor& grad,
269     const at::Tensor& exp_avg,
270     const at::Tensor& exp_avg_sq,
271     const at::Tensor& max_exp_avg_sq,
272     const at::Tensor& state_step,
273     const double lr,
274     const double beta1,
275     const double beta2,
276     const double weight_decay,
277     const double eps,
278     const bool amsgrad,
279     const bool maximize,
280     const float* grad_scale_ptr) {
281   using opmath_t = at::opmath_type<scalar_t>;
282   double step = state_step.item<float>();
283   scalar_t* param_data = param.data_ptr<scalar_t>();
284   scalar_t* exp_avg_data = exp_avg.data_ptr<scalar_t>();
285   scalar_t* exp_avg_sq_data = exp_avg_sq.data_ptr<scalar_t>();
286   scalar_t* max_exp_avg_sq_data = amsgrad ? max_exp_avg_sq.data_ptr<scalar_t>() : nullptr;
287   scalar_t* grad_data = grad.data_ptr<scalar_t>();
288 
289   // need to use double here to align with non-fused adam
290   double bias_correction1 = 1 - std::pow(beta1, step);
291   double bias_correction2 = 1 - std::pow(beta2, step);
292   double exp_avg_grad_coefficient = 1 - beta1;
293   double exp_avg_sq_grad_coefficient = 1 - beta2;
294   double bias_correction2_sqrt = std::sqrt(bias_correction2);
295 
296 
297   constexpr size_t cache_line_size = 64;
298   constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
299   size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
300 
301   auto adam_fn = [&](int64_t begin, int64_t end) {
302         // local pointers
303         begin *= cache_line_aligned_task_unit;
304         end = std::min(end * cache_line_aligned_task_unit, param.numel());
305         scalar_t* param_ptr = param_data + begin;
306         scalar_t* exp_avg_ptr = exp_avg_data + begin;
307         scalar_t* exp_avg_sq_ptr = exp_avg_sq_data + begin;
308         scalar_t* grad_ptr = grad_data + begin;
309         scalar_t* max_exp_avg_sq_ptr = amsgrad ? max_exp_avg_sq_data + begin : nullptr;
310 
311         const int64_t size = end - begin;
312         adam_math<scalar_t, opmath_t, adam_mode>(
313           param_ptr,
314           exp_avg_ptr,
315           exp_avg_sq_ptr,
316           grad_ptr,
317           max_exp_avg_sq_ptr,
318           lr,
319           bias_correction1,
320           bias_correction2,
321           exp_avg_grad_coefficient,
322           exp_avg_sq_grad_coefficient,
323           bias_correction2_sqrt,
324           eps,
325           weight_decay,
326           beta2,
327           amsgrad,
328           maximize,
329           grad_scale_ptr,
330           size
331         );
332       };
333   at::parallel_for(
334       0, num_units, 0, adam_fn);
335 }
336 
fused_adam_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & exp_avg,const at::Tensor & exp_avg_sq,const at::Tensor & max_exp_avg_sq,const at::Tensor & state_step,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const float * grad_scale_ptr,const ADAM_MODE adam_mode)337 void fused_adam_kernel(
338     const at::Tensor& param,
339     const at::Tensor& grad,
340     const at::Tensor& exp_avg,
341     const at::Tensor& exp_avg_sq,
342     const at::Tensor& max_exp_avg_sq,
343     const at::Tensor& state_step,
344     const double lr,
345     const double beta1,
346     const double beta2,
347     const double weight_decay,
348     const double eps,
349     const bool amsgrad,
350     const bool maximize,
351     const float* grad_scale_ptr,
352     const ADAM_MODE adam_mode
353   ) {
354   Tensor grad_contiguous = grad.contiguous();
355   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adam_kernel", [&] {
356     if(adam_mode == ADAM_MODE::ORIGINAL){
357       adam_fused_step_impl<scalar_t, ADAM_MODE::ORIGINAL>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
358     } else {
359       adam_fused_step_impl<scalar_t, ADAM_MODE::ADAMW>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
360     }
361 
362   });
363 }
364 
365 }
366 
367 REGISTER_DISPATCH(fused_adam_stub, &fused_adam_kernel);
368 } // namespace at::native
369