1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedSGD.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11
12 namespace{
13
14 template <typename scalar_t, typename opmath_t>
15 typename std::enable_if<
16 std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17 void>::
sgd_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * momentum_buf_ptr,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr,int64_t size)18 type inline sgd_math(
19 scalar_t* param_ptr,
20 scalar_t* grad_ptr,
21 scalar_t* momentum_buf_ptr,
22 const double weight_decay,
23 const double momentum,
24 const double lr,
25 const double dampening,
26 const bool nesterov,
27 const bool maximize,
28 const bool is_first_step,
29 const float* grad_scale_ptr,
30 int64_t size
31 ){
32 using lpVec = at::vec::Vectorized<scalar_t>;
33 using fVec = at::vec::Vectorized<opmath_t>;
34 int64_t d = 0;
35 for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
36 lpVec param_lpvec = lpVec::loadu(param_ptr + d);
37 auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
38 lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
39 auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
40 if (grad_scale_ptr) {
41 grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
42 grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
43 lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
44 grad_vec_to_store.store(grad_ptr + d);
45 }
46 if (maximize){
47 grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
48 grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
49 }
50 if (weight_decay != 0.0){
51 grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
52 grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
53 }
54 if (momentum != 0.0) {
55 fVec momentum_vec1, momentum_vec2;
56 if (is_first_step) {
57 momentum_vec1 = grad_vec1;
58 momentum_vec2 = grad_vec2;
59 } else {
60 momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
61 momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
62 momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
63 momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
64 }
65 vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
66 if (nesterov) {
67 grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
68 grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
69 } else {
70 grad_vec1 = momentum_vec1;
71 grad_vec2 = momentum_vec2;
72 }
73 }
74 }
75 for (; d < size; d++) {
76 opmath_t grad_val = grad_ptr[d];
77 opmath_t param_val = param_ptr[d];
78 if (grad_scale_ptr) {
79 grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
80 grad_ptr[d] = grad_val;
81 }
82 if (maximize) grad_val = -grad_val;
83 if (weight_decay != 0.0){
84 grad_val += param_val * opmath_t(weight_decay);
85 }
86 if (momentum != 0.0) {
87 opmath_t momentum_buf_var = momentum_buf_ptr[d];
88 if (is_first_step) {
89 momentum_buf_var = grad_val;
90 } else {
91 momentum_buf_var = momentum_buf_var * opmath_t(momentum) +
92 grad_val * opmath_t(1 - dampening);
93 }
94 momentum_buf_ptr[d] = momentum_buf_var;
95 if (nesterov) {
96 grad_val += momentum_buf_var * opmath_t(momentum);
97 } else {
98 grad_val = momentum_buf_var;
99 }
100 }
101 param_ptr[d] = param_val - grad_val * opmath_t(lr);
102 }
103 }
104
105
106 template <typename scalar_t, typename opmath_t>
107 typename std::enable_if<
108 std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
109 void>::
sgd_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * momentum_buf_ptr,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr,int64_t size)110 type inline sgd_math(
111 scalar_t* param_ptr,
112 scalar_t* grad_ptr,
113 scalar_t* momentum_buf_ptr,
114 const double weight_decay,
115 const double momentum,
116 const double lr,
117 const double dampening,
118 const bool nesterov,
119 const bool maximize,
120 const bool is_first_step,
121 const float* grad_scale_ptr,
122 int64_t size
123 ){
124 using Vec = at::vec::Vectorized<scalar_t>;
125 int64_t d = 0;
126 for (; d < size - (size % Vec::size()); d += Vec::size()) {
127 Vec param_vec = Vec::loadu(param_ptr + d);
128 Vec grad_vec = Vec::loadu(grad_ptr + d);
129 if (grad_scale_ptr) {
130 grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
131 Vec grad_vec_to_store = grad_vec;
132 grad_vec_to_store.store(grad_ptr + d);
133 }
134 if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
135 if (weight_decay != 0.0){
136 grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
137 }
138 if (momentum != 0.0) {
139 Vec momentum_vec;
140 if (is_first_step) {
141 momentum_vec = grad_vec;
142 } else {
143 momentum_vec =
144 Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
145 momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
146 }
147 momentum_vec.store(momentum_buf_ptr + d);
148 if (nesterov) {
149 grad_vec = vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
150 } else {
151 grad_vec = momentum_vec;
152 }
153 }
154 param_vec += grad_vec * Vec(scalar_t(-lr));
155 param_vec.store(param_ptr + d);
156 }
157 for (; d < size; d++) {
158 scalar_t grad_val = grad_ptr[d];
159 if (grad_scale_ptr) {
160 grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
161 grad_ptr[d] = grad_val;
162 }
163 if (maximize) grad_val = -grad_val;
164 if (weight_decay != 0.0){
165 grad_val += param_ptr[d] * scalar_t(weight_decay);
166 }
167 if (momentum != 0.0) {
168 if (is_first_step) {
169 momentum_buf_ptr[d] = grad_val;
170 } else {
171 momentum_buf_ptr[d] = momentum_buf_ptr[d] * scalar_t(momentum) +
172 grad_val * scalar_t(1 - dampening);
173 }
174 if (nesterov) {
175 grad_val += momentum_buf_ptr[d] * scalar_t(momentum);
176 } else {
177 grad_val = momentum_buf_ptr[d];
178 }
179 }
180 param_ptr[d] -= grad_val * scalar_t(lr);
181 }
182 }
183
184 template <typename scalar_t>
sgd_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & momentum_buffer,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)185 void sgd_fused_step_impl(
186 const at::Tensor& param,
187 const at::Tensor& grad,
188 const at::Tensor& momentum_buffer,
189 const double weight_decay,
190 const double momentum,
191 const double lr,
192 const double dampening,
193 const bool nesterov,
194 const bool maximize,
195 const bool is_first_step,
196 const float* grad_scale_ptr) {
197 using opmath_t = at::opmath_type<scalar_t>;
198 scalar_t* param_data = param.data_ptr<scalar_t>();
199 scalar_t* grad_data = grad.data_ptr<scalar_t>();
200 bool has_momentum_buffer = momentum != 0.0;
201 scalar_t* momentum_buffer_data = has_momentum_buffer ? momentum_buffer.data_ptr<scalar_t>() : nullptr;
202
203 constexpr size_t cache_line_size = 64;
204 constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
205 size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
206
207 auto sgd_fn = [&](int64_t begin, int64_t end) {
208 // local pointers
209 begin *= cache_line_aligned_task_unit;
210 end = std::min(end * cache_line_aligned_task_unit, param.numel());
211 scalar_t* param_ptr = param_data + begin;
212 scalar_t* grad_ptr = grad_data + begin;
213 scalar_t* momentum_buffer_ptr = has_momentum_buffer ? momentum_buffer_data + begin : nullptr;
214
215 const int64_t size = end - begin;
216 sgd_math<scalar_t, opmath_t>(
217 param_ptr,
218 grad_ptr,
219 momentum_buffer_ptr,
220 weight_decay,
221 momentum,
222 lr,
223 dampening,
224 nesterov,
225 maximize,
226 is_first_step,
227 grad_scale_ptr,
228 size
229 );
230 };
231 at::parallel_for(
232 0, num_units, 0, sgd_fn);
233 }
234
fused_sgd_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & momentum_buffer,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const float * grad_scale_ptr)235 void fused_sgd_kernel(
236 const at::Tensor& param,
237 const at::Tensor& grad,
238 const at::Tensor& momentum_buffer,
239 const double weight_decay,
240 const double momentum,
241 const double lr,
242 const double dampening,
243 const bool nesterov,
244 const bool maximize,
245 const bool is_first_step,
246 const float* grad_scale_ptr
247 ) {
248 Tensor grad_contiguous = grad.contiguous();
249 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_sgd_kernel", [&] {
250 sgd_fused_step_impl<scalar_t>(
251 param,
252 grad,
253 momentum_buffer,
254 weight_decay,
255 momentum,
256 lr,
257 dampening,
258 nesterov,
259 maximize,
260 is_first_step,
261 grad_scale_ptr);
262 });
263 }
264
265 }
266
267 REGISTER_DISPATCH(fused_sgd_stub, &fused_sgd_kernel);
268 } // namespace at::native
269