xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/SparseCsrTensorUtils.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/core/Scalar.h>
6 
7 namespace at::native::sparse::impl::cuda {
8 
9 void addmm_out_sparse_csr(
10     const Tensor& input,
11     const at::sparse_csr::SparseCsrTensor& mat1,
12     const Tensor& mat2,
13     const Scalar& beta,
14     const Scalar& alpha,
15     const Tensor& result);
16 
17 void addmv_out_sparse_csr(
18     const at::sparse_csr::SparseCsrTensor& mat,
19     const Tensor& vec,
20     const Scalar& beta,
21     const Scalar& alpha,
22     const Tensor& result);
23 
24 void add_out_sparse_csr(
25     const at::sparse_csr::SparseCsrTensor& mat1,
26     const at::sparse_csr::SparseCsrTensor& mat2,
27     const Scalar& alpha,
28     const Scalar& beta,
29     const at::sparse_csr::SparseCsrTensor& result);
30 
31 void triangular_solve_out_sparse_csr(
32     const at::sparse_csr::SparseCsrTensor& A,
33     const Tensor& B,
34     const Tensor& X,
35     bool upper,
36     bool transpose,
37     bool unitriangular);
38 
39 void sampled_addmm_out_sparse_csr(
40     const Tensor& mat1,
41     const Tensor& mat2,
42     const Scalar& beta,
43     const Scalar& alpha,
44     const at::sparse_csr::SparseCsrTensor& result);
45 
46 } // namespace at::native::sparse::impl::cuda
47