1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ 18 19 #include <cstdint> 20 #include <optional> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/shape.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/stream_executor/blas.h" 32 33 #if GOOGLE_CUDA 34 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h" 35 #include "tensorflow/stream_executor/scratch_allocator.h" 36 #endif // GOOGLE_CUDA 37 38 namespace xla { 39 namespace gpu { 40 41 StatusOr<std::vector<int64_t>> GetNonContractingDims( 42 const Shape& shape, absl::Span<const int64_t> batch_dims, 43 absl::Span<const int64_t> contracting_dims); 44 45 // Normalize shape to (batch, rows, columns) logical dimensions. 46 StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape, 47 absl::Span<const int64_t> batch_dims, 48 absl::Span<const int64_t> row_dims, 49 absl::Span<const int64_t> col_dims); 50 51 struct MatrixLayout { 52 enum class Order { 53 kRowMajor, // Elements in the same row are contiguous in memory. 54 kColumnMajor, // Elements in the same column are contiguous in memory. 55 }; 56 57 // Returns the matrix layout for a logical shape (batch, rows, columns). 58 static StatusOr<MatrixLayout> For(const Shape& shape); 59 // Returns the matrix layout with the given batch, row, col dimensions. 60 static StatusOr<MatrixLayout> For(const Shape& shape, 61 absl::Span<const int64_t> batch_dims, 62 absl::Span<const int64_t> row_dims, 63 absl::Span<const int64_t> col_dims); 64 // Returns the matrix layout for the output. 65 static StatusOr<MatrixLayout> For(const Shape& shape, 66 size_t lhs_num_batch_dims, 67 size_t lhs_num_row_dims, 68 size_t rhs_num_batch_dims, 69 size_t rhs_num_col_dims); 70 71 void Transpose(); 72 73 PrimitiveType dtype; 74 // `num_rows` / `num_cols` are for the "logical" matrix shape: 75 // i.e. the contracting dim has size `num_cols` for LHS operands and 76 // `num_rows` for RHS operands. 77 int64_t num_rows; 78 int64_t num_cols; 79 Order order; 80 int64_t leading_dim_stride; 81 int64_t batch_size; 82 int64_t batch_stride; // `batch_stride` is set to `0` when `batch_size == 1`. 83 }; 84 85 // GPU folding rule for the `TransposeFolding` pass. 86 StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot, 87 int64_t operand_idx); 88 89 struct GemmConfig { 90 static StatusOr<GemmConfig> For(const HloInstruction* gemm); 91 static StatusOr<GemmConfig> For(mlir::lmhlo_gpu::GEMMOp op); 92 93 static StatusOr<GemmConfig> For( 94 const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims, 95 absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape, 96 absl::Span<const int64_t> rhs_batch_dims, 97 absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape, 98 double alpha_real, double alpha_imag, double beta, 99 std::optional<int64_t> algorithm, int64_t compute_precision); 100 101 MatrixLayout lhs_layout; 102 MatrixLayout rhs_layout; 103 MatrixLayout output_layout; 104 complex128 alpha; 105 double beta; 106 std::optional<int64_t> algorithm; 107 int64_t compute_precision; 108 }; 109 110 // Run the given GEMM instruction `gemm` subject to the configuration 111 // in `gemm_config` and the passed buffers. 112 // 113 // If `algorithm` is provided, it overrides the one specified in `config`. 114 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, 115 se::DeviceMemoryBase rhs_buffer, 116 se::DeviceMemoryBase output_buffer, se::Stream* stream, 117 std::optional<se::blas::AlgorithmType> algorithm = std::nullopt, 118 se::blas::ProfileResult* profile_result = nullptr); 119 120 #if GOOGLE_CUDA 121 122 namespace cublas_lt { 123 124 StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue( 125 mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue); 126 127 class MatmulPlan { 128 public: 129 static StatusOr<MatmulPlan> For(mlir::lmhlo_gpu::CublasLtMatmulOp op); 130 static StatusOr<MatmulPlan> From(const GemmConfig& config, 131 se::cuda::BlasLt::Epilogue epilogue); 132 133 Status ExecuteOnStream(se::Stream* stream, se::DeviceMemoryBase a_buffer, 134 se::DeviceMemoryBase b_buffer, 135 se::DeviceMemoryBase c_buffer, 136 se::DeviceMemoryBase d_buffer, 137 se::DeviceMemoryBase bias_buffer, // may be null 138 const se::cuda::BlasLt::MatmulAlgorithm& algorithm, 139 se::ScratchAllocator& scratch_allocator, 140 se::blas::ProfileResult* profile_result = nullptr); 141 142 StatusOr<std::vector<se::cuda::BlasLt::MatmulAlgorithm>> GetAlgorithms( 143 se::Stream* stream) const; 144 145 private: MatmulPlan(se::cuda::BlasLt::MatmulPlan plan,complex128 alpha,double beta,bool must_swap_operands)146 MatmulPlan(se::cuda::BlasLt::MatmulPlan plan, complex128 alpha, double beta, 147 bool must_swap_operands) 148 : plan_(std::move(plan)), 149 alpha_(alpha), 150 beta_(beta), 151 must_swap_operands_(must_swap_operands) {} 152 153 template <typename Input, typename Scale = Input> 154 Status DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer, 155 se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer, 156 se::DeviceMemoryBase d_buffer, 157 se::DeviceMemoryBase bias_buffer, // may be null 158 const se::cuda::BlasLt::MatmulAlgorithm& algorithm, 159 se::ScratchAllocator& scratch_allocator, 160 se::blas::ProfileResult* profile_result); 161 162 se::cuda::BlasLt::MatmulPlan plan_; 163 complex128 alpha_; 164 double beta_; 165 bool must_swap_operands_; 166 }; 167 168 } // namespace cublas_lt 169 170 #endif // GOOGLE_CUDA 171 172 } // namespace gpu 173 } // namespace xla 174 175 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MATMUL_UTILS_H_ 176