1 #pragma once 2 3 #include <ATen/SparseCsrTensorUtils.h> 4 #include <ATen/Tensor.h> 5 #include <ATen/core/Scalar.h> 6 7 namespace at::native::sparse::impl::cuda { 8 9 void addmm_out_sparse_csr( 10 const Tensor& input, 11 const at::sparse_csr::SparseCsrTensor& mat1, 12 const Tensor& mat2, 13 const Scalar& beta, 14 const Scalar& alpha, 15 const Tensor& result); 16 17 void addmv_out_sparse_csr( 18 const at::sparse_csr::SparseCsrTensor& mat, 19 const Tensor& vec, 20 const Scalar& beta, 21 const Scalar& alpha, 22 const Tensor& result); 23 24 void add_out_sparse_csr( 25 const at::sparse_csr::SparseCsrTensor& mat1, 26 const at::sparse_csr::SparseCsrTensor& mat2, 27 const Scalar& alpha, 28 const Scalar& beta, 29 const at::sparse_csr::SparseCsrTensor& result); 30 31 void triangular_solve_out_sparse_csr( 32 const at::sparse_csr::SparseCsrTensor& A, 33 const Tensor& B, 34 const Tensor& X, 35 bool upper, 36 bool transpose, 37 bool unitriangular); 38 39 void sampled_addmm_out_sparse_csr( 40 const Tensor& mat1, 41 const Tensor& mat2, 42 const Scalar& beta, 43 const Scalar& alpha, 44 const at::sparse_csr::SparseCsrTensor& result); 45 46 } // namespace at::native::sparse::impl::cuda 47