1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 /*
17 Mapping of GpuSolver Methods to respective ROCm Library APIs.
18 /-----------------------------------------------------------=-----------------------------/
19 / GpuSolverMethod // rocblasAPI // rocsolverAPI // hipsolverAPI
20 (ROCM>4.5) / / Geam // rocblas_Xgeam // ---- //
21 ---- / / Getrf // ---- //
22 rocsolver_Xgetrf // hipsolverXgetrf / / GetrfBatched // ---- //
23 ""_Xgetrf_batched // ---- / / GetriBatched // ---- //
24 ""_Xgetri_batched // ---- / / Getrs // ---- //
25 rocsolver_Xgetrs // hipsolverXgetrs / / GetrsBatched // ---- //
26 ""_Xgetrs_batched // ---- / / Geqrf // ---- //
27 rocsolver_Xgeqrf // hipsolverXgeqrf / / Heevd // ---- //
28 ---- // hipsolverXheevd / / Potrf // ----
29 // rocsolver_Xpotrf // hipsolverXpotrf / / PotrfBatched // ----
30 // ""_Xpotrf_batched // ""XpotrfBatched / / Trsm //
31 rocblas_Xtrsm // ---- // ---- / / Ungqr //
32 ---- // rocsolver_Xungqr // hipsolverXungqr / / Unmqr // ----
33 // rocsolver_Xunmqr // hipsolverXunmqr /
34 /-----------------------------------------------------------------------------------------/
35 */
36 #if TENSORFLOW_USE_ROCM
37 #include <complex>
38 #include <unordered_map>
39 #include <vector>
40
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/platform/mutex.h"
49 #include "tensorflow/core/platform/stream_executor.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/util/gpu_solvers.h"
52 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
53 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
54 #include "tensorflow/stream_executor/lib/env.h"
55 #include "tensorflow/stream_executor/platform/default/dso_loader.h"
56 #include "tensorflow/stream_executor/platform/port.h"
57 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h"
58
59 namespace tensorflow {
60 namespace {
61
62 using stream_executor::gpu::GpuExecutor;
63 using stream_executor::gpu::ScopedActivateExecutorContext;
64
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)65 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
66 const void* src, uint64 bytes) {
67 auto stream = context->op_device_context()->stream();
68 se::DeviceMemoryBase wrapped_dst(dst);
69 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
70 }
71
72 struct GpuSolverHandles {
GpuSolverHandlestensorflow::__anonee4b71a20111::GpuSolverHandles73 explicit GpuSolverHandles(GpuExecutor* parent, hipStream_t stream) {
74 parent_ = parent;
75 ScopedActivateExecutorContext sac{parent_};
76 #if TF_ROCM_VERSION >= 40500
77 CHECK(wrap::hipsolverCreate(&hipsolver_handle) == rocblas_status_success)
78 << "Failed to create hipsolver instance";
79 #endif
80 CHECK(wrap::rocblas_create_handle(&rocm_blas_handle) ==
81 rocblas_status_success)
82 << "Failed to create rocBlas instance.";
83 CHECK(wrap::rocblas_set_stream(rocm_blas_handle, stream) ==
84 rocblas_status_success)
85 << "Failed to set rocBlas stream.";
86 }
87
~GpuSolverHandlestensorflow::__anonee4b71a20111::GpuSolverHandles88 ~GpuSolverHandles() {
89 ScopedActivateExecutorContext sac{parent_};
90 CHECK(wrap::rocblas_destroy_handle(rocm_blas_handle) ==
91 rocblas_status_success)
92 << "Failed to destroy rocBlas instance.";
93 #if TF_ROCM_VERSION >= 40500
94 CHECK(wrap::hipsolverDestroy(hipsolver_handle) == rocblas_status_success)
95 << "Failed to destroy hipsolver instance.";
96 #endif
97 }
98 GpuExecutor* parent_;
99 rocblas_handle rocm_blas_handle;
100 #if TF_ROCM_VERSION >= 40500
101 hipsolverHandle_t hipsolver_handle;
102 #endif
103 };
104
105 using HandleMap =
106 std::unordered_map<hipStream_t, std::unique_ptr<GpuSolverHandles>>;
107
108 // Returns a singleton map used for storing initialized handles for each unique
109 // gpu stream.
GetHandleMapSingleton()110 HandleMap* GetHandleMapSingleton() {
111 static HandleMap* cm = new HandleMap;
112 return cm;
113 }
114
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116
117 } // namespace
118
GpuSolver(OpKernelContext * context)119 GpuSolver::GpuSolver(OpKernelContext* context) : context_(context) {
120 mutex_lock lock(handle_map_mutex);
121 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(
122 context->op_device_context()->stream()->parent()->implementation());
123 const hipStream_t* hip_stream_ptr = CHECK_NOTNULL(
124 reinterpret_cast<const hipStream_t*>(context->op_device_context()
125 ->stream()
126 ->implementation()
127 ->GpuStreamMemberHack()));
128
129 hip_stream_ = *hip_stream_ptr;
130 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
131 auto it = handle_map->find(hip_stream_);
132 if (it == handle_map->end()) {
133 LOG(INFO) << "Creating GpuSolver handles for stream " << hip_stream_;
134 // Previously unseen Gpu stream. Initialize a set of Gpu solver library
135 // handles for it.
136 std::unique_ptr<GpuSolverHandles> new_handles(
137 new GpuSolverHandles(gpu_executor, hip_stream_));
138 it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles)))
139 .first;
140 }
141 rocm_blas_handle_ = it->second->rocm_blas_handle;
142 #if TF_ROCM_VERSION >= 40500
143 hipsolver_handle_ = it->second->hipsolver_handle;
144 #endif
145 }
146
~GpuSolver()147 GpuSolver::~GpuSolver() {
148 for (auto tensor_ref : scratch_tensor_refs_) {
149 tensor_ref.Unref();
150 }
151 }
152
153 // Static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)154 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
155 std::unique_ptr<GpuSolver> solver,
156 const std::vector<DeviceLapackInfo>& dev_lapack_infos,
157 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
158 info_checker_callback) {
159 CHECK(info_checker_callback != nullptr);
160 std::vector<HostLapackInfo> host_lapack_infos;
161 if (dev_lapack_infos.empty()) {
162 info_checker_callback(Status::OK(), host_lapack_infos);
163 return;
164 }
165
166 // Launch memcpys to copy info back from device to host
167 for (const auto& dev_lapack_info : dev_lapack_infos) {
168 bool success = true;
169 auto host_copy = dev_lapack_info.CopyToHost(&success);
170 OP_REQUIRES(
171 solver->context(), success,
172 errors::Internal(
173 "Failed to launch copy of dev_lapack_info to host, debug_info = ",
174 dev_lapack_info.debug_info()));
175 host_lapack_infos.push_back(std::move(host_copy));
176 }
177
178 // This callback checks that all batch items in all calls were processed
179 // successfully and passes status to the info_checker_callback accordingly.
180 auto* stream = solver->context()->op_device_context()->stream();
181 auto wrapped_info_checker_callback =
182 [stream](
183 GpuSolver* solver,
184 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
185 info_checker_callback,
186 std::vector<HostLapackInfo> host_lapack_infos) {
187 ScopedActivateExecutorContext scoped_activation{stream->parent()};
188 Status status;
189 for (const auto& host_lapack_info : host_lapack_infos) {
190 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
191 const int info_value = host_lapack_info(i);
192 if (info_value != 0) {
193 status = errors::InvalidArgument(
194 "Got info = ", info_value, " for batch index ", i,
195 ", expected info = 0. Debug_info = ",
196 host_lapack_info.debug_info());
197 }
198 }
199 if (!status.ok()) {
200 break;
201 }
202 }
203 // Delete solver to release temp tensor refs.
204 delete solver;
205
206 // Delegate further error checking to provided functor.
207 info_checker_callback(status, host_lapack_infos);
208 };
209 // Note: An std::function cannot have unique_ptr arguments (it must be copy
210 // constructible and therefore so must its arguments). Therefore, we release
211 // solver into a raw pointer to be deleted at the end of
212 // wrapped_info_checker_callback.
213 // Release ownership of solver. It will be deleted in the cb callback.
214 auto solver_raw_ptr = solver.release();
215 auto cb =
216 std::bind(wrapped_info_checker_callback, solver_raw_ptr,
217 std::move(info_checker_callback), std::move(host_lapack_infos));
218
219 solver_raw_ptr->context()
220 ->device()
221 ->tensorflow_accelerator_device_info()
222 ->event_mgr->ThenExecute(stream, std::move(cb));
223 }
224
225 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)226 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
227 std::unique_ptr<GpuSolver> solver,
228 const std::vector<DeviceLapackInfo>& dev_lapack_info,
229 AsyncOpKernel::DoneCallback done) {
230 OpKernelContext* context = solver->context();
231 auto wrapped_done = [context, done](
232 const Status& status,
233 const std::vector<HostLapackInfo>& /* unused */) {
234 if (done != nullptr) {
235 OP_REQUIRES_OK_ASYNC(context, status, done);
236 done();
237 } else {
238 OP_REQUIRES_OK(context, status);
239 }
240 };
241 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
242 wrapped_done);
243 }
244
245 #define TF_RETURN_IF_ROCBLAS_ERROR(expr) \
246 do { \
247 auto status = (expr); \
248 if (TF_PREDICT_FALSE(status != rocblas_status_success)) { \
249 return errors::Internal(__FILE__, ":", __LINE__, \
250 ": rocBlas call failed status = ", status); \
251 } \
252 } while (0)
253
254 // Macro that specializes a solver method for all 4 standard
255 // numeric types.
256 #define TF_CALL_ROCSOLV_TYPES(m) \
257 m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
258 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d)
259 #define BLAS_SOLVER_FN(method, type_prefix) \
260 wrap::rocblas##_##type_prefix##method
261
262 #if TF_ROCM_VERSION >= 40500
263 #define TF_CALL_LAPACK_TYPES(m) \
264 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
265 #define TF_CALL_LAPACK_TYPES_NO_REAL(m) \
266 m(std::complex<float>, C) m(std::complex<double>, Z)
267 #define SOLVER_FN(method, hip_prefix) wrap::hipsolver##hip_prefix##method
268 #else
269 #define TF_CALL_LAPACK_TYPES(m) \
270 m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
271 #define TF_CALL_LAPACK_TYPES_NO_REAL(m) \
272 m(std::complex<float>, c) m(std::complex<double>, z)
273 #define SOLVER_FN(method, type_prefix) wrap::rocsolver##_##type_prefix##method
274 #endif
275
276 // Macros to construct rocsolver/hipsolver method names.
277 #define ROCSOLVER_FN(method, type_prefix) \
278 wrap::rocsolver##_##type_prefix##method
279 #define BUFSIZE_FN(method, hip_prefix) \
280 wrap::hipsolver##hip_prefix##method##_bufferSize
281
282 #if TF_ROCM_VERSION >= 40500
283
284 #define GETRF_INSTANCE(Scalar, type_prefix) \
285 template <> \
286 Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
287 int* dev_pivots, int* dev_lapack_info) { \
288 mutex_lock lock(handle_map_mutex); \
289 int lwork; \
290 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(getrf, type_prefix)( \
291 hipsolver_handle_, m, n, AsHipComplex(A), lda, &lwork)); \
292 auto dev_work = \
293 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
294 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf, type_prefix)( \
295 hipsolver_handle_, m, n, AsHipComplex(A), lda, \
296 AsHipComplex(dev_work.mutable_data()), lwork, dev_pivots, \
297 dev_lapack_info)); \
298 return Status::OK(); \
299 }
300
301 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
302
303 #define GEQRF_INSTANCE(Scalar, type_prefix) \
304 template <> \
305 Status GpuSolver::Geqrf(int m, int n, Scalar* dev_A, int lda, \
306 Scalar* dev_tau, int* dev_lapack_info) { \
307 mutex_lock lock(handle_map_mutex); \
308 int lwork; \
309 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(geqrf, type_prefix)( \
310 hipsolver_handle_, m, n, AsHipComplex(dev_A), lda, &lwork)); \
311 auto dev_work = \
312 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
313 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(geqrf, type_prefix)( \
314 hipsolver_handle_, m, n, AsHipComplex(dev_A), lda, \
315 AsHipComplex(dev_tau), AsHipComplex(dev_work.mutable_data()), lwork, \
316 dev_lapack_info)); \
317 return Status::OK(); \
318 }
319
320 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
321
322 #define UNMQR_INSTANCE(Scalar, type_prefix) \
323 template <> \
324 Status GpuSolver::Unmqr(hipsolverSideMode_t side, \
325 hipsolverOperation_t trans, int m, int n, int k, \
326 const Scalar* dev_a, int lda, const Scalar* dev_tau, \
327 Scalar* dev_c, int ldc, int* dev_lapack_info) { \
328 mutex_lock lock(handle_map_mutex); \
329 using HipScalar = typename HipComplexT<Scalar>::type; \
330 ScratchSpace<uint8> dev_a_copy = this->GetScratchSpace<uint8>( \
331 sizeof(Scalar*) * m * k, "", /*on host */ false); \
332 if (!CopyHostToDevice(context_, dev_a_copy.mutable_data(), dev_a, \
333 dev_a_copy.bytes())) { \
334 return errors::Internal("Unmqr: Failed to copy ptrs to device"); \
335 } \
336 ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>( \
337 sizeof(Scalar*) * k * n, "", /*on host */ false); \
338 if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau, \
339 dev_tau_copy.bytes())) { \
340 return errors::Internal("Unmqr: Failed to copy ptrs to device"); \
341 } \
342 int lwork; \
343 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(unmqr, type_prefix)( \
344 hipsolver_handle_, side, trans, m, n, k, \
345 reinterpret_cast<HipScalar*>(dev_a_copy.mutable_data()), lda, \
346 reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()), \
347 AsHipComplex(dev_c), ldc, &lwork)); \
348 auto dev_work = \
349 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
350 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(unmqr, type_prefix)( \
351 hipsolver_handle_, side, trans, m, n, k, \
352 reinterpret_cast<HipScalar*>(dev_a_copy.mutable_data()), lda, \
353 reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()), \
354 AsHipComplex(dev_c), ldc, AsHipComplex(dev_work.mutable_data()), \
355 lwork, dev_lapack_info)); \
356 return Status::OK(); \
357 }
358
359 TF_CALL_LAPACK_TYPES_NO_REAL(UNMQR_INSTANCE);
360
361 #define UNGQR_INSTANCE(Scalar, type_prefix) \
362 template <> \
363 Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
364 const Scalar* dev_tau, int* dev_lapack_info) { \
365 mutex_lock lock(handle_map_mutex); \
366 using HipScalar = typename HipComplexT<Scalar>::type; \
367 ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>( \
368 sizeof(HipScalar*) * k * n, "", /*on host */ false); \
369 if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau, \
370 dev_tau_copy.bytes())) { \
371 return errors::Internal("Ungqr: Failed to copy ptrs to device"); \
372 } \
373 int lwork; \
374 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(ungqr, type_prefix)( \
375 hipsolver_handle_, m, n, k, AsHipComplex(dev_a), lda, \
376 reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()), &lwork)); \
377 auto dev_work = \
378 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
379 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(ungqr, type_prefix)( \
380 hipsolver_handle_, m, n, k, AsHipComplex(dev_a), lda, \
381 reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()), \
382 AsHipComplex(dev_work.mutable_data()), lwork, dev_lapack_info)); \
383 return Status::OK(); \
384 }
385
386 TF_CALL_LAPACK_TYPES_NO_REAL(UNGQR_INSTANCE);
387
388 #define POTRF_INSTANCE(Scalar, type_prefix) \
389 template <> \
390 Status GpuSolver::Potrf<Scalar>(hipsolverFillMode_t uplo, int n, \
391 Scalar* dev_A, int lda, \
392 int* dev_lapack_info) { \
393 mutex_lock lock(handle_map_mutex); \
394 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
395 int lwork; \
396 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(potrf, type_prefix)( \
397 hipsolver_handle_, uplo, n, AsHipComplex(dev_A), lda, &lwork)); \
398 auto dev_work = \
399 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
400 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf, type_prefix)( \
401 hipsolver_handle_, uplo, n, AsHipComplex(dev_A), lda, \
402 AsHipComplex(dev_work.mutable_data()), lwork, dev_lapack_info)); \
403 return Status::OK(); \
404 }
405
406 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
407
408 #define GETRS_INSTANCE(Scalar, type_prefix) \
409 template <> \
410 Status GpuSolver::Getrs<Scalar>(hipsolverOperation_t trans, int n, int nrhs, \
411 Scalar* A, int lda, int* dev_pivots, \
412 Scalar* B, int ldb, int* dev_lapack_info) { \
413 mutex_lock lock(handle_map_mutex); \
414 int lwork; \
415 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(getrs, type_prefix)( \
416 hipsolver_handle_, trans, n, nrhs, AsHipComplex(A), lda, dev_pivots, \
417 AsHipComplex(B), ldb, &lwork)); \
418 auto dev_work = \
419 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
420 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs, type_prefix)( \
421 hipsolver_handle_, trans, n, nrhs, AsHipComplex(A), lda, dev_pivots, \
422 AsHipComplex(B), ldb, AsHipComplex(dev_work.mutable_data()), lwork, \
423 dev_lapack_info)); \
424 return Status::OK(); \
425 }
426
427 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
428
429 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \
430 template <> \
431 Status GpuSolver::PotrfBatched<Scalar>( \
432 hipsolverFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
433 int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
434 rocblas_stride stride = n; \
435 mutex_lock lock(handle_map_mutex); \
436 using HipScalar = typename HipComplexT<Scalar>::type; \
437 ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>( \
438 sizeof(HipScalar*) * batch_size, "", /*on host */ false); \
439 if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs, \
440 dev_a.bytes())) { \
441 return errors::Internal("PotrfBatched: Failed to copy ptrs to device"); \
442 } \
443 int lwork; \
444 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(potrfBatched, type_prefix)( \
445 hipsolver_handle_, uplo, n, \
446 reinterpret_cast<HipScalar**>(dev_a.mutable_data()), lda, &lwork, \
447 batch_size)); \
448 auto dev_work = \
449 this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false); \
450 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrfBatched, type_prefix)( \
451 hipsolver_handle_, uplo, n, \
452 reinterpret_cast<HipScalar**>(dev_a.mutable_data()), lda, \
453 AsHipComplex(dev_work.mutable_data()), lwork, \
454 dev_lapack_info->mutable_data(), batch_size)); \
455 return Status::OK(); \
456 }
457
458 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
459
460 #define HEEVD_INSTANCE(Scalar, type_prefix) \
461 template <> \
462 Status GpuSolver::Heevd<Scalar>( \
463 hipsolverEigMode_t jobz, hipsolverFillMode_t uplo, int n, Scalar* dev_A, \
464 int lda, typename Eigen::NumTraits<Scalar>::Real* dev_W, \
465 int* dev_lapack_info) { \
466 mutex_lock lock(handle_map_mutex); \
467 using EigenScalar = typename Eigen::NumTraits<Scalar>::Real; \
468 int lwork; \
469 TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(heevd, type_prefix)( \
470 hipsolver_handle_, jobz, uplo, n, AsHipComplex(dev_A), lda, dev_W, \
471 &lwork)); \
472 auto dev_workspace = \
473 this->GetScratchSpace<Scalar>(lwork, "", /*on host */ false); \
474 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(heevd, type_prefix)( \
475 hipsolver_handle_, jobz, uplo, n, AsHipComplex(dev_A), lda, dev_W, \
476 AsHipComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); \
477 return Status::OK(); \
478 }
479
480 TF_CALL_LAPACK_TYPES_NO_REAL(HEEVD_INSTANCE);
481
482 #else
483 // Macro that specializes a solver method for all 4 standard
484 // numeric types.
485 // Macro to construct rocsolver method names.
486
487 #define GETRF_INSTANCE(Scalar, type_prefix) \
488 template <> \
489 Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
490 int* dev_pivots, int* dev_lapack_info) { \
491 mutex_lock lock(handle_map_mutex); \
492 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
493 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf, type_prefix)( \
494 rocm_blas_handle_, m, n, reinterpret_cast<ROCmScalar*>(A), lda, \
495 dev_pivots, dev_lapack_info)); \
496 return Status::OK(); \
497 }
498
499 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
500
501 #define GEQRF_INSTANCE(Scalar, type_prefix) \
502 template <> \
503 Status GpuSolver::Geqrf(int m, int n, Scalar* dev_A, int lda, \
504 Scalar* dev_tau, int* dev_lapack_info) { \
505 mutex_lock lock(handle_map_mutex); \
506 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
507 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(geqrf, type_prefix)( \
508 rocm_blas_handle_, m, n, reinterpret_cast<ROCmScalar*>(dev_A), lda, \
509 reinterpret_cast<ROCmScalar*>(dev_tau))); \
510 return Status::OK(); \
511 }
512
513 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
514
515 #define UMMQR_INSTANCE(Scalar, type_prefix) \
516 template <> \
517 Status GpuSolver::Unmqr(rocblas_side side, rocblas_operation trans, int m, \
518 int n, int k, const Scalar* dev_a, int lda, \
519 const Scalar* dev_tau, Scalar* dev_c, int ldc, \
520 int* dev_lapack_info) { \
521 mutex_lock lock(handle_map_mutex); \
522 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
523 ScratchSpace<uint8> dev_a_copy = this->GetScratchSpace<uint8>( \
524 sizeof(ROCmScalar*) * m * k, "", /*on host */ false); \
525 if (!CopyHostToDevice(context_, dev_a_copy.mutable_data(), dev_a, \
526 dev_a_copy.bytes())) { \
527 return errors::Internal("Unmqr: Failed to copy ptrs to device"); \
528 } \
529 ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>( \
530 sizeof(ROCmScalar*) * k * n, "", /*on host */ false); \
531 if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau, \
532 dev_tau_copy.bytes())) { \
533 return errors::Internal("Unmqr: Failed to copy ptrs to device"); \
534 } \
535 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(unmqr, type_prefix)( \
536 rocm_blas_handle_, side, trans, m, n, k, \
537 reinterpret_cast<ROCmScalar*>(dev_a_copy.mutable_data()), lda, \
538 reinterpret_cast<ROCmScalar*>(dev_tau_copy.mutable_data()), \
539 reinterpret_cast<ROCmScalar*>(dev_c), ldc)); \
540 return Status::OK(); \
541 }
542
543 TF_CALL_LAPACK_TYPES_NO_REAL(UMMQR_INSTANCE);
544
545 #define UNGQR_INSTANCE(Scalar, type_prefix) \
546 template <> \
547 Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
548 const Scalar* dev_tau, int* dev_lapack_info) { \
549 mutex_lock lock(handle_map_mutex); \
550 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
551 ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>( \
552 sizeof(ROCmScalar*) * k * n, "", /*on host */ false); \
553 if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau, \
554 dev_tau_copy.bytes())) { \
555 return errors::Internal("Ungqr: Failed to copy ptrs to device"); \
556 } \
557 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(ungqr, type_prefix)( \
558 rocm_blas_handle_, m, n, k, reinterpret_cast<ROCmScalar*>(dev_a), lda, \
559 reinterpret_cast<ROCmScalar*>(dev_tau_copy.mutable_data()))); \
560 return Status::OK(); \
561 }
562
563 TF_CALL_LAPACK_TYPES_NO_REAL(UNGQR_INSTANCE);
564
565 #define POTRF_INSTANCE(Scalar, type_prefix) \
566 template <> \
567 Status GpuSolver::Potrf<Scalar>(rocblas_fill uplo, int n, Scalar* dev_A, \
568 int lda, int* dev_lapack_info) { \
569 mutex_lock lock(handle_map_mutex); \
570 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
571 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf, type_prefix)( \
572 rocm_blas_handle_, uplo, n, reinterpret_cast<ROCmScalar*>(dev_A), lda, \
573 dev_lapack_info)); \
574 return Status::OK(); \
575 }
576
577 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
578
579 #define GETRS_INSTANCE(Scalar, type_prefix) \
580 template <> \
581 Status GpuSolver::Getrs<Scalar>(rocblas_operation trans, int n, int nrhs, \
582 Scalar* A, int lda, const int* dev_pivots, \
583 Scalar* B, int ldb, int* dev_lapack_info) { \
584 mutex_lock lock(handle_map_mutex); \
585 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
586 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs, type_prefix)( \
587 rocm_blas_handle_, trans, n, nrhs, reinterpret_cast<ROCmScalar*>(A), \
588 lda, dev_pivots, reinterpret_cast<ROCmScalar*>(B), ldb)); \
589 return Status::OK(); \
590 }
591
592 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
593
594 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \
595 template <> \
596 Status GpuSolver::PotrfBatched<Scalar>( \
597 rocblas_fill uplo, int n, const Scalar* const host_a_dev_ptrs[], \
598 int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
599 rocblas_stride stride = n; \
600 mutex_lock lock(handle_map_mutex); \
601 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
602 ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>( \
603 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
604 if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs, \
605 dev_a.bytes())) { \
606 return errors::Internal("PotrfBatched: Failed to copy ptrs to device"); \
607 } \
608 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf_batched, type_prefix)( \
609 rocm_blas_handle_, uplo, n, \
610 reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, \
611 dev_lapack_info->mutable_data(), batch_size)); \
612 return Status::OK(); \
613 }
614
615 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
616
617 #endif
618
619 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \
620 template <> \
621 Status GpuSolver::GetriBatched<Scalar>( \
622 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
623 const int* dev_pivots, const Scalar* const host_a_inverse_dev_ptrs[], \
624 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
625 mutex_lock lock(handle_map_mutex); \
626 rocblas_stride stride = n; \
627 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
628 ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>( \
629 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
630 if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs, \
631 dev_a.bytes())) { \
632 return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
633 } \
634 ScratchSpace<uint8> dev_a_inverse = this->GetScratchSpace<uint8>( \
635 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
636 if (!CopyHostToDevice(context_, dev_a_inverse.mutable_data(), \
637 host_a_inverse_dev_ptrs, dev_a_inverse.bytes())) { \
638 return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
639 } \
640 ScratchSpace<uint8> pivots = this->GetScratchSpace<uint8>( \
641 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
642 if (!CopyHostToDevice(context_, pivots.mutable_data(), dev_pivots, \
643 pivots.bytes())) { \
644 return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
645 } \
646 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getri_batched, type_prefix)( \
647 rocm_blas_handle_, n, \
648 reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, \
649 reinterpret_cast<int*>(pivots.mutable_data()), stride, \
650 dev_lapack_info->mutable_data(), batch_size)); \
651 return Status::OK(); \
652 }
653
654 TF_CALL_ROCSOLV_TYPES(GETRI_BATCHED_INSTANCE);
655
656 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \
657 template <> \
658 Status GpuSolver::GetrfBatched<Scalar>( \
659 int n, Scalar** A, int lda, int* dev_pivots, DeviceLapackInfo* dev_info, \
660 const int batch_size) { \
661 mutex_lock lock(handle_map_mutex); \
662 rocblas_stride stride = n; \
663 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
664 ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>( \
665 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
666 if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \
667 return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \
668 } \
669 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf_batched, type_prefix)( \
670 rocm_blas_handle_, n, n, \
671 reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
672 stride, dev_info->mutable_data(), batch_size)); \
673 return Status::OK(); \
674 }
675
676 TF_CALL_ROCSOLV_TYPES(GETRF_BATCHED_INSTANCE);
677
678 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \
679 template <> \
680 Status GpuSolver::GetrsBatched<Scalar>( \
681 const rocblas_operation trans, int n, int nrhs, Scalar** A, int lda, \
682 int* dev_pivots, Scalar** B, const int ldb, int* host_lapack_info, \
683 const int batch_size) { \
684 rocblas_stride stride = n; \
685 mutex_lock lock(handle_map_mutex); \
686 using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
687 ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>( \
688 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
689 if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \
690 return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \
691 } \
692 ScratchSpace<uint8> dev_b = this->GetScratchSpace<uint8>( \
693 sizeof(ROCmScalar*) * batch_size, "", /*on host */ false); \
694 if (!CopyHostToDevice(context_, dev_b.mutable_data(), B, dev_b.bytes())) { \
695 return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \
696 } \
697 TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs_batched, type_prefix)( \
698 rocm_blas_handle_, trans, n, nrhs, \
699 reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
700 stride, reinterpret_cast<ROCmScalar**>(dev_b.mutable_data()), ldb, \
701 batch_size)); \
702 return Status::OK(); \
703 }
704
705 TF_CALL_ROCSOLV_TYPES(GETRS_BATCHED_INSTANCE);
706
707 // Allocates a temporary tensor. The GpuSolver object maintains a
708 // TensorReference to the underlying Tensor to prevent it from being deallocated
709 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)710 Status GpuSolver::allocate_scoped_tensor(DataType type,
711 const TensorShape& shape,
712 Tensor* out_temp) {
713 const Status status = context_->allocate_temp(type, shape, out_temp);
714 if (status.ok()) {
715 scratch_tensor_refs_.emplace_back(*out_temp);
716 }
717 return status;
718 }
719
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)720 Status GpuSolver::forward_input_or_allocate_scoped_tensor(
721 gtl::ArraySlice<int> candidate_input_indices, DataType type,
722 const TensorShape& shape, Tensor* out_temp) {
723 const Status status = context_->forward_input_or_allocate_temp(
724 candidate_input_indices, type, shape, out_temp);
725 if (status.ok()) {
726 scratch_tensor_refs_.emplace_back(*out_temp);
727 }
728 return status;
729 }
730
731 template <typename Scalar, typename SolverFnT>
TrsmImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_side side,rocblas_fill uplo,rocblas_operation trans,rocblas_diagonal diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)732 static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver,
733 rocblas_handle rocm_blas_handle,
734 rocblas_side side, rocblas_fill uplo,
735 rocblas_operation trans, rocblas_diagonal diag,
736 int m, int n,
737 const Scalar* alpha, /* host or device pointer */
738 const Scalar* A, int lda, Scalar* B, int ldb) {
739 mutex_lock lock(handle_map_mutex);
740 using ROCmScalar = typename ROCmComplexT<Scalar>::type;
741
742 ScopedActivateExecutorContext sac{gpu_executor};
743 TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, side, uplo, trans, diag,
744 m, n,
745 reinterpret_cast<const ROCmScalar*>(alpha),
746 reinterpret_cast<const ROCmScalar*>(A), lda,
747 reinterpret_cast<ROCmScalar*>(B), ldb));
748
749 return Status::OK();
750 }
751
752 #define TRSM_INSTANCE(Scalar, type_prefix) \
753 template <> \
754 Status GpuSolver::Trsm<Scalar>( \
755 rocblas_side side, rocblas_fill uplo, rocblas_operation trans, \
756 rocblas_diagonal diag, int m, int n, \
757 const Scalar* alpha, /* host or device pointer */ \
758 const Scalar* A, int lda, Scalar* B, int ldb) { \
759 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>( \
760 context_->op_device_context()->stream()->parent()->implementation()); \
761 return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix), \
762 rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \
763 A, lda, B, ldb); \
764 }
765
766 TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE);
767
768 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,const Scalar * const host_a_inverse_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)769 Status MatInvBatchedImpl(GpuExecutor* gpu_executor, SolverFnT solver,
770 rocblas_handle rocm_blas_handle, int n,
771 const Scalar* const host_a_dev_ptrs[], int lda,
772 int* dev_pivots,
773 const Scalar* const host_a_inverse_dev_ptrs[],
774 int ldainv, DeviceLapackInfo* dev_lapack_info,
775 int batch_size) {
776 mutex_lock lock(handle_map_mutex);
777 using ROCmScalar = typename ROCmComplexT<Scalar>::type;
778 ScopedActivateExecutorContext sac{gpu_executor};
779
780 GetrfBatched(n, host_a_dev_ptrs, lda, dev_pivots, dev_lapack_info,
781 batch_size);
782
783 GetriBatched(n, host_a_dev_ptrs, lda, dev_pivots, host_a_inverse_dev_ptrs,
784 ldainv, dev_lapack_info, batch_size);
785
786 return Status::OK();
787 }
788
789 #define MATINVBATCHED_INSTANCE(Scalar, type_prefix) \
790 template <> \
791 Status GpuSolver::MatInvBatched<Scalar>( \
792 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
793 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, \
794 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
795 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>( \
796 context_->op_device_context()->stream()->parent()->implementation()); \
797 Tensor pivots; \
798 context_->allocate_scoped_tensor(DataTypeToEnum<int>::value, \
799 TensorShape{batch_size, n}, &pivots); \
800 auto pivots_mat = pivots.template matrix<int>(); \
801 int* dev_pivots = pivots_mat.data(); \
802 return MatInvBatchedImpl( \
803 gpu_executor, BLAS_SOLVER_FN(matinvbatched, type_prefix), \
804 rocm_blas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
805 host_a_inverse_dev_ptrs, ldainv, dev_lapack_info, batch_size); \
806 }
807
808 template <typename Scalar, typename SolverFnT>
GeamImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_operation transa,rocblas_operation transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)809 Status GeamImpl(GpuExecutor* gpu_executor, SolverFnT solver,
810 rocblas_handle rocm_blas_handle, rocblas_operation transa,
811 rocblas_operation transb, int m, int n, const Scalar* alpha,
812 /* host or device pointer */ const Scalar* A, int lda,
813 const Scalar* beta,
814 /* host or device pointer */ const Scalar* B, int ldb,
815 Scalar* C, int ldc) {
816 mutex_lock lock(handle_map_mutex);
817 using ROCmScalar = typename ROCmComplexT<Scalar>::type;
818
819 ScopedActivateExecutorContext sac{gpu_executor};
820 TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, transa, transb, m, n,
821 reinterpret_cast<const ROCmScalar*>(alpha),
822 reinterpret_cast<const ROCmScalar*>(A), lda,
823 reinterpret_cast<const ROCmScalar*>(beta),
824 reinterpret_cast<const ROCmScalar*>(B), ldb,
825 reinterpret_cast<ROCmScalar*>(C), ldc));
826 return Status::OK();
827 }
828
829 #define GEAM_INSTANCE(Scalar, type_prefix) \
830 template <> \
831 Status GpuSolver::Geam<Scalar>( \
832 rocblas_operation transa, rocblas_operation transb, int m, int n, \
833 const Scalar* alpha, const Scalar* A, int lda, const Scalar* beta, \
834 const Scalar* B, int ldb, Scalar* C, int ldc) { \
835 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>( \
836 context_->op_device_context()->stream()->parent()->implementation()); \
837 return GeamImpl(gpu_executor, BLAS_SOLVER_FN(geam, type_prefix), \
838 rocm_blas_handle_, transa, transb, m, n, alpha, A, lda, \
839 beta, B, ldb, C, ldc); \
840 }
841
842 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GEAM_INSTANCE);
843 } // namespace tensorflow
844
845 #endif // TENSORFLOW_USE_ROCM
846