1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include "tensorflow/core/kernels/matmul_util.h"
14
15 #include <optional>
16 #include <string>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/platform/tensor_float_32_utils.h"
22 #include "tensorflow/core/util/env_var.h"
23 #include "tensorflow/core/util/matmul_autotune.h"
24 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
25
26 namespace tensorflow {
27
GetWorkspaceLimit(int64_t default_value_in_bytes)28 int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
29 const char* workspace_limit_in_mb_str =
30 getenv("TF_CUBLAS_WORKSPACE_LIMIT_IN_MB");
31 if (workspace_limit_in_mb_str != nullptr &&
32 strcmp(workspace_limit_in_mb_str, "") != 0) {
33 int64_t scratch_limit_in_mb = -1;
34 if (strings::safe_strto64(workspace_limit_in_mb_str,
35 &scratch_limit_in_mb)) {
36 return scratch_limit_in_mb * (1 << 20);
37 } else {
38 LOG(WARNING) << "Invalid value for TF_CUBLAS_WORKSPACE_LIMIT_IN_MB: "
39 << workspace_limit_in_mb_str;
40 }
41 }
42 return default_value_in_bytes;
43 }
44
ToString() const45 std::string BlasLtMatmulPlanParams::ToString() const {
46 return ""; // TODO
47 }
48
operator ==(const BlasLtMatmulPlanParams & other) const49 bool BlasLtMatmulPlanParams::operator==(
50 const BlasLtMatmulPlanParams& other) const {
51 return internal::AsTuple(*this) == internal::AsTuple(other);
52 }
53
Find(const BlasLtMatmulPlanParams & params) const54 const PlanAndAlgorithms* BlasLtMatmulPlanMap::Find(
55 const BlasLtMatmulPlanParams& params) const {
56 absl::MutexLock lock(&mu_);
57 auto it = params_plan_map_.find(params);
58 return (it != params_plan_map_.end()) ? &it->second : nullptr;
59 }
60
Insert(const BlasLtMatmulPlanParams & params,PlanAndAlgorithms value)61 const PlanAndAlgorithms* BlasLtMatmulPlanMap::Insert(
62 const BlasLtMatmulPlanParams& params, PlanAndAlgorithms value) {
63 absl::MutexLock lock(&mu_);
64 return ¶ms_plan_map_.emplace(params, std::move(value)).first->second;
65 }
66
67 namespace {
68
MatmulMaxAutotuneAlgorithmCount()69 int MatmulMaxAutotuneAlgorithmCount() {
70 int64_t value;
71 Status status =
72 ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
73 if (!status.ok()) {
74 LOG(ERROR) << status.error_message();
75 }
76 static constexpr const int kMaxValue = std::numeric_limits<int>::max();
77 if (value < 1 || value > kMaxValue) {
78 LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
79 << value << " is not in range [1, " << kMaxValue << "]";
80 }
81 return value;
82 }
83
GetBlasComputationType(const se::blas::DataType & dtype)84 StatusOr<se::blas::ComputationType> GetBlasComputationType(
85 const se::blas::DataType& dtype) {
86 using se::blas::ComputationType;
87 static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
88 switch (dtype) {
89 case se::blas::DataType::kHalf:
90 return use_f32_for_f16_computation ? ComputationType::kF32
91 : ComputationType::kF16;
92 case se::blas::DataType::kBF16:
93 return ComputationType::kF32;
94 case se::blas::DataType::kFloat: // fall-through
95 case se::blas::DataType::kComplexFloat:
96 return tensor_float_32_execution_enabled() ? ComputationType::kTF32AsF32
97 : ComputationType::kF32;
98 case se::blas::DataType::kDouble: // fall-through
99 case se::blas::DataType::kComplexDouble:
100 return ComputationType::kF64;
101 default:
102 return errors::Internal("Unsupported dtype for Blas Plans.");
103 }
104 }
105
106 } // namespace
107
GetPlanAndAlgorithms(se::Stream * stream,const BlasLtMatmulPlanParams & params,std::optional<int> max_algorithm_count)108 StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
109 se::Stream* stream, const BlasLtMatmulPlanParams& params,
110 std::optional<int> max_algorithm_count) {
111 static const int64_t max_scratch_size =
112 GetWorkspaceLimit(1LL << 32); // 4GB by default
113 static const int64_t max_autotune_algorithm_count =
114 MatmulMaxAutotuneAlgorithmCount();
115
116 if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;
117
118 static auto& plan_map = *new BlasLtMatmulPlanMap();
119
120 const PlanAndAlgorithms* plan_and_algorithms = plan_map.Find(params);
121 if (!plan_and_algorithms) {
122 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
123 TF_RET_CHECK(blas_lt != nullptr);
124
125 TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
126 GetBlasComputationType(params.dtype));
127
128 se::blas::DataType scale_type =
129 se::cuda::BlasLt::GetScaleType(params.dtype, computation_type);
130
131 // cublas_lt's output is column-major. We want row-major so use identity:
132 // C^T = (A @ B)^T = B^T @ A^T.
133 constexpr auto kColMajor =
134 se::cuda::BlasLt::MatrixLayout::Order::kColumnMajor;
135
136 size_t rows_a = params.k;
137 size_t cols_a = params.m;
138 size_t rows_b = params.n;
139 size_t cols_b = params.k;
140
141 if (params.trans_a != se::blas::Transpose::kNoTranspose) {
142 std::swap(rows_a, cols_a);
143 }
144
145 if (params.trans_b != se::blas::Transpose::kNoTranspose) {
146 std::swap(rows_b, cols_b);
147 }
148
149 int64_t batch_stride_a =
150 params.broadcast_a ? 0 : static_cast<int64_t>(rows_a * cols_a);
151 int64_t batch_stride_b =
152 params.broadcast_b ? 0 : static_cast<int64_t>(rows_b * cols_b);
153
154 TF_ASSIGN_OR_RETURN(
155 auto a_desc,
156 se::cuda::BlasLt::MatrixLayout::Create(
157 params.dtype, rows_a, cols_a, kColMajor, params.batch_count,
158 /*leading_dim_stride=*/std::nullopt, batch_stride_a));
159 TF_ASSIGN_OR_RETURN(
160 auto b_desc,
161 se::cuda::BlasLt::MatrixLayout::Create(
162 params.dtype, rows_b, cols_b, kColMajor, params.batch_count,
163 /*leading_dim_stride=*/std::nullopt, batch_stride_b));
164 TF_ASSIGN_OR_RETURN(auto c_desc, se::cuda::BlasLt::MatrixLayout::Create(
165 params.dtype, params.n, params.m,
166 kColMajor, params.batch_count));
167 TF_ASSIGN_OR_RETURN(auto d_desc, se::cuda::BlasLt::MatrixLayout::Create(
168 params.dtype, params.n, params.m,
169 kColMajor, params.batch_count));
170
171 // `A` and `B` swapped (see above re. column-major output).
172 TF_ASSIGN_OR_RETURN(auto op_desc,
173 se::cuda::BlasLt::MatmulDesc::Create(
174 computation_type, scale_type,
175 /*trans_a=*/params.trans_b,
176 /*trans_b=*/params.trans_a, params.epilogue));
177
178 // `A` and `B` swapped (see above re. column-major output).
179 se::cuda::BlasLt::MatmulPlan plan{std::move(op_desc), std::move(b_desc),
180 std::move(a_desc), std::move(c_desc),
181 std::move(d_desc)};
182 TF_ASSIGN_OR_RETURN(
183 auto preference,
184 se::cuda::BlasLt::MatmulPreference::Create(max_scratch_size));
185
186 TF_ASSIGN_OR_RETURN(
187 std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms,
188 blas_lt->GetMatmulAlgorithms(plan, preference, *max_algorithm_count));
189
190 plan_and_algorithms =
191 plan_map.Insert(params, {std::move(plan), std::move(algorithms)});
192 }
193 return plan_and_algorithms;
194 }
195
196 } // namespace tensorflow
197