xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocm_blas.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
17 // library capabilities, and is only included into ROCM implementation code --
18 // it will not introduce rocm headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
22 
23 #include "absl/base/thread_annotations.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/types/span.h"
26 #include "rocm/include/rocblas.h"
27 #include "tensorflow/stream_executor/blas.h"
28 #include "tensorflow/stream_executor/platform/port.h"
29 #include "tensorflow/stream_executor/plugin_registry.h"
30 #include "tensorflow/stream_executor/temporary_device_memory.h"
31 
32 namespace stream_executor {
33 
34 class Stream;
35 
36 namespace gpu {
37 
38 // Type conversion helper that helps to map non-rocblas types to rocblas types
39 // Right now, it only converts the Eigen::half type to rocblas_half type
40 template <typename T>
41 struct RocBlasTypeConversionHelper {
42   using mapped_type = T;
43 };
44 
45 template <>
46 struct RocBlasTypeConversionHelper<Eigen::half> {
47   using mapped_type = rocblas_half;
48 };
49 
50 template <>
51 struct RocBlasTypeConversionHelper<std::complex<float>> {
52   using mapped_type = rocblas_float_complex;
53 };
54 
55 template <>
56 struct RocBlasTypeConversionHelper<std::complex<double>> {
57   using mapped_type = rocblas_double_complex;
58 };
59 
60 // Opaque and unique identifier for the rocBLAS plugin.
61 extern const PluginId kRocBlasPlugin;
62 
63 class GpuExecutor;
64 
65 // BLAS plugin for ROCM platform via rocBLAS library.
66 //
67 // This satisfies the platform-agnostic BlasSupport interface.
68 //
69 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the
70 // context (and, as a result, the device) that the parent GpuExecutor is tied
71 // to. This simply happens as an artifact of creating the rocBLAS handle when a
72 // ROCM context is active.
73 //
74 // Thread-safe post-initialization.
75 class ROCMBlas : public blas::BlasSupport {
76  public:
77   explicit ROCMBlas(GpuExecutor *parent);
78 
79   // Allocates a rocBLAS handle.
80   bool Init();
81 
82   // Releases the rocBLAS handle, if present.
83   ~ROCMBlas() override;
84 
85   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
86 
87  private:
88   // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
89   //
90   // rocBLAS is stateful, and only be associated with one stream (in order to
91   // enqueue dispatch) at a given time. As a result, this generally must be
92   // invoked before calling into rocBLAS.
93   bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
94 
95   // A helper function that calls the real rocBLAS function together with error
96   // handling.
97   //
98   // rocblas_func:       rocBLAS function pointer.
99   // rocblas_name:       rocBLAS function name.
100   // stream:             Stream to enqueue the BLAS operation onto.
101   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
102   //                     (true) or device (false).
103   // err_on_failure:     Whether to print an error if the rocBLAS function
104   // fails. args:               Arguments of rocBLAS function.
105   template <typename FuncT, typename... Args>
106   bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
107                           bool pointer_mode_host, bool err_on_failure,
108                           Args... args);
109 
110   // Convenience functions that call DoBlasInternalImpl with different values
111   // for err_on_failure.
112   template <typename FuncT, typename... Args>
113   bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
114                       bool pointer_mode_host, Args... args) {
115     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
116                               /*err_on_failure=*/true, args...);
117   }
118 
119   // Same as above, but returns Status.
120   template <typename... Args>
121   port::Status DoBlasInternalStatus(Args... args) {
122     if (!DoBlasInternal(args...)) {
123       return port::InternalError("Failed calling rocBLAS");
124     }
125     return port::Status::OK();
126   }
127 
128   template <typename FuncT, typename... Args>
129   bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
130                                bool pointer_mode_host, Args... args) {
131     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
132                               /*err_on_failure=*/false, args...);
133   }
134 
135   // A helper allocation function to convert raw pointers memory layout to
136   // strided flavor
137   template <typename T>
138   port::Status AllocateStridedBuffer(
139       const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
140           &raw_ptrs,
141       int batch_count, uint64_t batch_stride,
142       ScratchAllocator *scratch_allocator, Stream *stream,
143       std::unique_ptr<TemporaryDeviceMemory<
144           typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
145       DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
146           *device_memory,
147       bool copy_data, bool &reallocated);
148 
149   // A helper function to implement DoBlasGemmBatched interfaces for generic
150   // types.
151   //
152   // Note: This function is implemented using gemm_strided_batched interface,
153   // NOT gemm_batched interface, because rocblas do not support it. As a
154   // result, if the passed in batch matrix are not allocated in strided batched
155   // format, it might end up in non-trivial amount of memory allocation and
156   // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched
157   // interface.
158   //
159   // In most use cases, batch matrix do get allocated in strided manner, making
160   // calling this interface equivalent with DoBlasGemmStridedBatched. The only
161   // use case we see so far that violates this observation is when batch
162   // matrix is created by broadcasting from a smaller matrix. When it happens,
163   // It will take advantage of the AllocateStridedBuffer subroutine to
164   // reallocate the memory layout to be strided batched.
165   template <typename T, typename FuncT>
166   port::Status DoBlasGemmBatchedInternal(
167       FuncT rocblas_func, Stream *stream, blas::Transpose transa,
168       blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha,
169       const absl::Span<DeviceMemory<T> *const> &a_ptrs_to_wrappers, int lda,
170       const absl::Span<DeviceMemory<T> *const> &b_ptrs_to_wrappers, int ldb,
171       T beta, const absl::Span<DeviceMemory<T> *const> &c_ptrs_to_wrappers,
172       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
173 
174   // Helper function for implementing DoBlasGemmWithProfiling.
175   template <typename T, typename ParamType>
176   bool DoBlasGemmWithProfilingImpl(Stream *stream, blas::Transpose transa,
177                                    blas::Transpose transb, uint64_t m,
178                                    uint64_t n, uint64 k, const ParamType &alpha,
179                                    const DeviceMemory<T> &a, int lda,
180                                    const DeviceMemory<T> &b, int ldb,
181                                    const ParamType &beta, DeviceMemory<T> *c,
182                                    int ldc,
183                                    blas::ProfileResult *output_profile_result);
184 
185   // Helper function for implementing DoBlasGemvWithProfiling.
186   template <typename T>
187   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
188                                    uint64_t m, uint64 n, const T &alpha,
189                                    const DeviceMemory<T> &a, int lda,
190                                    const DeviceMemory<T> &x, int incx,
191                                    const T &beta, DeviceMemory<T> *y, int incy,
192                                    blas::ProfileResult *output_profile_result);
193 
194   // mutex that guards the rocBLAS handle for this device.
195   absl::Mutex mu_;
196 
197   // GpuExecutor which instantiated this ROCMBlas.
198   // Immutable post-initialization.
199   GpuExecutor *parent_;
200 
201   // rocBLAS library handle on the device.
202   rocblas_handle blas_ ABSL_GUARDED_BY(mu_);
203 
204   SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
205 };
206 
207 }  // namespace gpu
208 }  // namespace stream_executor
209 
210 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
211