1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file wraps rocblas API calls with dso loader so that we don't need to 17 // have explicit linking to librocblas. All TF hipsarse API usage should route 18 // through this wrapper. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ 22 23 #include "rocm/include/rocblas.h" 24 #include "tensorflow/stream_executor/gpu/gpu_activation.h" 25 #include "tensorflow/stream_executor/lib/env.h" 26 #include "tensorflow/stream_executor/platform/dso_loader.h" 27 #include "tensorflow/stream_executor/platform/port.h" 28 29 namespace tensorflow { 30 namespace wrap { 31 32 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; 33 34 #ifdef PLATFORM_GOOGLE 35 #define ROCBLAS_API_WRAPPER(__name) \ 36 struct WrapperShim__##__name { \ 37 static const char* kName; \ 38 template <typename... Args> \ 39 rocblas_status operator()(Args... args) { \ 40 return ::__name(args...); \ 41 } \ 42 } __name; \ 43 const char* WrapperShim__##__name::kName = #__name; 44 45 #else 46 47 #define ROCBLAS_API_WRAPPER(__name) \ 48 struct DynLoadShim__##__name { \ 49 static const char* kName; \ 50 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ 51 static void* GetDsoHandle() { \ 52 auto s = GetRocblasDsoHandle(); \ 53 return s.ValueOrDie(); \ 54 } \ 55 static FuncPtrT LoadOrDie() { \ 56 void* f; \ 57 auto s = \ 58 Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ 59 CHECK(s.ok()) << "could not find " << kName \ 60 << " in rocblas DSO; dlerror: " << s.error_message(); \ 61 return reinterpret_cast<FuncPtrT>(f); \ 62 } \ 63 static FuncPtrT DynLoad() { \ 64 static FuncPtrT f = LoadOrDie(); \ 65 return f; \ 66 } \ 67 template <typename... Args> \ 68 rocblas_status operator()(Args... args) { \ 69 return DynLoad()(args...); \ 70 } \ 71 } __name; \ 72 const char* DynLoadShim__##__name::kName = #__name; 73 74 #endif 75 76 // clang-format off 77 #define FOREACH_ROCBLAS_API(__macro) \ 78 __macro(rocblas_snrm2) \ 79 __macro(rocblas_dnrm2) \ 80 __macro(rocblas_scnrm2) \ 81 __macro(rocblas_dznrm2) \ 82 __macro(rocblas_sdot) \ 83 __macro(rocblas_ddot) \ 84 __macro(rocblas_cdotu) \ 85 __macro(rocblas_cdotc) \ 86 __macro(rocblas_zdotu) \ 87 __macro(rocblas_zdotc) \ 88 __macro(rocblas_sscal) \ 89 __macro(rocblas_dscal) \ 90 __macro(rocblas_cscal) \ 91 __macro(rocblas_csscal) \ 92 __macro(rocblas_zscal) \ 93 __macro(rocblas_zdscal) \ 94 __macro(rocblas_saxpy) \ 95 __macro(rocblas_daxpy) \ 96 __macro(rocblas_caxpy) \ 97 __macro(rocblas_zaxpy) \ 98 __macro(rocblas_scopy) \ 99 __macro(rocblas_dcopy) \ 100 __macro(rocblas_ccopy) \ 101 __macro(rocblas_zcopy) \ 102 __macro(rocblas_sswap) \ 103 __macro(rocblas_dswap) \ 104 __macro(rocblas_cswap) \ 105 __macro(rocblas_zswap) \ 106 __macro(rocblas_isamax) \ 107 __macro(rocblas_idamax) \ 108 __macro(rocblas_icamax) \ 109 __macro(rocblas_izamax) \ 110 __macro(rocblas_isamin) \ 111 __macro(rocblas_idamin) \ 112 __macro(rocblas_icamin) \ 113 __macro(rocblas_izamin) \ 114 __macro(rocblas_sasum) \ 115 __macro(rocblas_dasum) \ 116 __macro(rocblas_scasum) \ 117 __macro(rocblas_dzasum) \ 118 __macro(rocblas_srot) \ 119 __macro(rocblas_drot) \ 120 __macro(rocblas_crot) \ 121 __macro(rocblas_csrot) \ 122 __macro(rocblas_zrot) \ 123 __macro(rocblas_zdrot) \ 124 __macro(rocblas_srotg) \ 125 __macro(rocblas_drotg) \ 126 __macro(rocblas_crotg) \ 127 __macro(rocblas_zrotg) \ 128 __macro(rocblas_srotm) \ 129 __macro(rocblas_drotm) \ 130 __macro(rocblas_srotmg) \ 131 __macro(rocblas_drotmg) \ 132 __macro(rocblas_sgemv) \ 133 __macro(rocblas_dgemv) \ 134 __macro(rocblas_cgemv) \ 135 __macro(rocblas_zgemv) \ 136 __macro(rocblas_sgbmv) \ 137 __macro(rocblas_dgbmv) \ 138 __macro(rocblas_cgbmv) \ 139 __macro(rocblas_zgbmv) \ 140 __macro(rocblas_strmv) \ 141 __macro(rocblas_dtrmv) \ 142 __macro(rocblas_ctrmv) \ 143 __macro(rocblas_ztrmv) \ 144 __macro(rocblas_stbmv) \ 145 __macro(rocblas_dtbmv) \ 146 __macro(rocblas_ctbmv) \ 147 __macro(rocblas_ztbmv) \ 148 __macro(rocblas_stpmv) \ 149 __macro(rocblas_dtpmv) \ 150 __macro(rocblas_ctpmv) \ 151 __macro(rocblas_ztpmv) \ 152 __macro(rocblas_strsv) \ 153 __macro(rocblas_dtrsv) \ 154 __macro(rocblas_ctrsv) \ 155 __macro(rocblas_ztrsv) \ 156 __macro(rocblas_stpsv) \ 157 __macro(rocblas_dtpsv) \ 158 __macro(rocblas_ctpsv) \ 159 __macro(rocblas_ztpsv) \ 160 __macro(rocblas_stbsv) \ 161 __macro(rocblas_dtbsv) \ 162 __macro(rocblas_ctbsv) \ 163 __macro(rocblas_ztbsv) \ 164 __macro(rocblas_ssymv) \ 165 __macro(rocblas_dsymv) \ 166 __macro(rocblas_csymv) \ 167 __macro(rocblas_zsymv) \ 168 __macro(rocblas_chemv) \ 169 __macro(rocblas_zhemv) \ 170 __macro(rocblas_ssbmv) \ 171 __macro(rocblas_dsbmv) \ 172 __macro(rocblas_chbmv) \ 173 __macro(rocblas_zhbmv) \ 174 __macro(rocblas_sspmv) \ 175 __macro(rocblas_dspmv) \ 176 __macro(rocblas_chpmv) \ 177 __macro(rocblas_zhpmv) \ 178 __macro(rocblas_sger) \ 179 __macro(rocblas_dger) \ 180 __macro(rocblas_cgeru) \ 181 __macro(rocblas_cgerc) \ 182 __macro(rocblas_zgeru) \ 183 __macro(rocblas_zgerc) \ 184 __macro(rocblas_ssyr) \ 185 __macro(rocblas_dsyr) \ 186 __macro(rocblas_csyr) \ 187 __macro(rocblas_zsyr) \ 188 __macro(rocblas_cher) \ 189 __macro(rocblas_zher) \ 190 __macro(rocblas_sspr) \ 191 __macro(rocblas_dspr) \ 192 __macro(rocblas_chpr) \ 193 __macro(rocblas_zhpr) \ 194 __macro(rocblas_ssyr2) \ 195 __macro(rocblas_dsyr2) \ 196 __macro(rocblas_csyr2) \ 197 __macro(rocblas_zsyr2) \ 198 __macro(rocblas_cher2) \ 199 __macro(rocblas_zher2) \ 200 __macro(rocblas_sspr2) \ 201 __macro(rocblas_dspr2) \ 202 __macro(rocblas_chpr2) \ 203 __macro(rocblas_zhpr2) \ 204 __macro(rocblas_sgemm) \ 205 __macro(rocblas_dgemm) \ 206 __macro(rocblas_hgemm) \ 207 __macro(rocblas_cgemm) \ 208 __macro(rocblas_zgemm) \ 209 __macro(rocblas_ssyrk) \ 210 __macro(rocblas_dsyrk) \ 211 __macro(rocblas_csyrk) \ 212 __macro(rocblas_zsyrk) \ 213 __macro(rocblas_cherk) \ 214 __macro(rocblas_zherk) \ 215 __macro(rocblas_ssyr2k) \ 216 __macro(rocblas_dsyr2k) \ 217 __macro(rocblas_csyr2k) \ 218 __macro(rocblas_zsyr2k) \ 219 __macro(rocblas_cher2k) \ 220 __macro(rocblas_zher2k) \ 221 __macro(rocblas_ssyrkx) \ 222 __macro(rocblas_dsyrkx) \ 223 __macro(rocblas_csyrkx) \ 224 __macro(rocblas_zsyrkx) \ 225 __macro(rocblas_cherkx) \ 226 __macro(rocblas_zherkx) \ 227 __macro(rocblas_ssymm) \ 228 __macro(rocblas_dsymm) \ 229 __macro(rocblas_csymm) \ 230 __macro(rocblas_zsymm) \ 231 __macro(rocblas_chemm) \ 232 __macro(rocblas_zhemm) \ 233 __macro(rocblas_strsm) \ 234 __macro(rocblas_dtrsm) \ 235 __macro(rocblas_ctrsm) \ 236 __macro(rocblas_ztrsm) \ 237 __macro(rocblas_strmm) \ 238 __macro(rocblas_dtrmm) \ 239 __macro(rocblas_ctrmm) \ 240 __macro(rocblas_ztrmm) \ 241 __macro(rocblas_sgeam) \ 242 __macro(rocblas_dgeam) \ 243 __macro(rocblas_cgeam) \ 244 __macro(rocblas_zgeam) \ 245 __macro(rocblas_sdgmm) \ 246 __macro(rocblas_ddgmm) \ 247 __macro(rocblas_cdgmm) \ 248 __macro(rocblas_zdgmm) \ 249 __macro(rocblas_sgemm_batched) \ 250 __macro(rocblas_dgemm_batched) \ 251 __macro(rocblas_cgemm_batched) \ 252 __macro(rocblas_zgemm_batched) \ 253 __macro(rocblas_hgemm_strided_batched) \ 254 __macro(rocblas_sgemm_strided_batched) \ 255 __macro(rocblas_dgemm_strided_batched) \ 256 __macro(rocblas_cgemm_strided_batched) \ 257 __macro(rocblas_zgemm_strided_batched) \ 258 __macro(rocblas_gemm_ex) \ 259 __macro(rocblas_gemm_strided_batched_ex) \ 260 __macro(rocblas_strsm_batched) \ 261 __macro(rocblas_dtrsm_batched) \ 262 __macro(rocblas_ctrsm_batched) \ 263 __macro(rocblas_ztrsm_batched) \ 264 __macro(rocblas_create_handle) \ 265 __macro(rocblas_destroy_handle) \ 266 __macro(rocblas_set_stream) 267 268 // clang-format on 269 270 FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER) 271 272 } // namespace wrap 273 } // namespace tensorflow 274 275 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ 276