1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
13 #define TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
14
15 #include <optional>
16 #include <vector>
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
21
22 namespace tensorflow {
23
24 // Get a workspace limit from the environment variable, which is in MB.
25 // Return the workspace memory limit in bytes. If no value is set, return the
26 // default value.
27 int64_t GetWorkspaceLimit(int64_t default_value_in_bytes);
28
29 struct BlasLtMatmulPlanParams {
30 std::string ToString() const;
31 bool operator==(const BlasLtMatmulPlanParams& other) const;
32
33 se::blas::DataType dtype;
34 size_t m;
35 size_t n;
36 size_t k;
37 se::blas::Transpose trans_a;
38 se::blas::Transpose trans_b;
39 size_t batch_count = 1;
40 bool broadcast_a = false;
41 bool broadcast_b = false;
42 se::cuda::BlasLt::Epilogue epilogue = se::cuda::BlasLt::Epilogue::kDefault;
43 };
44
45 namespace internal {
46
AsTuple(const BlasLtMatmulPlanParams & p)47 inline auto AsTuple(const BlasLtMatmulPlanParams& p) {
48 return std::make_tuple(p.dtype, p.m, p.n, p.k, p.trans_a, p.trans_b,
49 p.batch_count, p.broadcast_a, p.broadcast_b,
50 p.epilogue);
51 }
52
53 } // namespace internal
54
55 template <typename H>
AbslHashValue(H h,const BlasLtMatmulPlanParams & params)56 H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) {
57 return H::combine(std::move(h), internal::AsTuple(params));
58 }
59
60 struct PlanAndAlgorithms {
61 se::cuda::BlasLt::MatmulPlan plan;
62 std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms;
63 };
64
65 // Thread-safe map from matmul parameters to their corresponding plan and
66 // algorithms.
67 class BlasLtMatmulPlanMap {
68 public:
69 const PlanAndAlgorithms* Find(const BlasLtMatmulPlanParams& params) const;
70 const PlanAndAlgorithms* Insert(const BlasLtMatmulPlanParams& params,
71 PlanAndAlgorithms value);
72
73 private:
74 mutable absl::Mutex mu_;
75 absl::flat_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms>
76 params_plan_map_ ABSL_GUARDED_BY(mu_);
77 };
78
79 StatusOr<se::blas::ComputationType> GetBlasComputationType(
80 const DataType& dtype);
81
82 StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
83 se::Stream* stream, const BlasLtMatmulPlanParams& params,
84 std::optional<int> max_algorithm_count = std::nullopt);
85
86 template <typename T>
87 Status DoBlasLtMatmul(se::Stream* stream,
88 const se::cuda::BlasLt::MatmulPlan& plan,
89 const se::DeviceMemory<T>& a,
90 const se::DeviceMemory<T>& b, se::DeviceMemory<T>& c,
91 const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
92 se::ScratchAllocator& scratch_allocator,
93 const se::DeviceMemory<T>& bias = {},
94 se::blas::ProfileResult* profile_result = nullptr) {
95 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
96 // TF_RET_CHECK(blas_lt != nullptr);
97
98 // The scale type may be f32 if the data type is f16 and bf16.
99 if constexpr (std::is_same_v<T, Eigen::half> ||
100 std::is_same_v<T, Eigen::bfloat16>) {
101 if (plan.op_desc.scale_type() == CUDA_R_32F) {
102 return blas_lt->DoMatmul(stream, plan, se::HostOrDeviceScalar<float>(1.0),
103 b, a, se::HostOrDeviceScalar<float>(0.0), c, c,
104 algorithm, scratch_allocator, bias,
105 profile_result);
106 }
107 }
108 return blas_lt->DoMatmul(stream, plan, se::HostOrDeviceScalar<T>(T(1.0)), b,
109 a, se::HostOrDeviceScalar<T>(T(0.0)), c, c,
110 algorithm, scratch_allocator, bias, profile_result);
111 }
112
113 } // namespace tensorflow
114
115 #endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
116