1 #include <memory>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/native/cpu/SoftmaxKernel.h>
4
5 #include <algorithm>
6 #include <iterator>
7 #include <numeric>
8
9 #include <ATen/Dispatch.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/OpMathType.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/cpu/vec/functional.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17 #include <ATen/OpMathType.h>
18
19 // [Note AVX-SSE transitions] In general we avoid calls into cmath for code
20 // compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
21 // Glibc2.23 See https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280
22 //
23 // On grainsize: The grainsize is chosen to roughly get GRAIN_SIZE number of
24 // computations per task. Each task works across dim_size elements. 16 should be
25 // a very rough approximation of the number of computations per dim_size element
26 // by counting simple computations (*, +, -) as 1 and exp or log as 4.
27 //
28 // We use a chunk size such that it'd fit in L1D.
29
30 namespace at::native {
31
32 namespace {
33 template <typename scalar_t>
_vec_log_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)34 inline void _vec_log_softmax_lastdim(
35 const scalar_t* input_data_base,
36 scalar_t* output_data_base,
37 int64_t outer_size,
38 int64_t dim_size) {
39 using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
40 // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
41 // size of L1D cache on many processors. Some processors have 48 KB L1D cache
42 // nowadays, so maybe in the future, we can leverage the knowledge of a
43 // machine's L1D cache size.
44 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
45 1,
46 at::internal::GRAIN_SIZE / (sizeof(scalar_t) * dim_size));
47 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
48 // Note: grain_size value of 0
49 // We don't change the number of OpenMP threads in the OpenMP thread-pool,
50 // so some threads do useful work, while others don't.
51 // We can simply use grain_size of 0 & rely upon invoke_parallel to distribute
52 // work among threads in an equitable manner. We compute CHUNK_SIZE to ensure
53 // each thread's computations would be efficient.
54 parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
55 // MSVC requires such a declaration of dynamic arrays
56 // Source: https://stackoverflow.com/a/33423538
57 auto tmp_sum_scalar = std::make_unique<scalar_t[]>(CHUNK_SIZE);
58 auto max_input_arr = std::make_unique<scalar_t[]>(CHUNK_SIZE);
59 for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
60 int64_t loop_end = CHUNK_SIZE;
61 if (ii + CHUNK_SIZE > end)
62 loop_end = end - ii;
63 for (const auto j : c10::irange(loop_end)) {
64 int64_t i = ii + j;
65 const scalar_t* input_data = input_data_base + i * dim_size;
66 max_input_arr[j] = vec::reduce_all<scalar_t>(
67 [](Vec& x, Vec& y) { return vec::maximum(x, y); },
68 input_data,
69 dim_size);
70 }
71 for (const auto j : c10::irange(loop_end)) {
72 int64_t i = ii + j;
73 const scalar_t* input_data = input_data_base + i * dim_size;
74 scalar_t max_input = max_input_arr[j];
75 tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
76 [max_input](Vec x) { return (x - Vec(max_input)).exp(); },
77 [](Vec x, Vec y) { return x + y; },
78 input_data,
79 dim_size);
80 }
81 // See [Note AVX-SSE transitions] for why this should call the
82 // vectorized version (aside from perf improvements).
83 vec::map(
84 [](Vec x) { return x.log(); },
85 tmp_sum_scalar.get(),
86 tmp_sum_scalar.get(),
87 loop_end);
88 for (const auto j : c10::irange(loop_end)) {
89 int64_t i = ii + j;
90 const scalar_t* input_data = input_data_base + i * dim_size;
91 scalar_t* output_data = output_data_base + i * dim_size;
92 scalar_t tmp_sum = tmp_sum_scalar[j];
93 scalar_t max_input = max_input_arr[j];
94
95 // It's necessary to keep the order of the operations below.
96 // In some cases that input is large digits and the difference
97 // is small, if we compute `max_input` plus `tmp_sum` before,
98 // there would be a numerical problem. See an example in
99 // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
100 vec::map(
101 [tmp_sum, max_input](Vec x) {
102 return x - Vec(max_input) - Vec(tmp_sum);
103 },
104 output_data,
105 input_data,
106 dim_size);
107 }
108 }
109 });
110 }
111
112 template<typename scalar_t>
113 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)114 _vec_softmax_lastdim(
115 const scalar_t* input_data_base,
116 scalar_t* output_data_base,
117 int64_t outer_size,
118 int64_t dim_size) {
119 using Vec = vec::Vectorized<scalar_t>;
120 // See Note: grain_size value of 0
121 parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
122 for (const auto i : c10::irange(begin, end)) {
123 const scalar_t* input_data = input_data_base + i * dim_size;
124 scalar_t* output_data = output_data_base + i * dim_size;
125 scalar_t max_input = vec::reduce_all<scalar_t>(
126 [](Vec& x, Vec& y) { return vec::maximum(x, y); },
127 input_data,
128 dim_size);
129 vec::map(
130 [max_input](Vec x) { return (x - Vec(max_input)).exp(); },
131 output_data,
132 input_data,
133 dim_size);
134 scalar_t tmp_sum = vec::reduce_all<scalar_t>(
135 [](Vec x, Vec y) { return x + y; }, output_data, dim_size);
136 tmp_sum = 1 / tmp_sum;
137 vec::map(
138 [tmp_sum](Vec x) { return x * Vec(tmp_sum); },
139 output_data,
140 output_data,
141 dim_size);
142 }
143 });
144 }
145
146 template<typename scalar_t>
147 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_lastdim(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)148 _vec_softmax_lastdim(
149 const scalar_t* input_data_base,
150 scalar_t* output_data_base,
151 int64_t outer_size,
152 int64_t dim_size) {
153 using Vec = vec::Vectorized<scalar_t>;
154 using fVec = vec::Vectorized<float>;
155 // See Note: grain_size value of 0
156 parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
157 // thread local temp buffer.
158 auto buffer = std::make_unique<float []>(dim_size);
159 float* buffer_data = buffer.get();
160
161 for (const auto i : c10::irange(begin, end)) {
162 const scalar_t* input_data = input_data_base + i * dim_size;
163 scalar_t* output_data = output_data_base + i * dim_size;
164 // reduce to max and cache float input data
165 fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
166 int64_t d0 = 0;
167 for (; d0 < dim_size - (dim_size % Vec::size()); d0 += Vec::size()) {
168 Vec data_vec = Vec::loadu(input_data + d0);
169 auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
170 max_fvec = vec::maximum(max_fvec, data_fvec0);
171 max_fvec = vec::maximum(max_fvec, data_fvec1);
172 data_fvec0.store(buffer_data + d0);
173 data_fvec1.store(buffer_data + d0 + fVec::size());
174 }
175 float max_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return vec::maximum(x, y); }, max_fvec);
176 for (; d0 < dim_size; d0++) {
177 float data_val = input_data[d0];
178 max_val = std::max(max_val, data_val);
179 buffer_data[d0] = data_val;
180 }
181
182 // map (x - max).exp() and reduce to sum
183 fVec sum_fvec = fVec(float(0));
184 int64_t d1 = 0;
185 for (; d1 < dim_size - (dim_size % fVec::size()); d1 += fVec::size()) {
186 fVec data_fvec = (fVec::loadu(buffer_data + d1) - fVec(max_val)).exp();
187 sum_fvec += data_fvec;
188 data_fvec.store(buffer_data + d1);
189 }
190 float sum_val = vec::vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec);
191 for (; d1 < dim_size; d1++) {
192 float data_val = std::exp(buffer_data[d1] - max_val);
193 sum_val += data_val;
194 buffer_data[d1] = data_val;
195 }
196
197 sum_val = 1 / sum_val;
198 int64_t d2 = 0;
199 for (; d2 < dim_size - (dim_size % Vec::size()); d2 += Vec::size()) {
200 fVec out_fvec0 = fVec::loadu(buffer_data + d2) * fVec(sum_val);
201 fVec out_fvec1 = fVec::loadu(buffer_data + d2 + fVec::size()) * fVec(sum_val);
202 Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
203 out_vec.store(output_data + d2);
204 }
205 for (; d2 < dim_size; d2++) {
206 output_data[d2] = scalar_t(buffer_data[d2] * sum_val);
207 }
208 }
209 });
210 }
211
212 template <typename scalar_t, bool log_softmax>
_vec_host_softmax_backward_lastdim(scalar_t * grad_input_data_base,const scalar_t * grad_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t dim_size)213 inline void _vec_host_softmax_backward_lastdim(
214 scalar_t* grad_input_data_base,
215 const scalar_t* grad_data_base,
216 const scalar_t* output_data_base,
217 int64_t outer_size,
218 int64_t dim_size) {
219 using Vec = vec::Vectorized<at::opmath_type<scalar_t>>;
220 // See Note: grain_size value of 0
221 parallel_for(
222 0,
223 outer_size,
224 0,
225 [&](int64_t begin, int64_t end) {
226 for (const auto i : c10::irange(begin, end)) {
227 scalar_t* grad_input_data = grad_input_data_base + i * dim_size;
228 const scalar_t* grad_data = grad_data_base + i * dim_size;
229 const scalar_t* output_data = output_data_base + i * dim_size;
230 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
231 scalar_t sum;
232 if (log_softmax) {
233 sum = vec::reduce_all<scalar_t>(
234 [](Vec& x, Vec& y) { return x + y; }, grad_data, dim_size);
235 } else {
236 sum = vec::map2_reduce_all<scalar_t>(
237 [](Vec x, Vec y) { return x * y; },
238 [](Vec x, Vec y) { return x + y; },
239 grad_data,
240 output_data,
241 dim_size);
242 }
243 if (log_softmax) {
244 vec::map2(
245 [sum](Vec x, Vec y) { return x - ((y.exp()) * Vec(sum)); },
246 grad_input_data,
247 grad_data,
248 output_data,
249 dim_size);
250 } else {
251 vec::map2(
252 [sum](Vec x, Vec y) { return (x - Vec(sum)) * y; },
253 grad_input_data,
254 grad_data,
255 output_data,
256 dim_size);
257 }
258 }
259 });
260 }
261
262 template<typename scalar_t>
263 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)264 _vec_softmax_backward(
265 scalar_t* grad_input_data_base,
266 const scalar_t* grad_output_data_base,
267 const scalar_t* output_data_base,
268 int64_t outer_size,
269 int64_t inner_size,
270 int64_t dim_size) {
271 using Vec = vec::Vectorized<scalar_t>;
272 int64_t outer_stride = dim_size * inner_size;
273 int64_t BLOCK_SIZE = 128 * 1024;
274 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
275 BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
276 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
277 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
278 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
279 // See Note: grain_size value of 0
280 parallel_for(
281 0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
282 // thread local temp buffer that holds vertical sum result
283 auto buffer = std::make_unique<scalar_t[]>(CHUNK_SIZE);
284 scalar_t* tmp_sum_data = buffer.get();
285
286 for (int64_t i = begin; i < end; i++) {
287 int64_t outer_idx = i / num_chunks;
288 int64_t k = i % num_chunks;
289 int64_t inner_idx_begin = k * CHUNK_SIZE;
290 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
291
292 // init
293 Vec zero_vec = Vec(scalar_t(0));
294 int64_t d0 = 0;
295 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
296 zero_vec.store(tmp_sum_data + d0);
297 }
298 for (; d0 < size; d0++) {
299 tmp_sum_data[d0] = scalar_t(0);
300 }
301
302 // compute sum of grad_output * output
303 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
304 int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
305 inner_idx_begin;
306 const scalar_t* grad_output_ptr = grad_output_data_base + offset;
307 const scalar_t* output_ptr = output_data_base + offset;
308
309 int64_t d1 = 0;
310 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
311 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
312 Vec output_vec = Vec::loadu(output_ptr + d1);
313 Vec sum_vec = Vec::loadu(tmp_sum_data + d1);
314 sum_vec += grad_output_vec * output_vec;
315 sum_vec.store(tmp_sum_data + d1);
316 }
317 for (; d1 < size; d1++) {
318 tmp_sum_data[d1] += grad_output_ptr[d1] * output_ptr[d1];
319 }
320 }
321
322 // compute output * (grad_output - sum)
323 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
324 int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
325 inner_idx_begin;
326 const scalar_t* grad_output_ptr = grad_output_data_base + offset;
327 const scalar_t* output_ptr = output_data_base + offset;
328 scalar_t* grad_input_ptr = grad_input_data_base + offset;
329
330 int64_t d2 = 0;
331 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
332 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2);
333 Vec output_vec = Vec::loadu(output_ptr + d2);
334 Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
335 Vec grad_input_vec = output_vec * (grad_output_vec - sum_vec);
336 grad_input_vec.store(grad_input_ptr + d2);
337 }
338 for (; d2 < size; d2++) {
339 grad_input_ptr[d2] = output_ptr[d2] * (grad_output_ptr[d2] - tmp_sum_data[d2]);
340 }
341 }
342 }
343 });
344 }
345
346 template<typename scalar_t>
347 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)348 _vec_softmax_backward(
349 scalar_t* grad_input_data_base,
350 const scalar_t* grad_output_data_base,
351 const scalar_t* output_data_base,
352 int64_t outer_size,
353 int64_t inner_size,
354 int64_t dim_size) {
355 using Vec = vec::Vectorized<scalar_t>;
356 using fVec = vec::Vectorized<float>;
357 int64_t outer_stride = dim_size * inner_size;
358 int64_t BLOCK_SIZE = 128 * 1024;
359 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
360 BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
361 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
362 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
363 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
364 // See Note: grain_size value of 0
365 parallel_for(
366 0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
367 // thread local temp buffer that holds vertical sum result
368 auto buffer = std::make_unique<float[]>(CHUNK_SIZE);
369 float* tmp_sum_data = buffer.get();
370
371 // thread local buffer that holds grad_output and output data in float32
372 auto grad_output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
373 float* grad_output_buffer_data = grad_output_buffer.get();
374
375 auto output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
376 float* output_buffer_data = output_buffer.get();
377
378 for (int64_t i = begin; i < end; i++) {
379 int64_t outer_idx = i / num_chunks;
380 int64_t k = i % num_chunks;
381 int64_t inner_idx_begin = k * CHUNK_SIZE;
382 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
383
384 // init
385 fVec zero_fvec = fVec(float(0));
386 int64_t d0 = 0;
387 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
388 zero_fvec.store(tmp_sum_data + d0);
389 zero_fvec.store(tmp_sum_data + d0 + fVec::size());
390 }
391 for (; d0 < size; d0++) {
392 tmp_sum_data[d0] = float(0);
393 }
394
395 // compute sum of grad_output * output
396 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
397 int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
398 inner_idx_begin;
399 const scalar_t* grad_output_ptr = grad_output_data_base + offset;
400 const scalar_t* output_ptr = output_data_base + offset;
401 float* grad_output_buffer_ptr =
402 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
403 float* output_buffer_ptr =
404 output_buffer_data + dim_idx * CHUNK_SIZE;
405
406 int64_t d1 = 0;
407 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
408 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
409 auto [grad_output_fvec0, grad_output_fvec1] =
410 vec::convert_to_float<scalar_t>(grad_output_vec);
411 Vec output_vec = Vec::loadu(output_ptr + d1);
412 auto [output_fvec0, output_fvec1] =
413 vec::convert_to_float<scalar_t>(output_vec);
414 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1);
415 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size());
416 sum_fvec0 += grad_output_fvec0 * output_fvec0;
417 sum_fvec1 += grad_output_fvec1 * output_fvec1;
418 sum_fvec0.store(tmp_sum_data + d1);
419 sum_fvec1.store(tmp_sum_data + d1 + fVec::size());
420
421 // cache the 'converted' float grad_output and output
422 grad_output_fvec0.store(grad_output_buffer_ptr + d1);
423 grad_output_fvec1.store(
424 grad_output_buffer_ptr + d1 + fVec::size());
425 output_fvec0.store(output_buffer_ptr + d1);
426 output_fvec1.store(output_buffer_ptr + d1 + fVec::size());
427 }
428 for (; d1 < size; d1++) {
429 float grad_output_val = float(grad_output_ptr[d1]);
430 float output_val = float(output_ptr[d1]);
431 tmp_sum_data[d1] += grad_output_val * output_val;
432 grad_output_buffer_ptr[d1] = grad_output_val;
433 output_buffer_ptr[d1] = output_val;
434 }
435 }
436
437 // compute output * (grad_output - sum)
438 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
439 scalar_t* grad_input_ptr = grad_input_data_base +
440 outer_idx * outer_stride + dim_idx * inner_size +
441 inner_idx_begin;
442 float* grad_output_buffer_ptr =
443 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
444 float* output_buffer_ptr =
445 output_buffer_data + dim_idx * CHUNK_SIZE;
446
447 int64_t d2 = 0;
448 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
449 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
450 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
451 fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2);
452 fVec grad_output_fvec1 =
453 fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size());
454 fVec output_fvec0 = fVec::loadu(output_buffer_ptr + d2);
455 fVec output_fvec1 =
456 fVec::loadu(output_buffer_ptr + d2 + fVec::size());
457 fVec grad_input_fvec0 =
458 output_fvec0 * (grad_output_fvec0 - sum_fvec0);
459 fVec grad_input_fvec1 =
460 output_fvec1 * (grad_output_fvec1 - sum_fvec1);
461 Vec grad_input_vec =
462 vec::convert_from_float<scalar_t>(grad_input_fvec0, grad_input_fvec1);
463 grad_input_vec.store(grad_input_ptr + d2);
464 }
465 for (; d2 < size; d2++) {
466 grad_input_ptr[d2] = output_buffer_ptr[d2] * (grad_output_buffer_ptr[d2] - tmp_sum_data[d2]);
467 }
468 }
469 }
470 });
471 }
472
473 template<typename scalar_t>
474 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_log_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)475 _vec_log_softmax_backward(
476 scalar_t* grad_input_data_base,
477 const scalar_t* grad_output_data_base,
478 const scalar_t* output_data_base,
479 int64_t outer_size,
480 int64_t inner_size,
481 int64_t dim_size) {
482 using Vec = vec::Vectorized<scalar_t>;
483 int64_t outer_stride = dim_size * inner_size;
484 int64_t BLOCK_SIZE = 128 * 1024;
485 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
486 BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
487 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
488 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
489 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
490 // See Note: grain_size value of 0
491 parallel_for(
492 0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
493 // thread local temp buffer that holds vertical sum result
494 auto buffer = std::make_unique<scalar_t[]>(CHUNK_SIZE);
495 scalar_t* tmp_sum_data = buffer.get();
496
497 for (int64_t i = begin; i < end; i++) {
498 int64_t outer_idx = i / num_chunks;
499 int64_t k = i % num_chunks;
500 int64_t inner_idx_begin = k * CHUNK_SIZE;
501 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
502
503 // init
504 Vec zero_vec = Vec(scalar_t(0));
505 int64_t d0 = 0;
506 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
507 zero_vec.store(tmp_sum_data + d0);
508 }
509 for (; d0 < size; d0++) {
510 tmp_sum_data[d0] = scalar_t(0);
511 }
512
513 // compute sum of grad_output
514 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
515 const scalar_t* grad_output_ptr = grad_output_data_base +
516 outer_idx * outer_stride + dim_idx * inner_size +
517 inner_idx_begin;
518
519 int64_t d1 = 0;
520 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
521 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
522 Vec sum_vec = Vec::loadu(tmp_sum_data + d1);
523 sum_vec += grad_output_vec;
524 sum_vec.store(tmp_sum_data + d1);
525 }
526 for (; d1 < size; d1++) {
527 tmp_sum_data[d1] += grad_output_ptr[d1];
528 }
529 }
530
531 // compute grad_output - output.exp() * sum
532 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
533 int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
534 inner_idx_begin;
535 const scalar_t* grad_output_ptr = grad_output_data_base + offset;
536 const scalar_t* output_ptr = output_data_base + offset;
537 scalar_t* grad_input_ptr = grad_input_data_base + offset;
538
539 int64_t d2 = 0;
540 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
541 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2);
542 Vec output_vec = Vec::loadu(output_ptr + d2);
543 Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
544 Vec grad_input_vec = grad_output_vec - output_vec.exp() * sum_vec;
545 grad_input_vec.store(grad_input_ptr + d2);
546 }
547 for (; d2 < size; d2++) {
548 grad_input_ptr[d2] = grad_output_ptr[d2] -
549 std::exp(output_ptr[d2]) * tmp_sum_data[d2];
550 }
551 }
552 }
553 });
554 }
555
556 template<typename scalar_t>
557 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_log_softmax_backward(scalar_t * grad_input_data_base,const scalar_t * grad_output_data_base,const scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)558 _vec_log_softmax_backward(
559 scalar_t* grad_input_data_base,
560 const scalar_t* grad_output_data_base,
561 const scalar_t* output_data_base,
562 int64_t outer_size,
563 int64_t inner_size,
564 int64_t dim_size) {
565 using Vec = vec::Vectorized<scalar_t>;
566 using fVec = vec::Vectorized<float>;
567 int64_t outer_stride = dim_size * inner_size;
568 int64_t BLOCK_SIZE = 128 * 1024;
569 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
570 BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
571 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
572 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
573 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
574 // See Note: grain_size value of 0
575 parallel_for(
576 0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
577 // thread local temp buffer that holds vertical sum result
578 auto buffer = std::make_unique<float[]>(CHUNK_SIZE);
579 float* tmp_sum_data = buffer.get();
580
581 // thread local buffer that holds grad_output data in float32
582 auto grad_output_buffer = std::make_unique<float[]>(dim_size * CHUNK_SIZE);
583 float* grad_output_buffer_data = grad_output_buffer.get();
584
585 for (int64_t i = begin; i < end; i++) {
586 int64_t outer_idx = i / num_chunks;
587 int64_t k = i % num_chunks;
588 int64_t inner_idx_begin = k * CHUNK_SIZE;
589 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
590
591 // init
592 fVec zero_fvec = fVec(float(0));
593 int64_t d0 = 0;
594 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
595 zero_fvec.store(tmp_sum_data + d0);
596 zero_fvec.store(tmp_sum_data + d0 + fVec::size());
597 }
598 for (; d0 < size; d0++) {
599 tmp_sum_data[d0] = float(0);
600 }
601
602 // compute sum of grad_output
603 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
604 const scalar_t* grad_output_ptr = grad_output_data_base +
605 outer_idx * outer_stride + dim_idx * inner_size +
606 inner_idx_begin;
607 float* grad_output_buffer_ptr =
608 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
609
610 int64_t d1 = 0;
611 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
612 Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1);
613 auto [grad_output_fvec0, grad_output_fvec1] =
614 vec::convert_to_float<scalar_t>(grad_output_vec);
615 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1);
616 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size());
617 sum_fvec0 += grad_output_fvec0;
618 sum_fvec1 += grad_output_fvec1;
619 sum_fvec0.store(tmp_sum_data + d1);
620 sum_fvec1.store(tmp_sum_data + d1 + fVec::size());
621
622 // cache the 'converted' float grad_output
623 grad_output_fvec0.store(grad_output_buffer_ptr + d1);
624 grad_output_fvec1.store(
625 grad_output_buffer_ptr + d1 + fVec::size());
626 }
627 for (; d1 < size; d1++) {
628 float grad_output_val = float(grad_output_ptr[d1]);
629 tmp_sum_data[d1] += grad_output_val;
630 grad_output_buffer_ptr[d1] = grad_output_val;
631 }
632 }
633
634 // compute grad_output - output.exp() * sum
635 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
636 int64_t offset = outer_idx * outer_stride + dim_idx * inner_size +
637 inner_idx_begin;
638 const scalar_t* output_ptr = output_data_base + offset;
639 scalar_t* grad_input_ptr = grad_input_data_base + offset;
640 float* grad_output_buffer_ptr =
641 grad_output_buffer_data + dim_idx * CHUNK_SIZE;
642
643 int64_t d2 = 0;
644 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
645 Vec output_vec = Vec::loadu(output_ptr + d2);
646 auto [output_fvec0, output_fvec1] =
647 vec::convert_to_float<scalar_t>(output_vec);
648 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
649 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
650 fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2);
651 fVec grad_output_fvec1 =
652 fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size());
653 fVec grad_input_fvec0 =
654 grad_output_fvec0 - output_fvec0.exp() * sum_fvec0;
655 fVec grad_input_fvec1 =
656 grad_output_fvec1 - output_fvec1.exp() * sum_fvec1;
657 Vec grad_input_vec =
658 vec::convert_from_float<scalar_t>(grad_input_fvec0, grad_input_fvec1);
659 grad_input_vec.store(grad_input_ptr + d2);
660 }
661 for (; d2 < size; d2++) {
662 grad_input_ptr[d2] = grad_output_buffer_ptr[d2] -
663 std::exp(float(output_ptr[d2])) * tmp_sum_data[d2];
664 }
665 }
666 }
667 });
668 }
669
670 template <typename scalar_t, bool LogSoftMax>
671 struct vec_host_softmax_lastdim {
applyat::native::__anon3dc914250111::vec_host_softmax_lastdim672 static void apply(const Tensor& output, const Tensor& input) {
673 int64_t outer_size = 1;
674 int64_t dim_size = input.size(input.ndimension() - 1);
675 for (int64_t i = 0; i < input.ndimension() - 1; ++i)
676 outer_size *= input.size(i);
677 const scalar_t* input_data_base = input.const_data_ptr<scalar_t>();
678 scalar_t* output_data_base = output.data_ptr<scalar_t>();
679 if (LogSoftMax) {
680 _vec_log_softmax_lastdim(
681 input_data_base, output_data_base, outer_size, dim_size);
682 } else {
683 _vec_softmax_lastdim(
684 input_data_base, output_data_base, outer_size, dim_size);
685 }
686 }
687 };
688
689 template<typename scalar_t>
690 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)691 _vec_softmax(
692 const scalar_t* input_data_base,
693 scalar_t* output_data_base,
694 int64_t outer_size,
695 int64_t inner_size,
696 int64_t dim_size) {
697 using Vec = vec::Vectorized<float>;
698 using Vec16 = vec::Vectorized<scalar_t>;
699 int64_t dim_stride = inner_size;
700 int64_t outer_stride = dim_size * dim_stride;
701 int vectorized_step = Vec16().size(); // Currently, we only support BFloat16/Half in this special implementation
702 // See Note: grain_size value of 0
703 parallel_for(
704 0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) {
705 int64_t idx = begin;
706 std::unique_ptr<float[]> temp_vec_input(new float[dim_size*vectorized_step]());
707 std::unique_ptr<float[]> temp_vec_output(new float[dim_size*vectorized_step]());
708 float* temp_vec_input_data = temp_vec_input.get();
709 float* temp_vec_output_data = temp_vec_output.get();
710 while (idx < end) {
711 int64_t outer_idx = idx / inner_size;
712 int64_t inner_idx = idx % inner_size;
713 if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
714 // Vectorization
715 const scalar_t* input_data =
716 input_data_base + outer_idx * outer_stride + inner_idx;
717 scalar_t* output_data =
718 output_data_base + outer_idx * outer_stride + inner_idx;
719 // Step 1: Get max Score
720 Vec16 max_vec_bf16 = Vec16::loadu(input_data);
721 std::tuple<Vec, Vec> convert_result = vec::convert_to_float<scalar_t>(max_vec_bf16);
722 Vec max_vec_o1 = std::get<0>(convert_result);
723 Vec max_vec_o2 = std::get<1>(convert_result);
724 std::get<0>(convert_result).store(temp_vec_input_data);
725 std::get<1>(convert_result).store(temp_vec_input_data + Vec().size());
726 for (const auto d : c10::irange(1, dim_size)) {
727 Vec16 input_vec_bf16 = Vec16::loadu(input_data + d * dim_stride);
728 convert_result = vec::convert_to_float<scalar_t>(input_vec_bf16);
729 max_vec_o1 = vec::maximum(max_vec_o1, std::get<0>(convert_result));
730 max_vec_o2 = vec::maximum(max_vec_o2, std::get<1>(convert_result));
731 std::get<0>(convert_result).store(temp_vec_input_data + d*vectorized_step);
732 std::get<1>(convert_result).store(temp_vec_input_data + d*vectorized_step + Vec().size());
733 }
734 // Step2: Calculate sum
735 Vec sum_vec_o1 = Vec(0.0);
736 Vec sum_vec_o2 = Vec(0.0);
737 for (const auto d : c10::irange(dim_size)) {
738 Vec output_vec_o1 = Vec::loadu(temp_vec_input_data + d*vectorized_step);
739 Vec output_vec_o2 = Vec::loadu(temp_vec_input_data + d*vectorized_step + Vec().size());
740 output_vec_o1 = (output_vec_o1 - max_vec_o1).exp();
741 output_vec_o2 = (output_vec_o2 - max_vec_o2).exp();
742 output_vec_o1.store(temp_vec_output_data + d*vectorized_step);
743 output_vec_o2.store(temp_vec_output_data + d*vectorized_step + Vec().size());
744
745 sum_vec_o1 = sum_vec_o1 + output_vec_o1;
746 sum_vec_o2 = sum_vec_o2 + output_vec_o2;
747 }
748 // Step3: Unify
749 for (const auto d : c10::irange(dim_size)) {
750 Vec output_vec_o1 = Vec::loadu(temp_vec_output_data + d*vectorized_step);
751 Vec output_vec_o2 = Vec::loadu(temp_vec_output_data + d*vectorized_step + Vec().size());
752 output_vec_o1 = output_vec_o1/sum_vec_o1;
753 output_vec_o2 = output_vec_o2/sum_vec_o2;
754 Vec16 output_vec_bf16 = vec::convert_from_float<scalar_t>(output_vec_o1, output_vec_o2);
755 output_vec_bf16.store(output_data + d * dim_stride);
756 }
757 idx += vectorized_step;
758 } else {
759 // Tail case(Scalar): it is exactly same logic as host_softmax
760 // inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
761 // cases which will fall through this part:
762 // Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallelization.
763 // Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallelization.
764 int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
765 for (const auto i : c10::irange(tail_number)) {
766 outer_idx = (idx + i) / inner_size;
767 inner_idx = (idx + i) % inner_size;
768 const scalar_t* input_data =
769 input_data_base + outer_idx * outer_stride + inner_idx;
770 scalar_t* output_data =
771 output_data_base + outer_idx * outer_stride + inner_idx;
772 // Step1: Get max score
773 float max_input = float(input_data[0]);
774 for (const auto d : c10::irange(1, dim_size)) {
775 max_input = std::max(max_input, float(input_data[d * dim_stride]));
776 }
777 // Step2: Calculate the Sum
778 float sum_data = 0.0;
779 float temp_output_data = 0.0;
780 for (const auto d : c10::irange(dim_size)) {
781 temp_output_data = std::exp(input_data[d * dim_stride] - max_input);
782 sum_data += temp_output_data;
783 output_data[d * dim_stride] = scalar_t(temp_output_data);
784 }
785 // Step3: Unify
786 for (const auto d : c10::irange(dim_size)) {
787 output_data[d * dim_stride] =
788 scalar_t(float(output_data[d * dim_stride])/sum_data);
789 }
790 }
791 idx += tail_number;
792 }
793 }
794 });
795 }
796
797 template<typename scalar_t>
798 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_softmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)799 _vec_softmax(
800 const scalar_t* input_data_base,
801 scalar_t* output_data_base,
802 int64_t outer_size,
803 int64_t inner_size,
804 int64_t dim_size) {
805 using Vec = vec::Vectorized<scalar_t>;
806 int64_t dim_stride = inner_size;
807 int64_t outer_stride = dim_size * dim_stride;
808 int vectorized_step = Vec().size();
809 // See Note: grain_size value of 0
810 parallel_for(
811 0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) {
812 int64_t idx = begin;
813 while (idx < end) {
814 int64_t outer_idx = idx / inner_size;
815 int64_t inner_idx = idx % inner_size;
816 if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
817 // Vectorization
818 const scalar_t* input_data =
819 input_data_base + outer_idx * outer_stride + inner_idx;
820 scalar_t* output_data =
821 output_data_base + outer_idx * outer_stride + inner_idx;
822 // Step 1: Get max Score
823 Vec max_vec = Vec::loadu(input_data);
824 for (const auto d : c10::irange(1, dim_size)) {
825 Vec input_vec = Vec::loadu(input_data + d * dim_stride);
826 max_vec = vec::maximum(max_vec, input_vec);
827 }
828 // Step2: Calculate sum
829 Vec sum_vec = Vec(0.0);
830 for (const auto d : c10::irange(dim_size)) {
831 Vec output_vec =
832 (Vec::loadu(input_data + d * dim_stride) - max_vec).exp();
833 output_vec.store(output_data + d * dim_stride);
834 sum_vec = sum_vec + output_vec;
835 }
836 // Step3: Unify
837 for (const auto d : c10::irange(dim_size)) {
838 Vec output_vec =
839 Vec::loadu(output_data + d * dim_stride) / sum_vec;
840 output_vec.store(output_data + d * dim_stride);
841 }
842 idx += vectorized_step;
843 } else {
844 // Tail case(Scalar): it is exactly same logic as host_softmax
845 // inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
846 // cases which will fall through this part:
847 // Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallelization.
848 // Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallelization.
849 int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
850 for (const auto i : c10::irange(tail_number)) {
851 outer_idx = (idx + i) / inner_size;
852 inner_idx = (idx + i) % inner_size;
853 const scalar_t* input_data =
854 input_data_base + outer_idx * outer_stride + inner_idx;
855 scalar_t* output_data =
856 output_data_base + outer_idx * outer_stride + inner_idx;
857 // Step1: Get max score
858 scalar_t max_input = input_data[0];
859 for (const auto d : c10::irange(1, dim_size)) {
860 max_input = std::max(max_input, input_data[d * dim_stride]);
861 }
862 // Step2: Calculate the Sum
863 scalar_t sum_data = 0;
864 for (const auto d : c10::irange(dim_size)) {
865 output_data[d * dim_stride] =
866 std::exp(input_data[d * dim_stride] - max_input);
867 sum_data += output_data[d * dim_stride];
868 }
869 // Step3: Unify
870 for (const auto d : c10::irange(dim_size)) {
871 output_data[d * dim_stride] =
872 output_data[d * dim_stride]/sum_data;
873 }
874 }
875 idx += tail_number;
876 }
877 }
878 });
879 }
880
881 // NB: fast kernel for log_softmax when dim != -1
882 // input shape is normalized to {outer_size, dim_size, inner_size}
883 //
884 // The algorithm requires to load input tensor 3 times, to increase parallelism
885 // and cache hit rate, inner_size is blocked as:
886 // inner_size: {CHUNK_SIZE, CHUNK_SIZE, ..., Remainder}
887 //
888 // Parallel on {outer_size, num_chunks} and do vertical reduction on each block of
889 // {dim_size, CHUNK_SIZE}, block size (128KB) selected to be L2 hit.
890 //
891 template<typename scalar_t>
892 inline typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_logsoftmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)893 _vec_logsoftmax(
894 const scalar_t* input_data_base,
895 scalar_t* output_data_base,
896 int64_t outer_size,
897 int64_t inner_size,
898 int64_t dim_size) {
899 using Vec = vec::Vectorized<scalar_t>;
900 int64_t BLOCK_SIZE = 128 * 1024;
901 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
902 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
903 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
904 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
905
906 // See Note: grain_size value of 0
907 at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
908 // thread local temp buffer which holds vertical reduction result: max and sum.
909 auto buffer = std::make_unique<scalar_t []>(CHUNK_SIZE * 2);
910 scalar_t* input_max_data = buffer.get();
911 scalar_t* tmp_sum_data = buffer.get() + CHUNK_SIZE;
912
913 for (int64_t i = begin; i < end; i++) {
914 int64_t outer_idx = i / num_chunks;
915 int64_t k = i % num_chunks;
916 int64_t inner_idx_begin = k * CHUNK_SIZE;
917 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
918
919 // init
920 Vec zero_vec = Vec(scalar_t(0));
921 Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
922 int64_t d0 = 0;
923 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
924 min_vec.store(input_max_data + d0);
925 zero_vec.store(tmp_sum_data + d0);
926 }
927 for (; d0 < size; d0++) {
928 input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
929 tmp_sum_data[d0] = scalar_t(0);
930 }
931
932 // compute max
933 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
934 const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
935 + dim_idx * inner_size + inner_idx_begin;
936
937 int64_t d1 = 0;
938 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
939 Vec data_vec = Vec::loadu(input_ptr + d1);
940 Vec max_vec = Vec::loadu(input_max_data + d1);
941 max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
942 max_vec.store(input_max_data + d1);
943 }
944 for (; d1 < size; d1++) {
945 scalar_t data_val = input_ptr[d1];
946 scalar_t max_val = input_max_data[d1];
947 input_max_data[d1] = data_val > max_val ? data_val : max_val;
948 }
949 }
950
951 // compute sum of (x - max).exp()
952 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
953 const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
954 + dim_idx * inner_size + inner_idx_begin;
955
956 int64_t d2 = 0;
957 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
958 Vec data_vec = Vec::loadu(input_ptr + d2);
959 Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
960 Vec max_vec = Vec::loadu(input_max_data + d2);
961 sum_vec += (data_vec - max_vec).exp();
962 sum_vec.store(tmp_sum_data + d2);
963 }
964 for (; d2 < size; d2++) {
965 scalar_t data_val = input_ptr[d2];
966 scalar_t max_val = input_max_data[d2];
967 tmp_sum_data[d2] += std::exp(data_val - max_val);
968 }
969 }
970
971 // apply log
972 vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
973
974 // compute x - max - sum
975 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
976 int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
977 const scalar_t* input_ptr = input_data_base + offset;
978 scalar_t* output_ptr = output_data_base + offset;
979
980 int64_t d3 = 0;
981 for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
982 Vec data_vec = Vec::loadu(input_ptr + d3);
983 Vec max_vec = Vec::loadu(input_max_data + d3);
984 Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
985 Vec out_vec = data_vec - max_vec - sum_vec;
986 out_vec.store(output_ptr + d3);
987 }
988 for (; d3 < size; d3++) {
989 output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
990 }
991 }
992 }
993 });
994 }
995
996 template<typename scalar_t>
997 inline typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
_vec_logsoftmax(const scalar_t * input_data_base,scalar_t * output_data_base,int64_t outer_size,int64_t inner_size,int64_t dim_size)998 _vec_logsoftmax(
999 const scalar_t* input_data_base,
1000 scalar_t* output_data_base,
1001 int64_t outer_size,
1002 int64_t inner_size,
1003 int64_t dim_size) {
1004 using Vec = vec::Vectorized<scalar_t>;
1005 using fVec = vec::Vectorized<float>;
1006 int64_t BLOCK_SIZE = 128 * 1024;
1007 int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
1008 MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
1009 int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
1010 int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
1011
1012 // See Note: grain_size value of 0
1013 at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
1014 auto buffer = std::make_unique<float []>(CHUNK_SIZE * 2);
1015 float* input_max_data = buffer.get();
1016 float* tmp_sum_data = buffer.get() + CHUNK_SIZE;
1017
1018 // thread local buffer that holds input data in float32 to save next 2 dtype conversion
1019 auto input_buffer = std::make_unique<float []>(dim_size * CHUNK_SIZE);
1020 float* input_buffer_data = input_buffer.get();
1021
1022 // init
1023 for (int64_t i = begin; i < end; i++) {
1024 int64_t outer_idx = i / num_chunks;
1025 int64_t k = i % num_chunks;
1026 int64_t inner_idx_begin = k * CHUNK_SIZE;
1027 int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
1028
1029 fVec zero_fvec = fVec(float(0));
1030 fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
1031 int64_t d0 = 0;
1032 for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
1033 min_fvec.store(input_max_data + d0);
1034 min_fvec.store(input_max_data + d0 + fVec::size());
1035 zero_fvec.store(tmp_sum_data + d0);
1036 zero_fvec.store(tmp_sum_data + d0 + fVec::size());
1037 }
1038 for (; d0 < size; d0++) {
1039 input_max_data[d0] = -std::numeric_limits<float>::infinity();
1040 tmp_sum_data[d0] = float(0);
1041 }
1042
1043 // compute max
1044 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1045 const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
1046 + dim_idx * inner_size + inner_idx_begin;
1047 float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1048
1049 int64_t d1 = 0;
1050 for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
1051 Vec data_vec = Vec::loadu(input_ptr + d1);
1052 auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
1053 fVec max_fvec0 = fVec::loadu(input_max_data + d1);
1054 fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
1055 max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
1056 max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
1057 max_fvec0.store(input_max_data + d1);
1058 max_fvec0.store(input_max_data + d1 + fVec::size());
1059
1060 // cache the 'converted' float input
1061 data_fvec0.store(input_buffer_ptr + d1);
1062 data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
1063 }
1064 for (; d1 < size; d1++) {
1065 float data_val = float(input_ptr[d1]);
1066 float max_val = input_max_data[d1];
1067 input_max_data[d1] = data_val > max_val ? data_val : max_val;
1068 input_buffer_ptr[d1] = data_val;
1069 }
1070 }
1071
1072 // compute sum of (x - max).exp()
1073 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1074 float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1075
1076 int64_t d2 = 0;
1077 for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
1078 fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
1079 fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
1080 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
1081 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
1082 fVec max_fvec0 = fVec::loadu(input_max_data + d2);
1083 fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
1084 sum_fvec0 += (data_fvec0 - max_fvec0).exp();
1085 sum_fvec1 += (data_fvec1 - max_fvec1).exp();
1086 sum_fvec0.store(tmp_sum_data + d2);
1087 sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
1088 }
1089 for (; d2 < size; d2++) {
1090 float data_val = input_buffer_ptr[d2];
1091 float max_val = input_max_data[d2];
1092 tmp_sum_data[d2] += std::exp(data_val - max_val);
1093 }
1094 }
1095
1096 // apply log
1097 vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
1098
1099 // compute x - max - sum
1100 for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
1101 float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
1102 scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
1103 + dim_idx * inner_size + inner_idx_begin;
1104
1105 int64_t d3 = 0;
1106 for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
1107 fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
1108 fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
1109 fVec max_fvec0 = fVec::loadu(input_max_data + d3);
1110 fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
1111 fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
1112 fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
1113 fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
1114 fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
1115 Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
1116 out_vec.store(output_ptr + d3);
1117 }
1118 for (; d3 < size; d3++) {
1119 output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
1120 }
1121 }
1122 }
1123 });
1124 }
1125
1126 template <typename scalar_t, bool LogSoftMax>
1127 struct vec_softmax {
applyat::native::__anon3dc914250111::vec_softmax1128 static void apply(const Tensor& output, const Tensor& input, int64_t dim) {
1129 int64_t outer_size = 1;
1130 int64_t dim_size = input.size(dim);
1131 int64_t inner_size = 1;
1132 for (const auto i : c10::irange(dim))outer_size *= input.size(i);
1133 for (int64_t i = dim + 1; i < input.dim(); ++i)
1134 inner_size *= input.size(i);
1135 const scalar_t* input_data_base = input.const_data_ptr<scalar_t>();
1136 scalar_t* output_data_base = output.data_ptr<scalar_t>();
1137 if (LogSoftMax) {
1138 _vec_logsoftmax(
1139 input_data_base, output_data_base, outer_size, inner_size, dim_size);
1140 } else {
1141 _vec_softmax(
1142 input_data_base, output_data_base, outer_size, inner_size, dim_size);
1143 }
1144 }
1145 };
1146
1147 template <typename scalar_t, bool LogSoftMax>
1148 struct vec_host_softmax_backward_lastdim {
1149 static void
applyat::native::__anon3dc914250111::vec_host_softmax_backward_lastdim1150 apply(const Tensor& grad_input, const Tensor& grad, const Tensor& output) {
1151 int64_t outer_size = 1;
1152 int64_t dim_size = grad.size(grad.ndimension() - 1);
1153 for (int64_t i = 0; i < grad.ndimension() - 1; ++i)
1154 outer_size *= grad.size(i);
1155 scalar_t* grad_input_data_base = grad_input.mutable_data_ptr<scalar_t>();
1156 const scalar_t* grad_data_base = grad.const_data_ptr<scalar_t>();
1157 const scalar_t* output_data_base = output.const_data_ptr<scalar_t>();
1158 _vec_host_softmax_backward_lastdim<scalar_t, LogSoftMax>(
1159 grad_input_data_base,
1160 grad_data_base,
1161 output_data_base,
1162 outer_size,
1163 dim_size);
1164 }
1165 };
1166
1167 template <typename scalar_t, bool LogSoftMax>
1168 struct vec_host_softmax_backward {
applyat::native::__anon3dc914250111::vec_host_softmax_backward1169 static void apply(
1170 const Tensor& grad_input,
1171 const Tensor& grad,
1172 const Tensor& output,
1173 int64_t dim) {
1174 int64_t outer_size = 1;
1175 int64_t dim_size = grad.size(dim);
1176 int64_t inner_size = 1;
1177 for (const auto i : c10::irange(dim)) {
1178 outer_size *= grad.size(i);
1179 }
1180 for (int64_t i = dim + 1; i < grad.dim(); ++i) {
1181 inner_size *= grad.size(i);
1182 }
1183 scalar_t* grad_input_data_base = grad_input.mutable_data_ptr<scalar_t>();
1184 const scalar_t* grad_output_data_base = grad.const_data_ptr<scalar_t>();
1185 const scalar_t* output_data_base = output.const_data_ptr<scalar_t>();
1186 if (LogSoftMax) {
1187 _vec_log_softmax_backward<scalar_t>(
1188 grad_input_data_base,
1189 grad_output_data_base,
1190 output_data_base,
1191 outer_size,
1192 inner_size,
1193 dim_size);
1194 } else {
1195 _vec_softmax_backward<scalar_t>(
1196 grad_input_data_base,
1197 grad_output_data_base,
1198 output_data_base,
1199 outer_size,
1200 inner_size,
1201 dim_size);
1202 }
1203 }
1204 };
1205
softmax_lastdim_kernel_impl(const Tensor & result,const Tensor & self)1206 static void softmax_lastdim_kernel_impl(
1207 const Tensor& result,
1208 const Tensor& self) {
1209 AT_DISPATCH_FLOATING_TYPES_AND2(
1210 at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1211 "softmax_lastdim_kernel_impl",
1212 [&] { vec_host_softmax_lastdim<scalar_t, false>::apply(result, self); });
1213 }
1214
softmax_kernel_impl(const Tensor & result,const Tensor & self,int64_t dim)1215 static void softmax_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim) {
1216 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1217 "softmax_kernel_impl",
1218 [&] { vec_softmax<scalar_t, false>::apply(result, self, dim); });
1219 }
1220
log_softmax_lastdim_kernel_impl(const Tensor & result,const Tensor & self)1221 static void log_softmax_lastdim_kernel_impl(
1222 const Tensor& result,
1223 const Tensor& self) {
1224 AT_DISPATCH_FLOATING_TYPES_AND2(
1225 at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1226 "log_softmax_lastdim_kernel_impl",
1227 [&] { vec_host_softmax_lastdim<scalar_t, true>::apply(result, self); });
1228 }
1229
log_softmax_kernel_impl(const Tensor & result,const Tensor & self,int64_t dim)1230 static void log_softmax_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim) {
1231 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(),
1232 "softmax_kernel_impl",
1233 [&] { vec_softmax<scalar_t, true>::apply(result, self, dim); });
1234 }
1235
softmax_backward_lastdim_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output)1236 static void softmax_backward_lastdim_kernel_impl(
1237 const Tensor& grad_input,
1238 const Tensor& grad,
1239 const Tensor& output) {
1240 AT_DISPATCH_FLOATING_TYPES_AND2(
1241 at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(),
1242 "softmax_backward_lastdim_kernel_impl", [&] {
1243 vec_host_softmax_backward_lastdim<scalar_t, false>::apply(
1244 grad_input, grad, output);
1245 });
1246 }
1247
log_softmax_backward_lastdim_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output)1248 static void log_softmax_backward_lastdim_kernel_impl(
1249 const Tensor& grad_input,
1250 const Tensor& grad,
1251 const Tensor& output) {
1252 AT_DISPATCH_FLOATING_TYPES_AND2(
1253 at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(),
1254 "log_softmax_backward_lastdim_kernel_impl", [&] {
1255 vec_host_softmax_backward_lastdim<scalar_t, true>::apply(
1256 grad_input, grad, output);
1257 });
1258 }
1259
softmax_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output,int64_t dim)1260 static void softmax_backward_kernel_impl(
1261 const Tensor& grad_input,
1262 const Tensor& grad,
1263 const Tensor& output,
1264 int64_t dim) {
1265 AT_DISPATCH_FLOATING_TYPES_AND2(
1266 at::ScalarType::BFloat16,
1267 at::ScalarType::Half,
1268 grad.scalar_type(),
1269 "softmax_backward_kernel_impl",
1270 [&] {
1271 vec_host_softmax_backward<scalar_t, false>::apply(
1272 grad_input, grad, output, dim);
1273 });
1274 }
1275
log_softmax_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad,const Tensor & output,int64_t dim)1276 static void log_softmax_backward_kernel_impl(
1277 const Tensor& grad_input,
1278 const Tensor& grad,
1279 const Tensor& output,
1280 int64_t dim) {
1281 AT_DISPATCH_FLOATING_TYPES_AND2(
1282 at::ScalarType::BFloat16,
1283 at::ScalarType::Half,
1284 grad.scalar_type(),
1285 "log_softmax_backward_kernel_impl",
1286 [&] {
1287 vec_host_softmax_backward<scalar_t, true>::apply(
1288 grad_input, grad, output, dim);
1289 });
1290 }
1291
1292 } // anonymous namespace
1293
1294 ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl);
1295 ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl);
1296 ALSO_REGISTER_AVX512_DISPATCH(
1297 softmax_backward_lastdim_kernel,
1298 &softmax_backward_lastdim_kernel_impl);
1299 ALSO_REGISTER_AVX512_DISPATCH(
1300 log_softmax_backward_lastdim_kernel,
1301 &log_softmax_backward_lastdim_kernel_impl);
1302
1303 ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl);
1304 ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl);
1305 ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl);
1306 ALSO_REGISTER_AVX512_DISPATCH(
1307 log_softmax_backward_kernel,
1308 &log_softmax_backward_kernel_impl);
1309 } // namespace at::native
1310