xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ValidateCompressedIndicesKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/ValidateCompressedIndicesCommon.h>
3 #include <ATen/native/cuda/Loops.cuh>
4 
5 namespace at::native {
6 
7 namespace {
8 
9 template <typename func_t>
10 struct CUDAKernelLauncher {
launchat::native::__anone03e164b0111::CUDAKernelLauncher11   static void launch(TensorIteratorBase& iter, const func_t& f) {
12     gpu_kernel(iter, f);
13   }
14 };
15 
16 }
17 
_validate_compressed_sparse_indices_cuda(const bool is_crow,const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)18 void _validate_compressed_sparse_indices_cuda(
19     const bool is_crow,
20     const Tensor& cidx,
21     const Tensor& idx,
22     const int64_t cdim,
23     const int64_t dim,
24     const int64_t nnz) {
25   validate_compressed_sparse_indices_kernel<CUDAKernelLauncher>(
26       is_crow, cidx, idx, cdim, dim, nnz);
27 }
28 
29 } // namespace at::native
30