/* Provides the implementations of MKL Sparse BLAS function templates. */ #define TORCH_ASSERT_NO_OPERATORS #include #include namespace at::mkl::sparse { namespace { template MKL_Complex to_mkl_complex(c10::complex scalar) { MKL_Complex mkl_scalar; mkl_scalar.real = scalar.real(); mkl_scalar.imag = scalar.imag(); return mkl_scalar; } } // namespace template <> void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_csr( A, indexing, rows, cols, rows_start, rows_end, col_indx, values)); } template <> void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_csr( A, indexing, rows, cols, rows_start, rows_end, col_indx, values)); } template <> void create_csr>( MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_csr( A, indexing, rows, cols, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void create_csr>( MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_csr( A, indexing, rows, cols, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_bsr( A, indexing, block_layout, rows, cols, block_size, rows_start, rows_end, col_indx, values)); } template <> void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_bsr( A, indexing, block_layout, rows, cols, block_size, rows_start, rows_end, col_indx, values)); } template <> void create_bsr>( MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_bsr( A, indexing, block_layout, rows, cols, block_size, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void create_bsr>( MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_bsr( A, indexing, block_layout, rows, cols, block_size, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void mv(MKL_SPARSE_MV_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK( mkl_sparse_s_mv(operation, alpha, A, descr, x, beta, y)); } template <> void mv(MKL_SPARSE_MV_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK( mkl_sparse_d_mv(operation, alpha, A, descr, x, beta, y)); } template <> void mv>(MKL_SPARSE_MV_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mv( operation, to_mkl_complex(alpha), A, descr, reinterpret_cast(x), to_mkl_complex(beta), reinterpret_cast(y))); } template <> void mv>(MKL_SPARSE_MV_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mv( operation, to_mkl_complex(alpha), A, descr, reinterpret_cast(x), to_mkl_complex(beta), reinterpret_cast(y))); } template <> void add(MKL_SPARSE_ADD_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_add(operation, A, alpha, B, C)); } template <> void add(MKL_SPARSE_ADD_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_add(operation, A, alpha, B, C)); } template <> void add>(MKL_SPARSE_ADD_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_add( operation, A, to_mkl_complex(alpha), B, C)); } template <> void add>(MKL_SPARSE_ADD_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_add( operation, A, to_mkl_complex(alpha), B, C)); } template <> void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_export_csr( source, indexing, rows, cols, rows_start, rows_end, col_indx, values)); } template <> void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_export_csr( source, indexing, rows, cols, rows_start, rows_end, col_indx, values)); } template <> void export_csr>( MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_export_csr( source, indexing, rows, cols, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void export_csr>( MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_export_csr( source, indexing, rows, cols, rows_start, rows_end, col_indx, reinterpret_cast(values))); } template <> void mm(MKL_SPARSE_MM_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_mm( operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc)); } template <> void mm(MKL_SPARSE_MM_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_mm( operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc)); } template <> void mm>(MKL_SPARSE_MM_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mm( operation, to_mkl_complex(alpha), A, descr, layout, reinterpret_cast(B), columns, ldb, to_mkl_complex(beta), reinterpret_cast(C), ldc)); } template <> void mm>(MKL_SPARSE_MM_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mm( operation, to_mkl_complex(alpha), A, descr, layout, reinterpret_cast(B), columns, ldb, to_mkl_complex(beta), reinterpret_cast(C), ldc)); } template <> void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(float)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_s_spmmd( operation, A, B, layout, C, ldc)); } template <> void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(double)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_d_spmmd( operation, A, B, layout, C, ldc)); } template <> void spmmd>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_c_spmmd( operation, A, B, layout, reinterpret_cast(C), ldc)); } template <> void spmmd>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex)) { TORCH_MKLSPARSE_CHECK(mkl_sparse_z_spmmd( operation, A, B, layout, reinterpret_cast(C), ldc)); } template <> sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(float)) { sparse_status_t status = mkl_sparse_s_trsv(operation, alpha, A, descr, x, y); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsv"); return status; } template <> sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(double)) { sparse_status_t status = mkl_sparse_d_trsv(operation, alpha, A, descr, x, y); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsv"); return status; } template <> sparse_status_t trsv>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex)) { sparse_status_t status = mkl_sparse_c_trsv( operation, to_mkl_complex(alpha), A, descr, reinterpret_cast(x), reinterpret_cast(y)); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsv"); return status; } template <> sparse_status_t trsv>( MKL_SPARSE_TRSV_ARGTYPES(c10::complex)) { sparse_status_t status = mkl_sparse_z_trsv( operation, to_mkl_complex(alpha), A, descr, reinterpret_cast(x), reinterpret_cast(y)); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsv"); return status; } template <> sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(float)) { sparse_status_t status = mkl_sparse_s_trsm( operation, alpha, A, descr, layout, x, columns, ldx, y, ldy); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsm"); return status; } template <> sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(double)) { sparse_status_t status = mkl_sparse_d_trsm( operation, alpha, A, descr, layout, x, columns, ldx, y, ldy); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsm"); return status; } template <> sparse_status_t trsm>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex)) { sparse_status_t status = mkl_sparse_c_trsm( operation, to_mkl_complex(alpha), A, descr, layout, reinterpret_cast(x), columns, ldx, reinterpret_cast(y), ldy); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsm"); return status; } template <> sparse_status_t trsm>( MKL_SPARSE_TRSM_ARGTYPES(c10::complex)) { sparse_status_t status = mkl_sparse_z_trsm( operation, to_mkl_complex(alpha), A, descr, layout, reinterpret_cast(x), columns, ldx, reinterpret_cast(y), ldy); TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsm"); return status; } } // namespace at::mkl::sparse