xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorIteratorInternal.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/native/TensorIterator.h>
3 #include <c10/util/SmallBuffer.h>
4 #include <c10/util/irange.h>
5 
6 namespace at {
7 
8 struct DimCounter {
9   DimCounter(IntArrayRef shape, Range range);
10 
11   void increment(const std::array<int64_t, 2>& step);
12   bool is_done() const;
13   std::array<int64_t, 2> max_2d_step() const;
14 
15   IntArrayRef shape;
16   Range range;
17   c10::SmallBuffer<int64_t, 4> values;
18   int64_t offset;
19 };
20 
21 namespace internal {
22 
get_data_ptrs(char ** ptrs,ArrayRef<char * > base,IntArrayRef strides,IntArrayRef counter)23 inline void get_data_ptrs(
24     char** ptrs,
25     ArrayRef<char*> base,
26     IntArrayRef strides,
27     IntArrayRef counter) {
28   const auto ntensors = base.size();
29   const auto ndim = counter.size();
30   std::copy(base.begin(), base.end(), ptrs);
31   for (const auto dim : c10::irange(ndim)) {
32     int64_t value = counter[dim];
33     for (const auto arg : c10::irange(ntensors)) {
34       ptrs[arg] += value * strides[dim * ntensors + arg];
35     }
36   }
37 }
38 
serial_for_each(IntArrayRef shape,IntArrayRef strides,char ** base_ptrs,size_t ntensors,typename TensorIteratorBase::loop2d_t loop,Range range)39 inline void serial_for_each(
40     IntArrayRef shape,
41     IntArrayRef strides,
42     char** base_ptrs,
43     size_t ntensors,
44     typename TensorIteratorBase::loop2d_t loop,
45     Range range) {
46   const auto ndim = shape.size();
47   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
48       strides.size() == ntensors * std::max(size_t{2}, ndim));
49 
50   if (ndim <= 1) {
51     if (range.begin == 0) {
52       loop(base_ptrs, strides.data(), range.size(), 1);
53     } else {
54       c10::SmallBuffer<char*, 4> ptrs(ntensors);
55       get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
56       loop(ptrs.data(), strides.data(), range.size(), 1);
57     }
58   } else {
59     c10::SmallBuffer<char*, 4> ptrs(ntensors);
60     auto counter = DimCounter(shape, range);
61     while (!counter.is_done()) {
62       get_data_ptrs(
63           ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
64       auto step = counter.max_2d_step();
65       loop(ptrs.data(), strides.data(), step[0], step[1]);
66       counter.increment(step);
67     }
68   }
69 }
70 
71 } // namespace internal
72 } // namespace at
73