xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/matmul_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <optional>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/stream_executor/blas.h"
41 
42 #if GOOGLE_CUDA
43 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
44 #include "tensorflow/stream_executor/host_or_device_scalar.h"
45 #endif  // GOOGLE_CUDA
46 
47 namespace xla {
48 namespace gpu {
49 
GetNonContractingDims(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> contracting_dims)50 StatusOr<std::vector<int64_t>> GetNonContractingDims(
51     const Shape& shape, absl::Span<const int64_t> batch_dims,
52     absl::Span<const int64_t> contracting_dims) {
53   std::vector<int64_t> non_contracting_dims;
54   // This is O(rank**2), but we expect rank to be small.
55   for (int64_t dim = 0; dim < shape.rank(); ++dim) {
56     bool is_batch = absl::c_count(batch_dims, dim) != 0;
57     bool is_contracting = absl::c_count(contracting_dims, dim) != 0;
58     TF_RET_CHECK(!(is_batch && is_contracting));
59     if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim);
60   }
61 
62   TF_RET_CHECK(batch_dims.size() + contracting_dims.size() +
63                    non_contracting_dims.size() ==
64                shape.rank());
65   return non_contracting_dims;
66 }
67 
GetBatchRowColumnShape(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)68 StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape,
69                                        absl::Span<const int64_t> batch_dims,
70                                        absl::Span<const int64_t> row_dims,
71                                        absl::Span<const int64_t> col_dims) {
72   TF_RET_CHECK(shape.has_layout());
73   TF_RET_CHECK(!row_dims.empty());
74   TF_RET_CHECK(!col_dims.empty());
75 
76   std::vector<int64_t> minor_to_major;
77   for (size_t i = 0; i < shape.rank();) {
78     // The GeMM output always has its layout set such that the batch, row, and
79     // col dim groups are each laid out physically sequentially. GeMM operands
80     // must, therefore, be laid out similarly.
81     auto check_physically_sequential = [&](absl::Span<const int64_t> dims) {
82       for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
83         // NOTE: `i` is incremented as we check the dimensions.
84         if (*it != shape.layout().minor_to_major()[i++])
85           return InvalidArgument("dims not physically sequential");
86       }
87       return OkStatus();
88     };
89 
90     int64_t dim = shape.layout().minor_to_major()[i];
91     if (dim == row_dims.back()) {
92       minor_to_major.push_back(1);
93       TF_RETURN_IF_ERROR(check_physically_sequential(row_dims));
94     } else if (dim == col_dims.back()) {
95       minor_to_major.push_back(2);
96       TF_RETURN_IF_ERROR(check_physically_sequential(col_dims));
97     } else if (!batch_dims.empty() && (dim == batch_dims.back())) {
98       minor_to_major.push_back(0);
99       TF_RETURN_IF_ERROR(check_physically_sequential(batch_dims));
100     } else {
101       return InvalidArgument("dims not physically sequential");
102     }
103   }
104 
105   if (batch_dims.empty()) minor_to_major.push_back(0);
106 
107   auto dim_size = [&](absl::Span<const int64_t> dims) {
108     return absl::c_accumulate(dims, 1, [&](int64_t size, int64_t dim) {
109       return size * shape.dimensions(dim);
110     });
111   };
112 
113   return ShapeUtil::MakeShapeWithLayout(
114       shape.element_type(),
115       {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)},
116       minor_to_major);
117 }
118 
119 // Returns the matrix layout for a logical shape (batch, rows, columns).
For(const Shape & shape)120 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape) {
121   TF_RET_CHECK(shape.rank() == 3);
122   TF_RET_CHECK(shape.has_layout());
123 
124   int64_t batch_size = shape.dimensions(0);
125   int64_t num_rows = shape.dimensions(1);
126   int64_t num_cols = shape.dimensions(2);
127 
128   MatrixLayout::Order order = MatrixLayout::Order::kRowMajor;
129   int64_t leading_dim_stride = num_cols;
130   int64_t batch_stride = num_rows * num_cols;
131 
132   // `MatrixLayout`, like BLAS, uses only two strides, so either the row or
133   // column must be contiguous in memory (i.e. most minor physical dimension).
134   absl::Span<const int64_t> minor_to_major = shape.layout().minor_to_major();
135   switch (64 * minor_to_major[2] + 8 * minor_to_major[1] + minor_to_major[0]) {
136     case 012:  // (B,R,C) (major-to-minor)
137       break;
138     case 021:  // (B,C,R)
139       order = MatrixLayout::Order::kColumnMajor;
140       leading_dim_stride = num_rows;
141       break;
142     case 0102:  // (R,B,C)
143       leading_dim_stride = batch_size * num_cols;
144       batch_stride = num_cols;
145       break;
146     case 0201:  // (C,B,R)
147       order = MatrixLayout::Order::kColumnMajor;
148       leading_dim_stride = batch_size * num_rows;
149       batch_stride = num_rows;
150       break;
151     default:
152       return Unimplemented("batch in most minor dimension");
153   }
154 
155   if (batch_size == 1) batch_stride = 0;
156   return MatrixLayout{
157       shape.element_type(), num_rows,   num_cols,     order,
158       leading_dim_stride,   batch_size, batch_stride,
159   };
160 }
161 
For(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)162 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(
163     const Shape& shape, absl::Span<const int64_t> batch_dims,
164     absl::Span<const int64_t> row_dims, absl::Span<const int64_t> col_dims) {
165   TF_ASSIGN_OR_RETURN(
166       Shape batch_row_col_shape,
167       GetBatchRowColumnShape(shape, batch_dims, row_dims, col_dims));
168   return MatrixLayout::For(batch_row_col_shape);
169 }
170 
For(const Shape & shape,size_t lhs_num_batch_dims,size_t lhs_num_row_dims,size_t rhs_num_batch_dims,size_t rhs_num_col_dims)171 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape,
172                                                     size_t lhs_num_batch_dims,
173                                                     size_t lhs_num_row_dims,
174                                                     size_t rhs_num_batch_dims,
175                                                     size_t rhs_num_col_dims) {
176   size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims);
177 
178   TF_RET_CHECK(shape.rank() ==
179                num_batch_dims + lhs_num_row_dims + rhs_num_col_dims);
180 
181   std::vector<int64_t> dims(shape.rank());
182   absl::c_iota(dims, 0);
183 
184   auto batch_dims = absl::Span<const int64_t>(dims).first(num_batch_dims);
185   auto row_dims =
186       absl::Span<const int64_t>(dims).subspan(num_batch_dims, lhs_num_row_dims);
187   auto col_dims = absl::Span<const int64_t>(dims).last(rhs_num_col_dims);
188 
189   return MatrixLayout::For(shape, batch_dims, row_dims, col_dims);
190 }
191 
Transpose()192 void MatrixLayout::Transpose() {
193   std::swap(num_rows, num_cols);
194   order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor;
195 }
196 
CanFoldTransposeOperandIntoDot(const HloInstruction & dot,int64_t operand_idx)197 StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
198                                               int64_t operand_idx) {
199   TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
200   TF_RET_CHECK(dot.operand_count() > operand_idx);
201 
202   const HloInstruction& transpose = *dot.operand(operand_idx);
203   TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
204 
205   const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
206 
207   auto transposed = [&](const auto& dims) {
208     std::vector<int64_t> transposed_dims;
209     transposed_dims.reserve(dims.size());
210     for (int64_t dim : dims) {
211       transposed_dims.push_back(transpose.dimensions(dim));
212     }
213     return transposed_dims;
214   };
215 
216   auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
217                                        : dot_dims.rhs_batch_dimensions();
218   auto contracting_dims = (operand_idx == 0)
219                               ? dot_dims.lhs_contracting_dimensions()
220                               : dot_dims.rhs_contracting_dimensions();
221   TF_ASSIGN_OR_RETURN(
222       std::vector<int64_t> non_contracting_dims,
223       GetNonContractingDims(transpose.shape(), batch_dims, contracting_dims));
224 
225   // If we're able to construct a valid `MatrixLayout` for the transposed
226   // dimensions, then GeMM can support folding the transpose.
227   return MatrixLayout::For(transpose.operand(0)->shape(),
228                            transposed(batch_dims), transposed(contracting_dims),
229                            transposed(non_contracting_dims))
230       .ok();
231 }
232 
For(const Shape & lhs_shape,absl::Span<const int64_t> lhs_batch_dims,absl::Span<const int64_t> lhs_contracting_dims,const Shape & rhs_shape,absl::Span<const int64_t> rhs_batch_dims,absl::Span<const int64_t> rhs_contracting_dims,const Shape & output_shape,double alpha_real,double alpha_imag,double beta,std::optional<int64_t> algorithm,int64_t compute_precision)233 /*static*/ StatusOr<GemmConfig> GemmConfig::For(
234     const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
235     absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
236     absl::Span<const int64_t> rhs_batch_dims,
237     absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
238     double alpha_real, double alpha_imag, double beta,
239     std::optional<int64_t> algorithm, int64_t compute_precision) {
240   absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims;
241   TF_ASSIGN_OR_RETURN(
242       std::vector<int64_t> lhs_row_dims,
243       GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
244 
245   TF_ASSIGN_OR_RETURN(
246       MatrixLayout lhs_layout,
247       MatrixLayout::For(lhs_shape, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
248 
249   absl::Span<const int64_t> rhs_row_dims = rhs_contracting_dims;
250   TF_ASSIGN_OR_RETURN(
251       std::vector<int64_t> rhs_col_dims,
252       GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
253 
254   TF_ASSIGN_OR_RETURN(
255       MatrixLayout rhs_layout,
256       MatrixLayout::For(rhs_shape, rhs_batch_dims, rhs_row_dims, rhs_col_dims));
257 
258   int64_t num_batch_dims =
259       std::max(lhs_batch_dims.size(), rhs_batch_dims.size());
260 
261   TF_RET_CHECK(output_shape.rank() ==
262                num_batch_dims + lhs_row_dims.size() + rhs_col_dims.size());
263 
264   std::vector<int64_t> output_dims(output_shape.rank());
265   absl::c_iota(output_dims, 0);
266 
267   auto output_batch_dims =
268       absl::Span<const int64_t>(output_dims).first(num_batch_dims);
269   auto output_row_dims = absl::Span<const int64_t>(output_dims)
270                              .subspan(num_batch_dims, lhs_row_dims.size());
271   auto output_col_dims =
272       absl::Span<const int64_t>(output_dims).last(rhs_col_dims.size());
273 
274   TF_ASSIGN_OR_RETURN(MatrixLayout output_layout,
275                       MatrixLayout::For(output_shape, output_batch_dims,
276                                         output_row_dims, output_col_dims));
277 
278   // TODO(cjfj): We should also check that the batch, contracting and
279   // non-contracting dimensions match in size and relative physical location.
280   TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows);
281   TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows);
282   TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols);
283   TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) ||
284                (lhs_layout.batch_size == 1));
285   TF_RET_CHECK((rhs_layout.batch_size == output_layout.batch_size) ||
286                (rhs_layout.batch_size == 1));
287 
288   switch (output_shape.element_type()) {
289     case F16:
290     case BF16:
291     case F32:
292     case F64:
293       TF_RET_CHECK(alpha_imag == 0);
294       break;
295     case C64:
296     case C128:
297       break;
298     case S32:
299       TF_RET_CHECK(alpha_imag == 0);
300       if (lhs_layout.dtype != PrimitiveType::S8 ||
301           rhs_layout.dtype != PrimitiveType::S8) {
302         return InternalError(
303             "For int32 gemm output only int8 input is supported, got input: "
304             "%s, %s",
305             primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype),
306             primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype));
307       }
308       break;
309     default:
310       return InternalError("Unexpected GEMM datatype: %s",
311                            primitive_util::LowercasePrimitiveTypeName(
312                                output_shape.element_type()));
313   }
314 
315   return GemmConfig{
316       lhs_layout, rhs_layout, output_layout,     {alpha_real, alpha_imag},
317       beta,       algorithm,  compute_precision,
318   };
319 }
320 
For(const HloInstruction * gemm)321 /*static*/ StatusOr<GemmConfig> GemmConfig::For(const HloInstruction* gemm) {
322   TF_ASSIGN_OR_RETURN(GemmBackendConfig config,
323                       gemm->backend_config<GemmBackendConfig>());
324 
325   std::optional<int64_t> algorithm;
326   if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) {
327     algorithm = config.selected_algorithm();
328   }
329 
330   const Shape& lhs_shape = gemm->operand(0)->shape();
331   const Shape& rhs_shape = gemm->operand(1)->shape();
332   const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers();
333 
334   return GemmConfig::For(
335       lhs_shape, dot_dims.lhs_batch_dimensions(),
336       dot_dims.lhs_contracting_dimensions(), rhs_shape,
337       dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(),
338       /*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(),
339       config.beta(), algorithm, se::blas::kDefaultComputePrecision);
340 }
341 
For(mlir::lmhlo_gpu::GEMMOp op)342 /*static*/ StatusOr<GemmConfig> GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) {
343   mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers();
344 
345   std::optional<int64_t> algorithm;
346   if (op.getAlgorithm()) algorithm = *op.getAlgorithm();
347 
348   int64_t compute_precision = 0;  // Default
349   if (op.getPrecisionConfig().has_value()) {
350     auto precision_config = op.getPrecisionConfig();
351     for (auto attr : precision_config.getValue()) {
352       int64_t value = static_cast<int64_t>(
353           attr.template cast<mlir::mhlo::PrecisionAttr>().getValue());
354       if (value > compute_precision) {
355         compute_precision = value;
356       }
357     }
358   }
359 
360   return GemmConfig::For(
361       GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(),
362       dot_dims.getLhsContractingDimensions(), GetShape(op.getB()),
363       dot_dims.getRhsBatchingDimensions(),
364       dot_dims.getRhsContractingDimensions(), GetShape(op.getC()),
365       op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(),
366       op.getBeta().convertToDouble(), algorithm, compute_precision);
367 }
368 
369 namespace {
370 
371 // BLAS GeMM's output is column-major. If we require row-major, use identity:
372 // C^T = (A @ B)^T = B^T @ A^T.
MakeOutputColumnMajor(MatrixLayout & lhs,MatrixLayout & rhs,MatrixLayout & output)373 bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs,
374                            MatrixLayout& output) {
375   bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor;
376   if (swap_operands) {
377     std::swap(lhs, rhs);
378     lhs.Transpose();
379     rhs.Transpose();
380     output.Transpose();
381   }
382   return swap_operands;
383 }
384 
GetBlasComputationType(PrimitiveType dtype)385 StatusOr<se::blas::ComputationType> GetBlasComputationType(
386     PrimitiveType dtype) {
387   switch (dtype) {
388     case F16:  // fall-through
389     case BF16:
390       // Accumulate in f32 precision.
391       return se::blas::ComputationType::kF32;
392     case F32:  // fall-through
393     case C64:
394       return se::blas::ComputationType::kTF32AsF32;
395     case F64:  // fall-through
396     case C128:
397       return se::blas::ComputationType::kF64;
398     case S32:
399       return se::blas::ComputationType::kI32;
400     default:
401       return InternalError("unsupported type");
402   }
403 }
404 
AsBlasTranspose(MatrixLayout::Order order)405 se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) {
406   // BLAS is column-major by default.
407   return (order == MatrixLayout::Order::kColumnMajor)
408              ? se::blas::Transpose::kNoTranspose
409              : se::blas::Transpose::kTranspose;
410 }
411 
GetMatrixDesc(const MatrixLayout & layout,se::DeviceMemoryBase data)412 se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout,
413                                          se::DeviceMemoryBase data) {
414   return {
415       data,
416       layout.leading_dim_stride,
417       layout.batch_stride,
418       AsBlasTranspose(layout.order),
419   };
420 }
421 
422 template <typename Input, typename Output>
DoGemmWithAlgorithm(int64_t batch_size,int64_t m,int64_t n,int64_t k,const se::blas::MatrixDescriptor & lhs,const se::blas::MatrixDescriptor & rhs,const se::blas::MatrixDescriptor & output,Output alpha,Output beta,se::Stream * stream,se::blas::AlgorithmType algorithm,se::blas::ProfileResult * profile_result)423 Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
424                            const se::blas::MatrixDescriptor& lhs,
425                            const se::blas::MatrixDescriptor& rhs,
426                            const se::blas::MatrixDescriptor& output,
427                            Output alpha, Output beta, se::Stream* stream,
428                            se::blas::AlgorithmType algorithm,
429                            se::blas::ProfileResult* profile_result) {
430   CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
431   PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
432   TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
433                       GetBlasComputationType(output_type));
434   se::DeviceMemory<Output> output_data(output.data);
435 
436   if (batch_size != 1) {
437     return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
438         lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
439         lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
440         rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
441         output.leading_dim_stride, output.batch_stride, batch_size,
442         computation_type, algorithm, profile_result);
443   } else {
444     return stream->ThenBlasGemmWithAlgorithm(
445         lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
446         lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
447         &output_data, output.leading_dim_stride, computation_type, algorithm,
448         profile_result);
449   }
450 }
451 
452 template <typename Input>
DoGemm(int64_t batch_size,int64_t m,int64_t n,int64_t k,const se::blas::MatrixDescriptor & lhs,const se::blas::MatrixDescriptor & rhs,const se::blas::MatrixDescriptor & output,Input alpha,Input beta,se::Stream * stream,std::optional<se::blas::AlgorithmType> algorithm,se::blas::ComputePrecision compute_precision,se::blas::ProfileResult * profile_result)453 Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
454               const se::blas::MatrixDescriptor& lhs,
455               const se::blas::MatrixDescriptor& rhs,
456               const se::blas::MatrixDescriptor& output, Input alpha, Input beta,
457               se::Stream* stream,
458               std::optional<se::blas::AlgorithmType> algorithm,
459               se::blas::ComputePrecision compute_precision,
460               se::blas::ProfileResult* profile_result) {
461   CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
462   se::DeviceMemory<Input> output_data(output.data);
463 
464   if (algorithm) {
465     return DoGemmWithAlgorithm<Input, Input>(batch_size, m, n, k, lhs, rhs,
466                                              output, alpha, beta, stream,
467                                              *algorithm, profile_result);
468   }
469 
470   if (batch_size != 1) {
471     return stream->ThenBlasGemmStridedBatched(
472         lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
473         lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
474         rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
475         output.leading_dim_stride, output.batch_stride, batch_size);
476   }
477 
478   return stream->ThenBlasGemm(
479       lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
480       lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
481       &output_data, output.leading_dim_stride, compute_precision);
482 }
483 
484 }  // namespace
485 
RunGemm(const GemmConfig & config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,std::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * profile_result)486 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
487                se::DeviceMemoryBase rhs_buffer,
488                se::DeviceMemoryBase output_buffer, se::Stream* stream,
489                std::optional<se::blas::AlgorithmType> algorithm,
490                se::blas::ProfileResult* profile_result) {
491   VLOG(2) << "Executing a GemmThunk";
492 
493   MatrixLayout lhs_layout = config.lhs_layout;
494   MatrixLayout rhs_layout = config.rhs_layout;
495   MatrixLayout output_layout = config.output_layout;
496   bool must_swap_operands =
497       MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout);
498   if (must_swap_operands) {
499     std::swap(lhs_buffer, rhs_buffer);
500   }
501 
502   int64_t m = output_layout.num_rows;
503   int64_t n = output_layout.num_cols;
504   int64_t k = lhs_layout.num_cols;
505   se::blas::MatrixDescriptor lhs = GetMatrixDesc(lhs_layout, lhs_buffer);
506   se::blas::MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer);
507   se::blas::MatrixDescriptor output =
508       GetMatrixDesc(output_layout, output_buffer);
509   int64_t batch_size = output_layout.batch_size;
510 
511   if (!algorithm) algorithm = config.algorithm;
512 
513   switch (output_layout.dtype) {
514     case S32:
515       if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo;
516       return DoGemmWithAlgorithm<int8_t, int32_t>(
517           batch_size, m, n, k, lhs, rhs, output,
518           static_cast<int32_t>(config.alpha.real()),
519           static_cast<int32_t>(config.beta), stream, *algorithm,
520           profile_result);
521     case F16:
522       return DoGemm<Eigen::half>(batch_size, m, n, k, lhs, rhs, output,
523                                  static_cast<Eigen::half>(config.alpha.real()),
524                                  static_cast<Eigen::half>(config.beta), stream,
525                                  algorithm, config.compute_precision,
526                                  profile_result);
527     case BF16:
528       return DoGemm<Eigen::bfloat16>(
529           batch_size, m, n, k, lhs, rhs, output,
530           static_cast<Eigen::bfloat16>(config.alpha.real()),
531           static_cast<Eigen::bfloat16>(config.beta), stream, algorithm,
532           config.compute_precision, profile_result);
533     case F32:
534       return DoGemm<float>(batch_size, m, n, k, lhs, rhs, output,
535                            config.alpha.real(), config.beta, stream, algorithm,
536                            config.compute_precision, profile_result);
537     case F64:
538       return DoGemm<double>(batch_size, m, n, k, lhs, rhs, output,
539                             config.alpha.real(), config.beta, stream, algorithm,
540                             config.compute_precision, profile_result);
541     case C64:
542       return DoGemm<complex64>(batch_size, m, n, k, lhs, rhs, output,
543                                static_cast<complex64>(config.alpha),
544                                static_cast<complex64>(config.beta), stream,
545                                algorithm, config.compute_precision,
546                                profile_result);
547     case C128:
548       return DoGemm<complex128>(
549           batch_size, m, n, k, lhs, rhs, output, config.alpha,
550           static_cast<complex128>(config.beta), stream, algorithm,
551           config.compute_precision, profile_result);
552     default:
553       return InternalError(
554           "Unexpected GEMM dtype: %s",
555           primitive_util::LowercasePrimitiveTypeName(output_layout.dtype));
556   }
557 }
558 
559 #if GOOGLE_CUDA
560 
561 namespace {
562 
AsBlasDataType(PrimitiveType dtype)563 StatusOr<se::blas::DataType> AsBlasDataType(PrimitiveType dtype) {
564   switch (dtype) {
565     case F16:
566       return se::blas::DataType::kHalf;
567     case BF16:
568       return se::blas::DataType::kBF16;
569     case F32:
570       return se::blas::DataType::kFloat;
571     case F64:
572       return se::blas::DataType::kDouble;
573     case C64:
574       return se::blas::DataType::kComplexFloat;
575     case C128:
576       return se::blas::DataType::kComplexDouble;
577     default:
578       return InternalError("unsupported type");
579   }
580 }
581 
AsBlasLtMatrixLayout(const MatrixLayout & layout)582 StatusOr<se::cuda::BlasLt::MatrixLayout> AsBlasLtMatrixLayout(
583     const MatrixLayout& layout) {
584   TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(layout.dtype));
585 
586   auto order = (layout.order == MatrixLayout::Order::kColumnMajor)
587                    ? se::cuda::BlasLt::MatrixLayout::Order::kColumnMajor
588                    : se::cuda::BlasLt::MatrixLayout::Order::kRowMajor;
589 
590   return se::cuda::BlasLt::MatrixLayout::Create(
591       dtype, layout.num_rows, layout.num_cols, order, layout.batch_size,
592       layout.leading_dim_stride, layout.batch_stride);
593 }
594 
595 }  // namespace
596 
597 namespace cublas_lt {
598 
AsBlasLtEpilogue(mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue)599 StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
600     mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) {
601   switch (epilogue) {
602     case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default:
603       return se::cuda::BlasLt::Epilogue::kDefault;
604     case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias:
605       return se::cuda::BlasLt::Epilogue::kBias;
606     default:
607       return InternalError("unknown epilogue");
608   }
609 }
610 
For(mlir::lmhlo_gpu::CublasLtMatmulOp op)611 /*static*/ StatusOr<MatmulPlan> MatmulPlan::For(
612     mlir::lmhlo_gpu::CublasLtMatmulOp op) {
613   mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers();
614 
615   int64_t compute_precision = 0;  // Default
616   if (op.getPrecisionConfig().hasValue()) {
617     auto precision_config = op.getPrecisionConfig();
618     for (auto attr : precision_config.getValue()) {
619       int64_t value = static_cast<int64_t>(
620           attr.template cast<mlir::mhlo::PrecisionAttr>().getValue());
621       if (value > compute_precision) {
622         compute_precision = value;
623       }
624     }
625   }
626 
627   TF_ASSIGN_OR_RETURN(
628       GemmConfig config,
629       GemmConfig::For(GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(),
630                       dot_dims.getLhsContractingDimensions(),
631                       GetShape(op.getB()), dot_dims.getRhsBatchingDimensions(),
632                       dot_dims.getRhsContractingDimensions(),
633                       GetShape(op.getC()), op.getAlphaReal().convertToDouble(),
634                       op.getAlphaImag().convertToDouble(),
635                       op.getBeta().convertToDouble(), op.getAlgorithm(),
636                       compute_precision));
637 
638   TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue,
639                       AsBlasLtEpilogue(op.getEpilogue()));
640   return From(config, epilogue);
641 }
642 
From(const GemmConfig & config,se::cuda::BlasLt::Epilogue epilogue)643 /*static*/ StatusOr<MatmulPlan> MatmulPlan::From(
644     const GemmConfig& config, se::cuda::BlasLt::Epilogue epilogue) {
645   MatrixLayout lhs_layout = config.lhs_layout;
646   MatrixLayout rhs_layout = config.rhs_layout;
647   MatrixLayout output_layout = config.output_layout;
648 
649   // cublasLt matmul requires batch sizes to be equal. If only one operand has a
650   // batch, the other will be broadcast (as its batch_stride == 0).
651   size_t batch_size = std::max(lhs_layout.batch_size, rhs_layout.batch_size);
652   lhs_layout.batch_size = batch_size;
653   rhs_layout.batch_size = batch_size;
654 
655   bool must_swap_operands =
656       MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout);
657 
658   TF_ASSIGN_OR_RETURN(se::blas::DataType output_dtype,
659                       AsBlasDataType(output_layout.dtype));
660   TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
661                       GetBlasComputationType(output_layout.dtype));
662   TF_ASSIGN_OR_RETURN(
663       se::cuda::BlasLt::MatmulDesc op_desc,
664       se::cuda::BlasLt::MatmulDesc::Create(
665           computation_type,
666           se::cuda::BlasLt::GetScaleType(output_dtype, computation_type),
667           /*trans_a=*/se::blas::Transpose::kNoTranspose,
668           /*trans_b=*/se::blas::Transpose::kNoTranspose, epilogue));
669 
670   TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout a_desc,
671                       AsBlasLtMatrixLayout(lhs_layout));
672   TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout b_desc,
673                       AsBlasLtMatrixLayout(rhs_layout));
674   TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout c_desc,
675                       AsBlasLtMatrixLayout(output_layout));
676   TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout d_desc,
677                       AsBlasLtMatrixLayout(output_layout));
678 
679   return MatmulPlan{
680       se::cuda::BlasLt::MatmulPlan{std::move(op_desc), std::move(a_desc),
681                                    std::move(b_desc), std::move(c_desc),
682                                    std::move(d_desc)},
683       config.alpha, config.beta, must_swap_operands};
684 }
685 
686 template <typename Input, typename Scale>
DoMatmul(se::Stream * stream,se::DeviceMemoryBase a_buffer,se::DeviceMemoryBase b_buffer,se::DeviceMemoryBase c_buffer,se::DeviceMemoryBase d_buffer,se::DeviceMemoryBase bias_buffer,const se::cuda::BlasLt::MatmulAlgorithm & algorithm,se::ScratchAllocator & scratch_allocator,se::blas::ProfileResult * profile_result)687 Status MatmulPlan::DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer,
688                             se::DeviceMemoryBase b_buffer,
689                             se::DeviceMemoryBase c_buffer,
690                             se::DeviceMemoryBase d_buffer,
691                             se::DeviceMemoryBase bias_buffer,
692                             const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
693                             se::ScratchAllocator& scratch_allocator,
694                             se::blas::ProfileResult* profile_result) {
695   se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
696   TF_RET_CHECK(blas_lt != nullptr);
697 
698   Scale alpha;
699   if constexpr (std::is_same_v<Scale, complex64> ||
700                 std::is_same_v<Scale, complex128>) {
701     alpha = static_cast<Scale>(alpha_);
702   } else {
703     alpha = static_cast<Scale>(alpha_.real());
704   }
705 
706   Scale beta = static_cast<Scale>(beta_);
707 
708   se::DeviceMemory<Input> output(d_buffer);
709   return blas_lt->DoMatmul(
710       stream, plan_, se::HostOrDeviceScalar<Scale>(alpha),
711       se::DeviceMemory<Input>(a_buffer), se::DeviceMemory<Input>(b_buffer),
712       se::HostOrDeviceScalar<Scale>(beta), se::DeviceMemory<Input>(c_buffer),
713       output, algorithm, scratch_allocator,
714       se::DeviceMemory<Input>(bias_buffer), profile_result);
715 }
716 
ExecuteOnStream(se::Stream * stream,se::DeviceMemoryBase a_buffer,se::DeviceMemoryBase b_buffer,se::DeviceMemoryBase c_buffer,se::DeviceMemoryBase d_buffer,se::DeviceMemoryBase bias_buffer,const se::cuda::BlasLt::MatmulAlgorithm & algorithm,se::ScratchAllocator & scratch_allocator,se::blas::ProfileResult * profile_result)717 Status MatmulPlan::ExecuteOnStream(
718     se::Stream* stream, se::DeviceMemoryBase a_buffer,
719     se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer,
720     se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer,
721     const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
722     se::ScratchAllocator& scratch_allocator,
723     se::blas::ProfileResult* profile_result) {
724   if (must_swap_operands_) {
725     std::swap(a_buffer, b_buffer);
726   }
727 
728   switch (plan_.d_desc.type()) {
729     case CUDA_R_16F:
730       return DoMatmul<Eigen::half, float>(stream, a_buffer, b_buffer, c_buffer,
731                                           d_buffer, bias_buffer, algorithm,
732                                           scratch_allocator, profile_result);
733     case CUDA_R_16BF:
734       return DoMatmul<Eigen::bfloat16, float>(
735           stream, a_buffer, b_buffer, c_buffer, d_buffer, bias_buffer,
736           algorithm, scratch_allocator, profile_result);
737     case CUDA_R_32F:
738       return DoMatmul<float>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
739                              bias_buffer, algorithm, scratch_allocator,
740                              profile_result);
741     case CUDA_R_64F:
742       return DoMatmul<double>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
743                               bias_buffer, algorithm, scratch_allocator,
744                               profile_result);
745     case CUDA_C_32F:
746       return DoMatmul<complex64>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
747                                  bias_buffer, algorithm, scratch_allocator,
748                                  profile_result);
749     case CUDA_C_64F:
750       return DoMatmul<complex128>(stream, a_buffer, b_buffer, c_buffer,
751                                   d_buffer, bias_buffer, algorithm,
752                                   scratch_allocator, profile_result);
753     default:
754       return InternalError("Unexpected dtype");
755   }
756 }
757 
758 StatusOr<std::vector<se::cuda::BlasLt::MatmulAlgorithm>>
GetAlgorithms(se::Stream * stream) const759 MatmulPlan::GetAlgorithms(se::Stream* stream) const {
760   se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
761   TF_RET_CHECK(blas_lt != nullptr);
762   TF_ASSIGN_OR_RETURN(auto preference,
763                       se::cuda::BlasLt::MatmulPreference::Create(
764                           /*max_workspace_size=*/1ll << 32));  // 4GB
765   return blas_lt->GetMatmulAlgorithms(plan_, preference);
766 }
767 
768 }  // namespace cublas_lt
769 
770 #endif  // GOOGLE_CUDA
771 
772 }  // namespace gpu
773 }  // namespace xla
774