xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseStubs.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <c10/util/ArrayRef.h>
5 #include <optional>
6 
7 namespace at {
8 
9 class Tensor;
10 
11 namespace native {
12 
13 using mul_sparse_sparse_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y);
14 DECLARE_DISPATCH(mul_sparse_sparse_out_fn, mul_sparse_sparse_out_stub);
15 
16 using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const std::optional<Tensor>& x_hash_opt);
17 DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub);
18 
19 using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const std::optional<Tensor>& x_hash_opt, bool accumulate_matches);
20 DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub);
21 
22 using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size);
23 DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub);
24 
25 } // namespace native
26 } // namespace at
27