xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/matmul_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
13 #define TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
14 
15 #include <optional>
16 #include <vector>
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
21 
22 namespace tensorflow {
23 
24 // Get a workspace limit from the environment variable, which is in MB.
25 // Return the workspace memory limit in bytes. If no value is set, return the
26 // default value.
27 int64_t GetWorkspaceLimit(int64_t default_value_in_bytes);
28 
29 struct BlasLtMatmulPlanParams {
30   std::string ToString() const;
31   bool operator==(const BlasLtMatmulPlanParams& other) const;
32 
33   se::blas::DataType dtype;
34   size_t m;
35   size_t n;
36   size_t k;
37   se::blas::Transpose trans_a;
38   se::blas::Transpose trans_b;
39   size_t batch_count = 1;
40   bool broadcast_a = false;
41   bool broadcast_b = false;
42   se::cuda::BlasLt::Epilogue epilogue = se::cuda::BlasLt::Epilogue::kDefault;
43 };
44 
45 namespace internal {
46 
AsTuple(const BlasLtMatmulPlanParams & p)47 inline auto AsTuple(const BlasLtMatmulPlanParams& p) {
48   return std::make_tuple(p.dtype, p.m, p.n, p.k, p.trans_a, p.trans_b,
49                          p.batch_count, p.broadcast_a, p.broadcast_b,
50                          p.epilogue);
51 }
52 
53 }  // namespace internal
54 
55 template <typename H>
AbslHashValue(H h,const BlasLtMatmulPlanParams & params)56 H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) {
57   return H::combine(std::move(h), internal::AsTuple(params));
58 }
59 
60 struct PlanAndAlgorithms {
61   se::cuda::BlasLt::MatmulPlan plan;
62   std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms;
63 };
64 
65 // Thread-safe map from matmul parameters to their corresponding plan and
66 // algorithms.
67 class BlasLtMatmulPlanMap {
68  public:
69   const PlanAndAlgorithms* Find(const BlasLtMatmulPlanParams& params) const;
70   const PlanAndAlgorithms* Insert(const BlasLtMatmulPlanParams& params,
71                                   PlanAndAlgorithms value);
72 
73  private:
74   mutable absl::Mutex mu_;
75   absl::flat_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms>
76       params_plan_map_ ABSL_GUARDED_BY(mu_);
77 };
78 
79 StatusOr<se::blas::ComputationType> GetBlasComputationType(
80     const DataType& dtype);
81 
82 StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
83     se::Stream* stream, const BlasLtMatmulPlanParams& params,
84     std::optional<int> max_algorithm_count = std::nullopt);
85 
86 template <typename T>
87 Status DoBlasLtMatmul(se::Stream* stream,
88                       const se::cuda::BlasLt::MatmulPlan& plan,
89                       const se::DeviceMemory<T>& a,
90                       const se::DeviceMemory<T>& b, se::DeviceMemory<T>& c,
91                       const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
92                       se::ScratchAllocator& scratch_allocator,
93                       const se::DeviceMemory<T>& bias = {},
94                       se::blas::ProfileResult* profile_result = nullptr) {
95   se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
96   // TF_RET_CHECK(blas_lt != nullptr);
97 
98   // The scale type may be f32 if the data type is f16 and bf16.
99   if constexpr (std::is_same_v<T, Eigen::half> ||
100                 std::is_same_v<T, Eigen::bfloat16>) {
101     if (plan.op_desc.scale_type() == CUDA_R_32F) {
102       return blas_lt->DoMatmul(stream, plan, se::HostOrDeviceScalar<float>(1.0),
103                                b, a, se::HostOrDeviceScalar<float>(0.0), c, c,
104                                algorithm, scratch_allocator, bias,
105                                profile_result);
106     }
107   }
108   return blas_lt->DoMatmul(stream, plan, se::HostOrDeviceScalar<T>(T(1.0)), b,
109                            a, se::HostOrDeviceScalar<T>(T(0.0)), c, c,
110                            algorithm, scratch_allocator, bias, profile_result);
111 }
112 
113 }  // namespace tensorflow
114 
115 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
116