1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/cpu/WeightNormKernel.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <c10/util/irange.h>
12
13 namespace at::native {
14
15 namespace {
16
17 template <typename scalar_t, typename accscalar_t>
weight_norm_first_dim_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t M,int64_t N)18 void weight_norm_first_dim_kernel(
19 TensorBase& w,
20 TensorBase& norm,
21 const TensorBase& v,
22 const TensorBase& g,
23 int64_t M, int64_t N) {
24 const auto v_data = v.data_ptr<scalar_t>();
25 const auto g_data = g.data_ptr<scalar_t>();
26 auto w_data = w.data_ptr<scalar_t>();
27 auto norm_data = norm.data_ptr<accscalar_t>();
28
29 using Vec = vec::Vectorized<accscalar_t>;
30 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
31 for (const auto i : c10::irange(begin, end)) {
32 accscalar_t norm_val = vec::map_reduce_all<scalar_t>(
33 [](Vec x) { return x * x; },
34 [](Vec x, Vec y) { return x + y; },
35 v_data + i * N,
36 N);
37 norm_val = std::sqrt(norm_val);
38 norm_data[i] = norm_val;
39
40 accscalar_t a = g_data[i] / norm_val;
41 vec::map(
42 [a](Vec x) { return x * Vec(a); },
43 w_data + i * N,
44 v_data + i * N,
45 N);
46 }
47 });
48 }
49
50 template <typename scalar_t>
sum_norm_per_row(scalar_t * out_ptr,const scalar_t * v_ptr,int64_t size)51 inline void sum_norm_per_row(
52 scalar_t* out_ptr,
53 const scalar_t* v_ptr,
54 int64_t size) {
55 using Vec = vec::Vectorized<scalar_t>;
56 vec::map2(
57 [](Vec out, Vec v) { return out + v * v; },
58 out_ptr,
59 out_ptr,
60 v_ptr,
61 size);
62 }
63
sum_norm_per_row(float * out_ptr,const BFloat16 * v_ptr,int64_t size)64 inline void sum_norm_per_row(
65 float* out_ptr,
66 const BFloat16* v_ptr,
67 int64_t size) {
68 using bVec = vec::Vectorized<BFloat16>;
69 using fVec = vec::Vectorized<float>;
70 int64_t d = 0;
71 for (; d < size - (size % bVec::size()); d += bVec::size()) {
72 bVec v_bvec = bVec::loadu(v_ptr + d);
73 auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
74
75 fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0;
76 fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1;
77 out_fvec0.store(out_ptr + d);
78 out_fvec1.store(out_ptr + d + fVec::size());
79 }
80 for(; d < size; ++d) {
81 float v_val = float(v_ptr[d]);
82 out_ptr[d] += v_val * v_val;
83 }
84 }
85
86 template <typename scalar_t>
apply_norm_per_row(scalar_t * w_ptr,const scalar_t * v_ptr,const scalar_t * a_ptr,int64_t size)87 inline void apply_norm_per_row(
88 scalar_t* w_ptr,
89 const scalar_t* v_ptr,
90 const scalar_t* a_ptr,
91 int64_t size) {
92 using Vec = vec::Vectorized<scalar_t>;
93 vec::map2(
94 [](Vec v, Vec a) { return v * a; },
95 w_ptr,
96 v_ptr,
97 a_ptr,
98 size);
99 }
100
apply_norm_per_row(BFloat16 * w_ptr,const BFloat16 * v_ptr,const float * a_ptr,int64_t size)101 inline void apply_norm_per_row(
102 BFloat16* w_ptr,
103 const BFloat16* v_ptr,
104 const float* a_ptr,
105 int64_t size) {
106 using bVec = vec::Vectorized<BFloat16>;
107 using fVec = vec::Vectorized<float>;
108 int64_t d = 0;
109 for (; d < size - (size % bVec::size()); d += bVec::size()) {
110 bVec v_bvec = bVec::loadu(v_ptr + d);
111 auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
112
113 fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0;
114 fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1;
115 bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1);
116 w_bvec.store(w_ptr + d);
117 }
118 for(; d < size; ++d) {
119 w_ptr[d] = float(v_ptr[d]) * a_ptr[d];
120 }
121 }
122
123 template <typename scalar_t, typename accscalar_t>
weight_norm_last_dim_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t M,int64_t N)124 void weight_norm_last_dim_kernel(
125 TensorBase& w,
126 TensorBase& norm,
127 const TensorBase& v,
128 const TensorBase& g,
129 int64_t M, int64_t N) {
130 const auto v_data = v.data_ptr<scalar_t>();
131 const auto g_data = g.data_ptr<scalar_t>();
132 auto w_data = w.data_ptr<scalar_t>();
133 auto norm_data = norm.data_ptr<accscalar_t>();
134
135 int num_threads = at::get_num_threads();
136 TensorBase buffer = at::detail::empty_cpu({num_threads, N}, norm.options()).zero_();
137 auto buffer_data = buffer.data_ptr<accscalar_t>();
138
139 // vertical parallel reduction
140 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
141 int tid = at::get_thread_num();
142 TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
143 auto buffer_ptr = buffer_data + tid * N;
144 for (const auto i : c10::irange(begin, end)) {
145 sum_norm_per_row(buffer_ptr, v_data + i * N, N);
146 }
147 });
148
149 for (const auto j : c10::irange(N)) {
150 accscalar_t sum = 0;
151 for (const auto t : c10::irange(num_threads)) {
152 sum += buffer_data[t * N + j];
153 }
154 norm_data[j] = std::sqrt(sum);
155 }
156
157 // reuse the first row of buffer to store g / norm
158 vec::convert(g_data, buffer_data, N);
159 using Vec = vec::Vectorized<accscalar_t>;
160 vec::map2(
161 [](Vec g, Vec norm) { return g / norm; },
162 buffer_data,
163 buffer_data,
164 norm_data,
165 N);
166
167 // apply w = v * (g/norm)
168 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
169 for (const auto i : c10::irange(begin, end)) {
170 apply_norm_per_row(w_data + i * N, v_data + i * N, buffer_data, N);
171 }
172 });
173 }
174
175 template <typename scalar_t, typename accscalar_t>
weight_norm_backward_first_dim_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t M,int64_t N)176 void weight_norm_backward_first_dim_kernel(
177 TensorBase& grad_v,
178 TensorBase& grad_g,
179 const TensorBase& grad_w,
180 const TensorBase& saved_v,
181 const TensorBase& saved_g,
182 const TensorBase& saved_norm,
183 int64_t M, int64_t N) {
184 const auto grad_w_data = grad_w.data_ptr<scalar_t>();
185 const auto saved_v_data = saved_v.data_ptr<scalar_t>();
186 const auto saved_g_data = saved_g.data_ptr<scalar_t>();
187 const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
188 auto grad_v_data = grad_v.data_ptr<scalar_t>();
189 auto grad_g_data = grad_g.data_ptr<scalar_t>();
190
191 using Vec = vec::Vectorized<accscalar_t>;
192 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
193 for (const auto i : c10::irange(begin, end)) {
194 accscalar_t per_dim_sum_val = vec::map2_reduce_all<scalar_t>(
195 [](Vec grad_w, Vec saved_v) { return grad_w * saved_v; },
196 [](Vec x, Vec y) { return x + y; },
197 grad_w_data + i * N,
198 saved_v_data + i * N,
199 N);
200
201 accscalar_t saved_norm_val = saved_norm_data[i];
202 accscalar_t saved_g_val = accscalar_t(saved_g_data[i]);
203 accscalar_t grad_g_val = per_dim_sum_val / saved_norm_val;
204
205 // grad_g = sum / norm
206 // grad_v = (g / norm) * (grad_w - v * (sum / norm^2))
207 // let a = g /norm
208 // b = a * grad_g / norm
209 // grad_v = a * grad_w - b * v
210 grad_g_data[i] = scalar_t(grad_g_val);
211 accscalar_t a = saved_g_val / saved_norm_val;
212 accscalar_t b = a * grad_g_val / saved_norm_val;
213
214 vec::map2(
215 [a, b](Vec grad_w, Vec v) { return Vec(a) * grad_w - Vec(b) * v; },
216 grad_v_data + i * N,
217 grad_w_data + i * N,
218 saved_v_data + i * N,
219 N);
220 }
221 });
222 }
223
224 template <typename scalar_t>
sum_product_per_row(scalar_t * out_ptr,const scalar_t * grad_w_ptr,const scalar_t * v_ptr,int64_t size)225 inline void sum_product_per_row(
226 scalar_t* out_ptr,
227 const scalar_t* grad_w_ptr,
228 const scalar_t* v_ptr,
229 int64_t size) {
230 using Vec = vec::Vectorized<scalar_t>;
231 vec::map3(
232 [](Vec out, Vec grad_w, Vec v) { return out + grad_w * v; },
233 out_ptr,
234 out_ptr,
235 grad_w_ptr,
236 v_ptr,
237 size);
238 }
239
sum_product_per_row(float * out_ptr,const BFloat16 * grad_w_ptr,const BFloat16 * v_ptr,int64_t size)240 inline void sum_product_per_row(
241 float* out_ptr,
242 const BFloat16* grad_w_ptr,
243 const BFloat16* v_ptr,
244 int64_t size) {
245 using bVec = vec::Vectorized<BFloat16>;
246 using fVec = vec::Vectorized<float>;
247 int64_t d = 0;
248 for (; d < size - (size % bVec::size()); d += bVec::size()) {
249 bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
250 auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
251 bVec v_bvec = bVec::loadu(v_ptr + d);
252 auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
253
254 fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0;
255 fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1;
256 out_fvec0.store(out_ptr + d);
257 out_fvec1.store(out_ptr + d + fVec::size());
258 }
259 for(; d < size; ++d) {
260 float grad_w_val = float(grad_w_ptr[d]);
261 float v_val = float(v_ptr[d]);
262 out_ptr[d] += grad_w_val * v_val;
263 }
264 }
265
266 template <typename scalar_t>
apply_per_row_backward(scalar_t * grad_v_ptr,const scalar_t * grad_w_ptr,const scalar_t * v_ptr,const scalar_t * a_ptr,const scalar_t * b_ptr,int64_t size)267 inline void apply_per_row_backward(
268 scalar_t* grad_v_ptr,
269 const scalar_t* grad_w_ptr,
270 const scalar_t* v_ptr,
271 const scalar_t* a_ptr,
272 const scalar_t* b_ptr,
273 int64_t size) {
274 using Vec = vec::Vectorized<scalar_t>;
275 vec::map4(
276 [](Vec grad_w, Vec v, Vec a, Vec b) { return a * grad_w - b * v; },
277 grad_v_ptr,
278 grad_w_ptr,
279 v_ptr,
280 a_ptr,
281 b_ptr,
282 size);
283 }
284
apply_per_row_backward(BFloat16 * grad_v_ptr,const BFloat16 * grad_w_ptr,const BFloat16 * v_ptr,const float * a_ptr,const float * b_ptr,int64_t size)285 inline void apply_per_row_backward(
286 BFloat16* grad_v_ptr,
287 const BFloat16* grad_w_ptr,
288 const BFloat16* v_ptr,
289 const float* a_ptr,
290 const float* b_ptr,
291 int64_t size) {
292 using bVec = vec::Vectorized<BFloat16>;
293 using fVec = vec::Vectorized<float>;
294 int64_t d = 0;
295 for (; d < size - (size % bVec::size()); d += bVec::size()) {
296 bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
297 auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
298 bVec v_bvec = bVec::loadu(v_ptr + d);
299 auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
300
301 fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0;
302 fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1
303 - fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1;
304 bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1);
305 grad_v_bvec.store(grad_v_ptr + d);
306 }
307 for(; d < size; ++d) {
308 grad_v_ptr[d] = float(grad_w_ptr[d]) * a_ptr[d] - float(v_ptr[d]) * b_ptr[d];
309 }
310 }
311
312 template <typename scalar_t, typename accscalar_t>
weight_norm_backward_last_dim_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t M,int64_t N)313 void weight_norm_backward_last_dim_kernel(
314 TensorBase& grad_v,
315 TensorBase& grad_g,
316 const TensorBase& grad_w,
317 const TensorBase& saved_v,
318 const TensorBase& saved_g,
319 const TensorBase& saved_norm,
320 int64_t M, int64_t N) {
321 const auto grad_w_data = grad_w.data_ptr<scalar_t>();
322 const auto saved_v_data = saved_v.data_ptr<scalar_t>();
323 const auto saved_g_data = saved_g.data_ptr<scalar_t>();
324 const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
325 auto grad_v_data = grad_v.data_ptr<scalar_t>();
326 auto grad_g_data = grad_g.data_ptr<scalar_t>();
327
328 // the temp buffer will be used twice:
329 // 1. vertical reduction from [M, N] to [T, N]
330 // 2. store the intermediate data of `sum`, `a` and `b`,
331 // so need to make sure it has at least 3 rows
332 //
333 int num_threads = at::get_num_threads();
334 int K = std::max(3, num_threads);
335 TensorBase buffer = at::detail::empty_cpu({K, N}, saved_norm.options()).zero_();
336 auto buffer_data = buffer.data_ptr<accscalar_t>();
337
338 // vertical parallel reduction
339 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
340 int tid = at::get_thread_num();
341 TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
342 auto buffer_ptr = buffer_data + tid * N;
343 for (const auto i : c10::irange(begin, end)) {
344 sum_product_per_row(buffer_ptr, grad_w_data + i * N, saved_v_data + i * N, N);
345 }
346 });
347
348 // store result on the first row of buffer
349 for (const auto j : c10::irange(N)) {
350 accscalar_t sum = 0;
351 for (const auto t : c10::irange(num_threads)) {
352 sum += buffer_data[t * N + j];
353 }
354 buffer_data[j] = sum;
355 }
356
357 // reuse the 1st row of buffer to store the sum
358 // 2nd row to store coefficient a
359 // 3rd row to store coefficient b
360 accscalar_t* per_dim_sum = buffer_data;
361 accscalar_t* a = buffer_data + N;
362 accscalar_t* b = buffer_data + 2 * N;
363
364 // a = g /norm
365 // b = a * grad_g / norm
366 for (const auto j : c10::irange(N)) {
367 accscalar_t saved_norm_val = saved_norm_data[j];
368 accscalar_t saved_g_val = accscalar_t(saved_g_data[j]);
369 accscalar_t grad_g_val = per_dim_sum[j] / saved_norm_val;
370 grad_g_data[j] = scalar_t(grad_g_val);
371
372 a[j] = saved_g_val / saved_norm_val;
373 b[j] = a[j] * grad_g_val / saved_norm_val;
374 }
375
376 // apply grad_v = a * grad_w - b * v
377 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
378 for (const auto i : c10::irange(begin, end)) {
379 apply_per_row_backward(
380 grad_v_data + i * N,
381 grad_w_data + i * N,
382 saved_v_data + i * N,
383 a,
384 b,
385 N);
386 }
387 });
388 }
389
weight_norm_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t dim)390 void weight_norm_kernel(
391 TensorBase& w,
392 TensorBase& norm,
393 const TensorBase& v,
394 const TensorBase& g,
395 int64_t dim) {
396 TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1,
397 "fused kernels can only be applied for first or last dim");
398 AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(),
399 "weight_norm_kernel", [&]() {
400 using accscalar_t = at::opmath_type<scalar_t>;
401 if (dim == 0) {
402 int64_t M = v.size(0);
403 int64_t N = v.numel() / M;
404 weight_norm_first_dim_kernel<scalar_t, accscalar_t>(w, norm, v, g, M, N);
405 } else {
406 int64_t N = v.size(-1);
407 int64_t M = v.numel() / N;
408 weight_norm_last_dim_kernel<scalar_t, accscalar_t>(w, norm, v, g, M, N);
409 }
410 });
411 }
412
weight_norm_backward_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t dim)413 void weight_norm_backward_kernel(
414 TensorBase& grad_v,
415 TensorBase& grad_g,
416 const TensorBase& grad_w,
417 const TensorBase& saved_v,
418 const TensorBase& saved_g,
419 const TensorBase& saved_norm,
420 int64_t dim) {
421 TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1,
422 "fused kernels can only be applied for first or last dim");
423 AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(),
424 "weight_norm_backward_kernel", [&]() {
425 using accscalar_t = at::opmath_type<scalar_t>;
426 if (dim == 0) {
427 int64_t M = saved_v.size(0);
428 int64_t N = saved_v.numel() / M;
429 weight_norm_backward_first_dim_kernel<scalar_t, accscalar_t>(grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, M, N);
430 } else {
431 int64_t N = saved_v.size(-1);
432 int64_t M = saved_v.numel() / N;
433 weight_norm_backward_last_dim_kernel<scalar_t, accscalar_t>(grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, M, N);
434 }
435 });
436 }
437
438 } // anonymous namespace
439
440 REGISTER_DISPATCH(weight_norm_stub, &weight_norm_kernel);
441 REGISTER_DISPATCH(weight_norm_backward_stub, &weight_norm_backward_kernel);
442
443 } // at::native
444