xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/FlattenIndicesCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/sparse/Macros.h>
7 #include <ATen/ExpandUtils.h>
8 #include <ATen/native/SparseTensorUtils.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/tensor.h>
16 #endif
17 
18 #ifdef GPUCC
19 #define NAME "flatten_indices_cuda"
20 #else
21 #define NAME "flatten_indices_cpu"
22 #endif
23 
24 namespace at::native {
25 
26 namespace {
27 
28 template <template <typename func_t> class kernel_t>
29 struct KernelLauncher {
30   template <typename func_t>
launchKernelLauncher31   static void launch(TensorIteratorBase& iter, const func_t& f) {
32     kernel_t<func_t>::launch(iter, f);
33   }
34 };
35 
36 template <
37   template <typename func_t> class kernel_t,
38   typename index_t,
39   int64_t max_static_len = 0>
_flatten_indices_impl(const Tensor & indices,IntArrayRef size)40 Tensor _flatten_indices_impl(const Tensor& indices, IntArrayRef size) {
41   TORCH_INTERNAL_ASSERT(indices.dim() > 1 && static_cast<size_t>(indices.size(0)) == size.size());
42 
43   // Need owning storage in case of the Tensor class.
44   const auto hash_coeffs_storage = [&]() -> auto {
45     auto strides = c10::contiguous_strides(size);
46     return at::sparse::TensorGeometryHolder<max_static_len>(strides, strides, indices.options());
47   }();
48   const auto hash_coeffs = std::get<0>(*hash_coeffs_storage);
49 
50   const auto hash_indices = [&]() -> Tensor {
51     // non-const because of gcc-5/clang-5 issues
52     auto sparse_dim = indices.size(0);
53     auto indices_dim_stride = indices.stride(0);
54     auto indices_nnz_stride = indices.stride(1);
55 
56     auto hash = at::arange(indices.size(1), indices.options().dtype(kLong));
57 
58     auto iter = TensorIteratorConfig()
59       .set_check_mem_overlap(false)
60       .add_output(hash)
61       .add_input(hash)
62       .build();
63 
64     {
65       const auto* RESTRICT ptr_indices = indices.const_data_ptr<index_t>();
66 
67       KernelLauncher<kernel_t>::launch(iter,
68           // NOTE: capture by value required by CUDA
69           [=] FUNCAPI (int64_t nnz_idx) -> int64_t {
70           const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
71           auto hash = static_cast<int64_t>(0);
72           for (int64_t dim = 0; dim < sparse_dim; ++dim) {
73             const auto dim_hash_coeff = hash_coeffs[dim];
74             const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
75             hash += dim_index * dim_hash_coeff;
76           }
77           return hash;
78       });
79     }
80 
81     return hash;
82   }();
83 
84   return hash_indices;
85 }
86 
87 template <template <typename func_t> class kernel_t>
_flatten_indices(const Tensor & indices,IntArrayRef size)88 Tensor _flatten_indices(const Tensor& indices, IntArrayRef size) {
89   TORCH_CHECK(indices.dim() > 1 && static_cast<size_t>(indices.size(0)) == size.size(),
90       NAME, "(): the dimensionality of sparse `indices` and the length of `size` must match. ",
91             "Got `indices.size(0) == ", indices.size(0), "` != `size.size() == ", size.size(), "`.");
92   Tensor flattened_indices;
93   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), NAME, [&] () {
94     constexpr int64_t max_sparse_dims = 8;
95     if (indices.size(0) <= max_sparse_dims) {
96       flattened_indices = _flatten_indices_impl<kernel_t, index_t, max_sparse_dims>(indices, size);
97     } else {
98       flattened_indices = _flatten_indices_impl<kernel_t, index_t>(indices, size);
99     }
100   });
101   return flattened_indices;
102 }
103 
104 }
105 
106 } // at::native
107