xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/IndexKernelUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/native/TensorIterator.h>
3 #include <c10/util/irange.h>
4 
5 namespace at::native {
6 
7 namespace {
is_constant_index(int ntensor,const int64_t * strides)8 static bool is_constant_index(int ntensor, const int64_t* strides) {
9   AT_ASSERT(ntensor >= 3);
10   for (const auto arg : c10::irange(2, ntensor)) {
11     if (strides[arg] != 0) {
12       return false;
13     }
14   }
15   return true;
16 }
17 
18 
19 struct Indexer {
IndexerIndexer20   Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
21           IntArrayRef original_sizes, IntArrayRef original_strides)
22     : num_indexers(num_indexers)
23     , indexers(indexers)
24     , indexer_strides(indexer_strides)
25     , original_strides(original_strides.data())
26     , original_sizes(original_sizes.data()) {
27     AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
28     AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
29   }
30 
31   int64_t num_indexers;
32   char** indexers;
33   const int64_t* indexer_strides;
34   const int64_t* original_strides;
35   const int64_t* original_sizes;
36 
getIndexer37   int64_t get(int64_t idx) {
38     int64_t offset = 0;
39     for (const auto j : c10::irange(num_indexers)) {
40       int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
41       int64_t size = original_sizes[j];
42       TORCH_CHECK_INDEX(value >= -size && value < size,
43                         "index ", value, " is out of bounds for dimension ", j, " with size ", size);
44       if (value < 0) {
45         value += size;
46       }
47       offset += value * original_strides[j];
48     }
49     return offset;
50   }
51 };
52 } // anonymous namespace
53 
54 template <typename scalar_t, typename func_t>
55 void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
56                       const func_t& f, bool serial_execution=false)
57 {
58   int ntensor = iter.ntensors();
59   // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
60   // to make the whole available thread numbers get more balanced work load and a better cache location.
61   // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
62   const int index_parallel_grain_size = 3000;
63   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
64     auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
65     char* dst = data[0];
66     char* src = data[1];
67     if (is_constant_index(ntensor, strides)) {
68       // specialization for when every element uses the same index
69       int64_t offset = indexer.get(0);
70       for (const auto i : c10::irange(n)) {
71         f(dst + strides[0] * i, src + strides[1] * i, offset);
72       }
73     } else {
74       for (const auto i : c10::irange(n)) {
75         int64_t offset = indexer.get(i);
76         f(dst + strides[0] * i, src + strides[1] * i, offset);
77       }
78     }
79   };
80   if (serial_execution) {
81     iter.serial_for_each(loop, {0, iter.numel()});
82   } else {
83     iter.for_each(loop, index_parallel_grain_size);
84   }
85 }
86 } // at
87 // native
88