xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // LinearAlgebraStubs.cpp
2 // Mostly a no-op unless BUILD_LAZY_CUDA_LINALG is defined
3 // In that case load library is dynamically loaded when first linalg call is made
4 // This helps reduce size of GPU memory context if linear algebra functions are not used
5 #include <ATen/Context.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAConfig.h>
8 #include <ATen/NativeFunctions.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/DynamicLibrary.h>
11 #include <ATen/NativeFunctions.h>
12 #include <ATen/native/cuda/MiscUtils.h>
13 #include <ATen/native/Resize.h>
14 #include <ATen/native/LinearAlgebra.h>
15 #include <ATen/native/BatchLinearAlgebra.h>
16 #include <ATen/native/TransposeType.h>
17 #if defined(BUILD_LAZY_CUDA_LINALG)
18 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
19 
20 #if AT_MAGMA_ENABLED()
21 #include <ATen/cuda/detail/CUDAHooks.h>
22 
23 namespace {
24 struct MagmaInitializer {
MagmaInitializer__anon96ca63d10111::MagmaInitializer25   MagmaInitializer() {
26     ::at::cuda::detail::set_magma_init_fn([]{ });
27   };
28 } initializer;
29 }  // namespace (anonymous)
30 #endif
31 #endif
32 namespace at::native {
33 #if defined(BUILD_LAZY_CUDA_LINALG)
34 namespace {
35 cuda::detail::LinalgDispatch disp = {_cholesky_solve_helper_cuda};
36 
getTorchLinalgLibrary()37 at::DynamicLibrary& getTorchLinalgLibrary() {
38   static at::DynamicLibrary lib("libtorch_cuda_linalg.so", nullptr, true);
39   return lib;
40 }
41 
42 // Lazy dispatches do nothing but load linalg library and call the stub
43 // Loading the library should override the registration of those with the proper implementation
44 // getTorchLinalgLibrary() throws an exception if library is not found,
45 // which makes it unnecessary to have an explicit error checking
46 // But make sure that this function is called only once, to avoid infinite recursion
loadLazyTorchLinalgLibrary()47 void loadLazyTorchLinalgLibrary() {
48   static int invoke_count = 0;
49   getTorchLinalgLibrary();
50   TORCH_CHECK(invoke_count++ == 0, "lazy wrapper should be called at most once");
51 }
52 
lazy_cholesky_kernel(const Tensor & input,const Tensor & info,bool upper)53 void lazy_cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) {
54   loadLazyTorchLinalgLibrary();
55   cholesky_stub(DeviceType::CUDA, input, info, upper);
56 }
57 
lazy_cholesky_inverse_kernel(Tensor & result,Tensor & infos,bool upper)58 Tensor& lazy_cholesky_inverse_kernel(Tensor &result, Tensor& infos, bool upper) {
59   loadLazyTorchLinalgLibrary();
60   return cholesky_inverse_stub(DeviceType::CUDA, result, infos, upper);
61 }
62 
lazy_lu_factor(const Tensor & input,const Tensor & pivots,const Tensor & infos,bool compute_pivots)63 void lazy_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
64   loadLazyTorchLinalgLibrary();
65   lu_factor_stub(DeviceType::CUDA, input, pivots, infos, compute_pivots);
66 }
67 
lazy_triangular_solve_kernel(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)68 void lazy_triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
69   loadLazyTorchLinalgLibrary();
70   triangular_solve_stub(DeviceType::CUDA, A, B, left, upper, transpose, unitriangular);
71 }
72 
lazy_orgqr_kernel(Tensor & result,const Tensor & tau)73 Tensor& lazy_orgqr_kernel(Tensor& result, const Tensor& tau) {
74   loadLazyTorchLinalgLibrary();
75   return orgqr_stub(DeviceType::CUDA, result, tau);
76 }
77 
lazy_ormqr_kernel(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)78 void lazy_ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
79   loadLazyTorchLinalgLibrary();
80   ormqr_stub(DeviceType::CUDA, input, tau, other, left, transpose);
81 }
82 
lazy_geqrf_kernel(const Tensor & input,const Tensor & tau)83 void lazy_geqrf_kernel(const Tensor& input, const Tensor& tau) {
84   loadLazyTorchLinalgLibrary();
85   geqrf_stub(DeviceType::CUDA, input, tau);
86 }
87 
lazy_linalg_eigh_kernel(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)88 void lazy_linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
89   loadLazyTorchLinalgLibrary();
90   linalg_eigh_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
91 }
92 
lazy_linalg_eig_kernel(Tensor & eigenvalues,Tensor & eigenvectors,Tensor & infos,const Tensor & input,bool compute_eigenvectors)93 void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
94   getTorchLinalgLibrary();
95   linalg_eig_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
96 }
97 
lazy_svd_kernel(const Tensor & A,const bool full_matrices,const bool compute_uv,const std::optional<c10::string_view> & driver,const Tensor & U,const Tensor & S,const Tensor & Vh,const Tensor & info)98 void lazy_svd_kernel(const Tensor& A,
99                      const bool full_matrices,
100                      const bool compute_uv,
101                      const std::optional<c10::string_view>& driver,
102                      const Tensor& U,
103                      const Tensor& S,
104                      const Tensor& Vh,
105                      const Tensor& info) {
106   getTorchLinalgLibrary();
107   svd_stub(DeviceType::CUDA, A, full_matrices, compute_uv, driver, U, S, Vh, info);
108 }
109 
lazy_lu_solve(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType trans)110 void lazy_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
111   getTorchLinalgLibrary();
112   lu_solve_stub(DeviceType::CUDA, LU, pivots, B, trans);
113 }
114 
lazy_lstsq_kernel(const Tensor & a,Tensor & b,Tensor & rank,Tensor & singular_values,Tensor & infos,double rcond,std::string driver_name)115 void lazy_lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name)  {
116   getTorchLinalgLibrary();
117   lstsq_stub(DeviceType::CUDA, a, b, rank, singular_values, infos, rcond, driver_name);
118 }
119 
lazy_ldl_factor(const Tensor & LD,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)120 void lazy_ldl_factor(
121     const Tensor& LD,
122     const Tensor& pivots,
123     const Tensor& info,
124     bool upper,
125     bool hermitian) {
126   loadLazyTorchLinalgLibrary();
127   ldl_factor_stub(DeviceType::CUDA, LD, pivots, info, upper, hermitian);
128 }
129 
lazy_ldl_solve(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool upper,bool hermitian)130 void lazy_ldl_solve(
131     const Tensor& LD,
132     const Tensor& pivots,
133     const Tensor& B,
134     bool upper,
135     bool hermitian) {
136   loadLazyTorchLinalgLibrary();
137   ldl_solve_stub(DeviceType::CUDA, LD, pivots, B, upper, hermitian);
138 }
139 
140 REGISTER_CUDA_DISPATCH(cholesky_stub, &lazy_cholesky_kernel)
141 REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel);
142 REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor);
143 REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor);
144 REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve);
145 REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel);
146 REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
147 REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
148 REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
149 REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
150 REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
151 REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
152 REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
153 REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
154 } // anonymous namespace
155 
156 // Old style dispatches
157 // torch_cuda_linalg dynamic library should have a global constructor
158 // that calls regiserLinaglDispatch so in order ot lazy bind
159 // old style dispatch all one have to do is to load library and call disp.func_name
160 // Protect from infinite recursion by initializing dispatch to self and checking
161 // that values are different after linalg library were loaded
162 
163 namespace cuda {
164 namespace detail {
registerLinalgDispatch(const LinalgDispatch & disp_)165 void registerLinalgDispatch(const LinalgDispatch& disp_) {
166   disp = disp_;
167 }
168 }} //namespace cuda::detail
169 
_cholesky_solve_helper_cuda(const Tensor & self,const Tensor & A,bool upper)170 Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
171     getTorchLinalgLibrary();
172     TORCH_CHECK(disp.cholesky_solve_helper != _cholesky_solve_helper_cuda, "Can't find _cholesky_solve_helper_cuda");
173     return disp.cholesky_solve_helper(self, A, upper);
174 }
175 
176 #endif /*defined(BUILD_LAZY_CUDA_LINALG)*/
177 
178 } // namespace at::native
179