xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/FusedAdagradKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedAdagrad.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11 
12 namespace{
13 
14 template <typename scalar_t, typename opmath_t>
15 typename std::enable_if<
16     std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17     void>::
adagrad_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * state_sum_ptr,const double clr,const double eps,const double weight_decay,const bool maximize,const float * grad_scale_ptr,int64_t size)18     type inline adagrad_math(
19   scalar_t* param_ptr,
20   scalar_t* grad_ptr,
21   scalar_t* state_sum_ptr,
22   const double clr,
23   const double eps,
24   const double weight_decay,
25   const bool maximize,
26   const float* grad_scale_ptr,
27   int64_t size
28 ){
29   using lpVec = at::vec::Vectorized<scalar_t>;
30   using fVec = at::vec::Vectorized<opmath_t>;
31   int64_t d = 0;
32   for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
33     lpVec param_lpvec = lpVec::loadu(param_ptr + d);
34     auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
35     lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
36     auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
37     if (grad_scale_ptr) {
38       grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
39       grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
40       lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
41       grad_vec_to_store.store(grad_ptr + d);
42     }
43     if (maximize){
44       grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
45       grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
46     }
47     if (weight_decay != 0.0){
48       grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
49       grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
50     }
51     auto [state_sum_vec1, state_sum_vec2] = vec::convert_to_float<scalar_t>(lpVec::loadu(state_sum_ptr + d));
52     state_sum_vec1 += grad_vec1 * grad_vec1;
53     state_sum_vec2 += grad_vec2 * grad_vec2;
54     vec::convert_from_float<scalar_t>(state_sum_vec1, state_sum_vec2).store(state_sum_ptr + d);
55 
56     fVec std_vec1 = state_sum_vec1.sqrt() + fVec(scalar_t(eps));
57     fVec std_vec2 = state_sum_vec2.sqrt() + fVec(scalar_t(eps));
58     param_vec1 = param_vec1 - fVec(scalar_t(clr)) * grad_vec1 / std_vec1;
59     param_vec2 = param_vec2 - fVec(scalar_t(clr)) * grad_vec2 / std_vec2;
60     vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
61   }
62   for (; d < size; d++) {
63     opmath_t grad_val = grad_ptr[d];
64     opmath_t param_val = param_ptr[d];
65     if (grad_scale_ptr) {
66       grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
67       grad_ptr[d] = grad_val;
68     }
69     if (maximize) grad_val = -grad_val;
70     if (weight_decay != 0.0){
71       grad_val += param_val * opmath_t(weight_decay);
72     }
73     opmath_t state_sum_val = state_sum_ptr[d];
74     state_sum_val += grad_val * grad_val;
75     state_sum_ptr[d] = state_sum_val;
76     opmath_t std_val = std::sqrt(state_sum_val) + opmath_t(eps);
77     param_val -= opmath_t(clr) * grad_val / std_val;
78     param_ptr[d] = param_val;
79   }
80 }
81 
82 
83 template <typename scalar_t, typename opmath_t>
84 typename std::enable_if<
85     std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
86     void>::
adagrad_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * state_sum_ptr,const double clr,const double eps,const double weight_decay,const bool maximize,const float * grad_scale_ptr,int64_t size)87     type inline adagrad_math(
88   scalar_t* param_ptr,
89   scalar_t* grad_ptr,
90   scalar_t* state_sum_ptr,
91   const double clr,
92   const double eps,
93   const double weight_decay,
94   const bool maximize,
95   const float* grad_scale_ptr,
96   int64_t size
97 ){
98   using Vec = at::vec::Vectorized<scalar_t>;
99   int64_t d = 0;
100   for (; d < size - (size % Vec::size()); d += Vec::size()) {
101     Vec param_vec = Vec::loadu(param_ptr + d);
102     Vec grad_vec = Vec::loadu(grad_ptr + d);
103     if (grad_scale_ptr) {
104       grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
105       Vec grad_vec_to_store = grad_vec;
106       grad_vec_to_store.store(grad_ptr + d);
107     }
108     if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
109     if (weight_decay != 0.0){
110       grad_vec += param_vec * Vec(scalar_t(weight_decay));
111     }
112 
113     Vec sum_vec = Vec::loadu(state_sum_ptr + d) + grad_vec * grad_vec;
114     sum_vec.store(state_sum_ptr + d);
115 
116     Vec std_vec = sum_vec.sqrt() + Vec(scalar_t(eps));
117     param_vec = param_vec - Vec(scalar_t(clr)) * grad_vec / std_vec;
118     param_vec.store(param_ptr + d);
119   }
120   scalar_t grad_val_to_store;
121   for (; d < size; d++) {
122     scalar_t grad_val = grad_ptr[d];
123     if (grad_scale_ptr) {
124       grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
125       grad_val_to_store = grad_val;
126       grad_ptr[d] = grad_val_to_store;
127     }
128     if (maximize) grad_val = -grad_val;
129     if (weight_decay != 0.0){
130       grad_val += param_ptr[d] * scalar_t(weight_decay);
131     }
132     state_sum_ptr[d] += grad_val * grad_val;
133 
134     scalar_t std_val = std::sqrt(state_sum_ptr[d]) + scalar_t(eps);
135     param_ptr[d] -= scalar_t(clr) * grad_val / std_val;
136   }
137 }
138 
139 template <typename scalar_t>
adagrad_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & state_sum,const at::Tensor & state_step,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const float * grad_scale_ptr)140 void adagrad_fused_step_impl(
141     const at::Tensor& param,
142     const at::Tensor& grad,
143     const at::Tensor& state_sum,
144     const at::Tensor& state_step,
145     const double lr,
146     const double lr_decay,
147     const double weight_decay,
148     const double eps,
149     const bool maximize,
150     const float* grad_scale_ptr) {
151   using opmath_t = at::opmath_type<scalar_t>;
152   scalar_t* param_data = param.data_ptr<scalar_t>();
153   scalar_t* grad_data = grad.data_ptr<scalar_t>();
154   scalar_t* state_sum_data = state_sum.data_ptr<scalar_t>();
155   double step = state_step.item<float>();
156   double clr = lr / (1.0 + (step - 1.0) * lr_decay);
157 
158   constexpr size_t cache_line_size = 64;
159   constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
160   size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
161 
162   auto adagrad_fn = [&](int64_t begin, int64_t end) {
163         // local pointers
164         begin *= cache_line_aligned_task_unit;
165         end = std::min(end * cache_line_aligned_task_unit, param.numel());
166         scalar_t* param_ptr = param_data + begin;
167         scalar_t* grad_ptr = grad_data + begin;
168         scalar_t* state_sum_ptr = state_sum_data + begin;
169 
170         const int64_t size = end - begin;
171         adagrad_math<scalar_t, opmath_t>(
172           param_ptr,
173           grad_ptr,
174           state_sum_ptr,
175           clr,
176           eps,
177           weight_decay,
178           maximize,
179           grad_scale_ptr,
180           size
181         );
182       };
183   at::parallel_for(
184       0, num_units, 0, adagrad_fn);
185 }
186 
fused_adagrad_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & state_sum,const at::Tensor & state_step,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const float * grad_scale_ptr)187 void fused_adagrad_kernel(
188     const at::Tensor& param,
189     const at::Tensor& grad,
190     const at::Tensor& state_sum,
191     const at::Tensor& state_step,
192     const double lr,
193     const double lr_decay,
194     const double weight_decay,
195     const double eps,
196     const bool maximize,
197     const float* grad_scale_ptr
198   ) {
199   Tensor grad_contiguous = grad.contiguous();
200   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adagrad_kernel", [&] {
201     adagrad_fused_step_impl<scalar_t>(
202       param,
203       grad,
204       state_sum,
205       state_step,
206       lr,
207       lr_decay,
208       weight_decay,
209       eps,
210       maximize,
211       grad_scale_ptr);
212   });
213 }
214 
215 }
216 
217 REGISTER_DISPATCH(fused_adagrad_stub, &fused_adagrad_kernel);
218 } // namespace at::native
219