xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/WeightNormKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/cpu/WeightNormKernel.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::native {
14 
15 namespace {
16 
17 template <typename scalar_t, typename accscalar_t>
weight_norm_first_dim_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t M,int64_t N)18 void weight_norm_first_dim_kernel(
19     TensorBase& w,
20     TensorBase& norm,
21     const TensorBase& v,
22     const TensorBase& g,
23     int64_t M, int64_t N) {
24   const auto v_data = v.data_ptr<scalar_t>();
25   const auto g_data = g.data_ptr<scalar_t>();
26   auto w_data = w.data_ptr<scalar_t>();
27   auto norm_data = norm.data_ptr<accscalar_t>();
28 
29   using Vec = vec::Vectorized<accscalar_t>;
30   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
31     for (const auto i : c10::irange(begin, end)) {
32       accscalar_t norm_val = vec::map_reduce_all<scalar_t>(
33           [](Vec x) { return x * x; },
34           [](Vec x, Vec y) { return x + y; },
35           v_data + i * N,
36           N);
37       norm_val = std::sqrt(norm_val);
38       norm_data[i] = norm_val;
39 
40       accscalar_t a = g_data[i] / norm_val;
41       vec::map(
42           [a](Vec x) { return x * Vec(a); },
43           w_data + i * N,
44           v_data + i * N,
45           N);
46     }
47   });
48 }
49 
50 template <typename scalar_t>
sum_norm_per_row(scalar_t * out_ptr,const scalar_t * v_ptr,int64_t size)51 inline void sum_norm_per_row(
52     scalar_t* out_ptr,
53     const scalar_t* v_ptr,
54     int64_t size) {
55   using Vec = vec::Vectorized<scalar_t>;
56   vec::map2(
57       [](Vec out, Vec v) { return out + v * v; },
58       out_ptr,
59       out_ptr,
60       v_ptr,
61       size);
62 }
63 
sum_norm_per_row(float * out_ptr,const BFloat16 * v_ptr,int64_t size)64 inline void sum_norm_per_row(
65     float* out_ptr,
66     const BFloat16* v_ptr,
67     int64_t size) {
68   using bVec = vec::Vectorized<BFloat16>;
69   using fVec = vec::Vectorized<float>;
70   int64_t d = 0;
71   for (; d < size - (size % bVec::size()); d += bVec::size()) {
72     bVec v_bvec = bVec::loadu(v_ptr + d);
73     auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
74 
75     fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0;
76     fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1;
77     out_fvec0.store(out_ptr + d);
78     out_fvec1.store(out_ptr + d + fVec::size());
79   }
80   for(; d < size; ++d) {
81     float v_val = float(v_ptr[d]);
82     out_ptr[d] += v_val * v_val;
83   }
84 }
85 
86 template <typename scalar_t>
apply_norm_per_row(scalar_t * w_ptr,const scalar_t * v_ptr,const scalar_t * a_ptr,int64_t size)87 inline void apply_norm_per_row(
88     scalar_t* w_ptr,
89     const scalar_t* v_ptr,
90     const scalar_t* a_ptr,
91     int64_t size) {
92   using Vec = vec::Vectorized<scalar_t>;
93   vec::map2(
94       [](Vec v, Vec a) { return v * a; },
95       w_ptr,
96       v_ptr,
97       a_ptr,
98       size);
99 }
100 
apply_norm_per_row(BFloat16 * w_ptr,const BFloat16 * v_ptr,const float * a_ptr,int64_t size)101 inline void apply_norm_per_row(
102     BFloat16* w_ptr,
103     const BFloat16* v_ptr,
104     const float* a_ptr,
105     int64_t size) {
106   using bVec = vec::Vectorized<BFloat16>;
107   using fVec = vec::Vectorized<float>;
108   int64_t d = 0;
109   for (; d < size - (size % bVec::size()); d += bVec::size()) {
110     bVec v_bvec = bVec::loadu(v_ptr + d);
111     auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
112 
113     fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0;
114     fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1;
115     bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1);
116     w_bvec.store(w_ptr + d);
117   }
118   for(; d < size; ++d) {
119     w_ptr[d] = float(v_ptr[d]) * a_ptr[d];
120   }
121 }
122 
123 template <typename scalar_t, typename accscalar_t>
weight_norm_last_dim_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t M,int64_t N)124 void weight_norm_last_dim_kernel(
125     TensorBase& w,
126     TensorBase& norm,
127     const TensorBase& v,
128     const TensorBase& g,
129     int64_t M, int64_t N) {
130   const auto v_data = v.data_ptr<scalar_t>();
131   const auto g_data = g.data_ptr<scalar_t>();
132   auto w_data = w.data_ptr<scalar_t>();
133   auto norm_data = norm.data_ptr<accscalar_t>();
134 
135   int num_threads = at::get_num_threads();
136   TensorBase buffer = at::detail::empty_cpu({num_threads, N}, norm.options()).zero_();
137   auto buffer_data = buffer.data_ptr<accscalar_t>();
138 
139   // vertical parallel reduction
140   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
141     int tid = at::get_thread_num();
142     TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
143     auto buffer_ptr = buffer_data + tid * N;
144     for (const auto i : c10::irange(begin, end)) {
145       sum_norm_per_row(buffer_ptr, v_data + i * N, N);
146     }
147   });
148 
149   for (const auto j : c10::irange(N)) {
150     accscalar_t sum = 0;
151     for (const auto t : c10::irange(num_threads)) {
152       sum += buffer_data[t * N + j];
153     }
154     norm_data[j] = std::sqrt(sum);
155   }
156 
157   // reuse the first row of buffer to store g / norm
158   vec::convert(g_data, buffer_data, N);
159   using Vec = vec::Vectorized<accscalar_t>;
160   vec::map2(
161       [](Vec g, Vec norm) { return g / norm; },
162       buffer_data,
163       buffer_data,
164       norm_data,
165       N);
166 
167   // apply w = v * (g/norm)
168   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
169     for (const auto i : c10::irange(begin, end)) {
170       apply_norm_per_row(w_data + i * N, v_data + i * N, buffer_data, N);
171     }
172   });
173 }
174 
175 template <typename scalar_t, typename accscalar_t>
weight_norm_backward_first_dim_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t M,int64_t N)176 void weight_norm_backward_first_dim_kernel(
177     TensorBase& grad_v,
178     TensorBase& grad_g,
179     const TensorBase& grad_w,
180     const TensorBase& saved_v,
181     const TensorBase& saved_g,
182     const TensorBase& saved_norm,
183     int64_t M, int64_t N) {
184   const auto grad_w_data = grad_w.data_ptr<scalar_t>();
185   const auto saved_v_data = saved_v.data_ptr<scalar_t>();
186   const auto saved_g_data = saved_g.data_ptr<scalar_t>();
187   const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
188   auto grad_v_data = grad_v.data_ptr<scalar_t>();
189   auto grad_g_data = grad_g.data_ptr<scalar_t>();
190 
191   using Vec = vec::Vectorized<accscalar_t>;
192   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
193     for (const auto i : c10::irange(begin, end)) {
194       accscalar_t per_dim_sum_val = vec::map2_reduce_all<scalar_t>(
195           [](Vec grad_w, Vec saved_v) { return grad_w * saved_v; },
196           [](Vec x, Vec y) { return x + y; },
197           grad_w_data + i * N,
198           saved_v_data + i * N,
199           N);
200 
201       accscalar_t saved_norm_val = saved_norm_data[i];
202       accscalar_t saved_g_val = accscalar_t(saved_g_data[i]);
203       accscalar_t grad_g_val = per_dim_sum_val / saved_norm_val;
204 
205       // grad_g = sum / norm
206       // grad_v = (g / norm) * (grad_w - v * (sum / norm^2))
207       //  let a = g /norm
208       //      b = a * grad_g / norm
209       // grad_v = a * grad_w - b * v
210       grad_g_data[i] = scalar_t(grad_g_val);
211       accscalar_t a = saved_g_val / saved_norm_val;
212       accscalar_t b = a * grad_g_val / saved_norm_val;
213 
214       vec::map2(
215           [a, b](Vec grad_w, Vec v) { return Vec(a) * grad_w - Vec(b) * v; },
216           grad_v_data + i * N,
217           grad_w_data + i * N,
218           saved_v_data + i * N,
219           N);
220     }
221   });
222 }
223 
224 template <typename scalar_t>
sum_product_per_row(scalar_t * out_ptr,const scalar_t * grad_w_ptr,const scalar_t * v_ptr,int64_t size)225 inline void sum_product_per_row(
226     scalar_t* out_ptr,
227     const scalar_t* grad_w_ptr,
228     const scalar_t* v_ptr,
229     int64_t size) {
230   using Vec = vec::Vectorized<scalar_t>;
231   vec::map3(
232       [](Vec out, Vec grad_w, Vec v) { return out + grad_w * v; },
233       out_ptr,
234       out_ptr,
235       grad_w_ptr,
236       v_ptr,
237       size);
238 }
239 
sum_product_per_row(float * out_ptr,const BFloat16 * grad_w_ptr,const BFloat16 * v_ptr,int64_t size)240 inline void sum_product_per_row(
241     float* out_ptr,
242     const BFloat16* grad_w_ptr,
243     const BFloat16* v_ptr,
244     int64_t size) {
245   using bVec = vec::Vectorized<BFloat16>;
246   using fVec = vec::Vectorized<float>;
247   int64_t d = 0;
248   for (; d < size - (size % bVec::size()); d += bVec::size()) {
249     bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
250     auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
251     bVec v_bvec = bVec::loadu(v_ptr + d);
252     auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
253 
254     fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0;
255     fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1;
256     out_fvec0.store(out_ptr + d);
257     out_fvec1.store(out_ptr + d + fVec::size());
258   }
259   for(; d < size; ++d) {
260     float grad_w_val = float(grad_w_ptr[d]);
261     float v_val = float(v_ptr[d]);
262     out_ptr[d] += grad_w_val * v_val;
263   }
264 }
265 
266 template <typename scalar_t>
apply_per_row_backward(scalar_t * grad_v_ptr,const scalar_t * grad_w_ptr,const scalar_t * v_ptr,const scalar_t * a_ptr,const scalar_t * b_ptr,int64_t size)267 inline void apply_per_row_backward(
268     scalar_t* grad_v_ptr,
269     const scalar_t* grad_w_ptr,
270     const scalar_t* v_ptr,
271     const scalar_t* a_ptr,
272     const scalar_t* b_ptr,
273     int64_t size) {
274   using Vec = vec::Vectorized<scalar_t>;
275   vec::map4(
276       [](Vec grad_w, Vec v, Vec a, Vec b) { return a * grad_w - b * v; },
277       grad_v_ptr,
278       grad_w_ptr,
279       v_ptr,
280       a_ptr,
281       b_ptr,
282       size);
283 }
284 
apply_per_row_backward(BFloat16 * grad_v_ptr,const BFloat16 * grad_w_ptr,const BFloat16 * v_ptr,const float * a_ptr,const float * b_ptr,int64_t size)285 inline void apply_per_row_backward(
286     BFloat16* grad_v_ptr,
287     const BFloat16* grad_w_ptr,
288     const BFloat16* v_ptr,
289     const float* a_ptr,
290     const float* b_ptr,
291     int64_t size) {
292   using bVec = vec::Vectorized<BFloat16>;
293   using fVec = vec::Vectorized<float>;
294   int64_t d = 0;
295   for (; d < size - (size % bVec::size()); d += bVec::size()) {
296     bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
297     auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
298     bVec v_bvec = bVec::loadu(v_ptr + d);
299     auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
300 
301     fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0;
302     fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1
303         - fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1;
304     bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1);
305     grad_v_bvec.store(grad_v_ptr + d);
306   }
307   for(; d < size; ++d) {
308     grad_v_ptr[d] = float(grad_w_ptr[d]) * a_ptr[d] - float(v_ptr[d]) * b_ptr[d];
309   }
310 }
311 
312 template <typename scalar_t, typename accscalar_t>
weight_norm_backward_last_dim_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t M,int64_t N)313 void weight_norm_backward_last_dim_kernel(
314     TensorBase& grad_v,
315     TensorBase& grad_g,
316     const TensorBase& grad_w,
317     const TensorBase& saved_v,
318     const TensorBase& saved_g,
319     const TensorBase& saved_norm,
320     int64_t M, int64_t N) {
321   const auto grad_w_data = grad_w.data_ptr<scalar_t>();
322   const auto saved_v_data = saved_v.data_ptr<scalar_t>();
323   const auto saved_g_data = saved_g.data_ptr<scalar_t>();
324   const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
325   auto grad_v_data = grad_v.data_ptr<scalar_t>();
326   auto grad_g_data = grad_g.data_ptr<scalar_t>();
327 
328   // the temp buffer will be used twice:
329   // 1. vertical reduction from [M, N] to [T, N]
330   // 2. store the intermediate data of `sum`, `a` and `b`,
331   //    so need to make sure it has at least 3 rows
332   //
333   int num_threads = at::get_num_threads();
334   int K = std::max(3, num_threads);
335   TensorBase buffer = at::detail::empty_cpu({K, N}, saved_norm.options()).zero_();
336   auto buffer_data = buffer.data_ptr<accscalar_t>();
337 
338   // vertical parallel reduction
339   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
340     int tid = at::get_thread_num();
341     TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
342     auto buffer_ptr = buffer_data + tid * N;
343     for (const auto i : c10::irange(begin, end)) {
344       sum_product_per_row(buffer_ptr, grad_w_data + i * N, saved_v_data + i * N, N);
345     }
346   });
347 
348   // store result on the first row of buffer
349   for (const auto j : c10::irange(N)) {
350     accscalar_t sum = 0;
351     for (const auto t : c10::irange(num_threads)) {
352       sum += buffer_data[t * N + j];
353     }
354     buffer_data[j] = sum;
355   }
356 
357   // reuse the 1st row of buffer to store the sum
358   // 2nd row to store coefficient a
359   // 3rd row to store coefficient b
360   accscalar_t* per_dim_sum = buffer_data;
361   accscalar_t* a = buffer_data + N;
362   accscalar_t* b = buffer_data + 2 * N;
363 
364   // a = g /norm
365   // b = a * grad_g / norm
366   for (const auto j : c10::irange(N)) {
367     accscalar_t saved_norm_val = saved_norm_data[j];
368     accscalar_t saved_g_val = accscalar_t(saved_g_data[j]);
369     accscalar_t grad_g_val = per_dim_sum[j] / saved_norm_val;
370     grad_g_data[j] = scalar_t(grad_g_val);
371 
372     a[j] = saved_g_val / saved_norm_val;
373     b[j] = a[j] * grad_g_val / saved_norm_val;
374   }
375 
376   // apply grad_v = a * grad_w - b * v
377   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
378     for (const auto i : c10::irange(begin, end)) {
379       apply_per_row_backward(
380           grad_v_data + i * N,
381           grad_w_data + i * N,
382           saved_v_data + i * N,
383           a,
384           b,
385           N);
386     }
387   });
388 }
389 
weight_norm_kernel(TensorBase & w,TensorBase & norm,const TensorBase & v,const TensorBase & g,int64_t dim)390 void weight_norm_kernel(
391     TensorBase& w,
392     TensorBase& norm,
393     const TensorBase& v,
394     const TensorBase& g,
395     int64_t dim) {
396   TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1,
397       "fused kernels can only be applied for first or last dim");
398   AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(),
399       "weight_norm_kernel", [&]() {
400     using accscalar_t = at::opmath_type<scalar_t>;
401     if (dim == 0) {
402       int64_t M = v.size(0);
403       int64_t N = v.numel() / M;
404       weight_norm_first_dim_kernel<scalar_t, accscalar_t>(w, norm, v, g, M, N);
405     } else {
406       int64_t N = v.size(-1);
407       int64_t M = v.numel() / N;
408       weight_norm_last_dim_kernel<scalar_t, accscalar_t>(w, norm, v, g, M, N);
409     }
410   });
411 }
412 
weight_norm_backward_kernel(TensorBase & grad_v,TensorBase & grad_g,const TensorBase & grad_w,const TensorBase & saved_v,const TensorBase & saved_g,const TensorBase & saved_norm,int64_t dim)413 void weight_norm_backward_kernel(
414     TensorBase& grad_v,
415     TensorBase& grad_g,
416     const TensorBase& grad_w,
417     const TensorBase& saved_v,
418     const TensorBase& saved_g,
419     const TensorBase& saved_norm,
420     int64_t dim) {
421   TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1,
422       "fused kernels can only be applied for first or last dim");
423   AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(),
424       "weight_norm_backward_kernel", [&]() {
425     using accscalar_t = at::opmath_type<scalar_t>;
426     if (dim == 0) {
427       int64_t M = saved_v.size(0);
428       int64_t N = saved_v.numel() / M;
429       weight_norm_backward_first_dim_kernel<scalar_t, accscalar_t>(grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, M, N);
430     } else {
431       int64_t N = saved_v.size(-1);
432       int64_t M = saved_v.numel() / N;
433       weight_norm_backward_last_dim_kernel<scalar_t, accscalar_t>(grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, M, N);
434     }
435   });
436 }
437 
438 } // anonymous namespace
439 
440 REGISTER_DISPATCH(weight_norm_stub, &weight_norm_kernel);
441 REGISTER_DISPATCH(weight_norm_backward_stub, &weight_norm_backward_kernel);
442 
443 } // at::native
444