xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorDimApply.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4 
5 namespace at::native {
6 //input tensors are non-zero dim and non-empty
7 template<typename T1, typename T2, typename Function>
8 
tensor_dim_apply3(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim,Function func)9 void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
10   int ndims = self.dim();
11   int tensor_dim_apply_has_finished = 0;
12   std::vector<int64_t> counter(ndims, 0);
13   const T1* self_data = self.const_data_ptr<T1>();
14   T1* values_data = values.data_ptr<T1>();
15   T2* indices_data = indices.data_ptr<T2>();
16   int64_t self_stride = self.stride(dim);
17   int64_t values_stride = values.stride(dim);
18   int64_t indices_stride = indices.stride(dim);
19   int self_dim_size = self.size(dim);
20 
21   while (!tensor_dim_apply_has_finished) {
22     func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
23     if (ndims == 1) {
24        break;
25     }
26     for (const auto dim_i : c10::irange(ndims)) {
27       if (dim_i == dim) {
28         if (dim_i == (ndims - 1)) {
29           tensor_dim_apply_has_finished = 1;
30           break;
31         }
32         continue;
33       }
34       counter[dim_i]++;
35       self_data += self.stride(dim_i);
36       values_data += values.stride(dim_i);
37       indices_data += indices.stride(dim_i);
38 
39       if (counter[dim_i] == self.size(dim_i)) {
40         if (dim_i == ndims-1) {
41           tensor_dim_apply_has_finished = 1;
42           break;
43         } else {
44           self_data -= counter[dim_i]*self.stride(dim_i);
45           values_data -= counter[dim_i]*values.stride(dim_i);
46           indices_data -= counter[dim_i]*indices.stride(dim_i);
47           counter[dim_i] = 0;
48         }
49       } else {
50         break;
51      }
52     }
53   }
54 }
55 } // namespace at::native
56