1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedAdam.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11
12 namespace{
13
14 template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
15 typename std::enable_if<
16 std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17 void>::
adam_math(scalar_t * param_ptr,scalar_t * exp_avg_ptr,scalar_t * exp_avg_sq_ptr,scalar_t * grad_ptr,scalar_t * max_exp_avg_sq_ptr,double lr,double bias_correction1,double bias_correction2,double exp_avg_grad_coefficient,double exp_avg_sq_grad_coefficient,double bias_correction2_sqrt,double eps,double weight_decay,double beta2,bool amsgrad,bool maximize,const float * grad_scale_ptr,int64_t size)18 type inline adam_math(
19 scalar_t* param_ptr,
20 scalar_t* exp_avg_ptr,
21 scalar_t* exp_avg_sq_ptr,
22 scalar_t* grad_ptr,
23 scalar_t* max_exp_avg_sq_ptr,
24 double lr,
25 double bias_correction1,
26 double bias_correction2,
27 double exp_avg_grad_coefficient,
28 double exp_avg_sq_grad_coefficient,
29 double bias_correction2_sqrt,
30 double eps,
31 double weight_decay,
32 double beta2,
33 bool amsgrad,
34 bool maximize,
35 const float* grad_scale_ptr,
36 int64_t size
37 ){
38 double step_size = lr / bias_correction1;
39 using lpVec = at::vec::Vectorized<scalar_t>;
40 using fVec = at::vec::Vectorized<opmath_t>;
41 int64_t d = 0;
42 for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
43 lpVec param_lpvec = lpVec::loadu(param_ptr + d);
44 auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
45 lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
46 auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
47 if (grad_scale_ptr) {
48 grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
49 grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
50 lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
51 grad_vec_to_store.store(grad_ptr + d);
52 }
53 if (maximize){
54 grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
55 grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
56 }
57 if (weight_decay != 0.f){
58 if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
59 grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay));
60 grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay));
61 } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
62 param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay));
63 param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay));
64 }
65 }
66
67 lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d);
68 auto [exp_avg_vec1, exp_avg_vec2] = vec::convert_to_float<scalar_t>(exp_avg_lpvec);
69
70 // exp_avg.lerp_(grad, 1 - beta1)
71 const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient));
72 auto mask = lerp_weight.abs() < fVec(0.5);
73 auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask);
74
75 auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask);
76 exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1);
77
78 auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask);
79 exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2);
80
81 lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d);
82 auto [exp_avg_sq_vec1, exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(exp_avg_sq_lpvec);
83 exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) +
84 fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1;
85 exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) +
86 fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2;
87
88 vec::convert_from_float<scalar_t>(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d);
89 vec::convert_from_float<scalar_t>(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d);
90
91 fVec denom_vec1, denom_vec2;
92 if (amsgrad) {
93 lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d);
94 auto [max_exp_avg_sq_vec1, max_exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(max_exp_avg_sq_lpvec);
95 max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1);
96 max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2);
97 vec::convert_from_float<scalar_t>(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d);
98 denom_vec1 =
99 (max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
100 denom_vec2 =
101 (max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
102 } else {
103 denom_vec1 =
104 (exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
105 denom_vec2 =
106 (exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
107 }
108 param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1;
109 param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2;
110 vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
111 }
112 for (; d < size; d++) {
113 opmath_t grad_val = grad_ptr[d];
114 opmath_t param_val = param_ptr[d];
115 if (grad_scale_ptr) {
116 grad_val = grad_ptr[d] / float(*grad_scale_ptr);
117 grad_ptr[d] = grad_val;
118 }
119 if (maximize) grad_val = -grad_val;
120 if (weight_decay != 0.f){
121 if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
122 grad_val += param_val * opmath_t(weight_decay);
123 } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
124 param_val = param_val * opmath_t(1 - lr * weight_decay);
125 }
126 }
127 // exp_avg.lerp_(grad, 1 - beta1)
128 opmath_t exp_avg_var = exp_avg_ptr[d];
129 auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5);
130 if (is_lerp_weight_small) {
131 exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var);
132 } else {
133 exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient));
134 }
135 exp_avg_ptr[d] = scalar_t(exp_avg_var);
136 opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d];
137 exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2);
138 exp_avg_sq_var = exp_avg_sq_var +
139 opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
140 exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var);
141 opmath_t demon_val;
142 if (amsgrad) {
143 opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d];
144 max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var);
145 max_exp_avg_sq_ptr[d] =
146 scalar_t(max_exp_avg_sq_var);
147 demon_val =
148 std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
149 } else {
150 demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
151 }
152 param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val;
153 }
154 }
155
156
157 template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
158 typename std::enable_if<
159 std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
160 void>::
adam_math(scalar_t * param_ptr,scalar_t * exp_avg_ptr,scalar_t * exp_avg_sq_ptr,scalar_t * grad_ptr,scalar_t * max_exp_avg_sq_ptr,double lr,double bias_correction1,double bias_correction2,double exp_avg_grad_coefficient,double exp_avg_sq_grad_coefficient,double bias_correction2_sqrt,double eps,double weight_decay,double beta2,bool amsgrad,bool maximize,const float * grad_scale_ptr,int64_t size)161 type inline adam_math(
162 scalar_t* param_ptr,
163 scalar_t* exp_avg_ptr,
164 scalar_t* exp_avg_sq_ptr,
165 scalar_t* grad_ptr,
166 scalar_t* max_exp_avg_sq_ptr,
167 double lr,
168 double bias_correction1,
169 double bias_correction2,
170 double exp_avg_grad_coefficient,
171 double exp_avg_sq_grad_coefficient,
172 double bias_correction2_sqrt,
173 double eps,
174 double weight_decay,
175 double beta2,
176 bool amsgrad,
177 bool maximize,
178 const float* grad_scale_ptr,
179 int64_t size
180 ){
181 double step_size = lr / bias_correction1;
182 using Vec = at::vec::Vectorized<scalar_t>;
183 int64_t d = 0;
184 for (; d < size - (size % Vec::size()); d += Vec::size()) {
185 Vec param_vec = Vec::loadu(param_ptr + d);
186 Vec grad_vec = Vec::loadu(grad_ptr + d);
187 if (grad_scale_ptr) {
188 grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
189 Vec grad_vec_to_store = grad_vec;
190 grad_vec_to_store.store(grad_ptr + d);
191 }
192 if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
193 if (weight_decay != 0.f){
194 if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
195 grad_vec += param_vec * Vec(scalar_t(weight_decay));
196 } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
197 param_vec = param_vec * Vec(scalar_t(1 - lr * weight_decay));
198 }
199 }
200 Vec exp_avg_vec = Vec::loadu(exp_avg_ptr + d);
201 // exp_avg.lerp_(grad, 1 - beta1)
202 const Vec lerp_weight = Vec(scalar_t(exp_avg_grad_coefficient));
203 auto mask = lerp_weight.abs() < Vec(0.5);
204 auto coeff = Vec::blendv(lerp_weight - Vec(1), lerp_weight, mask);
205 auto base = Vec::blendv(grad_vec, exp_avg_vec, mask);
206 exp_avg_vec = vec::fmadd(coeff, grad_vec - exp_avg_vec, base);
207
208 Vec exp_avg_sq_vec = Vec::loadu(exp_avg_sq_ptr + d) * Vec(scalar_t(beta2)) +
209 Vec(scalar_t(exp_avg_sq_grad_coefficient)) * grad_vec * grad_vec;
210 exp_avg_vec.store(exp_avg_ptr + d);
211 exp_avg_sq_vec.store(exp_avg_sq_ptr + d);
212
213 Vec denom_vec;
214 if (amsgrad) {
215 Vec max_exp_avg_sq_vec =
216 maximum(Vec::loadu(max_exp_avg_sq_ptr + d), exp_avg_sq_vec);
217 max_exp_avg_sq_vec.store(max_exp_avg_sq_ptr + d);
218 denom_vec =
219 (max_exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
220 } else {
221 denom_vec =
222 (exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
223 }
224 param_vec = param_vec + Vec(scalar_t(-step_size)) * exp_avg_vec / denom_vec;
225 param_vec.store(param_ptr + d);
226 }
227 for (; d < size; d++) {
228 scalar_t grad_val = grad_ptr[d];
229 if (grad_scale_ptr) {
230 grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
231 grad_ptr[d] = grad_val;
232 }
233 if (maximize) grad_val = -grad_val;
234 if (weight_decay != 0.f){
235 if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
236 grad_val += param_ptr[d] * scalar_t(weight_decay);
237 } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
238 param_ptr[d] = param_ptr[d] * scalar_t(1 - lr * weight_decay);
239 }
240 }
241 // exp_avg.lerp_(grad, 1 - beta1)
242 auto is_lerp_weight_small = std::abs(scalar_t(exp_avg_grad_coefficient)) < scalar_t(0.5);
243 if (is_lerp_weight_small) {
244 exp_avg_ptr[d] = exp_avg_ptr[d] + scalar_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_ptr[d]);
245 } else {
246 exp_avg_ptr[d] = grad_val - (grad_val - exp_avg_ptr[d]) * (scalar_t(1) - scalar_t(exp_avg_grad_coefficient));
247 }
248 exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * scalar_t(beta2);
249 exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
250 scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
251 scalar_t demon_val;
252 if (amsgrad) {
253 max_exp_avg_sq_ptr[d] =
254 std::max(max_exp_avg_sq_ptr[d], exp_avg_sq_ptr[d]);
255 demon_val =
256 std::sqrt(max_exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
257 } else {
258 demon_val = std::sqrt(exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
259 }
260 param_ptr[d] = param_ptr[d] - scalar_t(step_size) * exp_avg_ptr[d] / demon_val;
261 }
262 }
263
264
265 template <typename scalar_t, ADAM_MODE adam_mode>
adam_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & exp_avg,const at::Tensor & exp_avg_sq,const at::Tensor & max_exp_avg_sq,const at::Tensor & state_step,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const float * grad_scale_ptr)266 void adam_fused_step_impl(
267 const at::Tensor& param,
268 const at::Tensor& grad,
269 const at::Tensor& exp_avg,
270 const at::Tensor& exp_avg_sq,
271 const at::Tensor& max_exp_avg_sq,
272 const at::Tensor& state_step,
273 const double lr,
274 const double beta1,
275 const double beta2,
276 const double weight_decay,
277 const double eps,
278 const bool amsgrad,
279 const bool maximize,
280 const float* grad_scale_ptr) {
281 using opmath_t = at::opmath_type<scalar_t>;
282 double step = state_step.item<float>();
283 scalar_t* param_data = param.data_ptr<scalar_t>();
284 scalar_t* exp_avg_data = exp_avg.data_ptr<scalar_t>();
285 scalar_t* exp_avg_sq_data = exp_avg_sq.data_ptr<scalar_t>();
286 scalar_t* max_exp_avg_sq_data = amsgrad ? max_exp_avg_sq.data_ptr<scalar_t>() : nullptr;
287 scalar_t* grad_data = grad.data_ptr<scalar_t>();
288
289 // need to use double here to align with non-fused adam
290 double bias_correction1 = 1 - std::pow(beta1, step);
291 double bias_correction2 = 1 - std::pow(beta2, step);
292 double exp_avg_grad_coefficient = 1 - beta1;
293 double exp_avg_sq_grad_coefficient = 1 - beta2;
294 double bias_correction2_sqrt = std::sqrt(bias_correction2);
295
296
297 constexpr size_t cache_line_size = 64;
298 constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
299 size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
300
301 auto adam_fn = [&](int64_t begin, int64_t end) {
302 // local pointers
303 begin *= cache_line_aligned_task_unit;
304 end = std::min(end * cache_line_aligned_task_unit, param.numel());
305 scalar_t* param_ptr = param_data + begin;
306 scalar_t* exp_avg_ptr = exp_avg_data + begin;
307 scalar_t* exp_avg_sq_ptr = exp_avg_sq_data + begin;
308 scalar_t* grad_ptr = grad_data + begin;
309 scalar_t* max_exp_avg_sq_ptr = amsgrad ? max_exp_avg_sq_data + begin : nullptr;
310
311 const int64_t size = end - begin;
312 adam_math<scalar_t, opmath_t, adam_mode>(
313 param_ptr,
314 exp_avg_ptr,
315 exp_avg_sq_ptr,
316 grad_ptr,
317 max_exp_avg_sq_ptr,
318 lr,
319 bias_correction1,
320 bias_correction2,
321 exp_avg_grad_coefficient,
322 exp_avg_sq_grad_coefficient,
323 bias_correction2_sqrt,
324 eps,
325 weight_decay,
326 beta2,
327 amsgrad,
328 maximize,
329 grad_scale_ptr,
330 size
331 );
332 };
333 at::parallel_for(
334 0, num_units, 0, adam_fn);
335 }
336
fused_adam_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & exp_avg,const at::Tensor & exp_avg_sq,const at::Tensor & max_exp_avg_sq,const at::Tensor & state_step,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const float * grad_scale_ptr,const ADAM_MODE adam_mode)337 void fused_adam_kernel(
338 const at::Tensor& param,
339 const at::Tensor& grad,
340 const at::Tensor& exp_avg,
341 const at::Tensor& exp_avg_sq,
342 const at::Tensor& max_exp_avg_sq,
343 const at::Tensor& state_step,
344 const double lr,
345 const double beta1,
346 const double beta2,
347 const double weight_decay,
348 const double eps,
349 const bool amsgrad,
350 const bool maximize,
351 const float* grad_scale_ptr,
352 const ADAM_MODE adam_mode
353 ) {
354 Tensor grad_contiguous = grad.contiguous();
355 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adam_kernel", [&] {
356 if(adam_mode == ADAM_MODE::ORIGINAL){
357 adam_fused_step_impl<scalar_t, ADAM_MODE::ORIGINAL>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
358 } else {
359 adam_fused_step_impl<scalar_t, ADAM_MODE::ADAMW>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
360 }
361
362 });
363 }
364
365 }
366
367 REGISTER_DISPATCH(fused_adam_stub, &fused_adam_kernel);
368 } // namespace at::native
369