xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps rocm driver calls with dso loader so that we don't need to
17 // have explicit linking to librocm. All TF rocm driver usage should route
18 // through this wrapper.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
22 
23 #define __HIP_DISABLE_CPP_FUNCTIONS__
24 
25 #include "rocm/include/hip/hip_runtime.h"
26 #include "tensorflow/stream_executor/lib/env.h"
27 #include "tensorflow/stream_executor/platform/dso_loader.h"
28 #include "tensorflow/stream_executor/platform/port.h"
29 
30 namespace tensorflow {
31 namespace wrap {
32 #ifdef PLATFORM_GOOGLE
33 // Use static linked library
34 #define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName)                          \
35   template <typename... Args>                                            \
36   auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
37     return ::hipSymbolName(args...);                                     \
38   }
39 
40 // This macro wraps a global identifier, given by hipSymbolName, in a callable
41 // structure that loads the DLL symbol out of the DSO handle in a thread-safe
42 // manner on first use. This dynamic loading technique is used to avoid DSO
43 // dependencies on vendor libraries which may or may not be available in the
44 // deployed binary environment.
45 #else
46 #define TO_STR_(x) #x
47 #define TO_STR(x) TO_STR_(x)
48 
49 #define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName)                             \
50   template <typename... Args>                                               \
51   auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) {    \
52     using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type;     \
53     static FuncPtrT loaded = []() -> FuncPtrT {                             \
54       static const char *kName = TO_STR(hipSymbolName);                     \
55       void *f;                                                              \
56       auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
57           stream_executor::internal::CachedDsoLoader::GetHipDsoHandle()     \
58               .ValueOrDie(),                                                \
59           kName, &f);                                                       \
60       CHECK(s.ok()) << "could not find " << kName                           \
61                     << " in HIP DSO; dlerror: " << s.error_message();       \
62       return reinterpret_cast<FuncPtrT>(f);                                 \
63     }();                                                                    \
64     return loaded(args...);                                                 \
65   }
66 #endif
67 
68 // clang-format off
69 // IMPORTANT: if you add a new HIP API to this list, please notify
70 // the rocm-profiler developers to track the API traces.
71 #define HIP_ROUTINE_EACH(__macro)                   \
72   __macro(hipDeviceCanAccessPeer)                   \
73   __macro(hipDeviceEnablePeerAccess)                \
74   __macro(hipDeviceGet)                             \
75   __macro(hipDeviceGetAttribute)                    \
76   __macro(hipDeviceGetName)                         \
77   __macro(hipDeviceGetPCIBusId)                     \
78   __macro(hipDeviceGetSharedMemConfig)              \
79   __macro(hipDeviceSetSharedMemConfig)              \
80   __macro(hipDeviceSynchronize)                     \
81   __macro(hipDeviceTotalMem)                        \
82   __macro(hipDriverGetVersion)                      \
83   __macro(hipEventCreateWithFlags)                  \
84   __macro(hipEventDestroy)                          \
85   __macro(hipEventElapsedTime)                      \
86   __macro(hipEventQuery)                            \
87   __macro(hipEventRecord)                           \
88   __macro(hipEventSynchronize)                      \
89   __macro(hipFree)                                  \
90   __macro(hipFuncSetCacheConfig)                    \
91   __macro(hipGetDevice)                             \
92   __macro(hipGetDeviceCount)                        \
93   __macro(hipGetDeviceProperties)                   \
94   __macro(hipHostFree)                              \
95   __macro(hipHostMalloc)                            \
96   __macro(hipHostRegister)                          \
97   __macro(hipHostUnregister)                        \
98   __macro(hipInit)                                  \
99   __macro(hipMalloc)                                \
100   __macro(hipMemGetAddressRange)                    \
101   __macro(hipMemGetInfo)                            \
102   __macro(hipMemcpyDtoD)                            \
103   __macro(hipMemcpyDtoDAsync)                       \
104   __macro(hipMemcpyDtoH)                            \
105   __macro(hipMemcpyDtoHAsync)                       \
106   __macro(hipMemcpyHtoD)                            \
107   __macro(hipMemcpyHtoDAsync)                       \
108   __macro(hipMemset)                                \
109   __macro(hipMemsetD8)                              \
110   __macro(hipMemsetD16)                             \
111   __macro(hipMemsetD32)                             \
112   __macro(hipMemsetAsync)                           \
113   __macro(hipMemsetD8Async)                         \
114   __macro(hipMemsetD16Async)                        \
115   __macro(hipMemsetD32Async)                        \
116   __macro(hipModuleGetFunction)                     \
117   __macro(hipModuleGetGlobal)                       \
118   __macro(hipModuleLaunchKernel)                    \
119   __macro(hipModuleLoadData)                        \
120   __macro(hipModuleUnload)                          \
121   __macro(hipPointerGetAttributes)                  \
122   __macro(hipSetDevice)                             \
123   __macro(hipStreamAddCallback)                     \
124   __macro(hipStreamCreateWithFlags)                 \
125   __macro(hipStreamCreateWithPriority)              \
126   __macro(hipStreamDestroy)                         \
127   __macro(hipStreamQuery)                           \
128   __macro(hipStreamSynchronize)                     \
129   __macro(hipStreamWaitEvent)                       \
130 // clang-format on
131 
132 HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)
133 #undef HIP_ROUTINE_EACH
134 #undef STREAM_EXECUTOR_HIP_WRAP
135 #undef TO_STR
136 #undef TO_STR_
137 }  // namespace wrap
138 }  // namespace tensorflow
139 
140 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
141