1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS 17 // library capabilities, and is only included into ROCM implementation code -- 18 // it will not introduce rocm headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 22 23 #include "absl/base/thread_annotations.h" 24 #include "absl/synchronization/mutex.h" 25 #include "absl/types/span.h" 26 #include "rocm/include/rocblas.h" 27 #include "tensorflow/stream_executor/blas.h" 28 #include "tensorflow/stream_executor/platform/port.h" 29 #include "tensorflow/stream_executor/plugin_registry.h" 30 #include "tensorflow/stream_executor/temporary_device_memory.h" 31 32 namespace stream_executor { 33 34 class Stream; 35 36 namespace gpu { 37 38 // Type conversion helper that helps to map non-rocblas types to rocblas types 39 // Right now, it only converts the Eigen::half type to rocblas_half type 40 template <typename T> 41 struct RocBlasTypeConversionHelper { 42 using mapped_type = T; 43 }; 44 45 template <> 46 struct RocBlasTypeConversionHelper<Eigen::half> { 47 using mapped_type = rocblas_half; 48 }; 49 50 template <> 51 struct RocBlasTypeConversionHelper<std::complex<float>> { 52 using mapped_type = rocblas_float_complex; 53 }; 54 55 template <> 56 struct RocBlasTypeConversionHelper<std::complex<double>> { 57 using mapped_type = rocblas_double_complex; 58 }; 59 60 // Opaque and unique identifier for the rocBLAS plugin. 61 extern const PluginId kRocBlasPlugin; 62 63 class GpuExecutor; 64 65 // BLAS plugin for ROCM platform via rocBLAS library. 66 // 67 // This satisfies the platform-agnostic BlasSupport interface. 68 // 69 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the 70 // context (and, as a result, the device) that the parent GpuExecutor is tied 71 // to. This simply happens as an artifact of creating the rocBLAS handle when a 72 // ROCM context is active. 73 // 74 // Thread-safe post-initialization. 75 class ROCMBlas : public blas::BlasSupport { 76 public: 77 explicit ROCMBlas(GpuExecutor *parent); 78 79 // Allocates a rocBLAS handle. 80 bool Init(); 81 82 // Releases the rocBLAS handle, if present. 83 ~ROCMBlas() override; 84 85 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 86 87 private: 88 // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. 89 // 90 // rocBLAS is stateful, and only be associated with one stream (in order to 91 // enqueue dispatch) at a given time. As a result, this generally must be 92 // invoked before calling into rocBLAS. 93 bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 94 95 // A helper function that calls the real rocBLAS function together with error 96 // handling. 97 // 98 // rocblas_func: rocBLAS function pointer. 99 // rocblas_name: rocBLAS function name. 100 // stream: Stream to enqueue the BLAS operation onto. 101 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 102 // (true) or device (false). 103 // err_on_failure: Whether to print an error if the rocBLAS function 104 // fails. args: Arguments of rocBLAS function. 105 template <typename FuncT, typename... Args> 106 bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, 107 bool pointer_mode_host, bool err_on_failure, 108 Args... args); 109 110 // Convenience functions that call DoBlasInternalImpl with different values 111 // for err_on_failure. 112 template <typename FuncT, typename... Args> 113 bool DoBlasInternal(FuncT rocblas_func, Stream *stream, 114 bool pointer_mode_host, Args... args) { 115 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 116 /*err_on_failure=*/true, args...); 117 } 118 119 // Same as above, but returns Status. 120 template <typename... Args> 121 port::Status DoBlasInternalStatus(Args... args) { 122 if (!DoBlasInternal(args...)) { 123 return port::InternalError("Failed calling rocBLAS"); 124 } 125 return port::Status::OK(); 126 } 127 128 template <typename FuncT, typename... Args> 129 bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, 130 bool pointer_mode_host, Args... args) { 131 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 132 /*err_on_failure=*/false, args...); 133 } 134 135 // A helper allocation function to convert raw pointers memory layout to 136 // strided flavor 137 template <typename T> 138 port::Status AllocateStridedBuffer( 139 const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> 140 &raw_ptrs, 141 int batch_count, uint64_t batch_stride, 142 ScratchAllocator *scratch_allocator, Stream *stream, 143 std::unique_ptr<TemporaryDeviceMemory< 144 typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, 145 DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> 146 *device_memory, 147 bool copy_data, bool &reallocated); 148 149 // A helper function to implement DoBlasGemmBatched interfaces for generic 150 // types. 151 // 152 // Note: This function is implemented using gemm_strided_batched interface, 153 // NOT gemm_batched interface, because rocblas do not support it. As a 154 // result, if the passed in batch matrix are not allocated in strided batched 155 // format, it might end up in non-trivial amount of memory allocation and 156 // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched 157 // interface. 158 // 159 // In most use cases, batch matrix do get allocated in strided manner, making 160 // calling this interface equivalent with DoBlasGemmStridedBatched. The only 161 // use case we see so far that violates this observation is when batch 162 // matrix is created by broadcasting from a smaller matrix. When it happens, 163 // It will take advantage of the AllocateStridedBuffer subroutine to 164 // reallocate the memory layout to be strided batched. 165 template <typename T, typename FuncT> 166 port::Status DoBlasGemmBatchedInternal( 167 FuncT rocblas_func, Stream *stream, blas::Transpose transa, 168 blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha, 169 const absl::Span<DeviceMemory<T> *const> &a_ptrs_to_wrappers, int lda, 170 const absl::Span<DeviceMemory<T> *const> &b_ptrs_to_wrappers, int ldb, 171 T beta, const absl::Span<DeviceMemory<T> *const> &c_ptrs_to_wrappers, 172 int ldc, int batch_count, ScratchAllocator *scratch_allocator); 173 174 // Helper function for implementing DoBlasGemmWithProfiling. 175 template <typename T, typename ParamType> 176 bool DoBlasGemmWithProfilingImpl(Stream *stream, blas::Transpose transa, 177 blas::Transpose transb, uint64_t m, 178 uint64_t n, uint64 k, const ParamType &alpha, 179 const DeviceMemory<T> &a, int lda, 180 const DeviceMemory<T> &b, int ldb, 181 const ParamType &beta, DeviceMemory<T> *c, 182 int ldc, 183 blas::ProfileResult *output_profile_result); 184 185 // Helper function for implementing DoBlasGemvWithProfiling. 186 template <typename T> 187 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 188 uint64_t m, uint64 n, const T &alpha, 189 const DeviceMemory<T> &a, int lda, 190 const DeviceMemory<T> &x, int incx, 191 const T &beta, DeviceMemory<T> *y, int incy, 192 blas::ProfileResult *output_profile_result); 193 194 // mutex that guards the rocBLAS handle for this device. 195 absl::Mutex mu_; 196 197 // GpuExecutor which instantiated this ROCMBlas. 198 // Immutable post-initialization. 199 GpuExecutor *parent_; 200 201 // rocBLAS library handle on the device. 202 rocblas_handle blas_ ABSL_GUARDED_BY(mu_); 203 204 SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); 205 }; 206 207 } // namespace gpu 208 } // namespace stream_executor 209 210 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 211