xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/matmul_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include "tensorflow/core/kernels/matmul_util.h"
14 
15 #include <optional>
16 #include <string>
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/platform/tensor_float_32_utils.h"
22 #include "tensorflow/core/util/env_var.h"
23 #include "tensorflow/core/util/matmul_autotune.h"
24 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
25 
26 namespace tensorflow {
27 
GetWorkspaceLimit(int64_t default_value_in_bytes)28 int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
29   const char* workspace_limit_in_mb_str =
30       getenv("TF_CUBLAS_WORKSPACE_LIMIT_IN_MB");
31   if (workspace_limit_in_mb_str != nullptr &&
32       strcmp(workspace_limit_in_mb_str, "") != 0) {
33     int64_t scratch_limit_in_mb = -1;
34     if (strings::safe_strto64(workspace_limit_in_mb_str,
35                               &scratch_limit_in_mb)) {
36       return scratch_limit_in_mb * (1 << 20);
37     } else {
38       LOG(WARNING) << "Invalid value for TF_CUBLAS_WORKSPACE_LIMIT_IN_MB: "
39                    << workspace_limit_in_mb_str;
40     }
41   }
42   return default_value_in_bytes;
43 }
44 
ToString() const45 std::string BlasLtMatmulPlanParams::ToString() const {
46   return "";  // TODO
47 }
48 
operator ==(const BlasLtMatmulPlanParams & other) const49 bool BlasLtMatmulPlanParams::operator==(
50     const BlasLtMatmulPlanParams& other) const {
51   return internal::AsTuple(*this) == internal::AsTuple(other);
52 }
53 
Find(const BlasLtMatmulPlanParams & params) const54 const PlanAndAlgorithms* BlasLtMatmulPlanMap::Find(
55     const BlasLtMatmulPlanParams& params) const {
56   absl::MutexLock lock(&mu_);
57   auto it = params_plan_map_.find(params);
58   return (it != params_plan_map_.end()) ? &it->second : nullptr;
59 }
60 
Insert(const BlasLtMatmulPlanParams & params,PlanAndAlgorithms value)61 const PlanAndAlgorithms* BlasLtMatmulPlanMap::Insert(
62     const BlasLtMatmulPlanParams& params, PlanAndAlgorithms value) {
63   absl::MutexLock lock(&mu_);
64   return &params_plan_map_.emplace(params, std::move(value)).first->second;
65 }
66 
67 namespace {
68 
MatmulMaxAutotuneAlgorithmCount()69 int MatmulMaxAutotuneAlgorithmCount() {
70   int64_t value;
71   Status status =
72       ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
73   if (!status.ok()) {
74     LOG(ERROR) << status.error_message();
75   }
76   static constexpr const int kMaxValue = std::numeric_limits<int>::max();
77   if (value < 1 || value > kMaxValue) {
78     LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
79                << value << " is not in range [1, " << kMaxValue << "]";
80   }
81   return value;
82 }
83 
GetBlasComputationType(const se::blas::DataType & dtype)84 StatusOr<se::blas::ComputationType> GetBlasComputationType(
85     const se::blas::DataType& dtype) {
86   using se::blas::ComputationType;
87   static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
88   switch (dtype) {
89     case se::blas::DataType::kHalf:
90       return use_f32_for_f16_computation ? ComputationType::kF32
91                                          : ComputationType::kF16;
92     case se::blas::DataType::kBF16:
93       return ComputationType::kF32;
94     case se::blas::DataType::kFloat:  // fall-through
95     case se::blas::DataType::kComplexFloat:
96       return tensor_float_32_execution_enabled() ? ComputationType::kTF32AsF32
97                                                  : ComputationType::kF32;
98     case se::blas::DataType::kDouble:  // fall-through
99     case se::blas::DataType::kComplexDouble:
100       return ComputationType::kF64;
101     default:
102       return errors::Internal("Unsupported dtype for Blas Plans.");
103   }
104 }
105 
106 }  // namespace
107 
GetPlanAndAlgorithms(se::Stream * stream,const BlasLtMatmulPlanParams & params,std::optional<int> max_algorithm_count)108 StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
109     se::Stream* stream, const BlasLtMatmulPlanParams& params,
110     std::optional<int> max_algorithm_count) {
111   static const int64_t max_scratch_size =
112       GetWorkspaceLimit(1LL << 32);  // 4GB by default
113   static const int64_t max_autotune_algorithm_count =
114       MatmulMaxAutotuneAlgorithmCount();
115 
116   if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;
117 
118   static auto& plan_map = *new BlasLtMatmulPlanMap();
119 
120   const PlanAndAlgorithms* plan_and_algorithms = plan_map.Find(params);
121   if (!plan_and_algorithms) {
122     se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
123     TF_RET_CHECK(blas_lt != nullptr);
124 
125     TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
126                         GetBlasComputationType(params.dtype));
127 
128     se::blas::DataType scale_type =
129         se::cuda::BlasLt::GetScaleType(params.dtype, computation_type);
130 
131     // cublas_lt's output is column-major. We want row-major so use identity:
132     // C^T = (A @ B)^T = B^T @ A^T.
133     constexpr auto kColMajor =
134         se::cuda::BlasLt::MatrixLayout::Order::kColumnMajor;
135 
136     size_t rows_a = params.k;
137     size_t cols_a = params.m;
138     size_t rows_b = params.n;
139     size_t cols_b = params.k;
140 
141     if (params.trans_a != se::blas::Transpose::kNoTranspose) {
142       std::swap(rows_a, cols_a);
143     }
144 
145     if (params.trans_b != se::blas::Transpose::kNoTranspose) {
146       std::swap(rows_b, cols_b);
147     }
148 
149     int64_t batch_stride_a =
150         params.broadcast_a ? 0 : static_cast<int64_t>(rows_a * cols_a);
151     int64_t batch_stride_b =
152         params.broadcast_b ? 0 : static_cast<int64_t>(rows_b * cols_b);
153 
154     TF_ASSIGN_OR_RETURN(
155         auto a_desc,
156         se::cuda::BlasLt::MatrixLayout::Create(
157             params.dtype, rows_a, cols_a, kColMajor, params.batch_count,
158             /*leading_dim_stride=*/std::nullopt, batch_stride_a));
159     TF_ASSIGN_OR_RETURN(
160         auto b_desc,
161         se::cuda::BlasLt::MatrixLayout::Create(
162             params.dtype, rows_b, cols_b, kColMajor, params.batch_count,
163             /*leading_dim_stride=*/std::nullopt, batch_stride_b));
164     TF_ASSIGN_OR_RETURN(auto c_desc, se::cuda::BlasLt::MatrixLayout::Create(
165                                          params.dtype, params.n, params.m,
166                                          kColMajor, params.batch_count));
167     TF_ASSIGN_OR_RETURN(auto d_desc, se::cuda::BlasLt::MatrixLayout::Create(
168                                          params.dtype, params.n, params.m,
169                                          kColMajor, params.batch_count));
170 
171     // `A` and `B` swapped (see above re. column-major output).
172     TF_ASSIGN_OR_RETURN(auto op_desc,
173                         se::cuda::BlasLt::MatmulDesc::Create(
174                             computation_type, scale_type,
175                             /*trans_a=*/params.trans_b,
176                             /*trans_b=*/params.trans_a, params.epilogue));
177 
178     // `A` and `B` swapped (see above re. column-major output).
179     se::cuda::BlasLt::MatmulPlan plan{std::move(op_desc), std::move(b_desc),
180                                       std::move(a_desc), std::move(c_desc),
181                                       std::move(d_desc)};
182     TF_ASSIGN_OR_RETURN(
183         auto preference,
184         se::cuda::BlasLt::MatmulPreference::Create(max_scratch_size));
185 
186     TF_ASSIGN_OR_RETURN(
187         std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms,
188         blas_lt->GetMatmulAlgorithms(plan, preference, *max_algorithm_count));
189 
190     plan_and_algorithms =
191         plan_map.Insert(params, {std::move(plan), std::move(algorithms)});
192   }
193   return plan_and_algorithms;
194 }
195 
196 }  // namespace tensorflow
197