1 #pragma once
2 #include <ATen/native/TensorIterator.h>
3 #include <c10/util/irange.h>
4
5 namespace at::native {
6
7 namespace {
is_constant_index(int ntensor,const int64_t * strides)8 static bool is_constant_index(int ntensor, const int64_t* strides) {
9 AT_ASSERT(ntensor >= 3);
10 for (const auto arg : c10::irange(2, ntensor)) {
11 if (strides[arg] != 0) {
12 return false;
13 }
14 }
15 return true;
16 }
17
18
19 struct Indexer {
IndexerIndexer20 Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
21 IntArrayRef original_sizes, IntArrayRef original_strides)
22 : num_indexers(num_indexers)
23 , indexers(indexers)
24 , indexer_strides(indexer_strides)
25 , original_strides(original_strides.data())
26 , original_sizes(original_sizes.data()) {
27 AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
28 AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
29 }
30
31 int64_t num_indexers;
32 char** indexers;
33 const int64_t* indexer_strides;
34 const int64_t* original_strides;
35 const int64_t* original_sizes;
36
getIndexer37 int64_t get(int64_t idx) {
38 int64_t offset = 0;
39 for (const auto j : c10::irange(num_indexers)) {
40 int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
41 int64_t size = original_sizes[j];
42 TORCH_CHECK_INDEX(value >= -size && value < size,
43 "index ", value, " is out of bounds for dimension ", j, " with size ", size);
44 if (value < 0) {
45 value += size;
46 }
47 offset += value * original_strides[j];
48 }
49 return offset;
50 }
51 };
52 } // anonymous namespace
53
54 template <typename scalar_t, typename func_t>
55 void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
56 const func_t& f, bool serial_execution=false)
57 {
58 int ntensor = iter.ntensors();
59 // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
60 // to make the whole available thread numbers get more balanced work load and a better cache location.
61 // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
62 const int index_parallel_grain_size = 3000;
63 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
64 auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
65 char* dst = data[0];
66 char* src = data[1];
67 if (is_constant_index(ntensor, strides)) {
68 // specialization for when every element uses the same index
69 int64_t offset = indexer.get(0);
70 for (const auto i : c10::irange(n)) {
71 f(dst + strides[0] * i, src + strides[1] * i, offset);
72 }
73 } else {
74 for (const auto i : c10::irange(n)) {
75 int64_t offset = indexer.get(i);
76 f(dst + strides[0] * i, src + strides[1] * i, offset);
77 }
78 }
79 };
80 if (serial_execution) {
81 iter.serial_for_each(loop, {0, iter.numel()});
82 } else {
83 iter.for_each(loop, index_parallel_grain_size);
84 }
85 }
86 } // at
87 // native
88