1 // Copyright 2020 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //===- gemm_pattern.cc
16 //---------------------------------------------------------===//
17 //
18 // Pattern to lower lhlogpu_gemm Ops to tfrt cuda dialect.
19 //
20 //===----------------------------------------------------------------------===//
21 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
22
23 #include <assert.h>
24 #include <stdint.h>
25
26 #include <type_traits>
27 #include <utility>
28
29 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
30 #include "mlir/IR/BuiltinAttributes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/Types.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
40 #include "tensorflow/compiler/xla/layout_util.h"
41 #include "tensorflow/compiler/xla/shape.h"
42 #include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
43 #include "tfrt/gpu/pass/pass.h" // from @tf_runtime
44 #include "tfrt/gpu/wrapper/cublas_wrapper.h" // from @tf_runtime
45 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
46 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
47
48 namespace tensorflow {
49 namespace {
50
51 using llvm::ArrayRef;
52
53 // This struct contains the metadata of a matrix, e.g., its base address and
54 // dimensions.
55 struct MatrixDescriptor {
56 mlir::Value converted_value;
57 bool transpose; // Whether this matrix needs to be transposed.
58 int64_t num_rows;
59 int64_t num_cols;
60 };
61
MlirTypeToCudaDataType(mlir::Type type)62 cudaDataType_t MlirTypeToCudaDataType(mlir::Type type) {
63 mlir::Builder builder(type.getContext());
64 if (type.isF16())
65 return CUDA_R_16F;
66 else if (type.isF32())
67 return CUDA_R_32F;
68 else if (type.isF64())
69 return CUDA_R_64F;
70 else if (type == mlir::ComplexType::get(builder.getF32Type()))
71 return CUDA_C_32F;
72 else if (type == mlir::ComplexType::get(builder.getF64Type()))
73 return CUDA_C_64F;
74
75 llvm_unreachable("unsupported type");
76 }
77
78 // TODO(b/176561997): remove this once lhlo_gpu ops have properly typed alpha
79 // and beta attributes. We can't use std::complex here because the effect of
80 // instantiating it for anything other than float, double, or long double is
81 // unspecified. We need it for APFloat.
82 template <class T>
83 struct Complex {
84 T real;
85 T imag;
86 };
87
88 // TODO(b/176561997): remove this once lhlo_gpu ops have properly typed alpha
89 // and beta attributes.
MakeScalingFactorConstant(mlir::OpBuilder & builder,mlir::Location loc,mlir::Type type,Complex<llvm::APFloat> value)90 mlir::Value MakeScalingFactorConstant(mlir::OpBuilder& builder,
91 mlir::Location loc, mlir::Type type,
92 Complex<llvm::APFloat> value) {
93 // Dummy boolean we need to pass to convert functions. Since this whole
94 // funciton will go away when the scaling factors are properly typed
95 // (b/176561997), we won't worry about possible losses during conversions for
96 // now.
97 bool losesInfo = false;
98 // TODO(b/176913138): remove second argument to `builder.create` calls
99 // TODO(b/176562488): handle {,B}F16
100 if (type.isF32()) {
101 value.real.convert(llvm::APFloat::IEEEsingle(),
102 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
103 return builder.create<tfrt::compiler::ConstantF32Op>(loc, type, value.real);
104 } else if (type.isF64()) {
105 value.real.convert(llvm::APFloat::IEEEdouble(),
106 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
107 return builder.create<tfrt::compiler::ConstantF64Op>(loc, type, value.real);
108 } else if (type == mlir::ComplexType::get(builder.getF32Type())) {
109 value.real.convert(llvm::APFloat::IEEEsingle(),
110 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
111 value.imag.convert(llvm::APFloat::IEEEsingle(),
112 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
113 return builder.create<tfrt::compiler::ConstantComplexF32Op>(
114 loc, type, value.real, value.imag);
115 } else if (type == mlir::ComplexType::get(builder.getF64Type())) {
116 value.real.convert(llvm::APFloat::IEEEdouble(),
117 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
118 value.imag.convert(llvm::APFloat::IEEEdouble(),
119 llvm::RoundingMode::NearestTiesToEven, &losesInfo);
120 return builder.create<tfrt::compiler::ConstantComplexF64Op>(
121 loc, type, value.real, value.imag);
122 }
123
124 llvm_unreachable("unsupported type");
125 }
126
127 // Create all the Ops necessary for the GEMM operation, including the GEMM
128 // operation itself.
129 // TODO(b/175130778): element_type parameter when we move from GpuBuffers to
130 // MemRefs
CreateTfrtOps(mlir::Location loc,mlir::Value chain,mlir::Value stream,int64_t batch_size,mlir::Type element_type,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,Complex<llvm::APFloat> alpha,Complex<llvm::APFloat> beta,cublasGemmAlgo_t algorithm,mlir::OpBuilder & builder)131 FailureOr<Value> CreateTfrtOps(
132 mlir::Location loc, mlir::Value chain, mlir::Value stream,
133 int64_t batch_size, mlir::Type element_type, MatrixDescriptor lhs_matrix,
134 MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix,
135 Complex<llvm::APFloat> alpha, Complex<llvm::APFloat> beta,
136 cublasGemmAlgo_t algorithm, mlir::OpBuilder& builder) {
137 auto k_val = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
138
139 // TODO(b/176913138): remove second argument to `rewriter.create` calls
140 auto m = builder.create<tfrt::compiler::ConstantI32Op>(
141 loc, builder.getI32Type(), output_matrix.num_rows);
142 auto n = builder.create<tfrt::compiler::ConstantI32Op>(
143 loc, builder.getI32Type(), output_matrix.num_cols);
144 auto k = builder.create<tfrt::compiler::ConstantI32Op>(
145 loc, builder.getI32Type(), k_val);
146
147 auto const_alpha =
148 MakeScalingFactorConstant(builder, loc, element_type, alpha);
149
150 auto lda = builder.create<tfrt::compiler::ConstantI32Op>(
151 loc, builder.getI32Type(), lhs_matrix.num_rows);
152 auto ldb = builder.create<tfrt::compiler::ConstantI32Op>(
153 loc, builder.getI32Type(), rhs_matrix.num_rows);
154
155 auto const_beta = MakeScalingFactorConstant(builder, loc, element_type, beta);
156
157 cudaDataType_t data_type = MlirTypeToCudaDataType(element_type);
158
159 auto ldc = builder.create<tfrt::compiler::ConstantI32Op>(
160 loc, builder.getI32Type(), output_matrix.num_rows);
161
162 auto compute_type = data_type; // use the data_type for compute as well.
163
164 auto algo = builder.create<tfrt::gpu::BlasGemmAlgoOp>(loc, algorithm);
165
166 auto blas_handle_type = builder.getType<tfrt::gpu::BlasHandleType>();
167 auto blas_handle =
168 builder.create<tfrt::gpu::BlasCreateOp>(loc, blas_handle_type, stream);
169
170 auto lhs_op = lhs_matrix.transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
171 auto rhs_op = rhs_matrix.transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
172
173 if (batch_size != 1) {
174 int64_t lhs_stride_val = lhs_matrix.num_rows * lhs_matrix.num_cols;
175 int64_t rhs_stride_val = rhs_matrix.num_rows * rhs_matrix.num_cols;
176 int64_t output_stride_val = output_matrix.num_rows * output_matrix.num_cols;
177 auto lhs_stride = builder.create<tfrt::compiler::ConstantI64Op>(
178 loc, builder.getI64Type(), lhs_stride_val);
179 auto rhs_stride = builder.create<tfrt::compiler::ConstantI64Op>(
180 loc, builder.getI64Type(), rhs_stride_val);
181 auto output_stride = builder.create<tfrt::compiler::ConstantI64Op>(
182 loc, builder.getI64Type(), output_stride_val);
183 auto batch = builder.create<tfrt::compiler::ConstantI32Op>(
184 loc, builder.getI32Type(), batch_size);
185 return builder
186 .create<tfrt::gpu::BlasGemmBatchExOp>(
187 loc, chain.getType(), blas_handle, lhs_op, rhs_op, m, n, k,
188 const_alpha, lhs_matrix.converted_value, data_type, lda, lhs_stride,
189 rhs_matrix.converted_value, data_type, ldb, rhs_stride, const_beta,
190 output_matrix.converted_value, data_type, ldc, output_stride, batch,
191 compute_type, algo, chain)
192 .getResult();
193 }
194
195 return builder
196 .create<tfrt::gpu::BlasGemmOp>(loc, chain.getType(), blas_handle, lhs_op,
197 rhs_op, m, n, k, const_alpha,
198 lhs_matrix.converted_value, data_type, lda,
199 rhs_matrix.converted_value, data_type, ldb,
200 const_beta, output_matrix.converted_value,
201 data_type, ldc, compute_type, algo, chain)
202 .getResult();
203 }
204
205 template <class GemmOpType>
GemmOpConversionRewrite(GemmOpType srcOp,Value chain,Value stream,mlir::BlockAndValueMapping & mapping,mlir::OpBuilder & builder,absl::optional<llvm::APFloat> beta_arg=absl::nullopt)206 FailureOr<Value> GemmOpConversionRewrite(
207 GemmOpType srcOp, Value chain, Value stream,
208 mlir::BlockAndValueMapping& mapping, mlir::OpBuilder& builder,
209 absl::optional<llvm::APFloat> beta_arg = absl::nullopt) {
210 mlir::Type element_type = srcOp.output()
211 .getType()
212 .template cast<mlir::MemRefType>()
213 .getElementType();
214 // Ensure the types of all elements are the same.
215 if (element_type !=
216 srcOp.lhs().getType().template cast<mlir::MemRefType>().getElementType())
217 return mlir::failure();
218 if (element_type !=
219 srcOp.rhs().getType().template cast<mlir::MemRefType>().getElementType())
220 return mlir::failure();
221 const mlir::mhlo::DotDimensionNumbers dim_nums =
222 srcOp.dot_dimension_numbers();
223
224 // The row and column dimensions are the last two dimensions. All the
225 // dimensions before them are batching dimensions.
226 int64_t row_dim = dim_nums.lhs_batching_dimensions().size();
227 int64_t col_dim = dim_nums.lhs_batching_dimensions().size() + 1;
228
229 int64_t batch_size = srcOp.batch_size();
230
231 // Check that the batch dims don't cover the last two dims.
232 for (auto batch_dim : dim_nums.lhs_batching_dimensions()) {
233 if (row_dim == batch_dim) return mlir::failure();
234 if (col_dim == batch_dim) return mlir::failure();
235 }
236
237 // Verify that the non-batch dimensions are minor-most. This is required for
238 // efficient access.
239 const xla::Shape& lhs_shape = xla::TypeToShape(srcOp.lhs().getType());
240 const xla::Shape& rhs_shape = xla::TypeToShape(srcOp.rhs().getType());
241 const xla::Shape& output_shape = xla::TypeToShape(srcOp.output().getType());
242 for (const auto* shape : {&lhs_shape, &rhs_shape, &output_shape}) {
243 if (shape->layout().minor_to_major(row_dim) >= 2) return mlir::failure();
244 if (shape->layout().minor_to_major(col_dim) >= 2) return mlir::failure();
245 }
246
247 // BLAS gemm expects the inputs and the output are in column-major order.
248 // Therefore, we need to convert multiplication between row-major matrices to
249 // that between column-major matrices. The key insight for the conversion is
250 // that, in linear storage, matrix M in column-major order is identical to the
251 // transpose of M in row-major order. In other words,
252 //
253 // column-major(M) = row-major(M^T).
254 //
255 // Leveraging this insight, we can perform dot between row-major matrices as
256 // follows.
257 //
258 // row-major(C)
259 // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
260 // = gemm(column-major(B^T), column-major(A^T))
261 // = gemm(row-major(B), row-major(A))
262 //
263 // Although we do not modify the content of A and B in linear memory, we
264 // should use the dimensions of B^T and A^T when calling gemm. For example,
265 // the leading dimension of the LHS matrix of gemm is the number of rows in
266 // B^T and thus the number of columns in B.
267 auto make_descriptor = [&](const xla::Shape& shape,
268 mlir::Value replaced_value,
269 bool transpose) -> MatrixDescriptor {
270 bool is_row_major = xla::LayoutUtil::Minor(shape.layout(), row_dim) != 0;
271 bool layout_mismatch =
272 xla::LayoutUtil::Minor(shape.layout(), row_dim) !=
273 xla::LayoutUtil::Minor(output_shape.layout(), row_dim);
274 return MatrixDescriptor{
275 replaced_value, static_cast<bool>(transpose ^ layout_mismatch),
276 shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
277 shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
278 };
279
280 MatrixDescriptor lhs_matrix = make_descriptor(
281 lhs_shape, mapping.lookup(srcOp.lhs()),
282 dim_nums.lhs_contracting_dimensions().getValue<int64_t>({0}) == row_dim);
283 MatrixDescriptor rhs_matrix = make_descriptor(
284 rhs_shape, mapping.lookup(srcOp.rhs()),
285 dim_nums.rhs_contracting_dimensions().getValue<int64_t>({0}) == col_dim);
286 MatrixDescriptor output_matrix = MatrixDescriptor{
287 mapping.lookup(srcOp.output()), /*transpose=*/false,
288 output_shape.dimensions(row_dim), output_shape.dimensions(col_dim)};
289
290 Complex<llvm::APFloat> alpha{srcOp.alpha_real(), srcOp.alpha_imag()};
291 // If no beta_arg is supplied, we copy alpha and then zero it out to ensure
292 // beta has the same float semantics (IEEE single, IEEE double, ...) as alpha.
293 llvm::APFloat beta_real = beta_arg.has_value()
294 ? beta_arg.value()
295 : APFloat::getZero(alpha.real.getSemantics());
296 Complex<llvm::APFloat> beta{beta_real,
297 APFloat::getZero(alpha.imag.getSemantics())};
298
299 if (xla::LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
300 std::swap(lhs_matrix, rhs_matrix);
301 std::swap(output_matrix.num_cols, output_matrix.num_rows);
302 }
303
304 auto algorithm = static_cast<cublasGemmAlgo_t>(
305 srcOp.algorithm().getValueOr(CUBLAS_GEMM_DEFAULT));
306
307 return CreateTfrtOps(srcOp.getLoc(), chain, stream, batch_size, element_type,
308 lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
309 algorithm, builder);
310 }
311
GetBeta(lmhlo_gpu::GEMMOp op)312 absl::optional<llvm::APFloat> GetBeta(lmhlo_gpu::GEMMOp op) {
313 return absl::nullopt;
314 }
315
GetBeta(lmhlo_gpu::GEMM_BiasOp op)316 absl::optional<llvm::APFloat> GetBeta(lmhlo_gpu::GEMM_BiasOp op) {
317 return op.beta();
318 }
319
320 template <class GemmOpType>
321 struct GemmRewritePattern : tfrt::gpu::GpuAsyncOpConversionPattern<GemmOpType> {
322 using tfrt::gpu::GpuAsyncOpConversionPattern<
323 GemmOpType>::GpuAsyncOpConversionPattern;
matchAndRewriteOptensorflow::__anond3e04f710111::GemmRewritePattern324 FailureOr<Value> matchAndRewriteOp(
325 GemmOpType op, Value chain, Value stream, ArrayRef<Value> operands,
326 ConversionPatternRewriter& rewriter) const override {
327 if (!all_of(operands, [](Value operand) {
328 return operand.getType().isa<tfrt::gpu::BufferType>();
329 }))
330 return rewriter.notifyMatchFailure(op, "expected buffer operands");
331
332 BlockAndValueMapping mapping;
333 for (auto pair : llvm::zip_first(op->getOperands(), operands))
334 mapping.map(std::get<0>(pair), std::get<1>(pair));
335
336 rewriter.eraseOp(op);
337
338 return GemmOpConversionRewrite(op, chain, stream, mapping, rewriter,
339 GetBeta(op));
340 }
341 };
342
343 } // namespace
344
GemmOpConversionRewrite(mlir::lmhlo_gpu::GEMMOp srcOp,mlir::BlockAndValueMapping & mapping,mlir::OpBuilder & builder)345 mlir::LogicalResult GemmOpConversionRewrite(mlir::lmhlo_gpu::GEMMOp srcOp,
346 mlir::BlockAndValueMapping& mapping,
347 mlir::OpBuilder& builder) {
348 auto chain_type = builder.getType<tfrt::compiler::ChainType>();
349 auto chain =
350 builder.create<tfrt::compiler::NewChainOp>(srcOp->getLoc(), chain_type);
351 auto enclosing_func =
352 llvm::cast<mlir::FuncOp>(builder.getBlock()->getParentOp());
353 // The first argument is the GpuStream.
354 auto stream = enclosing_func.getArgument(0);
355 return GemmOpConversionRewrite(srcOp, chain, stream, mapping, builder);
356 }
357
populateGemmConversionPattern(RewritePatternSet & patterns)358 void populateGemmConversionPattern(RewritePatternSet& patterns) {
359 patterns.add<GemmRewritePattern<lmhlo_gpu::GEMMOp>,
360 GemmRewritePattern<lmhlo_gpu::GEMM_BiasOp>>(
361 patterns.getContext());
362 }
363
364 } // namespace tensorflow
365