xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocblas_wrapper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps rocblas API calls with dso loader so that we don't need to
17 // have explicit linking to librocblas. All TF hipsarse API usage should route
18 // through this wrapper.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
22 
23 #include "rocm/include/rocblas.h"
24 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
25 #include "tensorflow/stream_executor/lib/env.h"
26 #include "tensorflow/stream_executor/platform/dso_loader.h"
27 #include "tensorflow/stream_executor/platform/port.h"
28 
29 namespace tensorflow {
30 namespace wrap {
31 
32 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
33 
34 #ifdef PLATFORM_GOOGLE
35 #define ROCBLAS_API_WRAPPER(__name)           \
36   struct WrapperShim__##__name {              \
37     static const char* kName;                 \
38     template <typename... Args>               \
39     rocblas_status operator()(Args... args) { \
40       return ::__name(args...);               \
41     }                                         \
42   } __name;                                   \
43   const char* WrapperShim__##__name::kName = #__name;
44 
45 #else
46 
47 #define ROCBLAS_API_WRAPPER(__name)                                        \
48   struct DynLoadShim__##__name {                                           \
49     static const char* kName;                                              \
50     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;           \
51     static void* GetDsoHandle() {                                          \
52       auto s = GetRocblasDsoHandle();                                      \
53       return s.ValueOrDie();                                               \
54     }                                                                      \
55     static FuncPtrT LoadOrDie() {                                          \
56       void* f;                                                             \
57       auto s =                                                             \
58           Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \
59       CHECK(s.ok()) << "could not find " << kName                          \
60                     << " in rocblas DSO; dlerror: " << s.error_message();  \
61       return reinterpret_cast<FuncPtrT>(f);                                \
62     }                                                                      \
63     static FuncPtrT DynLoad() {                                            \
64       static FuncPtrT f = LoadOrDie();                                     \
65       return f;                                                            \
66     }                                                                      \
67     template <typename... Args>                                            \
68     rocblas_status operator()(Args... args) {                              \
69       return DynLoad()(args...);                                           \
70     }                                                                      \
71   } __name;                                                                \
72   const char* DynLoadShim__##__name::kName = #__name;
73 
74 #endif
75 
76 // clang-format off
77 #define FOREACH_ROCBLAS_API(__macro)            \
78   __macro(rocblas_snrm2)                        \
79   __macro(rocblas_dnrm2)                        \
80   __macro(rocblas_scnrm2)                       \
81   __macro(rocblas_dznrm2)                       \
82   __macro(rocblas_sdot)                         \
83   __macro(rocblas_ddot)                         \
84   __macro(rocblas_cdotu)                        \
85   __macro(rocblas_cdotc)                        \
86   __macro(rocblas_zdotu)                        \
87   __macro(rocblas_zdotc)                        \
88   __macro(rocblas_sscal)                        \
89   __macro(rocblas_dscal)                        \
90   __macro(rocblas_cscal)                        \
91   __macro(rocblas_csscal)                       \
92   __macro(rocblas_zscal)                        \
93   __macro(rocblas_zdscal)                       \
94   __macro(rocblas_saxpy)                        \
95   __macro(rocblas_daxpy)                        \
96   __macro(rocblas_caxpy)                        \
97   __macro(rocblas_zaxpy)                        \
98   __macro(rocblas_scopy)                        \
99   __macro(rocblas_dcopy)                        \
100   __macro(rocblas_ccopy)                        \
101   __macro(rocblas_zcopy)                        \
102   __macro(rocblas_sswap)                        \
103   __macro(rocblas_dswap)                        \
104   __macro(rocblas_cswap)                        \
105   __macro(rocblas_zswap)                        \
106   __macro(rocblas_isamax)                       \
107   __macro(rocblas_idamax)                       \
108   __macro(rocblas_icamax)                       \
109   __macro(rocblas_izamax)                       \
110   __macro(rocblas_isamin)                       \
111   __macro(rocblas_idamin)                       \
112   __macro(rocblas_icamin)                       \
113   __macro(rocblas_izamin)                       \
114   __macro(rocblas_sasum)                        \
115   __macro(rocblas_dasum)                        \
116   __macro(rocblas_scasum)                       \
117   __macro(rocblas_dzasum)                       \
118   __macro(rocblas_srot)                         \
119   __macro(rocblas_drot)                         \
120   __macro(rocblas_crot)                         \
121   __macro(rocblas_csrot)                        \
122   __macro(rocblas_zrot)                         \
123   __macro(rocblas_zdrot)                        \
124   __macro(rocblas_srotg)                        \
125   __macro(rocblas_drotg)                        \
126   __macro(rocblas_crotg)                        \
127   __macro(rocblas_zrotg)                        \
128   __macro(rocblas_srotm)                        \
129   __macro(rocblas_drotm)                        \
130   __macro(rocblas_srotmg)                       \
131   __macro(rocblas_drotmg)                       \
132   __macro(rocblas_sgemv)                        \
133   __macro(rocblas_dgemv)                        \
134   __macro(rocblas_cgemv)                        \
135   __macro(rocblas_zgemv)                        \
136   __macro(rocblas_sgbmv)                        \
137   __macro(rocblas_dgbmv)                        \
138   __macro(rocblas_cgbmv)                        \
139   __macro(rocblas_zgbmv)                        \
140   __macro(rocblas_strmv)                        \
141   __macro(rocblas_dtrmv)                        \
142   __macro(rocblas_ctrmv)                        \
143   __macro(rocblas_ztrmv)                        \
144   __macro(rocblas_stbmv)                        \
145   __macro(rocblas_dtbmv)                        \
146   __macro(rocblas_ctbmv)                        \
147   __macro(rocblas_ztbmv)                        \
148   __macro(rocblas_stpmv)                        \
149   __macro(rocblas_dtpmv)                        \
150   __macro(rocblas_ctpmv)                        \
151   __macro(rocblas_ztpmv)                        \
152   __macro(rocblas_strsv)                        \
153   __macro(rocblas_dtrsv)                        \
154   __macro(rocblas_ctrsv)                        \
155   __macro(rocblas_ztrsv)                        \
156   __macro(rocblas_stpsv)                        \
157   __macro(rocblas_dtpsv)                        \
158   __macro(rocblas_ctpsv)                        \
159   __macro(rocblas_ztpsv)                        \
160   __macro(rocblas_stbsv)                        \
161   __macro(rocblas_dtbsv)                        \
162   __macro(rocblas_ctbsv)                        \
163   __macro(rocblas_ztbsv)                        \
164   __macro(rocblas_ssymv)                        \
165   __macro(rocblas_dsymv)                        \
166   __macro(rocblas_csymv)                        \
167   __macro(rocblas_zsymv)                        \
168   __macro(rocblas_chemv)                        \
169   __macro(rocblas_zhemv)                        \
170   __macro(rocblas_ssbmv)                        \
171   __macro(rocblas_dsbmv)                        \
172   __macro(rocblas_chbmv)                        \
173   __macro(rocblas_zhbmv)                        \
174   __macro(rocblas_sspmv)                        \
175   __macro(rocblas_dspmv)                        \
176   __macro(rocblas_chpmv)                        \
177   __macro(rocblas_zhpmv)                        \
178   __macro(rocblas_sger)                         \
179   __macro(rocblas_dger)                         \
180   __macro(rocblas_cgeru)                        \
181   __macro(rocblas_cgerc)                        \
182   __macro(rocblas_zgeru)                        \
183   __macro(rocblas_zgerc)                        \
184   __macro(rocblas_ssyr)                         \
185   __macro(rocblas_dsyr)                         \
186   __macro(rocblas_csyr)                         \
187   __macro(rocblas_zsyr)                         \
188   __macro(rocblas_cher)                         \
189   __macro(rocblas_zher)                         \
190   __macro(rocblas_sspr)                         \
191   __macro(rocblas_dspr)                         \
192   __macro(rocblas_chpr)                         \
193   __macro(rocblas_zhpr)                         \
194   __macro(rocblas_ssyr2)                        \
195   __macro(rocblas_dsyr2)                        \
196   __macro(rocblas_csyr2)                        \
197   __macro(rocblas_zsyr2)                        \
198   __macro(rocblas_cher2)                        \
199   __macro(rocblas_zher2)                        \
200   __macro(rocblas_sspr2)                        \
201   __macro(rocblas_dspr2)                        \
202   __macro(rocblas_chpr2)                        \
203   __macro(rocblas_zhpr2)                        \
204   __macro(rocblas_sgemm)                        \
205   __macro(rocblas_dgemm)                        \
206   __macro(rocblas_hgemm)                        \
207   __macro(rocblas_cgemm)                        \
208   __macro(rocblas_zgemm)                        \
209   __macro(rocblas_ssyrk)                        \
210   __macro(rocblas_dsyrk)                        \
211   __macro(rocblas_csyrk)                        \
212   __macro(rocblas_zsyrk)                        \
213   __macro(rocblas_cherk)                        \
214   __macro(rocblas_zherk)                        \
215   __macro(rocblas_ssyr2k)                       \
216   __macro(rocblas_dsyr2k)                       \
217   __macro(rocblas_csyr2k)                       \
218   __macro(rocblas_zsyr2k)                       \
219   __macro(rocblas_cher2k)                       \
220   __macro(rocblas_zher2k)                       \
221   __macro(rocblas_ssyrkx)                       \
222   __macro(rocblas_dsyrkx)                       \
223   __macro(rocblas_csyrkx)                       \
224   __macro(rocblas_zsyrkx)                       \
225   __macro(rocblas_cherkx)                       \
226   __macro(rocblas_zherkx)                       \
227   __macro(rocblas_ssymm)                        \
228   __macro(rocblas_dsymm)                        \
229   __macro(rocblas_csymm)                        \
230   __macro(rocblas_zsymm)                        \
231   __macro(rocblas_chemm)                        \
232   __macro(rocblas_zhemm)                        \
233   __macro(rocblas_strsm)                        \
234   __macro(rocblas_dtrsm)                        \
235   __macro(rocblas_ctrsm)                        \
236   __macro(rocblas_ztrsm)                        \
237   __macro(rocblas_strmm)                        \
238   __macro(rocblas_dtrmm)                        \
239   __macro(rocblas_ctrmm)                        \
240   __macro(rocblas_ztrmm)                        \
241   __macro(rocblas_sgeam)                        \
242   __macro(rocblas_dgeam)                        \
243   __macro(rocblas_cgeam)                        \
244   __macro(rocblas_zgeam)                        \
245   __macro(rocblas_sdgmm)                        \
246   __macro(rocblas_ddgmm)                        \
247   __macro(rocblas_cdgmm)                        \
248   __macro(rocblas_zdgmm)                        \
249   __macro(rocblas_sgemm_batched)                \
250   __macro(rocblas_dgemm_batched)                \
251   __macro(rocblas_cgemm_batched)                \
252   __macro(rocblas_zgemm_batched)                \
253   __macro(rocblas_hgemm_strided_batched)        \
254   __macro(rocblas_sgemm_strided_batched)        \
255   __macro(rocblas_dgemm_strided_batched)        \
256   __macro(rocblas_cgemm_strided_batched)        \
257   __macro(rocblas_zgemm_strided_batched)        \
258   __macro(rocblas_gemm_ex)                      \
259   __macro(rocblas_gemm_strided_batched_ex)      \
260   __macro(rocblas_strsm_batched)                \
261   __macro(rocblas_dtrsm_batched)                \
262   __macro(rocblas_ctrsm_batched)                \
263   __macro(rocblas_ztrsm_batched)                \
264   __macro(rocblas_create_handle)                \
265   __macro(rocblas_destroy_handle)               \
266   __macro(rocblas_set_stream)
267 
268 // clang-format on
269 
270 FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER)
271 
272 }  // namespace wrap
273 }  // namespace tensorflow
274 
275 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_
276