xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SoftMaxKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <memory>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/native/cpu/SoftmaxKernel.h>
4 
5 #include <algorithm>
6 #include <iterator>
7 #include <numeric>
8 
9 #include <ATen/Dispatch.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/OpMathType.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/cpu/vec/functional.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17 #include <ATen/OpMathType.h>
18 
19 // [Note AVX-SSE transitions] In general we avoid calls into cmath for code
20 // compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
21 // Glibc2.23 See https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280
22 //
23 // On grainsize: The grainsize is chosen to roughly get GRAIN_SIZE number of
24 // computations per task. Each task works across dim_size elements. 16 should be
25 // a very rough approximation of the number of computations per dim_size element
26 // by counting simple computations (*, +, -) as 1 and exp or log as 4.
27 //
28 // We use a chunk size such that it'd fit in L1D.
29 
30 namespace at::native {
31 
32 namespace {
33 template <typename scalar_t>
_vec_log_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)34 inline void _vec_log_softmax_lastdim(
35     const scalar_t* input_data_base,
36     scalar_t* output_data_base,
37     int64_t outer_size,
38     int64_t dim_size) {
39   using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
40   // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
41   // size of L1D cache on many processors. Some processors have 48 KB L1D cache
42   // nowadays, so maybe in the future, we can leverage the knowledge of a
43   // machine's L1D cache size.
44   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
45       1,
46       at::internal::GRAIN_SIZE / (sizeof(scalar_t) * dim_size));
47   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
48   // Note: grain_size value of 0
49   // We don't change the number of OpenMP threads in the OpenMP thread-pool,
50   // so some threads do useful work, while others don't.
51   // We can simply use grain_size of 0 & rely upon invoke_parallel to distribute
52   // work among threads in an equitable manner. We compute CHUNK_SIZE to ensure
53   // each thread's computations would be efficient.
54   parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
55     // MSVC requires such a declaration of dynamic arrays
56     // Source: https://stackoverflow.com/a/33423538
57     auto tmp_sum_scalar = std::make_unique<scalar_t[]>(CHUNK_SIZE);
58     auto max_input_arr = std::make_unique<scalar_t[]>(CHUNK_SIZE);
59     for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
60       int64_t loop_end = CHUNK_SIZE;
61       if (ii + CHUNK_SIZE > end)
62         loop_end = end - ii;
63       for (const auto j : c10::irange(loop_end)) {
64         int64_t i = ii + j;
65         const scalar_t* input_data = input_data_base + i * dim_size;
66         max_input_arr[j] = vec::reduce_all<scalar_t>(
67             [](Vec& x, Vec& y) { return vec::maximum(x, y); },
68             input_data,
69             dim_size);
70       }
71       for (const auto j : c10::irange(loop_end)) {
72         int64_t i = ii + j;
73         const scalar_t* input_data = input_data_base + i * dim_size;
74         scalar_t max_input = max_input_arr[j];
75         tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
76             [max_input](Vec x) { return (x - Vec(max_input)).exp(); },
77             [](Vec x, Vec y) { return x + y; },
78             input_data,
79             dim_size);
80       }
81       // See [Note AVX-SSE transitions] for why this should call the
82       // vectorized version (aside from perf improvements).
83       vec::map(
84           [](Vec x) { return x.log(); },
85           tmp_sum_scalar.get(),
86           tmp_sum_scalar.get(),
87           loop_end);
88       for (const auto j : c10::irange(loop_end)) {
89         int64_t i = ii + j;
90         const scalar_t* input_data = input_data_base + i * dim_size;
91         scalar_t* output_data = output_data_base + i * dim_size;
92         scalar_t tmp_sum = tmp_sum_scalar[j];
93         scalar_t max_input = max_input_arr[j];
94 
95         // It's necessary to keep the order of the operations below.
96         // In some cases that input is large digits and the difference
97         // is small, if we compute `max_input` plus `tmp_sum` before,
98         // there would be a numerical problem. See an example in
99         // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
100         vec::map(
101             [tmp_sum, max_input](Vec x) {
102               return x - Vec(max_input) - Vec(tmp_sum);
103             },
104             output_data,
105             input_data,
106             dim_size);
107       }
108     }
109   });
110 }
111 
112 template<typename scalar_t>
113 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)114 _vec_softmax_lastdim(
115     const scalar_t* input_data_base,
116     scalar_t* output_data_base,
117     int64_t outer_size,
118     int64_t dim_size) {
119   using Vec = vec::Vectorized<scalar_t>;
120   // See Note: grain_size value of 0
121   parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
122     for (const auto i : c10::irange(begin, end)) {
123       const scalar_t* input_data = input_data_base + i * dim_size;
124       scalar_t* output_data = output_data_base + i * dim_size;
125       scalar_t max_input = vec::reduce_all<scalar_t>(
126           [](Vec& x, Vec& y) { return vec::maximum(x, y); },
127           input_data,
128           dim_size);
129       vec::map(
130           [max_input](Vec x) { return (x - Vec(max_input)).exp(); },
131           output_data,
132           input_data,
133           dim_size);
134       scalar_t tmp_sum = vec::reduce_all<scalar_t>(
135           [](Vec x, Vec y) { return x + y; }, output_data, dim_size);
136       tmp_sum = 1 / tmp_sum;
137       vec::map(
138           [tmp_sum](Vec x) { return x * Vec(tmp_sum); },
139           output_data,
140           output_data,
141           dim_size);
142     }
143   });
144 }
145 
146 template<typename scalar_t>
147 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)148 _vec_softmax_lastdim(
149     const scalar_t* input_data_base,
150     scalar_t* output_data_base,
151     int64_t outer_size,
152     int64_t dim_size) {
153   using Vec = vec::Vectorized<scalar_t>;
154   using fVec = vec::Vectorized<float>;
155   // See Note: grain_size value of 0
156   parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
157     // thread local temp buffer.
158     auto buffer = std::make_unique<float []>(dim_size);
159     float* buffer_data = buffer.get();
160 
161     for (const auto i : c10::irange(begin, end)) {
162       const scalar_t* input_data = input_data_base + i * dim_size;
163       scalar_t* output_data = output_data_base + i * dim_size;
164       // reduce to max and cache float input data
165       fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
166       int64_t d0 = 0;
167       for (; d0 < dim_size - (dim_size % Vec::size()); d0 += Vec::size()) {
168         Vec data_vec = Vec::loadu(input_data + d0);
169         auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
170         max_fvec = vec::maximum(max_fvec, data_fvec0);
171         max_fvec = vec::maximum(max_fvec, data_fvec1);
172         data_fvec0.store(buffer_data + d0);
173         data_fvec1.store(buffer_data + d0 + fVec::size());
174       }
175       float max_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return vec::maximum(x, y); }, max_fvec);
176       for (; d0 < dim_size; d0++) {
177         float data_val = input_data[d0];
178         max_val = std::max(max_val, data_val);
179         buffer_data[d0] = data_val;
180       }
181 
182       // map (x - max).exp() and reduce to sum
183       fVec sum_fvec = fVec(float(0));
184       int64_t d1 = 0;
185       for (; d1 < dim_size - (dim_size % fVec::size()); d1 += fVec::size()) {
186         fVec data_fvec = (fVec::loadu(buffer_data + d1) - fVec(max_val)).exp();
187         sum_fvec += data_fvec;
188         data_fvec.store(buffer_data + d1);
189       }
190       float sum_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec);
191       for (; d1 < dim_size; d1++) {
192         float data_val = std::exp(buffer_data[d1] - max_val);
193         sum_val += data_val;
194         buffer_data[d1] = data_val;
195       }
196 
197       sum_val = 1 / sum_val;
198       int64_t d2 = 0;
199       for (; d2 < dim_size - (dim_size % Vec::size()); d2 += Vec::size()) {
200         fVec out_fvec0 = fVec::loadu(buffer_data + d2) * fVec(sum_val);
201         fVec out_fvec1 = fVec::loadu(buffer_data + d2 + fVec::size()) * fVec(sum_val);
202         Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
203         out_vec.store(output_data + d2);
204       }
205       for (; d2 < dim_size; d2++) {
206         output_data[d2] = scalar_t(buffer_data[d2] * sum_val);
207       }
208     }
209   });
210 }
211 
212 template <typename scalar_t, bool log_softmax>
_vec_host_softmax_backward_lastdim(scalar_t * grad_input_data_base,const scalar_t * grad_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)213 inline void _vec_host_softmax_backward_lastdim(
214     scalar_t* grad_input_data_base,
215     const scalar_t* grad_data_base,
216     const scalar_t* output_data_base,
217     int64_t outer_size,
218     int64_t dim_size) {
219   using Vec = vec::Vectorized<at::opmath_type<scalar_t>>;
220   // See Note: grain_size value of 0
221   parallel_for(
222       0,
223       outer_size,
224       0,
225       [&](int64_t begin, int64_t end) {
226         for (const auto i : c10::irange(begin, end)) {
227           scalar_t* grad_input_data = grad_input_data_base + i * dim_size;
228           const scalar_t* grad_data = grad_data_base + i * dim_size;
229           const scalar_t* output_data = output_data_base + i * dim_size;
230           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
231           scalar_t sum;
232           if (log_softmax) {
233             sum = vec::reduce_all<scalar_t>(
234                 [](Vec& x, Vec& y) { return x + y; }, grad_data, dim_size);
235           } else {
236             sum = vec::map2_reduce_all<scalar_t>(
237                 [](Vec x, Vec y) { return x * y; },
238                 [](Vec x, Vec y) { return x + y; },
239                 grad_data,
240                 output_data,
241                 dim_size);
242           }
243           if (log_softmax) {
244             vec::map2(
245                 [sum](Vec x, Vec y) { return x - ((y.exp()) * Vec(sum)); },
246                 grad_input_data,
247                 grad_data,
248                 output_data,
249                 dim_size);
250           } else {
251             vec::map2(
252                 [sum](Vec x, Vec y) { return (x - Vec(sum)) * y; },
253                 grad_input_data,
254                 grad_data,
255                 output_data,
256                 dim_size);
257           }
258         }
259       });
260 }
261 
262 template<typename scalar_t>
263 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)264 _vec_softmax_backward(
265     scalar_t* grad_input_data_base,
266     const scalar_t* grad_output_data_base,
267     const scalar_t* output_data_base,
268     int64_t outer_size,
269     int64_t inner_size,
270     int64_t dim_size) {
271   using Vec = vec::Vectorized<scalar_t>;
272   int64_t outer_stride = dim_size * inner_size;
273   int64_t BLOCK_SIZE = 128 * 1024;
274   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
275       BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
276   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
277   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
278   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
279   // See Note: grain_size value of 0
280   parallel_for(
281       0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
282         // thread local temp buffer that holds vertical sum result
283         auto buffer = std::make_unique<scalar_t[]>(CHUNK_SIZE);
284         scalar_t* tmp_sum_data = buffer.get();
285 
286         for (int64_t i = begin; i < end; i++) {
287           int64_t outer_idx = i / num_chunks;
288           int64_t k = i % num_chunks;
289           int64_t inner_idx_begin = k * CHUNK_SIZE;
290           int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
291 
292           // init
293           Vec zero_vec = Vec(scalar_t(0));
294           int64_t d0 = 0;
295           for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
296             zero_vec.store(tmp_sum_data + d0);
297           }
298           for (; d0 < size; d0++) {
299             tmp_sum_data[d0] = scalar_t(0);
300           }
301 
302           // compute sum of grad_output * output
303           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
304             int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
305                 inner_idx_begin;
306             const scalar_t* grad_output_ptr = grad_output_data_base + offset;
307             const scalar_t* output_ptr = output_data_base + offset;
308 
309             int64_t d1 = 0;
310             for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
311               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
312               Vec output_vec = Vec::loadu(output_ptr + d1);
313               Vec sum_vec = Vec::loadu(tmp_sum_data + d1);
314               sum_vec += grad_output_vec * output_vec;
315               sum_vec.store(tmp_sum_data + d1);
316             }
317             for (; d1 < size; d1++) {
318               tmp_sum_data[d1] += grad_output_ptr[d1] * output_ptr[d1];
319             }
320           }
321 
322           // compute output * (grad_output - sum)
323           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
324             int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
325                 inner_idx_begin;
326             const scalar_t* grad_output_ptr = grad_output_data_base + offset;
327             const scalar_t* output_ptr = output_data_base + offset;
328             scalar_t* grad_input_ptr = grad_input_data_base + offset;
329 
330             int64_t d2 = 0;
331             for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
332               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2);
333               Vec output_vec = Vec::loadu(output_ptr + d2);
334               Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
335               Vec grad_input_vec = output_vec * (grad_output_vec - sum_vec);
336               grad_input_vec.store(grad_input_ptr + d2);
337             }
338             for (; d2 < size; d2++) {
339               grad_input_ptr[d2] = output_ptr[d2] * (grad_output_ptr[d2] - tmp_sum_data[d2]);
340             }
341           }
342         }
343       });
344 }
345 
346 template<typename scalar_t>
347 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)348 _vec_softmax_backward(
349     scalar_t* grad_input_data_base,
350     const scalar_t* grad_output_data_base,
351     const scalar_t* output_data_base,
352     int64_t outer_size,
353     int64_t inner_size,
354     int64_t dim_size) {
355   using Vec = vec::Vectorized<scalar_t>;
356   using fVec = vec::Vectorized<float>;
357   int64_t outer_stride = dim_size * inner_size;
358   int64_t BLOCK_SIZE = 128 * 1024;
359   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
360       BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
361   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
362   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
363   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
364   // See Note: grain_size value of 0
365   parallel_for(
366       0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
367         // thread local temp buffer that holds vertical sum result
368         auto buffer = std::make_unique<float[]>(CHUNK_SIZE);
369         float* tmp_sum_data = buffer.get();
370 
371         // thread local buffer that holds grad_output and output data in float32
372         auto grad_output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
373         float* grad_output_buffer_data = grad_output_buffer.get();
374 
375         auto output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
376         float* output_buffer_data = output_buffer.get();
377 
378         for (int64_t i = begin; i < end; i++) {
379           int64_t outer_idx = i / num_chunks;
380           int64_t k = i % num_chunks;
381           int64_t inner_idx_begin = k * CHUNK_SIZE;
382           int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
383 
384           // init
385           fVec zero_fvec = fVec(float(0));
386           int64_t d0 = 0;
387           for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
388             zero_fvec.store(tmp_sum_data + d0);
389             zero_fvec.store(tmp_sum_data + d0 + fVec::size());
390           }
391           for (; d0 < size; d0++) {
392             tmp_sum_data[d0] = float(0);
393           }
394 
395           // compute sum of grad_output * output
396           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
397             int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
398                 inner_idx_begin;
399             const scalar_t* grad_output_ptr = grad_output_data_base + offset;
400             const scalar_t* output_ptr = output_data_base + offset;
401             float* grad_output_buffer_ptr =
402                 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
403             float* output_buffer_ptr =
404                 output_buffer_data + dim_idx * CHUNK_SIZE;
405 
406             int64_t d1 = 0;
407             for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
408               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
409               auto [grad_output_fvec0, grad_output_fvec1] =
410                   vec::convert_to_float<scalar_t>(grad_output_vec);
411               Vec output_vec = Vec::loadu(output_ptr + d1);
412               auto [output_fvec0, output_fvec1] =
413                   vec::convert_to_float<scalar_t>(output_vec);
414               fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1);
415               fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size());
416               sum_fvec0 += grad_output_fvec0 * output_fvec0;
417               sum_fvec1 += grad_output_fvec1 * output_fvec1;
418               sum_fvec0.store(tmp_sum_data + d1);
419               sum_fvec1.store(tmp_sum_data + d1 + fVec::size());
420 
421               // cache the 'converted' float grad_output and output
422               grad_output_fvec0.store(grad_output_buffer_ptr + d1);
423               grad_output_fvec1.store(
424                   grad_output_buffer_ptr + d1 + fVec::size());
425               output_fvec0.store(output_buffer_ptr + d1);
426               output_fvec1.store(output_buffer_ptr + d1 + fVec::size());
427             }
428             for (; d1 < size; d1++) {
429               float grad_output_val = float(grad_output_ptr[d1]);
430               float output_val = float(output_ptr[d1]);
431               tmp_sum_data[d1] += grad_output_val * output_val;
432               grad_output_buffer_ptr[d1] = grad_output_val;
433               output_buffer_ptr[d1] = output_val;
434             }
435           }
436 
437           // compute output * (grad_output - sum)
438           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
439             scalar_t* grad_input_ptr = grad_input_data_base +
440                 outer_idx * outer_stride + dim_idx * inner_size +
441                 inner_idx_begin;
442             float* grad_output_buffer_ptr =
443                 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
444             float* output_buffer_ptr =
445                 output_buffer_data + dim_idx * CHUNK_SIZE;
446 
447             int64_t d2 = 0;
448             for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
449               fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
450               fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
451               fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2);
452               fVec grad_output_fvec1 =
453                   fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size());
454               fVec output_fvec0 = fVec::loadu(output_buffer_ptr + d2);
455               fVec output_fvec1 =
456                   fVec::loadu(output_buffer_ptr + d2 + fVec::size());
457               fVec grad_input_fvec0 =
458                   output_fvec0 * (grad_output_fvec0 - sum_fvec0);
459               fVec grad_input_fvec1 =
460                   output_fvec1 * (grad_output_fvec1 - sum_fvec1);
461               Vec grad_input_vec =
462                   vec::convert_from_float<scalar_t>(grad_input_fvec0, grad_input_fvec1);
463               grad_input_vec.store(grad_input_ptr + d2);
464             }
465             for (; d2 < size; d2++) {
466               grad_input_ptr[d2] = output_buffer_ptr[d2] * (grad_output_buffer_ptr[d2] - tmp_sum_data[d2]);
467             }
468           }
469         }
470       });
471 }
472 
473 template<typename scalar_t>
474 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_log_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)475 _vec_log_softmax_backward(
476     scalar_t* grad_input_data_base,
477     const scalar_t* grad_output_data_base,
478     const scalar_t* output_data_base,
479     int64_t outer_size,
480     int64_t inner_size,
481     int64_t dim_size) {
482   using Vec = vec::Vectorized<scalar_t>;
483   int64_t outer_stride = dim_size * inner_size;
484   int64_t BLOCK_SIZE = 128 * 1024;
485   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
486       BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
487   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
488   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
489   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
490   // See Note: grain_size value of 0
491   parallel_for(
492       0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
493         // thread local temp buffer that holds vertical sum result
494         auto buffer = std::make_unique<scalar_t[]>(CHUNK_SIZE);
495         scalar_t* tmp_sum_data = buffer.get();
496 
497         for (int64_t i = begin; i < end; i++) {
498           int64_t outer_idx = i / num_chunks;
499           int64_t k = i % num_chunks;
500           int64_t inner_idx_begin = k * CHUNK_SIZE;
501           int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
502 
503           // init
504           Vec zero_vec = Vec(scalar_t(0));
505           int64_t d0 = 0;
506           for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
507             zero_vec.store(tmp_sum_data + d0);
508           }
509           for (; d0 < size; d0++) {
510             tmp_sum_data[d0] = scalar_t(0);
511           }
512 
513           // compute sum of grad_output
514           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
515             const scalar_t* grad_output_ptr = grad_output_data_base +
516                 outer_idx * outer_stride + dim_idx * inner_size +
517                 inner_idx_begin;
518 
519             int64_t d1 = 0;
520             for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
521               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
522               Vec sum_vec = Vec::loadu(tmp_sum_data + d1);
523               sum_vec += grad_output_vec;
524               sum_vec.store(tmp_sum_data + d1);
525             }
526             for (; d1 < size; d1++) {
527               tmp_sum_data[d1] += grad_output_ptr[d1];
528             }
529           }
530 
531           // compute grad_output - output.exp() * sum
532           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
533             int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
534                 inner_idx_begin;
535             const scalar_t* grad_output_ptr = grad_output_data_base + offset;
536             const scalar_t* output_ptr = output_data_base + offset;
537             scalar_t* grad_input_ptr = grad_input_data_base + offset;
538 
539             int64_t d2 = 0;
540             for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
541               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2);
542               Vec output_vec = Vec::loadu(output_ptr + d2);
543               Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
544               Vec grad_input_vec = grad_output_vec - output_vec.exp() * sum_vec;
545               grad_input_vec.store(grad_input_ptr + d2);
546             }
547             for (; d2 < size; d2++) {
548               grad_input_ptr[d2] = grad_output_ptr[d2] -
549                   std::exp(output_ptr[d2]) * tmp_sum_data[d2];
550             }
551           }
552         }
553       });
554 }
555 
556 template<typename scalar_t>
557 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_log_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)558 _vec_log_softmax_backward(
559     scalar_t* grad_input_data_base,
560     const scalar_t* grad_output_data_base,
561     const scalar_t* output_data_base,
562     int64_t outer_size,
563     int64_t inner_size,
564     int64_t dim_size) {
565   using Vec = vec::Vectorized<scalar_t>;
566   using fVec = vec::Vectorized<float>;
567   int64_t outer_stride = dim_size * inner_size;
568   int64_t BLOCK_SIZE = 128 * 1024;
569   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
570       BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
571   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
572   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
573   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
574   // See Note: grain_size value of 0
575   parallel_for(
576       0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
577         // thread local temp buffer that holds vertical sum result
578         auto buffer = std::make_unique<float[]>(CHUNK_SIZE);
579         float* tmp_sum_data = buffer.get();
580 
581         // thread local buffer that holds grad_output data in float32
582         auto grad_output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
583         float* grad_output_buffer_data = grad_output_buffer.get();
584 
585         for (int64_t i = begin; i < end; i++) {
586           int64_t outer_idx = i / num_chunks;
587           int64_t k = i % num_chunks;
588           int64_t inner_idx_begin = k * CHUNK_SIZE;
589           int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
590 
591           // init
592           fVec zero_fvec = fVec(float(0));
593           int64_t d0 = 0;
594           for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
595             zero_fvec.store(tmp_sum_data + d0);
596             zero_fvec.store(tmp_sum_data + d0 + fVec::size());
597           }
598           for (; d0 < size; d0++) {
599             tmp_sum_data[d0] = float(0);
600           }
601 
602           // compute sum of grad_output
603           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
604             const scalar_t* grad_output_ptr = grad_output_data_base +
605                 outer_idx * outer_stride + dim_idx * inner_size +
606                 inner_idx_begin;
607             float* grad_output_buffer_ptr =
608                 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
609 
610             int64_t d1 = 0;
611             for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
612               Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
613               auto [grad_output_fvec0, grad_output_fvec1] =
614                   vec::convert_to_float<scalar_t>(grad_output_vec);
615               fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1);
616               fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size());
617               sum_fvec0 += grad_output_fvec0;
618               sum_fvec1 += grad_output_fvec1;
619               sum_fvec0.store(tmp_sum_data + d1);
620               sum_fvec1.store(tmp_sum_data + d1 + fVec::size());
621 
622               // cache the 'converted' float grad_output
623               grad_output_fvec0.store(grad_output_buffer_ptr + d1);
624               grad_output_fvec1.store(
625                   grad_output_buffer_ptr + d1 + fVec::size());
626             }
627             for (; d1 < size; d1++) {
628               float grad_output_val = float(grad_output_ptr[d1]);
629               tmp_sum_data[d1] += grad_output_val;
630               grad_output_buffer_ptr[d1] = grad_output_val;
631             }
632           }
633 
634           // compute grad_output - output.exp() * sum
635           for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
636             int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
637                 inner_idx_begin;
638             const scalar_t* output_ptr = output_data_base + offset;
639             scalar_t* grad_input_ptr = grad_input_data_base + offset;
640             float* grad_output_buffer_ptr =
641                 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
642 
643             int64_t d2 = 0;
644             for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
645               Vec output_vec = Vec::loadu(output_ptr + d2);
646               auto [output_fvec0, output_fvec1] =
647                   vec::convert_to_float<scalar_t>(output_vec);
648               fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
649               fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
650               fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2);
651               fVec grad_output_fvec1 =
652                   fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size());
653               fVec grad_input_fvec0 =
654                   grad_output_fvec0 - output_fvec0.exp() * sum_fvec0;
655               fVec grad_input_fvec1 =
656                   grad_output_fvec1 - output_fvec1.exp() * sum_fvec1;
657               Vec grad_input_vec =
658                   vec::convert_from_float<scalar_t>(grad_input_fvec0, grad_input_fvec1);
659               grad_input_vec.store(grad_input_ptr + d2);
660             }
661             for (; d2 < size; d2++) {
662               grad_input_ptr[d2] = grad_output_buffer_ptr[d2] -
663                   std::exp(float(output_ptr[d2])) * tmp_sum_data[d2];
664             }
665           }
666         }
667       });
668 }
669 
670 template <typename scalar_t, bool LogSoftMax>
671 struct vec_host_softmax_lastdim {
applyat::native::__anon3dc914250111::vec_host_softmax_lastdim672   static void apply(const Tensor& output, const Tensor& input) {
673     int64_t outer_size = 1;
674     int64_t dim_size = input.size(input.ndimension() - 1);
675     for (int64_t i = 0; i < input.ndimension() - 1; ++i)
676       outer_size *= input.size(i);
677     const scalar_t* input_data_base = input.const_data_ptr<scalar_t>();
678     scalar_t* output_data_base = output.data_ptr<scalar_t>();
679     if (LogSoftMax) {
680       _vec_log_softmax_lastdim(
681           input_data_base, output_data_base, outer_size, dim_size);
682     } else {
683       _vec_softmax_lastdim(
684           input_data_base, output_data_base, outer_size, dim_size);
685     }
686   }
687 };
688 
689 template<typename scalar_t>
690 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)691 _vec_softmax(
692     const scalar_t* input_data_base,
693     scalar_t* output_data_base,
694     int64_t outer_size,
695     int64_t inner_size,
696     int64_t dim_size) {
697   using Vec = vec::Vectorized<float>;
698   using Vec16 = vec::Vectorized<scalar_t>;
699   int64_t dim_stride = inner_size;
700   int64_t outer_stride = dim_size * dim_stride;
701   int vectorized_step = Vec16().size(); // Currently, we only support BFloat16/Half in this special implementation
702   // See Note: grain_size value of 0
703   parallel_for(
704       0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) {
705         int64_t idx = begin;
706         std::unique_ptr<float[]> temp_vec_input(new float[dim_size*vectorized_step]());
707         std::unique_ptr<float[]> temp_vec_output(new float[dim_size*vectorized_step]());
708         float* temp_vec_input_data = temp_vec_input.get();
709         float* temp_vec_output_data = temp_vec_output.get();
710         while (idx < end) {
711           int64_t outer_idx = idx / inner_size;
712           int64_t inner_idx = idx % inner_size;
713           if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
714             // Vectorization
715             const scalar_t* input_data =
716                 input_data_base + outer_idx * outer_stride + inner_idx;
717             scalar_t* output_data =
718                 output_data_base + outer_idx * outer_stride + inner_idx;
719             // Step 1: Get max Score
720             Vec16 max_vec_bf16 = Vec16::loadu(input_data);
721             std::tuple<Vec, Vec> convert_result = vec::convert_to_float<scalar_t>(max_vec_bf16);
722             Vec max_vec_o1 = std::get<0>(convert_result);
723             Vec max_vec_o2 = std::get<1>(convert_result);
724             std::get<0>(convert_result).store(temp_vec_input_data);
725             std::get<1>(convert_result).store(temp_vec_input_data + Vec().size());
726             for (const auto d : c10::irange(1, dim_size)) {
727               Vec16 input_vec_bf16 = Vec16::loadu(input_data + d * dim_stride);
728               convert_result = vec::convert_to_float<scalar_t>(input_vec_bf16);
729               max_vec_o1 = vec::maximum(max_vec_o1, std::get<0>(convert_result));
730               max_vec_o2 = vec::maximum(max_vec_o2, std::get<1>(convert_result));
731               std::get<0>(convert_result).store(temp_vec_input_data + d*vectorized_step);
732               std::get<1>(convert_result).store(temp_vec_input_data + d*vectorized_step + Vec().size());
733             }
734             // Step2: Calculate sum
735             Vec sum_vec_o1 = Vec(0.0);
736             Vec sum_vec_o2 = Vec(0.0);
737             for (const auto d : c10::irange(dim_size)) {
738               Vec output_vec_o1 = Vec::loadu(temp_vec_input_data + d*vectorized_step);
739               Vec output_vec_o2 = Vec::loadu(temp_vec_input_data + d*vectorized_step + Vec().size());
740               output_vec_o1 = (output_vec_o1 - max_vec_o1).exp();
741               output_vec_o2 = (output_vec_o2 - max_vec_o2).exp();
742               output_vec_o1.store(temp_vec_output_data + d*vectorized_step);
743               output_vec_o2.store(temp_vec_output_data + d*vectorized_step + Vec().size());
744 
745               sum_vec_o1 = sum_vec_o1 + output_vec_o1;
746               sum_vec_o2 = sum_vec_o2 + output_vec_o2;
747             }
748             // Step3: Unify
749             for (const auto d : c10::irange(dim_size)) {
750               Vec output_vec_o1 = Vec::loadu(temp_vec_output_data + d*vectorized_step);
751               Vec output_vec_o2 = Vec::loadu(temp_vec_output_data + d*vectorized_step + Vec().size());
752               output_vec_o1 = output_vec_o1/sum_vec_o1;
753               output_vec_o2 = output_vec_o2/sum_vec_o2;
754               Vec16 output_vec_bf16 = vec::convert_from_float<scalar_t>(output_vec_o1, output_vec_o2);
755               output_vec_bf16.store(output_data + d * dim_stride);
756             }
757             idx += vectorized_step;
758           } else {
759             // Tail case(Scalar): it is exactly same logic as host_softmax
760             // inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
761             // cases which will fall through this part:
762             // Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallelization.
763             // Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallelization.
764             int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
765             for (const auto i : c10::irange(tail_number)) {
766               outer_idx = (idx + i) / inner_size;
767               inner_idx = (idx + i) % inner_size;
768               const scalar_t* input_data =
769                   input_data_base + outer_idx * outer_stride + inner_idx;
770               scalar_t* output_data =
771                   output_data_base + outer_idx * outer_stride + inner_idx;
772               // Step1: Get max score
773               float max_input = float(input_data[0]);
774               for (const auto d : c10::irange(1, dim_size)) {
775                 max_input = std::max(max_input, float(input_data[d * dim_stride]));
776               }
777               // Step2: Calculate the Sum
778               float sum_data = 0.0;
779               float temp_output_data = 0.0;
780               for (const auto d : c10::irange(dim_size)) {
781                 temp_output_data = std::exp(input_data[d * dim_stride] - max_input);
782                 sum_data += temp_output_data;
783                 output_data[d * dim_stride] = scalar_t(temp_output_data);
784               }
785               // Step3: Unify
786               for (const auto d : c10::irange(dim_size)) {
787                 output_data[d * dim_stride] =
788                     scalar_t(float(output_data[d * dim_stride])/sum_data);
789               }
790             }
791             idx += tail_number;
792           }
793         }
794       });
795 }
796 
797 template<typename scalar_t>
798 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)799 _vec_softmax(
800     const scalar_t* input_data_base,
801     scalar_t* output_data_base,
802     int64_t outer_size,
803     int64_t inner_size,
804     int64_t dim_size) {
805   using Vec = vec::Vectorized<scalar_t>;
806   int64_t dim_stride = inner_size;
807   int64_t outer_stride = dim_size * dim_stride;
808   int vectorized_step = Vec().size();
809   // See Note: grain_size value of 0
810   parallel_for(
811       0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) {
812         int64_t idx = begin;
813         while (idx < end) {
814           int64_t outer_idx = idx / inner_size;
815           int64_t inner_idx = idx % inner_size;
816           if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
817             // Vectorization
818             const scalar_t* input_data =
819                 input_data_base + outer_idx * outer_stride + inner_idx;
820             scalar_t* output_data =
821                 output_data_base + outer_idx * outer_stride + inner_idx;
822             // Step 1: Get max Score
823             Vec max_vec = Vec::loadu(input_data);
824             for (const auto d : c10::irange(1, dim_size)) {
825               Vec input_vec = Vec::loadu(input_data + d * dim_stride);
826               max_vec = vec::maximum(max_vec, input_vec);
827             }
828             // Step2: Calculate sum
829             Vec sum_vec = Vec(0.0);
830             for (const auto d : c10::irange(dim_size)) {
831               Vec output_vec =
832                   (Vec::loadu(input_data + d * dim_stride) - max_vec).exp();
833               output_vec.store(output_data + d * dim_stride);
834               sum_vec = sum_vec + output_vec;
835             }
836             // Step3: Unify
837             for (const auto d : c10::irange(dim_size)) {
838               Vec output_vec =
839                   Vec::loadu(output_data + d * dim_stride) / sum_vec;
840               output_vec.store(output_data + d * dim_stride);
841             }
842             idx += vectorized_step;
843           } else {
844             // Tail case(Scalar): it is exactly same logic as host_softmax
845             // inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
846             // cases which will fall through this part:
847             // Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallelization.
848             // Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallelization.
849             int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
850             for (const auto i : c10::irange(tail_number)) {
851               outer_idx = (idx + i) / inner_size;
852               inner_idx = (idx + i) % inner_size;
853               const scalar_t* input_data =
854                   input_data_base + outer_idx * outer_stride + inner_idx;
855               scalar_t* output_data =
856                   output_data_base + outer_idx * outer_stride + inner_idx;
857               // Step1: Get max score
858               scalar_t max_input = input_data[0];
859               for (const auto d : c10::irange(1, dim_size)) {
860                 max_input = std::max(max_input, input_data[d * dim_stride]);
861               }
862               // Step2: Calculate the Sum
863               scalar_t sum_data = 0;
864               for (const auto d : c10::irange(dim_size)) {
865                 output_data[d * dim_stride] =
866                     std::exp(input_data[d * dim_stride] - max_input);
867                 sum_data += output_data[d * dim_stride];
868               }
869               // Step3: Unify
870               for (const auto d : c10::irange(dim_size)) {
871                 output_data[d * dim_stride] =
872                     output_data[d * dim_stride]/sum_data;
873               }
874             }
875             idx += tail_number;
876           }
877         }
878       });
879 }
880 
881 // NB: fast kernel for log_softmax when dim != -1
882 // input shape is normalized to {outer_size, dim_size, inner_size}
883 //
884 // The algorithm requires to load input tensor 3 times, to increase parallelism
885 // and cache hit rate, inner_size is blocked as:
886 //   inner_size: {CHUNK_SIZE, CHUNK_SIZE, ..., Remainder}
887 //
888 // Parallel on {outer_size, num_chunks} and do vertical reduction on each block of
889 // {dim_size, CHUNK_SIZE}, block size (128KB) selected to be L2 hit.
890 //
891 template<typename scalar_t>
892 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_logsoftmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)893 _vec_logsoftmax(
894     const scalar_t* input_data_base,
895     scalar_t* output_data_base,
896     int64_t outer_size,
897     int64_t inner_size,
898     int64_t dim_size) {
899   using Vec = vec::Vectorized<scalar_t>;
900   int64_t BLOCK_SIZE = 128 * 1024;
901   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
902   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
903   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
904   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
905 
906   // See Note: grain_size value of 0
907   at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
908     // thread local temp buffer which holds vertical reduction result: max and sum.
909     auto buffer = std::make_unique<scalar_t []>(CHUNK_SIZE * 2);
910     scalar_t* input_max_data = buffer.get();
911     scalar_t* tmp_sum_data = buffer.get() + CHUNK_SIZE;
912 
913     for (int64_t i = begin; i < end; i++) {
914       int64_t outer_idx = i / num_chunks;
915       int64_t k = i % num_chunks;
916       int64_t inner_idx_begin = k * CHUNK_SIZE;
917       int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
918 
919       // init
920       Vec zero_vec = Vec(scalar_t(0));
921       Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
922       int64_t d0 = 0;
923       for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
924         min_vec.store(input_max_data + d0);
925         zero_vec.store(tmp_sum_data + d0);
926       }
927       for (; d0 < size; d0++) {
928         input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
929         tmp_sum_data[d0] = scalar_t(0);
930       }
931 
932       // compute max
933       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
934         const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
935             + dim_idx * inner_size + inner_idx_begin;
936 
937         int64_t d1 = 0;
938         for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
939           Vec data_vec = Vec::loadu(input_ptr + d1);
940           Vec max_vec = Vec::loadu(input_max_data + d1);
941           max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
942           max_vec.store(input_max_data + d1);
943         }
944         for (; d1 < size; d1++) {
945           scalar_t data_val = input_ptr[d1];
946           scalar_t max_val = input_max_data[d1];
947           input_max_data[d1] = data_val > max_val ? data_val : max_val;
948         }
949       }
950 
951       // compute sum of (x - max).exp()
952       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
953         const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
954             + dim_idx * inner_size + inner_idx_begin;
955 
956         int64_t d2 = 0;
957         for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
958           Vec data_vec = Vec::loadu(input_ptr + d2);
959           Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
960           Vec max_vec = Vec::loadu(input_max_data + d2);
961           sum_vec += (data_vec - max_vec).exp();
962           sum_vec.store(tmp_sum_data + d2);
963         }
964         for (; d2 < size; d2++) {
965           scalar_t data_val = input_ptr[d2];
966           scalar_t max_val = input_max_data[d2];
967           tmp_sum_data[d2] += std::exp(data_val - max_val);
968         }
969       }
970 
971       // apply log
972       vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
973 
974       // compute x - max - sum
975       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
976         int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
977         const scalar_t* input_ptr = input_data_base + offset;
978         scalar_t* output_ptr = output_data_base + offset;
979 
980         int64_t d3 = 0;
981         for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
982           Vec data_vec = Vec::loadu(input_ptr + d3);
983           Vec max_vec = Vec::loadu(input_max_data + d3);
984           Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
985           Vec out_vec = data_vec - max_vec - sum_vec;
986           out_vec.store(output_ptr + d3);
987         }
988         for (; d3 < size; d3++) {
989           output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
990         }
991       }
992     }
993   });
994 }
995 
996 template<typename scalar_t>
997 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_logsoftmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)998 _vec_logsoftmax(
999     const scalar_t* input_data_base,
1000     scalar_t* output_data_base,
1001     int64_t outer_size,
1002     int64_t inner_size,
1003     int64_t dim_size) {
1004   using Vec = vec::Vectorized<scalar_t>;
1005   using fVec = vec::Vectorized<float>;
1006   int64_t BLOCK_SIZE = 128 * 1024;
1007   int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
1008   MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
1009   int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
1010   int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
1011 
1012   // See Note: grain_size value of 0
1013   at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
1014     auto buffer = std::make_unique<float []>(CHUNK_SIZE * 2);
1015     float* input_max_data = buffer.get();
1016     float* tmp_sum_data = buffer.get() + CHUNK_SIZE;
1017 
1018     // thread local buffer that holds input data in float32 to save next 2 dtype conversion
1019     auto input_buffer = std::make_unique<float []>(dim_size * CHUNK_SIZE);
1020     float* input_buffer_data = input_buffer.get();
1021 
1022     // init
1023     for (int64_t i = begin; i < end; i++) {
1024       int64_t outer_idx = i / num_chunks;
1025       int64_t k = i % num_chunks;
1026       int64_t inner_idx_begin = k * CHUNK_SIZE;
1027       int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
1028 
1029       fVec zero_fvec = fVec(float(0));
1030       fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
1031       int64_t d0 = 0;
1032       for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
1033         min_fvec.store(input_max_data + d0);
1034         min_fvec.store(input_max_data + d0 + fVec::size());
1035         zero_fvec.store(tmp_sum_data + d0);
1036         zero_fvec.store(tmp_sum_data + d0 + fVec::size());
1037       }
1038       for (; d0 < size; d0++) {
1039         input_max_data[d0] = -std::numeric_limits<float>::infinity();
1040         tmp_sum_data[d0] = float(0);
1041       }
1042 
1043       // compute max
1044       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1045         const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
1046             + dim_idx * inner_size + inner_idx_begin;
1047         float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1048 
1049         int64_t d1 = 0;
1050         for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
1051           Vec data_vec = Vec::loadu(input_ptr + d1);
1052           auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
1053           fVec max_fvec0 = fVec::loadu(input_max_data + d1);
1054           fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
1055           max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
1056           max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
1057           max_fvec0.store(input_max_data + d1);
1058           max_fvec0.store(input_max_data + d1 + fVec::size());
1059 
1060           // cache the 'converted' float input
1061           data_fvec0.store(input_buffer_ptr + d1);
1062           data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
1063         }
1064         for (; d1 < size; d1++) {
1065           float data_val = float(input_ptr[d1]);
1066           float max_val = input_max_data[d1];
1067           input_max_data[d1] = data_val > max_val ? data_val : max_val;
1068           input_buffer_ptr[d1] = data_val;
1069         }
1070       }
1071 
1072       // compute sum of (x - max).exp()
1073       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1074         float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1075 
1076         int64_t d2 = 0;
1077         for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
1078           fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
1079           fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
1080           fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
1081           fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
1082           fVec max_fvec0 = fVec::loadu(input_max_data + d2);
1083           fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
1084           sum_fvec0 += (data_fvec0 - max_fvec0).exp();
1085           sum_fvec1 += (data_fvec1 - max_fvec1).exp();
1086           sum_fvec0.store(tmp_sum_data + d2);
1087           sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
1088         }
1089         for (; d2 < size; d2++) {
1090           float data_val = input_buffer_ptr[d2];
1091           float max_val = input_max_data[d2];
1092           tmp_sum_data[d2] += std::exp(data_val - max_val);
1093         }
1094       }
1095 
1096       // apply log
1097       vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
1098 
1099       // compute x - max - sum
1100       for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1101         float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1102         scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
1103             + dim_idx * inner_size + inner_idx_begin;
1104 
1105         int64_t d3 = 0;
1106         for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
1107           fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
1108           fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
1109           fVec max_fvec0 = fVec::loadu(input_max_data + d3);
1110           fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
1111           fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
1112           fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
1113           fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
1114           fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
1115           Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
1116           out_vec.store(output_ptr + d3);
1117         }
1118         for (; d3 < size; d3++) {
1119           output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
1120         }
1121       }
1122     }
1123   });
1124 }
1125 
1126 template <typename scalar_t, bool LogSoftMax>
1127 struct vec_softmax {
applyat::native::__anon3dc914250111::vec_softmax1128   static void apply(const Tensor& output, const Tensor& input, int64_t dim) {
1129     int64_t outer_size = 1;
1130     int64_t dim_size = input.size(dim);
1131     int64_t inner_size = 1;
1132     for (const auto i : c10::irange(dim))outer_size *= input.size(i);
1133     for (int64_t i = dim + 1; i < input.dim(); ++i)
1134       inner_size *= input.size(i);
1135     const scalar_t* input_data_base = input.const_data_ptr<scalar_t>();
1136     scalar_t* output_data_base = output.data_ptr<scalar_t>();
1137     if (LogSoftMax) {
1138       _vec_logsoftmax(
1139           input_data_base, output_data_base, outer_size, inner_size, dim_size);
1140     } else {
1141       _vec_softmax(
1142           input_data_base, output_data_base, outer_size, inner_size, dim_size);
1143     }
1144   }
1145 };
1146 
1147 template <typename scalar_t, bool LogSoftMax>
1148 struct vec_host_softmax_backward_lastdim {
1149   static void
applyat::native::__anon3dc914250111::vec_host_softmax_backward_lastdim1150   apply(const Tensor& grad_input, const Tensor& grad, const Tensor& output) {
1151     int64_t outer_size = 1;
1152     int64_t dim_size = grad.size(grad.ndimension() - 1);
1153     for (int64_t i = 0; i < grad.ndimension() - 1; ++i)
1154       outer_size *= grad.size(i);
1155     scalar_t* grad_input_data_base = grad_input.mutable_data_ptr<scalar_t>();
1156     const scalar_t* grad_data_base = grad.const_data_ptr<scalar_t>();
1157     const scalar_t* output_data_base = output.const_data_ptr<scalar_t>();
1158     _vec_host_softmax_backward_lastdim<scalar_t, LogSoftMax>(
1159         grad_input_data_base,
1160         grad_data_base,
1161         output_data_base,
1162         outer_size,
1163         dim_size);
1164   }
1165 };
1166 
1167 template <typename scalar_t, bool LogSoftMax>
1168 struct vec_host_softmax_backward {
applyat::native::__anon3dc914250111::vec_host_softmax_backward1169   static void apply(
1170       const Tensor& grad_input,
1171       const Tensor& grad,
1172       const Tensor& output,
1173       int64_t dim) {
1174     int64_t outer_size = 1;
1175     int64_t dim_size = grad.size(dim);
1176     int64_t inner_size = 1;
1177     for (const auto i : c10::irange(dim)) {
1178       outer_size *= grad.size(i);
1179     }
1180     for (int64_t i = dim + 1; i < grad.dim(); ++i) {
1181       inner_size *= grad.size(i);
1182     }
1183     scalar_t* grad_input_data_base = grad_input.mutable_data_ptr<scalar_t>();
1184     const scalar_t* grad_output_data_base = grad.const_data_ptr<scalar_t>();
1185     const scalar_t* output_data_base = output.const_data_ptr<scalar_t>();
1186     if (LogSoftMax) {
1187       _vec_log_softmax_backward<scalar_t>(
1188           grad_input_data_base,
1189           grad_output_data_base,
1190           output_data_base,
1191           outer_size,
1192           inner_size,
1193           dim_size);
1194     } else {
1195       _vec_softmax_backward<scalar_t>(
1196           grad_input_data_base,
1197           grad_output_data_base,
1198           output_data_base,
1199           outer_size,
1200           inner_size,
1201           dim_size);
1202     }
1203   }
1204 };
1205 
softmax_lastdim_kernel_impl(const Tensor & result,const Tensor & self)1206 static void softmax_lastdim_kernel_impl(
1207     const Tensor& result,
1208     const Tensor& self) {
1209   AT_DISPATCH_FLOATING_TYPES_AND2(
1210       at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1211       "softmax_lastdim_kernel_impl",
1212       [&] { vec_host_softmax_lastdim<scalar_t, false>::apply(result, self); });
1213 }
1214 
softmax_kernel_impl(const Tensor & result,const Tensor & self,int64_t dim)1215 static void softmax_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim) {
1216   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1217     "softmax_kernel_impl",
1218     [&] { vec_softmax<scalar_t, false>::apply(result, self, dim); });
1219 }
1220 
log_softmax_lastdim_kernel_impl(const Tensor & result,const Tensor & self)1221 static void log_softmax_lastdim_kernel_impl(
1222     const Tensor& result,
1223     const Tensor& self) {
1224   AT_DISPATCH_FLOATING_TYPES_AND2(
1225       at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1226       "log_softmax_lastdim_kernel_impl",
1227       [&] { vec_host_softmax_lastdim<scalar_t, true>::apply(result, self); });
1228 }
1229 
log_softmax_kernel_impl(const Tensor & result,const Tensor & self,int64_t dim)1230 static void log_softmax_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim) {
1231   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1232     "softmax_kernel_impl",
1233     [&] { vec_softmax<scalar_t, true>::apply(result, self, dim); });
1234 }
1235 
softmax_backward_lastdim_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output)1236 static void softmax_backward_lastdim_kernel_impl(
1237     const Tensor& grad_input,
1238     const Tensor& grad,
1239     const Tensor& output) {
1240   AT_DISPATCH_FLOATING_TYPES_AND2(
1241       at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(),
1242       "softmax_backward_lastdim_kernel_impl", [&] {
1243         vec_host_softmax_backward_lastdim<scalar_t, false>::apply(
1244             grad_input, grad, output);
1245       });
1246 }
1247 
log_softmax_backward_lastdim_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output)1248 static void log_softmax_backward_lastdim_kernel_impl(
1249     const Tensor& grad_input,
1250     const Tensor& grad,
1251     const Tensor& output) {
1252   AT_DISPATCH_FLOATING_TYPES_AND2(
1253       at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(),
1254       "log_softmax_backward_lastdim_kernel_impl", [&] {
1255         vec_host_softmax_backward_lastdim<scalar_t, true>::apply(
1256             grad_input, grad, output);
1257       });
1258 }
1259 
softmax_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output,int64_t dim)1260 static void softmax_backward_kernel_impl(
1261     const Tensor& grad_input,
1262     const Tensor& grad,
1263     const Tensor& output,
1264     int64_t dim) {
1265   AT_DISPATCH_FLOATING_TYPES_AND2(
1266       at::ScalarType::BFloat16,
1267       at::ScalarType::Half,
1268       grad.scalar_type(),
1269       "softmax_backward_kernel_impl",
1270       [&] {
1271         vec_host_softmax_backward<scalar_t, false>::apply(
1272             grad_input, grad, output, dim);
1273       });
1274 }
1275 
log_softmax_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output,int64_t dim)1276 static void log_softmax_backward_kernel_impl(
1277     const Tensor& grad_input,
1278     const Tensor& grad,
1279     const Tensor& output,
1280     int64_t dim) {
1281   AT_DISPATCH_FLOATING_TYPES_AND2(
1282       at::ScalarType::BFloat16,
1283       at::ScalarType::Half,
1284       grad.scalar_type(),
1285       "log_softmax_backward_kernel_impl",
1286       [&] {
1287         vec_host_softmax_backward<scalar_t, true>::apply(
1288             grad_input, grad, output, dim);
1289       });
1290 }
1291 
1292 } // anonymous namespace
1293 
1294 ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl);
1295 ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl);
1296 ALSO_REGISTER_AVX512_DISPATCH(
1297     softmax_backward_lastdim_kernel,
1298     &softmax_backward_lastdim_kernel_impl);
1299 ALSO_REGISTER_AVX512_DISPATCH(
1300     log_softmax_backward_lastdim_kernel,
1301     &log_softmax_backward_lastdim_kernel_impl);
1302 
1303 ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl);
1304 ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl);
1305 ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl);
1306 ALSO_REGISTER_AVX512_DISPATCH(
1307     log_softmax_backward_kernel,
1308     &log_softmax_backward_kernel_impl);
1309 } // namespace at::native
1310