1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <c10/util/ArrayRef.h> 5 #include <optional> 6 7 namespace at { 8 9 class Tensor; 10 11 namespace native { 12 13 using mul_sparse_sparse_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y); 14 DECLARE_DISPATCH(mul_sparse_sparse_out_fn, mul_sparse_sparse_out_stub); 15 16 using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const std::optional<Tensor>& x_hash_opt); 17 DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub); 18 19 using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const std::optional<Tensor>& x_hash_opt, bool accumulate_matches); 20 DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub); 21 22 using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size); 23 DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub); 24 25 } // namespace native 26 } // namespace at 27