1 // LinearAlgebraStubs.cpp
2 // Mostly a no-op unless BUILD_LAZY_CUDA_LINALG is defined
3 // In that case load library is dynamically loaded when first linalg call is made
4 // This helps reduce size of GPU memory context if linear algebra functions are not used
5 #include <ATen/Context.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAConfig.h>
8 #include <ATen/NativeFunctions.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/DynamicLibrary.h>
11 #include <ATen/NativeFunctions.h>
12 #include <ATen/native/cuda/MiscUtils.h>
13 #include <ATen/native/Resize.h>
14 #include <ATen/native/LinearAlgebra.h>
15 #include <ATen/native/BatchLinearAlgebra.h>
16 #include <ATen/native/TransposeType.h>
17 #if defined(BUILD_LAZY_CUDA_LINALG)
18 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
19
20 #if AT_MAGMA_ENABLED()
21 #include <ATen/cuda/detail/CUDAHooks.h>
22
23 namespace {
24 struct MagmaInitializer {
MagmaInitializer__anon96ca63d10111::MagmaInitializer25 MagmaInitializer() {
26 ::at::cuda::detail::set_magma_init_fn([]{ });
27 };
28 } initializer;
29 } // namespace (anonymous)
30 #endif
31 #endif
32 namespace at::native {
33 #if defined(BUILD_LAZY_CUDA_LINALG)
34 namespace {
35 cuda::detail::LinalgDispatch disp = {_cholesky_solve_helper_cuda};
36
getTorchLinalgLibrary()37 at::DynamicLibrary& getTorchLinalgLibrary() {
38 static at::DynamicLibrary lib("libtorch_cuda_linalg.so", nullptr, true);
39 return lib;
40 }
41
42 // Lazy dispatches do nothing but load linalg library and call the stub
43 // Loading the library should override the registration of those with the proper implementation
44 // getTorchLinalgLibrary() throws an exception if library is not found,
45 // which makes it unnecessary to have an explicit error checking
46 // But make sure that this function is called only once, to avoid infinite recursion
loadLazyTorchLinalgLibrary()47 void loadLazyTorchLinalgLibrary() {
48 static int invoke_count = 0;
49 getTorchLinalgLibrary();
50 TORCH_CHECK(invoke_count++ == 0, "lazy wrapper should be called at most once");
51 }
52
lazy_cholesky_kernel(const Tensor & input,const Tensor & info,bool upper)53 void lazy_cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) {
54 loadLazyTorchLinalgLibrary();
55 cholesky_stub(DeviceType::CUDA, input, info, upper);
56 }
57
lazy_cholesky_inverse_kernel(Tensor & result,Tensor & infos,bool upper)58 Tensor& lazy_cholesky_inverse_kernel(Tensor &result, Tensor& infos, bool upper) {
59 loadLazyTorchLinalgLibrary();
60 return cholesky_inverse_stub(DeviceType::CUDA, result, infos, upper);
61 }
62
lazy_lu_factor(const Tensor & input,const Tensor & pivots,const Tensor & infos,bool compute_pivots)63 void lazy_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
64 loadLazyTorchLinalgLibrary();
65 lu_factor_stub(DeviceType::CUDA, input, pivots, infos, compute_pivots);
66 }
67
lazy_triangular_solve_kernel(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)68 void lazy_triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
69 loadLazyTorchLinalgLibrary();
70 triangular_solve_stub(DeviceType::CUDA, A, B, left, upper, transpose, unitriangular);
71 }
72
lazy_orgqr_kernel(Tensor & result,const Tensor & tau)73 Tensor& lazy_orgqr_kernel(Tensor& result, const Tensor& tau) {
74 loadLazyTorchLinalgLibrary();
75 return orgqr_stub(DeviceType::CUDA, result, tau);
76 }
77
lazy_ormqr_kernel(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)78 void lazy_ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
79 loadLazyTorchLinalgLibrary();
80 ormqr_stub(DeviceType::CUDA, input, tau, other, left, transpose);
81 }
82
lazy_geqrf_kernel(const Tensor & input,const Tensor & tau)83 void lazy_geqrf_kernel(const Tensor& input, const Tensor& tau) {
84 loadLazyTorchLinalgLibrary();
85 geqrf_stub(DeviceType::CUDA, input, tau);
86 }
87
lazy_linalg_eigh_kernel(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)88 void lazy_linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
89 loadLazyTorchLinalgLibrary();
90 linalg_eigh_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
91 }
92
lazy_linalg_eig_kernel(Tensor & eigenvalues,Tensor & eigenvectors,Tensor & infos,const Tensor & input,bool compute_eigenvectors)93 void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
94 getTorchLinalgLibrary();
95 linalg_eig_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
96 }
97
lazy_svd_kernel(const Tensor & A,const bool full_matrices,const bool compute_uv,const std::optional<c10::string_view> & driver,const Tensor & U,const Tensor & S,const Tensor & Vh,const Tensor & info)98 void lazy_svd_kernel(const Tensor& A,
99 const bool full_matrices,
100 const bool compute_uv,
101 const std::optional<c10::string_view>& driver,
102 const Tensor& U,
103 const Tensor& S,
104 const Tensor& Vh,
105 const Tensor& info) {
106 getTorchLinalgLibrary();
107 svd_stub(DeviceType::CUDA, A, full_matrices, compute_uv, driver, U, S, Vh, info);
108 }
109
lazy_lu_solve(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType trans)110 void lazy_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
111 getTorchLinalgLibrary();
112 lu_solve_stub(DeviceType::CUDA, LU, pivots, B, trans);
113 }
114
lazy_lstsq_kernel(const Tensor & a,Tensor & b,Tensor & rank,Tensor & singular_values,Tensor & infos,double rcond,std::string driver_name)115 void lazy_lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) {
116 getTorchLinalgLibrary();
117 lstsq_stub(DeviceType::CUDA, a, b, rank, singular_values, infos, rcond, driver_name);
118 }
119
lazy_ldl_factor(const Tensor & LD,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)120 void lazy_ldl_factor(
121 const Tensor& LD,
122 const Tensor& pivots,
123 const Tensor& info,
124 bool upper,
125 bool hermitian) {
126 loadLazyTorchLinalgLibrary();
127 ldl_factor_stub(DeviceType::CUDA, LD, pivots, info, upper, hermitian);
128 }
129
lazy_ldl_solve(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool upper,bool hermitian)130 void lazy_ldl_solve(
131 const Tensor& LD,
132 const Tensor& pivots,
133 const Tensor& B,
134 bool upper,
135 bool hermitian) {
136 loadLazyTorchLinalgLibrary();
137 ldl_solve_stub(DeviceType::CUDA, LD, pivots, B, upper, hermitian);
138 }
139
140 REGISTER_CUDA_DISPATCH(cholesky_stub, &lazy_cholesky_kernel)
141 REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel);
142 REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor);
143 REGISTER_CUDA_DISPATCH(ldl_factor_stub, &lazy_ldl_factor);
144 REGISTER_CUDA_DISPATCH(ldl_solve_stub, &lazy_ldl_solve);
145 REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel);
146 REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
147 REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
148 REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
149 REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
150 REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
151 REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
152 REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
153 REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
154 } // anonymous namespace
155
156 // Old style dispatches
157 // torch_cuda_linalg dynamic library should have a global constructor
158 // that calls regiserLinaglDispatch so in order ot lazy bind
159 // old style dispatch all one have to do is to load library and call disp.func_name
160 // Protect from infinite recursion by initializing dispatch to self and checking
161 // that values are different after linalg library were loaded
162
163 namespace cuda {
164 namespace detail {
registerLinalgDispatch(const LinalgDispatch & disp_)165 void registerLinalgDispatch(const LinalgDispatch& disp_) {
166 disp = disp_;
167 }
168 }} //namespace cuda::detail
169
_cholesky_solve_helper_cuda(const Tensor & self,const Tensor & A,bool upper)170 Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
171 getTorchLinalgLibrary();
172 TORCH_CHECK(disp.cholesky_solve_helper != _cholesky_solve_helper_cuda, "Can't find _cholesky_solve_helper_cuda");
173 return disp.cholesky_solve_helper(self, A, upper);
174 }
175
176 #endif /*defined(BUILD_LAZY_CUDA_LINALG)*/
177
178 } // namespace at::native
179