1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4
5 namespace at::native {
6 //input tensors are non-zero dim and non-empty
7 template<typename T1, typename T2, typename Function>
8
tensor_dim_apply3(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim,Function func)9 void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
10 int ndims = self.dim();
11 int tensor_dim_apply_has_finished = 0;
12 std::vector<int64_t> counter(ndims, 0);
13 const T1* self_data = self.const_data_ptr<T1>();
14 T1* values_data = values.data_ptr<T1>();
15 T2* indices_data = indices.data_ptr<T2>();
16 int64_t self_stride = self.stride(dim);
17 int64_t values_stride = values.stride(dim);
18 int64_t indices_stride = indices.stride(dim);
19 int self_dim_size = self.size(dim);
20
21 while (!tensor_dim_apply_has_finished) {
22 func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
23 if (ndims == 1) {
24 break;
25 }
26 for (const auto dim_i : c10::irange(ndims)) {
27 if (dim_i == dim) {
28 if (dim_i == (ndims - 1)) {
29 tensor_dim_apply_has_finished = 1;
30 break;
31 }
32 continue;
33 }
34 counter[dim_i]++;
35 self_data += self.stride(dim_i);
36 values_data += values.stride(dim_i);
37 indices_data += indices.stride(dim_i);
38
39 if (counter[dim_i] == self.size(dim_i)) {
40 if (dim_i == ndims-1) {
41 tensor_dim_apply_has_finished = 1;
42 break;
43 } else {
44 self_data -= counter[dim_i]*self.stride(dim_i);
45 values_data -= counter[dim_i]*values.stride(dim_i);
46 indices_data -= counter[dim_i]*indices.stride(dim_i);
47 counter[dim_i] = 0;
48 }
49 } else {
50 break;
51 }
52 }
53 }
54 }
55 } // namespace at::native
56