1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/native/DispatchStub.h> 5 #include <ATen/native/ReductionType.h> 6 7 namespace at::native { 8 9 using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); 10 using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); 11 using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); 12 using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); 13 using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); 14 15 DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub); 16 DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub); 17 DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub); 18 DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub); 19 DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub); 20 DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub); 21 22 } // at::native 23