xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/cublas_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <iostream>
2 
3 #include <torch/extension.h>
4 #include <ATen/cuda/CUDAContext.h>
5 
6 #include <cublas_v2.h>
7 
noop_cublas_function(torch::Tensor x)8 torch::Tensor noop_cublas_function(torch::Tensor x) {
9   cublasHandle_t handle;
10   TORCH_CUDABLAS_CHECK(cublasCreate(&handle));
11   TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
12   return x;
13 }
14 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)15 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16     m.def("noop_cublas_function", &noop_cublas_function, "a cublas function");
17 }
18