xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/matmul_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_
18 
19 #include <cstdint>
20 #include <optional>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/stream_executor/blas.h"
32 
33 #if GOOGLE_CUDA
34 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
35 #include "tensorflow/stream_executor/scratch_allocator.h"
36 #endif  // GOOGLE_CUDA
37 
38 namespace xla {
39 namespace gpu {
40 
41 StatusOr<std::vector<int64_t>> GetNonContractingDims(
42     const Shape& shape, absl::Span<const int64_t> batch_dims,
43     absl::Span<const int64_t> contracting_dims);
44 
45 // Normalize shape to (batch, rows, columns) logical dimensions.
46 StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape,
47                                        absl::Span<const int64_t> batch_dims,
48                                        absl::Span<const int64_t> row_dims,
49                                        absl::Span<const int64_t> col_dims);
50 
51 struct MatrixLayout {
52   enum class Order {
53     kRowMajor,     // Elements in the same row are contiguous in memory.
54     kColumnMajor,  // Elements in the same column are contiguous in memory.
55   };
56 
57   // Returns the matrix layout for a logical shape (batch, rows, columns).
58   static StatusOr<MatrixLayout> For(const Shape& shape);
59   // Returns the matrix layout with the given batch, row, col dimensions.
60   static StatusOr<MatrixLayout> For(const Shape& shape,
61                                     absl::Span<const int64_t> batch_dims,
62                                     absl::Span<const int64_t> row_dims,
63                                     absl::Span<const int64_t> col_dims);
64   // Returns the matrix layout for the output.
65   static StatusOr<MatrixLayout> For(const Shape& shape,
66                                     size_t lhs_num_batch_dims,
67                                     size_t lhs_num_row_dims,
68                                     size_t rhs_num_batch_dims,
69                                     size_t rhs_num_col_dims);
70 
71   void Transpose();
72 
73   PrimitiveType dtype;
74   // `num_rows` / `num_cols` are for the "logical" matrix shape:
75   // i.e. the contracting dim has size `num_cols` for LHS operands and
76   // `num_rows` for RHS operands.
77   int64_t num_rows;
78   int64_t num_cols;
79   Order order;
80   int64_t leading_dim_stride;
81   int64_t batch_size;
82   int64_t batch_stride;  // `batch_stride` is set to `0` when `batch_size == 1`.
83 };
84 
85 // GPU folding rule for the `TransposeFolding` pass.
86 StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
87                                               int64_t operand_idx);
88 
89 struct GemmConfig {
90   static StatusOr<GemmConfig> For(const HloInstruction* gemm);
91   static StatusOr<GemmConfig> For(mlir::lmhlo_gpu::GEMMOp op);
92 
93   static StatusOr<GemmConfig> For(
94       const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
95       absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
96       absl::Span<const int64_t> rhs_batch_dims,
97       absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
98       double alpha_real, double alpha_imag, double beta,
99       std::optional<int64_t> algorithm, int64_t compute_precision);
100 
101   MatrixLayout lhs_layout;
102   MatrixLayout rhs_layout;
103   MatrixLayout output_layout;
104   complex128 alpha;
105   double beta;
106   std::optional<int64_t> algorithm;
107   int64_t compute_precision;
108 };
109 
110 // Run the given GEMM instruction `gemm` subject to the configuration
111 // in `gemm_config` and the passed buffers.
112 //
113 // If `algorithm` is provided, it overrides the one specified in `config`.
114 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
115                se::DeviceMemoryBase rhs_buffer,
116                se::DeviceMemoryBase output_buffer, se::Stream* stream,
117                std::optional<se::blas::AlgorithmType> algorithm = std::nullopt,
118                se::blas::ProfileResult* profile_result = nullptr);
119 
120 #if GOOGLE_CUDA
121 
122 namespace cublas_lt {
123 
124 StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
125     mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue);
126 
127 class MatmulPlan {
128  public:
129   static StatusOr<MatmulPlan> For(mlir::lmhlo_gpu::CublasLtMatmulOp op);
130   static StatusOr<MatmulPlan> From(const GemmConfig& config,
131                                    se::cuda::BlasLt::Epilogue epilogue);
132 
133   Status ExecuteOnStream(se::Stream* stream, se::DeviceMemoryBase a_buffer,
134                          se::DeviceMemoryBase b_buffer,
135                          se::DeviceMemoryBase c_buffer,
136                          se::DeviceMemoryBase d_buffer,
137                          se::DeviceMemoryBase bias_buffer,  // may be null
138                          const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
139                          se::ScratchAllocator& scratch_allocator,
140                          se::blas::ProfileResult* profile_result = nullptr);
141 
142   StatusOr<std::vector<se::cuda::BlasLt::MatmulAlgorithm>> GetAlgorithms(
143       se::Stream* stream) const;
144 
145  private:
MatmulPlan(se::cuda::BlasLt::MatmulPlan plan,complex128 alpha,double beta,bool must_swap_operands)146   MatmulPlan(se::cuda::BlasLt::MatmulPlan plan, complex128 alpha, double beta,
147              bool must_swap_operands)
148       : plan_(std::move(plan)),
149         alpha_(alpha),
150         beta_(beta),
151         must_swap_operands_(must_swap_operands) {}
152 
153   template <typename Input, typename Scale = Input>
154   Status DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer,
155                   se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer,
156                   se::DeviceMemoryBase d_buffer,
157                   se::DeviceMemoryBase bias_buffer,  // may be null
158                   const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
159                   se::ScratchAllocator& scratch_allocator,
160                   se::blas::ProfileResult* profile_result);
161 
162   se::cuda::BlasLt::MatmulPlan plan_;
163   complex128 alpha_;
164   double beta_;
165   bool must_swap_operands_;
166 };
167 
168 }  // namespace cublas_lt
169 
170 #endif  // GOOGLE_CUDA
171 
172 }  // namespace gpu
173 }  // namespace xla
174 
175 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_
176