xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/FusedSGDKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedSGD.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11 
12 namespace{
13 
14 template <typename scalar_t, typename opmath_t>
15 typename std::enable_if<
16     std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17     void>::
sgd_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * momentum_buf_ptr,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr,int64_t size)18     type inline sgd_math(
19   scalar_t* param_ptr,
20   scalar_t* grad_ptr,
21   scalar_t* momentum_buf_ptr,
22   const double weight_decay,
23   const double momentum,
24   const double lr,
25   const double dampening,
26   const bool nesterov,
27   const bool maximize,
28   const bool is_first_step,
29   const float* grad_scale_ptr,
30   int64_t size
31 ){
32   using lpVec = at::vec::Vectorized<scalar_t>;
33   using fVec = at::vec::Vectorized<opmath_t>;
34   int64_t d = 0;
35   for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
36     lpVec param_lpvec = lpVec::loadu(param_ptr + d);
37     auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
38     lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
39     auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
40     if (grad_scale_ptr) {
41       grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
42       grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
43       lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
44       grad_vec_to_store.store(grad_ptr + d);
45     }
46     if (maximize){
47       grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
48       grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
49     }
50     if (weight_decay != 0.0){
51       grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
52       grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
53     }
54     if (momentum != 0.0) {
55       fVec momentum_vec1, momentum_vec2;
56       if (is_first_step) {
57         momentum_vec1 = grad_vec1;
58         momentum_vec2 = grad_vec2;
59       } else {
60         momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
61         momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
62         momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
63         momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
64       }
65       vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
66       if (nesterov) {
67         grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
68         grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
69       } else {
70         grad_vec1 = momentum_vec1;
71         grad_vec2 = momentum_vec2;
72       }
73     }
74   }
75   for (; d < size; d++) {
76     opmath_t grad_val = grad_ptr[d];
77     opmath_t param_val = param_ptr[d];
78     if (grad_scale_ptr) {
79       grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
80       grad_ptr[d] = grad_val;
81     }
82     if (maximize) grad_val = -grad_val;
83     if (weight_decay != 0.0){
84       grad_val += param_val * opmath_t(weight_decay);
85     }
86     if (momentum != 0.0) {
87       opmath_t momentum_buf_var = momentum_buf_ptr[d];
88       if (is_first_step) {
89         momentum_buf_var = grad_val;
90       } else {
91         momentum_buf_var = momentum_buf_var * opmath_t(momentum) +
92             grad_val * opmath_t(1 - dampening);
93       }
94       momentum_buf_ptr[d] = momentum_buf_var;
95       if (nesterov) {
96         grad_val += momentum_buf_var * opmath_t(momentum);
97       } else {
98         grad_val = momentum_buf_var;
99       }
100     }
101     param_ptr[d] = param_val - grad_val * opmath_t(lr);
102   }
103 }
104 
105 
106 template <typename scalar_t, typename opmath_t>
107 typename std::enable_if<
108     std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
109     void>::
sgd_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * momentum_buf_ptr,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr,int64_t size)110     type inline sgd_math(
111   scalar_t* param_ptr,
112   scalar_t* grad_ptr,
113   scalar_t* momentum_buf_ptr,
114   const double weight_decay,
115   const double momentum,
116   const double lr,
117   const double dampening,
118   const bool nesterov,
119   const bool maximize,
120   const bool is_first_step,
121   const float* grad_scale_ptr,
122   int64_t size
123 ){
124   using Vec = at::vec::Vectorized<scalar_t>;
125   int64_t d = 0;
126   for (; d < size - (size % Vec::size()); d += Vec::size()) {
127     Vec param_vec = Vec::loadu(param_ptr + d);
128     Vec grad_vec = Vec::loadu(grad_ptr + d);
129     if (grad_scale_ptr) {
130       grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
131       Vec grad_vec_to_store = grad_vec;
132       grad_vec_to_store.store(grad_ptr + d);
133     }
134     if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
135     if (weight_decay != 0.0){
136       grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
137     }
138     if (momentum != 0.0) {
139       Vec momentum_vec;
140       if (is_first_step) {
141         momentum_vec = grad_vec;
142       } else {
143         momentum_vec =
144             Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
145         momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
146       }
147       momentum_vec.store(momentum_buf_ptr + d);
148       if (nesterov) {
149         grad_vec =  vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
150       } else {
151         grad_vec = momentum_vec;
152       }
153     }
154     param_vec += grad_vec * Vec(scalar_t(-lr));
155     param_vec.store(param_ptr + d);
156   }
157   for (; d < size; d++) {
158     scalar_t grad_val = grad_ptr[d];
159     if (grad_scale_ptr) {
160       grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
161       grad_ptr[d] = grad_val;
162     }
163     if (maximize) grad_val = -grad_val;
164     if (weight_decay != 0.0){
165       grad_val += param_ptr[d] * scalar_t(weight_decay);
166     }
167     if (momentum != 0.0) {
168       if (is_first_step) {
169         momentum_buf_ptr[d] = grad_val;
170       } else {
171         momentum_buf_ptr[d] = momentum_buf_ptr[d] * scalar_t(momentum) +
172             grad_val * scalar_t(1 - dampening);
173       }
174       if (nesterov) {
175         grad_val += momentum_buf_ptr[d] * scalar_t(momentum);
176       } else {
177         grad_val = momentum_buf_ptr[d];
178       }
179     }
180     param_ptr[d] -= grad_val * scalar_t(lr);
181   }
182 }
183 
184 template <typename scalar_t>
sgd_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & momentum_buffer,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)185 void sgd_fused_step_impl(
186     const at::Tensor& param,
187     const at::Tensor& grad,
188     const at::Tensor& momentum_buffer,
189     const double weight_decay,
190     const double momentum,
191     const double lr,
192     const double dampening,
193     const bool nesterov,
194     const bool maximize,
195     const bool is_first_step,
196     const float* grad_scale_ptr) {
197   using opmath_t = at::opmath_type<scalar_t>;
198   scalar_t* param_data = param.data_ptr<scalar_t>();
199   scalar_t* grad_data = grad.data_ptr<scalar_t>();
200   bool has_momentum_buffer = momentum != 0.0;
201   scalar_t* momentum_buffer_data = has_momentum_buffer ? momentum_buffer.data_ptr<scalar_t>() : nullptr;
202 
203   constexpr size_t cache_line_size = 64;
204   constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
205   size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
206 
207   auto sgd_fn = [&](int64_t begin, int64_t end) {
208         // local pointers
209         begin *= cache_line_aligned_task_unit;
210         end = std::min(end * cache_line_aligned_task_unit, param.numel());
211         scalar_t* param_ptr = param_data + begin;
212         scalar_t* grad_ptr = grad_data + begin;
213         scalar_t* momentum_buffer_ptr = has_momentum_buffer ? momentum_buffer_data + begin : nullptr;
214 
215         const int64_t size = end - begin;
216         sgd_math<scalar_t, opmath_t>(
217           param_ptr,
218           grad_ptr,
219           momentum_buffer_ptr,
220           weight_decay,
221           momentum,
222           lr,
223           dampening,
224           nesterov,
225           maximize,
226           is_first_step,
227           grad_scale_ptr,
228           size
229         );
230       };
231   at::parallel_for(
232       0, num_units, 0, sgd_fn);
233 }
234 
fused_sgd_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & momentum_buffer,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)235 void fused_sgd_kernel(
236     const at::Tensor& param,
237     const at::Tensor& grad,
238     const at::Tensor& momentum_buffer,
239     const double weight_decay,
240     const double momentum,
241     const double lr,
242     const double dampening,
243     const bool nesterov,
244     const bool maximize,
245     const bool is_first_step,
246     const float* grad_scale_ptr
247   ) {
248   Tensor grad_contiguous = grad.contiguous();
249   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_sgd_kernel", [&] {
250     sgd_fused_step_impl<scalar_t>(
251       param,
252       grad,
253       momentum_buffer,
254       weight_decay,
255       momentum,
256       lr,
257       dampening,
258       nesterov,
259       maximize,
260       is_first_step,
261       grad_scale_ptr);
262   });
263 }
264 
265 }
266 
267 REGISTER_DISPATCH(fused_sgd_stub, &fused_sgd_kernel);
268 } // namespace at::native
269