1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file wraps rocm driver calls with dso loader so that we don't need to 17 // have explicit linking to librocm. All TF rocm driver usage should route 18 // through this wrapper. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ 22 23 #define __HIP_DISABLE_CPP_FUNCTIONS__ 24 25 #include "rocm/include/hip/hip_runtime.h" 26 #include "tensorflow/stream_executor/lib/env.h" 27 #include "tensorflow/stream_executor/platform/dso_loader.h" 28 #include "tensorflow/stream_executor/platform/port.h" 29 30 namespace tensorflow { 31 namespace wrap { 32 #ifdef PLATFORM_GOOGLE 33 // Use static linked library 34 #define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ 35 template <typename... Args> \ 36 auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \ 37 return ::hipSymbolName(args...); \ 38 } 39 40 // This macro wraps a global identifier, given by hipSymbolName, in a callable 41 // structure that loads the DLL symbol out of the DSO handle in a thread-safe 42 // manner on first use. This dynamic loading technique is used to avoid DSO 43 // dependencies on vendor libraries which may or may not be available in the 44 // deployed binary environment. 45 #else 46 #define TO_STR_(x) #x 47 #define TO_STR(x) TO_STR_(x) 48 49 #define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ 50 template <typename... Args> \ 51 auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \ 52 using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \ 53 static FuncPtrT loaded = []() -> FuncPtrT { \ 54 static const char *kName = TO_STR(hipSymbolName); \ 55 void *f; \ 56 auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \ 57 stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ 58 .ValueOrDie(), \ 59 kName, &f); \ 60 CHECK(s.ok()) << "could not find " << kName \ 61 << " in HIP DSO; dlerror: " << s.error_message(); \ 62 return reinterpret_cast<FuncPtrT>(f); \ 63 }(); \ 64 return loaded(args...); \ 65 } 66 #endif 67 68 // clang-format off 69 // IMPORTANT: if you add a new HIP API to this list, please notify 70 // the rocm-profiler developers to track the API traces. 71 #define HIP_ROUTINE_EACH(__macro) \ 72 __macro(hipDeviceCanAccessPeer) \ 73 __macro(hipDeviceEnablePeerAccess) \ 74 __macro(hipDeviceGet) \ 75 __macro(hipDeviceGetAttribute) \ 76 __macro(hipDeviceGetName) \ 77 __macro(hipDeviceGetPCIBusId) \ 78 __macro(hipDeviceGetSharedMemConfig) \ 79 __macro(hipDeviceSetSharedMemConfig) \ 80 __macro(hipDeviceSynchronize) \ 81 __macro(hipDeviceTotalMem) \ 82 __macro(hipDriverGetVersion) \ 83 __macro(hipEventCreateWithFlags) \ 84 __macro(hipEventDestroy) \ 85 __macro(hipEventElapsedTime) \ 86 __macro(hipEventQuery) \ 87 __macro(hipEventRecord) \ 88 __macro(hipEventSynchronize) \ 89 __macro(hipFree) \ 90 __macro(hipFuncSetCacheConfig) \ 91 __macro(hipGetDevice) \ 92 __macro(hipGetDeviceCount) \ 93 __macro(hipGetDeviceProperties) \ 94 __macro(hipHostFree) \ 95 __macro(hipHostMalloc) \ 96 __macro(hipHostRegister) \ 97 __macro(hipHostUnregister) \ 98 __macro(hipInit) \ 99 __macro(hipMalloc) \ 100 __macro(hipMemGetAddressRange) \ 101 __macro(hipMemGetInfo) \ 102 __macro(hipMemcpyDtoD) \ 103 __macro(hipMemcpyDtoDAsync) \ 104 __macro(hipMemcpyDtoH) \ 105 __macro(hipMemcpyDtoHAsync) \ 106 __macro(hipMemcpyHtoD) \ 107 __macro(hipMemcpyHtoDAsync) \ 108 __macro(hipMemset) \ 109 __macro(hipMemsetD8) \ 110 __macro(hipMemsetD16) \ 111 __macro(hipMemsetD32) \ 112 __macro(hipMemsetAsync) \ 113 __macro(hipMemsetD8Async) \ 114 __macro(hipMemsetD16Async) \ 115 __macro(hipMemsetD32Async) \ 116 __macro(hipModuleGetFunction) \ 117 __macro(hipModuleGetGlobal) \ 118 __macro(hipModuleLaunchKernel) \ 119 __macro(hipModuleLoadData) \ 120 __macro(hipModuleUnload) \ 121 __macro(hipPointerGetAttributes) \ 122 __macro(hipSetDevice) \ 123 __macro(hipStreamAddCallback) \ 124 __macro(hipStreamCreateWithFlags) \ 125 __macro(hipStreamCreateWithPriority) \ 126 __macro(hipStreamDestroy) \ 127 __macro(hipStreamQuery) \ 128 __macro(hipStreamSynchronize) \ 129 __macro(hipStreamWaitEvent) \ 130 // clang-format on 131 132 HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP) 133 #undef HIP_ROUTINE_EACH 134 #undef STREAM_EXECUTOR_HIP_WRAP 135 #undef TO_STR 136 #undef TO_STR_ 137 } // namespace wrap 138 } // namespace tensorflow 139 140 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ 141