1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/FusedAdagrad.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/cpu/vec/functional.h>
10 namespace at::native {
11
12 namespace{
13
14 template <typename scalar_t, typename opmath_t>
15 typename std::enable_if<
16 std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17 void>::
adagrad_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * state_sum_ptr,const double clr,const double eps,const double weight_decay,const bool maximize,const float * grad_scale_ptr,int64_t size)18 type inline adagrad_math(
19 scalar_t* param_ptr,
20 scalar_t* grad_ptr,
21 scalar_t* state_sum_ptr,
22 const double clr,
23 const double eps,
24 const double weight_decay,
25 const bool maximize,
26 const float* grad_scale_ptr,
27 int64_t size
28 ){
29 using lpVec = at::vec::Vectorized<scalar_t>;
30 using fVec = at::vec::Vectorized<opmath_t>;
31 int64_t d = 0;
32 for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
33 lpVec param_lpvec = lpVec::loadu(param_ptr + d);
34 auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
35 lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
36 auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
37 if (grad_scale_ptr) {
38 grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
39 grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
40 lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
41 grad_vec_to_store.store(grad_ptr + d);
42 }
43 if (maximize){
44 grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
45 grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
46 }
47 if (weight_decay != 0.0){
48 grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
49 grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
50 }
51 auto [state_sum_vec1, state_sum_vec2] = vec::convert_to_float<scalar_t>(lpVec::loadu(state_sum_ptr + d));
52 state_sum_vec1 += grad_vec1 * grad_vec1;
53 state_sum_vec2 += grad_vec2 * grad_vec2;
54 vec::convert_from_float<scalar_t>(state_sum_vec1, state_sum_vec2).store(state_sum_ptr + d);
55
56 fVec std_vec1 = state_sum_vec1.sqrt() + fVec(scalar_t(eps));
57 fVec std_vec2 = state_sum_vec2.sqrt() + fVec(scalar_t(eps));
58 param_vec1 = param_vec1 - fVec(scalar_t(clr)) * grad_vec1 / std_vec1;
59 param_vec2 = param_vec2 - fVec(scalar_t(clr)) * grad_vec2 / std_vec2;
60 vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
61 }
62 for (; d < size; d++) {
63 opmath_t grad_val = grad_ptr[d];
64 opmath_t param_val = param_ptr[d];
65 if (grad_scale_ptr) {
66 grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
67 grad_ptr[d] = grad_val;
68 }
69 if (maximize) grad_val = -grad_val;
70 if (weight_decay != 0.0){
71 grad_val += param_val * opmath_t(weight_decay);
72 }
73 opmath_t state_sum_val = state_sum_ptr[d];
74 state_sum_val += grad_val * grad_val;
75 state_sum_ptr[d] = state_sum_val;
76 opmath_t std_val = std::sqrt(state_sum_val) + opmath_t(eps);
77 param_val -= opmath_t(clr) * grad_val / std_val;
78 param_ptr[d] = param_val;
79 }
80 }
81
82
83 template <typename scalar_t, typename opmath_t>
84 typename std::enable_if<
85 std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
86 void>::
adagrad_math(scalar_t * param_ptr,scalar_t * grad_ptr,scalar_t * state_sum_ptr,const double clr,const double eps,const double weight_decay,const bool maximize,const float * grad_scale_ptr,int64_t size)87 type inline adagrad_math(
88 scalar_t* param_ptr,
89 scalar_t* grad_ptr,
90 scalar_t* state_sum_ptr,
91 const double clr,
92 const double eps,
93 const double weight_decay,
94 const bool maximize,
95 const float* grad_scale_ptr,
96 int64_t size
97 ){
98 using Vec = at::vec::Vectorized<scalar_t>;
99 int64_t d = 0;
100 for (; d < size - (size % Vec::size()); d += Vec::size()) {
101 Vec param_vec = Vec::loadu(param_ptr + d);
102 Vec grad_vec = Vec::loadu(grad_ptr + d);
103 if (grad_scale_ptr) {
104 grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
105 Vec grad_vec_to_store = grad_vec;
106 grad_vec_to_store.store(grad_ptr + d);
107 }
108 if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
109 if (weight_decay != 0.0){
110 grad_vec += param_vec * Vec(scalar_t(weight_decay));
111 }
112
113 Vec sum_vec = Vec::loadu(state_sum_ptr + d) + grad_vec * grad_vec;
114 sum_vec.store(state_sum_ptr + d);
115
116 Vec std_vec = sum_vec.sqrt() + Vec(scalar_t(eps));
117 param_vec = param_vec - Vec(scalar_t(clr)) * grad_vec / std_vec;
118 param_vec.store(param_ptr + d);
119 }
120 scalar_t grad_val_to_store;
121 for (; d < size; d++) {
122 scalar_t grad_val = grad_ptr[d];
123 if (grad_scale_ptr) {
124 grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
125 grad_val_to_store = grad_val;
126 grad_ptr[d] = grad_val_to_store;
127 }
128 if (maximize) grad_val = -grad_val;
129 if (weight_decay != 0.0){
130 grad_val += param_ptr[d] * scalar_t(weight_decay);
131 }
132 state_sum_ptr[d] += grad_val * grad_val;
133
134 scalar_t std_val = std::sqrt(state_sum_ptr[d]) + scalar_t(eps);
135 param_ptr[d] -= scalar_t(clr) * grad_val / std_val;
136 }
137 }
138
139 template <typename scalar_t>
adagrad_fused_step_impl(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & state_sum,const at::Tensor & state_step,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const float * grad_scale_ptr)140 void adagrad_fused_step_impl(
141 const at::Tensor& param,
142 const at::Tensor& grad,
143 const at::Tensor& state_sum,
144 const at::Tensor& state_step,
145 const double lr,
146 const double lr_decay,
147 const double weight_decay,
148 const double eps,
149 const bool maximize,
150 const float* grad_scale_ptr) {
151 using opmath_t = at::opmath_type<scalar_t>;
152 scalar_t* param_data = param.data_ptr<scalar_t>();
153 scalar_t* grad_data = grad.data_ptr<scalar_t>();
154 scalar_t* state_sum_data = state_sum.data_ptr<scalar_t>();
155 double step = state_step.item<float>();
156 double clr = lr / (1.0 + (step - 1.0) * lr_decay);
157
158 constexpr size_t cache_line_size = 64;
159 constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
160 size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
161
162 auto adagrad_fn = [&](int64_t begin, int64_t end) {
163 // local pointers
164 begin *= cache_line_aligned_task_unit;
165 end = std::min(end * cache_line_aligned_task_unit, param.numel());
166 scalar_t* param_ptr = param_data + begin;
167 scalar_t* grad_ptr = grad_data + begin;
168 scalar_t* state_sum_ptr = state_sum_data + begin;
169
170 const int64_t size = end - begin;
171 adagrad_math<scalar_t, opmath_t>(
172 param_ptr,
173 grad_ptr,
174 state_sum_ptr,
175 clr,
176 eps,
177 weight_decay,
178 maximize,
179 grad_scale_ptr,
180 size
181 );
182 };
183 at::parallel_for(
184 0, num_units, 0, adagrad_fn);
185 }
186
fused_adagrad_kernel(const at::Tensor & param,const at::Tensor & grad,const at::Tensor & state_sum,const at::Tensor & state_step,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const float * grad_scale_ptr)187 void fused_adagrad_kernel(
188 const at::Tensor& param,
189 const at::Tensor& grad,
190 const at::Tensor& state_sum,
191 const at::Tensor& state_step,
192 const double lr,
193 const double lr_decay,
194 const double weight_decay,
195 const double eps,
196 const bool maximize,
197 const float* grad_scale_ptr
198 ) {
199 Tensor grad_contiguous = grad.contiguous();
200 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adagrad_kernel", [&] {
201 adagrad_fused_step_impl<scalar_t>(
202 param,
203 grad,
204 state_sum,
205 state_step,
206 lr,
207 lr_decay,
208 weight_decay,
209 eps,
210 maximize,
211 grad_scale_ptr);
212 });
213 }
214
215 }
216
217 REGISTER_DISPATCH(fused_adagrad_stub, &fused_adagrad_kernel);
218 } // namespace at::native
219