1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <optional>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/stream_executor/blas.h"
41
42 #if GOOGLE_CUDA
43 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
44 #include "tensorflow/stream_executor/host_or_device_scalar.h"
45 #endif // GOOGLE_CUDA
46
47 namespace xla {
48 namespace gpu {
49
GetNonContractingDims(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> contracting_dims)50 StatusOr<std::vector<int64_t>> GetNonContractingDims(
51 const Shape& shape, absl::Span<const int64_t> batch_dims,
52 absl::Span<const int64_t> contracting_dims) {
53 std::vector<int64_t> non_contracting_dims;
54 // This is O(rank**2), but we expect rank to be small.
55 for (int64_t dim = 0; dim < shape.rank(); ++dim) {
56 bool is_batch = absl::c_count(batch_dims, dim) != 0;
57 bool is_contracting = absl::c_count(contracting_dims, dim) != 0;
58 TF_RET_CHECK(!(is_batch && is_contracting));
59 if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim);
60 }
61
62 TF_RET_CHECK(batch_dims.size() + contracting_dims.size() +
63 non_contracting_dims.size() ==
64 shape.rank());
65 return non_contracting_dims;
66 }
67
GetBatchRowColumnShape(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)68 StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape,
69 absl::Span<const int64_t> batch_dims,
70 absl::Span<const int64_t> row_dims,
71 absl::Span<const int64_t> col_dims) {
72 TF_RET_CHECK(shape.has_layout());
73 TF_RET_CHECK(!row_dims.empty());
74 TF_RET_CHECK(!col_dims.empty());
75
76 std::vector<int64_t> minor_to_major;
77 for (size_t i = 0; i < shape.rank();) {
78 // The GeMM output always has its layout set such that the batch, row, and
79 // col dim groups are each laid out physically sequentially. GeMM operands
80 // must, therefore, be laid out similarly.
81 auto check_physically_sequential = [&](absl::Span<const int64_t> dims) {
82 for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
83 // NOTE: `i` is incremented as we check the dimensions.
84 if (*it != shape.layout().minor_to_major()[i++])
85 return InvalidArgument("dims not physically sequential");
86 }
87 return OkStatus();
88 };
89
90 int64_t dim = shape.layout().minor_to_major()[i];
91 if (dim == row_dims.back()) {
92 minor_to_major.push_back(1);
93 TF_RETURN_IF_ERROR(check_physically_sequential(row_dims));
94 } else if (dim == col_dims.back()) {
95 minor_to_major.push_back(2);
96 TF_RETURN_IF_ERROR(check_physically_sequential(col_dims));
97 } else if (!batch_dims.empty() && (dim == batch_dims.back())) {
98 minor_to_major.push_back(0);
99 TF_RETURN_IF_ERROR(check_physically_sequential(batch_dims));
100 } else {
101 return InvalidArgument("dims not physically sequential");
102 }
103 }
104
105 if (batch_dims.empty()) minor_to_major.push_back(0);
106
107 auto dim_size = [&](absl::Span<const int64_t> dims) {
108 return absl::c_accumulate(dims, 1, [&](int64_t size, int64_t dim) {
109 return size * shape.dimensions(dim);
110 });
111 };
112
113 return ShapeUtil::MakeShapeWithLayout(
114 shape.element_type(),
115 {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)},
116 minor_to_major);
117 }
118
119 // Returns the matrix layout for a logical shape (batch, rows, columns).
For(const Shape & shape)120 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape) {
121 TF_RET_CHECK(shape.rank() == 3);
122 TF_RET_CHECK(shape.has_layout());
123
124 int64_t batch_size = shape.dimensions(0);
125 int64_t num_rows = shape.dimensions(1);
126 int64_t num_cols = shape.dimensions(2);
127
128 MatrixLayout::Order order = MatrixLayout::Order::kRowMajor;
129 int64_t leading_dim_stride = num_cols;
130 int64_t batch_stride = num_rows * num_cols;
131
132 // `MatrixLayout`, like BLAS, uses only two strides, so either the row or
133 // column must be contiguous in memory (i.e. most minor physical dimension).
134 absl::Span<const int64_t> minor_to_major = shape.layout().minor_to_major();
135 switch (64 * minor_to_major[2] + 8 * minor_to_major[1] + minor_to_major[0]) {
136 case 012: // (B,R,C) (major-to-minor)
137 break;
138 case 021: // (B,C,R)
139 order = MatrixLayout::Order::kColumnMajor;
140 leading_dim_stride = num_rows;
141 break;
142 case 0102: // (R,B,C)
143 leading_dim_stride = batch_size * num_cols;
144 batch_stride = num_cols;
145 break;
146 case 0201: // (C,B,R)
147 order = MatrixLayout::Order::kColumnMajor;
148 leading_dim_stride = batch_size * num_rows;
149 batch_stride = num_rows;
150 break;
151 default:
152 return Unimplemented("batch in most minor dimension");
153 }
154
155 if (batch_size == 1) batch_stride = 0;
156 return MatrixLayout{
157 shape.element_type(), num_rows, num_cols, order,
158 leading_dim_stride, batch_size, batch_stride,
159 };
160 }
161
For(const Shape & shape,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)162 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(
163 const Shape& shape, absl::Span<const int64_t> batch_dims,
164 absl::Span<const int64_t> row_dims, absl::Span<const int64_t> col_dims) {
165 TF_ASSIGN_OR_RETURN(
166 Shape batch_row_col_shape,
167 GetBatchRowColumnShape(shape, batch_dims, row_dims, col_dims));
168 return MatrixLayout::For(batch_row_col_shape);
169 }
170
For(const Shape & shape,size_t lhs_num_batch_dims,size_t lhs_num_row_dims,size_t rhs_num_batch_dims,size_t rhs_num_col_dims)171 /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape,
172 size_t lhs_num_batch_dims,
173 size_t lhs_num_row_dims,
174 size_t rhs_num_batch_dims,
175 size_t rhs_num_col_dims) {
176 size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims);
177
178 TF_RET_CHECK(shape.rank() ==
179 num_batch_dims + lhs_num_row_dims + rhs_num_col_dims);
180
181 std::vector<int64_t> dims(shape.rank());
182 absl::c_iota(dims, 0);
183
184 auto batch_dims = absl::Span<const int64_t>(dims).first(num_batch_dims);
185 auto row_dims =
186 absl::Span<const int64_t>(dims).subspan(num_batch_dims, lhs_num_row_dims);
187 auto col_dims = absl::Span<const int64_t>(dims).last(rhs_num_col_dims);
188
189 return MatrixLayout::For(shape, batch_dims, row_dims, col_dims);
190 }
191
Transpose()192 void MatrixLayout::Transpose() {
193 std::swap(num_rows, num_cols);
194 order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor;
195 }
196
CanFoldTransposeOperandIntoDot(const HloInstruction & dot,int64_t operand_idx)197 StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
198 int64_t operand_idx) {
199 TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
200 TF_RET_CHECK(dot.operand_count() > operand_idx);
201
202 const HloInstruction& transpose = *dot.operand(operand_idx);
203 TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
204
205 const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
206
207 auto transposed = [&](const auto& dims) {
208 std::vector<int64_t> transposed_dims;
209 transposed_dims.reserve(dims.size());
210 for (int64_t dim : dims) {
211 transposed_dims.push_back(transpose.dimensions(dim));
212 }
213 return transposed_dims;
214 };
215
216 auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
217 : dot_dims.rhs_batch_dimensions();
218 auto contracting_dims = (operand_idx == 0)
219 ? dot_dims.lhs_contracting_dimensions()
220 : dot_dims.rhs_contracting_dimensions();
221 TF_ASSIGN_OR_RETURN(
222 std::vector<int64_t> non_contracting_dims,
223 GetNonContractingDims(transpose.shape(), batch_dims, contracting_dims));
224
225 // If we're able to construct a valid `MatrixLayout` for the transposed
226 // dimensions, then GeMM can support folding the transpose.
227 return MatrixLayout::For(transpose.operand(0)->shape(),
228 transposed(batch_dims), transposed(contracting_dims),
229 transposed(non_contracting_dims))
230 .ok();
231 }
232
For(const Shape & lhs_shape,absl::Span<const int64_t> lhs_batch_dims,absl::Span<const int64_t> lhs_contracting_dims,const Shape & rhs_shape,absl::Span<const int64_t> rhs_batch_dims,absl::Span<const int64_t> rhs_contracting_dims,const Shape & output_shape,double alpha_real,double alpha_imag,double beta,std::optional<int64_t> algorithm,int64_t compute_precision)233 /*static*/ StatusOr<GemmConfig> GemmConfig::For(
234 const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
235 absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
236 absl::Span<const int64_t> rhs_batch_dims,
237 absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
238 double alpha_real, double alpha_imag, double beta,
239 std::optional<int64_t> algorithm, int64_t compute_precision) {
240 absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims;
241 TF_ASSIGN_OR_RETURN(
242 std::vector<int64_t> lhs_row_dims,
243 GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
244
245 TF_ASSIGN_OR_RETURN(
246 MatrixLayout lhs_layout,
247 MatrixLayout::For(lhs_shape, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
248
249 absl::Span<const int64_t> rhs_row_dims = rhs_contracting_dims;
250 TF_ASSIGN_OR_RETURN(
251 std::vector<int64_t> rhs_col_dims,
252 GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
253
254 TF_ASSIGN_OR_RETURN(
255 MatrixLayout rhs_layout,
256 MatrixLayout::For(rhs_shape, rhs_batch_dims, rhs_row_dims, rhs_col_dims));
257
258 int64_t num_batch_dims =
259 std::max(lhs_batch_dims.size(), rhs_batch_dims.size());
260
261 TF_RET_CHECK(output_shape.rank() ==
262 num_batch_dims + lhs_row_dims.size() + rhs_col_dims.size());
263
264 std::vector<int64_t> output_dims(output_shape.rank());
265 absl::c_iota(output_dims, 0);
266
267 auto output_batch_dims =
268 absl::Span<const int64_t>(output_dims).first(num_batch_dims);
269 auto output_row_dims = absl::Span<const int64_t>(output_dims)
270 .subspan(num_batch_dims, lhs_row_dims.size());
271 auto output_col_dims =
272 absl::Span<const int64_t>(output_dims).last(rhs_col_dims.size());
273
274 TF_ASSIGN_OR_RETURN(MatrixLayout output_layout,
275 MatrixLayout::For(output_shape, output_batch_dims,
276 output_row_dims, output_col_dims));
277
278 // TODO(cjfj): We should also check that the batch, contracting and
279 // non-contracting dimensions match in size and relative physical location.
280 TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows);
281 TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows);
282 TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols);
283 TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) ||
284 (lhs_layout.batch_size == 1));
285 TF_RET_CHECK((rhs_layout.batch_size == output_layout.batch_size) ||
286 (rhs_layout.batch_size == 1));
287
288 switch (output_shape.element_type()) {
289 case F16:
290 case BF16:
291 case F32:
292 case F64:
293 TF_RET_CHECK(alpha_imag == 0);
294 break;
295 case C64:
296 case C128:
297 break;
298 case S32:
299 TF_RET_CHECK(alpha_imag == 0);
300 if (lhs_layout.dtype != PrimitiveType::S8 ||
301 rhs_layout.dtype != PrimitiveType::S8) {
302 return InternalError(
303 "For int32 gemm output only int8 input is supported, got input: "
304 "%s, %s",
305 primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype),
306 primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype));
307 }
308 break;
309 default:
310 return InternalError("Unexpected GEMM datatype: %s",
311 primitive_util::LowercasePrimitiveTypeName(
312 output_shape.element_type()));
313 }
314
315 return GemmConfig{
316 lhs_layout, rhs_layout, output_layout, {alpha_real, alpha_imag},
317 beta, algorithm, compute_precision,
318 };
319 }
320
For(const HloInstruction * gemm)321 /*static*/ StatusOr<GemmConfig> GemmConfig::For(const HloInstruction* gemm) {
322 TF_ASSIGN_OR_RETURN(GemmBackendConfig config,
323 gemm->backend_config<GemmBackendConfig>());
324
325 std::optional<int64_t> algorithm;
326 if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) {
327 algorithm = config.selected_algorithm();
328 }
329
330 const Shape& lhs_shape = gemm->operand(0)->shape();
331 const Shape& rhs_shape = gemm->operand(1)->shape();
332 const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers();
333
334 return GemmConfig::For(
335 lhs_shape, dot_dims.lhs_batch_dimensions(),
336 dot_dims.lhs_contracting_dimensions(), rhs_shape,
337 dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(),
338 /*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(),
339 config.beta(), algorithm, se::blas::kDefaultComputePrecision);
340 }
341
For(mlir::lmhlo_gpu::GEMMOp op)342 /*static*/ StatusOr<GemmConfig> GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) {
343 mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers();
344
345 std::optional<int64_t> algorithm;
346 if (op.getAlgorithm()) algorithm = *op.getAlgorithm();
347
348 int64_t compute_precision = 0; // Default
349 if (op.getPrecisionConfig().has_value()) {
350 auto precision_config = op.getPrecisionConfig();
351 for (auto attr : precision_config.getValue()) {
352 int64_t value = static_cast<int64_t>(
353 attr.template cast<mlir::mhlo::PrecisionAttr>().getValue());
354 if (value > compute_precision) {
355 compute_precision = value;
356 }
357 }
358 }
359
360 return GemmConfig::For(
361 GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(),
362 dot_dims.getLhsContractingDimensions(), GetShape(op.getB()),
363 dot_dims.getRhsBatchingDimensions(),
364 dot_dims.getRhsContractingDimensions(), GetShape(op.getC()),
365 op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(),
366 op.getBeta().convertToDouble(), algorithm, compute_precision);
367 }
368
369 namespace {
370
371 // BLAS GeMM's output is column-major. If we require row-major, use identity:
372 // C^T = (A @ B)^T = B^T @ A^T.
MakeOutputColumnMajor(MatrixLayout & lhs,MatrixLayout & rhs,MatrixLayout & output)373 bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs,
374 MatrixLayout& output) {
375 bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor;
376 if (swap_operands) {
377 std::swap(lhs, rhs);
378 lhs.Transpose();
379 rhs.Transpose();
380 output.Transpose();
381 }
382 return swap_operands;
383 }
384
GetBlasComputationType(PrimitiveType dtype)385 StatusOr<se::blas::ComputationType> GetBlasComputationType(
386 PrimitiveType dtype) {
387 switch (dtype) {
388 case F16: // fall-through
389 case BF16:
390 // Accumulate in f32 precision.
391 return se::blas::ComputationType::kF32;
392 case F32: // fall-through
393 case C64:
394 return se::blas::ComputationType::kTF32AsF32;
395 case F64: // fall-through
396 case C128:
397 return se::blas::ComputationType::kF64;
398 case S32:
399 return se::blas::ComputationType::kI32;
400 default:
401 return InternalError("unsupported type");
402 }
403 }
404
AsBlasTranspose(MatrixLayout::Order order)405 se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) {
406 // BLAS is column-major by default.
407 return (order == MatrixLayout::Order::kColumnMajor)
408 ? se::blas::Transpose::kNoTranspose
409 : se::blas::Transpose::kTranspose;
410 }
411
GetMatrixDesc(const MatrixLayout & layout,se::DeviceMemoryBase data)412 se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout,
413 se::DeviceMemoryBase data) {
414 return {
415 data,
416 layout.leading_dim_stride,
417 layout.batch_stride,
418 AsBlasTranspose(layout.order),
419 };
420 }
421
422 template <typename Input, typename Output>
DoGemmWithAlgorithm(int64_t batch_size,int64_t m,int64_t n,int64_t k,const se::blas::MatrixDescriptor & lhs,const se::blas::MatrixDescriptor & rhs,const se::blas::MatrixDescriptor & output,Output alpha,Output beta,se::Stream * stream,se::blas::AlgorithmType algorithm,se::blas::ProfileResult * profile_result)423 Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
424 const se::blas::MatrixDescriptor& lhs,
425 const se::blas::MatrixDescriptor& rhs,
426 const se::blas::MatrixDescriptor& output,
427 Output alpha, Output beta, se::Stream* stream,
428 se::blas::AlgorithmType algorithm,
429 se::blas::ProfileResult* profile_result) {
430 CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
431 PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
432 TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
433 GetBlasComputationType(output_type));
434 se::DeviceMemory<Output> output_data(output.data);
435
436 if (batch_size != 1) {
437 return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
438 lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
439 lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
440 rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
441 output.leading_dim_stride, output.batch_stride, batch_size,
442 computation_type, algorithm, profile_result);
443 } else {
444 return stream->ThenBlasGemmWithAlgorithm(
445 lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
446 lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
447 &output_data, output.leading_dim_stride, computation_type, algorithm,
448 profile_result);
449 }
450 }
451
452 template <typename Input>
DoGemm(int64_t batch_size,int64_t m,int64_t n,int64_t k,const se::blas::MatrixDescriptor & lhs,const se::blas::MatrixDescriptor & rhs,const se::blas::MatrixDescriptor & output,Input alpha,Input beta,se::Stream * stream,std::optional<se::blas::AlgorithmType> algorithm,se::blas::ComputePrecision compute_precision,se::blas::ProfileResult * profile_result)453 Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
454 const se::blas::MatrixDescriptor& lhs,
455 const se::blas::MatrixDescriptor& rhs,
456 const se::blas::MatrixDescriptor& output, Input alpha, Input beta,
457 se::Stream* stream,
458 std::optional<se::blas::AlgorithmType> algorithm,
459 se::blas::ComputePrecision compute_precision,
460 se::blas::ProfileResult* profile_result) {
461 CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
462 se::DeviceMemory<Input> output_data(output.data);
463
464 if (algorithm) {
465 return DoGemmWithAlgorithm<Input, Input>(batch_size, m, n, k, lhs, rhs,
466 output, alpha, beta, stream,
467 *algorithm, profile_result);
468 }
469
470 if (batch_size != 1) {
471 return stream->ThenBlasGemmStridedBatched(
472 lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
473 lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
474 rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
475 output.leading_dim_stride, output.batch_stride, batch_size);
476 }
477
478 return stream->ThenBlasGemm(
479 lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
480 lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
481 &output_data, output.leading_dim_stride, compute_precision);
482 }
483
484 } // namespace
485
RunGemm(const GemmConfig & config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,std::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * profile_result)486 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
487 se::DeviceMemoryBase rhs_buffer,
488 se::DeviceMemoryBase output_buffer, se::Stream* stream,
489 std::optional<se::blas::AlgorithmType> algorithm,
490 se::blas::ProfileResult* profile_result) {
491 VLOG(2) << "Executing a GemmThunk";
492
493 MatrixLayout lhs_layout = config.lhs_layout;
494 MatrixLayout rhs_layout = config.rhs_layout;
495 MatrixLayout output_layout = config.output_layout;
496 bool must_swap_operands =
497 MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout);
498 if (must_swap_operands) {
499 std::swap(lhs_buffer, rhs_buffer);
500 }
501
502 int64_t m = output_layout.num_rows;
503 int64_t n = output_layout.num_cols;
504 int64_t k = lhs_layout.num_cols;
505 se::blas::MatrixDescriptor lhs = GetMatrixDesc(lhs_layout, lhs_buffer);
506 se::blas::MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer);
507 se::blas::MatrixDescriptor output =
508 GetMatrixDesc(output_layout, output_buffer);
509 int64_t batch_size = output_layout.batch_size;
510
511 if (!algorithm) algorithm = config.algorithm;
512
513 switch (output_layout.dtype) {
514 case S32:
515 if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo;
516 return DoGemmWithAlgorithm<int8_t, int32_t>(
517 batch_size, m, n, k, lhs, rhs, output,
518 static_cast<int32_t>(config.alpha.real()),
519 static_cast<int32_t>(config.beta), stream, *algorithm,
520 profile_result);
521 case F16:
522 return DoGemm<Eigen::half>(batch_size, m, n, k, lhs, rhs, output,
523 static_cast<Eigen::half>(config.alpha.real()),
524 static_cast<Eigen::half>(config.beta), stream,
525 algorithm, config.compute_precision,
526 profile_result);
527 case BF16:
528 return DoGemm<Eigen::bfloat16>(
529 batch_size, m, n, k, lhs, rhs, output,
530 static_cast<Eigen::bfloat16>(config.alpha.real()),
531 static_cast<Eigen::bfloat16>(config.beta), stream, algorithm,
532 config.compute_precision, profile_result);
533 case F32:
534 return DoGemm<float>(batch_size, m, n, k, lhs, rhs, output,
535 config.alpha.real(), config.beta, stream, algorithm,
536 config.compute_precision, profile_result);
537 case F64:
538 return DoGemm<double>(batch_size, m, n, k, lhs, rhs, output,
539 config.alpha.real(), config.beta, stream, algorithm,
540 config.compute_precision, profile_result);
541 case C64:
542 return DoGemm<complex64>(batch_size, m, n, k, lhs, rhs, output,
543 static_cast<complex64>(config.alpha),
544 static_cast<complex64>(config.beta), stream,
545 algorithm, config.compute_precision,
546 profile_result);
547 case C128:
548 return DoGemm<complex128>(
549 batch_size, m, n, k, lhs, rhs, output, config.alpha,
550 static_cast<complex128>(config.beta), stream, algorithm,
551 config.compute_precision, profile_result);
552 default:
553 return InternalError(
554 "Unexpected GEMM dtype: %s",
555 primitive_util::LowercasePrimitiveTypeName(output_layout.dtype));
556 }
557 }
558
559 #if GOOGLE_CUDA
560
561 namespace {
562
AsBlasDataType(PrimitiveType dtype)563 StatusOr<se::blas::DataType> AsBlasDataType(PrimitiveType dtype) {
564 switch (dtype) {
565 case F16:
566 return se::blas::DataType::kHalf;
567 case BF16:
568 return se::blas::DataType::kBF16;
569 case F32:
570 return se::blas::DataType::kFloat;
571 case F64:
572 return se::blas::DataType::kDouble;
573 case C64:
574 return se::blas::DataType::kComplexFloat;
575 case C128:
576 return se::blas::DataType::kComplexDouble;
577 default:
578 return InternalError("unsupported type");
579 }
580 }
581
AsBlasLtMatrixLayout(const MatrixLayout & layout)582 StatusOr<se::cuda::BlasLt::MatrixLayout> AsBlasLtMatrixLayout(
583 const MatrixLayout& layout) {
584 TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(layout.dtype));
585
586 auto order = (layout.order == MatrixLayout::Order::kColumnMajor)
587 ? se::cuda::BlasLt::MatrixLayout::Order::kColumnMajor
588 : se::cuda::BlasLt::MatrixLayout::Order::kRowMajor;
589
590 return se::cuda::BlasLt::MatrixLayout::Create(
591 dtype, layout.num_rows, layout.num_cols, order, layout.batch_size,
592 layout.leading_dim_stride, layout.batch_stride);
593 }
594
595 } // namespace
596
597 namespace cublas_lt {
598
AsBlasLtEpilogue(mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue)599 StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
600 mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) {
601 switch (epilogue) {
602 case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default:
603 return se::cuda::BlasLt::Epilogue::kDefault;
604 case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias:
605 return se::cuda::BlasLt::Epilogue::kBias;
606 default:
607 return InternalError("unknown epilogue");
608 }
609 }
610
For(mlir::lmhlo_gpu::CublasLtMatmulOp op)611 /*static*/ StatusOr<MatmulPlan> MatmulPlan::For(
612 mlir::lmhlo_gpu::CublasLtMatmulOp op) {
613 mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers();
614
615 int64_t compute_precision = 0; // Default
616 if (op.getPrecisionConfig().hasValue()) {
617 auto precision_config = op.getPrecisionConfig();
618 for (auto attr : precision_config.getValue()) {
619 int64_t value = static_cast<int64_t>(
620 attr.template cast<mlir::mhlo::PrecisionAttr>().getValue());
621 if (value > compute_precision) {
622 compute_precision = value;
623 }
624 }
625 }
626
627 TF_ASSIGN_OR_RETURN(
628 GemmConfig config,
629 GemmConfig::For(GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(),
630 dot_dims.getLhsContractingDimensions(),
631 GetShape(op.getB()), dot_dims.getRhsBatchingDimensions(),
632 dot_dims.getRhsContractingDimensions(),
633 GetShape(op.getC()), op.getAlphaReal().convertToDouble(),
634 op.getAlphaImag().convertToDouble(),
635 op.getBeta().convertToDouble(), op.getAlgorithm(),
636 compute_precision));
637
638 TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue,
639 AsBlasLtEpilogue(op.getEpilogue()));
640 return From(config, epilogue);
641 }
642
From(const GemmConfig & config,se::cuda::BlasLt::Epilogue epilogue)643 /*static*/ StatusOr<MatmulPlan> MatmulPlan::From(
644 const GemmConfig& config, se::cuda::BlasLt::Epilogue epilogue) {
645 MatrixLayout lhs_layout = config.lhs_layout;
646 MatrixLayout rhs_layout = config.rhs_layout;
647 MatrixLayout output_layout = config.output_layout;
648
649 // cublasLt matmul requires batch sizes to be equal. If only one operand has a
650 // batch, the other will be broadcast (as its batch_stride == 0).
651 size_t batch_size = std::max(lhs_layout.batch_size, rhs_layout.batch_size);
652 lhs_layout.batch_size = batch_size;
653 rhs_layout.batch_size = batch_size;
654
655 bool must_swap_operands =
656 MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout);
657
658 TF_ASSIGN_OR_RETURN(se::blas::DataType output_dtype,
659 AsBlasDataType(output_layout.dtype));
660 TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type,
661 GetBlasComputationType(output_layout.dtype));
662 TF_ASSIGN_OR_RETURN(
663 se::cuda::BlasLt::MatmulDesc op_desc,
664 se::cuda::BlasLt::MatmulDesc::Create(
665 computation_type,
666 se::cuda::BlasLt::GetScaleType(output_dtype, computation_type),
667 /*trans_a=*/se::blas::Transpose::kNoTranspose,
668 /*trans_b=*/se::blas::Transpose::kNoTranspose, epilogue));
669
670 TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout a_desc,
671 AsBlasLtMatrixLayout(lhs_layout));
672 TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout b_desc,
673 AsBlasLtMatrixLayout(rhs_layout));
674 TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout c_desc,
675 AsBlasLtMatrixLayout(output_layout));
676 TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout d_desc,
677 AsBlasLtMatrixLayout(output_layout));
678
679 return MatmulPlan{
680 se::cuda::BlasLt::MatmulPlan{std::move(op_desc), std::move(a_desc),
681 std::move(b_desc), std::move(c_desc),
682 std::move(d_desc)},
683 config.alpha, config.beta, must_swap_operands};
684 }
685
686 template <typename Input, typename Scale>
DoMatmul(se::Stream * stream,se::DeviceMemoryBase a_buffer,se::DeviceMemoryBase b_buffer,se::DeviceMemoryBase c_buffer,se::DeviceMemoryBase d_buffer,se::DeviceMemoryBase bias_buffer,const se::cuda::BlasLt::MatmulAlgorithm & algorithm,se::ScratchAllocator & scratch_allocator,se::blas::ProfileResult * profile_result)687 Status MatmulPlan::DoMatmul(se::Stream* stream, se::DeviceMemoryBase a_buffer,
688 se::DeviceMemoryBase b_buffer,
689 se::DeviceMemoryBase c_buffer,
690 se::DeviceMemoryBase d_buffer,
691 se::DeviceMemoryBase bias_buffer,
692 const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
693 se::ScratchAllocator& scratch_allocator,
694 se::blas::ProfileResult* profile_result) {
695 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
696 TF_RET_CHECK(blas_lt != nullptr);
697
698 Scale alpha;
699 if constexpr (std::is_same_v<Scale, complex64> ||
700 std::is_same_v<Scale, complex128>) {
701 alpha = static_cast<Scale>(alpha_);
702 } else {
703 alpha = static_cast<Scale>(alpha_.real());
704 }
705
706 Scale beta = static_cast<Scale>(beta_);
707
708 se::DeviceMemory<Input> output(d_buffer);
709 return blas_lt->DoMatmul(
710 stream, plan_, se::HostOrDeviceScalar<Scale>(alpha),
711 se::DeviceMemory<Input>(a_buffer), se::DeviceMemory<Input>(b_buffer),
712 se::HostOrDeviceScalar<Scale>(beta), se::DeviceMemory<Input>(c_buffer),
713 output, algorithm, scratch_allocator,
714 se::DeviceMemory<Input>(bias_buffer), profile_result);
715 }
716
ExecuteOnStream(se::Stream * stream,se::DeviceMemoryBase a_buffer,se::DeviceMemoryBase b_buffer,se::DeviceMemoryBase c_buffer,se::DeviceMemoryBase d_buffer,se::DeviceMemoryBase bias_buffer,const se::cuda::BlasLt::MatmulAlgorithm & algorithm,se::ScratchAllocator & scratch_allocator,se::blas::ProfileResult * profile_result)717 Status MatmulPlan::ExecuteOnStream(
718 se::Stream* stream, se::DeviceMemoryBase a_buffer,
719 se::DeviceMemoryBase b_buffer, se::DeviceMemoryBase c_buffer,
720 se::DeviceMemoryBase d_buffer, se::DeviceMemoryBase bias_buffer,
721 const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
722 se::ScratchAllocator& scratch_allocator,
723 se::blas::ProfileResult* profile_result) {
724 if (must_swap_operands_) {
725 std::swap(a_buffer, b_buffer);
726 }
727
728 switch (plan_.d_desc.type()) {
729 case CUDA_R_16F:
730 return DoMatmul<Eigen::half, float>(stream, a_buffer, b_buffer, c_buffer,
731 d_buffer, bias_buffer, algorithm,
732 scratch_allocator, profile_result);
733 case CUDA_R_16BF:
734 return DoMatmul<Eigen::bfloat16, float>(
735 stream, a_buffer, b_buffer, c_buffer, d_buffer, bias_buffer,
736 algorithm, scratch_allocator, profile_result);
737 case CUDA_R_32F:
738 return DoMatmul<float>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
739 bias_buffer, algorithm, scratch_allocator,
740 profile_result);
741 case CUDA_R_64F:
742 return DoMatmul<double>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
743 bias_buffer, algorithm, scratch_allocator,
744 profile_result);
745 case CUDA_C_32F:
746 return DoMatmul<complex64>(stream, a_buffer, b_buffer, c_buffer, d_buffer,
747 bias_buffer, algorithm, scratch_allocator,
748 profile_result);
749 case CUDA_C_64F:
750 return DoMatmul<complex128>(stream, a_buffer, b_buffer, c_buffer,
751 d_buffer, bias_buffer, algorithm,
752 scratch_allocator, profile_result);
753 default:
754 return InternalError("Unexpected dtype");
755 }
756 }
757
758 StatusOr<std::vector<se::cuda::BlasLt::MatmulAlgorithm>>
GetAlgorithms(se::Stream * stream) const759 MatmulPlan::GetAlgorithms(se::Stream* stream) const {
760 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
761 TF_RET_CHECK(blas_lt != nullptr);
762 TF_ASSIGN_OR_RETURN(auto preference,
763 se::cuda::BlasLt::MatmulPreference::Create(
764 /*max_workspace_size=*/1ll << 32)); // 4GB
765 return blas_lt->GetMatmulAlgorithms(plan_, preference);
766 }
767
768 } // namespace cublas_lt
769
770 #endif // GOOGLE_CUDA
771
772 } // namespace gpu
773 } // namespace xla
774