1 // Copyright 2020 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //===- gemm_pattern.cc
16 //---------------------------------------------------------===//
17 //
18 // Pattern to lower lhlogpu_gemm Ops to tfrt cuda dialect.
19 //
20 //===----------------------------------------------------------------------===//
21 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
22 
23 #include <assert.h>
24 #include <stdint.h>
25 
26 #include <type_traits>
27 #include <utility>
28 
29 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
30 #include "mlir/IR/BuiltinAttributes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/Types.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
40 #include "tensorflow/compiler/xla/layout_util.h"
41 #include "tensorflow/compiler/xla/shape.h"
42 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
43 #include "tfrt/gpu/pass/pass.h"  // from @tf_runtime
44 #include "tfrt/gpu/wrapper/cublas_wrapper.h"  // from @tf_runtime
45 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
46 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
47 
48 namespace tensorflow {
49 namespace {
50 
51 using llvm::ArrayRef;
52 
53 // This struct contains the metadata of a matrix, e.g., its base address and
54 // dimensions.
55 struct MatrixDescriptor {
56   mlir::Value converted_value;
57   bool transpose;  // Whether this matrix needs to be transposed.
58   int64_t num_rows;
59   int64_t num_cols;
60 };
61 
MlirTypeToCudaDataType(mlir::Type type)62 cudaDataType_t MlirTypeToCudaDataType(mlir::Type type) {
63   mlir::Builder builder(type.getContext());
64   if (type.isF16())
65     return CUDA_R_16F;
66   else if (type.isF32())
67     return CUDA_R_32F;
68   else if (type.isF64())
69     return CUDA_R_64F;
70   else if (type == mlir::ComplexType::get(builder.getF32Type()))
71     return CUDA_C_32F;
72   else if (type == mlir::ComplexType::get(builder.getF64Type()))
73     return CUDA_C_64F;
74 
75   llvm_unreachable("unsupported type");
76 }
77 
78 // TODO(b/176561997): remove this once lhlo_gpu ops have properly typed alpha
79 // and beta attributes. We can't use std::complex here because the effect of
80 // instantiating it for anything other than float, double, or long double is
81 // unspecified. We need it for APFloat.
82 template <class T>
83 struct Complex {
84   T real;
85   T imag;
86 };
87 
88 // TODO(b/176561997): remove this once lhlo_gpu ops have properly typed alpha
89 // and beta attributes.
MakeScalingFactorConstant(mlir::OpBuilder & builder,mlir::Location loc,mlir::Type type,Complex<llvm::APFloat> value)90 mlir::Value MakeScalingFactorConstant(mlir::OpBuilder& builder,
91                                       mlir::Location loc, mlir::Type type,
92                                       Complex<llvm::APFloat> value) {
93   // Dummy boolean we need to pass to convert functions. Since this whole
94   // funciton will go away when the scaling factors are properly typed
95   // (b/176561997), we won't worry about possible losses during conversions for
96   // now.
97   bool losesInfo = false;
98   // TODO(b/176913138): remove second argument to `builder.create` calls
99   // TODO(b/176562488): handle {,B}F16
100   if (type.isF32()) {
101     value.real.convert(llvm::APFloat::IEEEsingle(),
102                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
103     return builder.create<tfrt::compiler::ConstantF32Op>(loc, type, value.real);
104   } else if (type.isF64()) {
105     value.real.convert(llvm::APFloat::IEEEdouble(),
106                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
107     return builder.create<tfrt::compiler::ConstantF64Op>(loc, type, value.real);
108   } else if (type == mlir::ComplexType::get(builder.getF32Type())) {
109     value.real.convert(llvm::APFloat::IEEEsingle(),
110                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
111     value.imag.convert(llvm::APFloat::IEEEsingle(),
112                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
113     return builder.create<tfrt::compiler::ConstantComplexF32Op>(
114         loc, type, value.real, value.imag);
115   } else if (type == mlir::ComplexType::get(builder.getF64Type())) {
116     value.real.convert(llvm::APFloat::IEEEdouble(),
117                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
118     value.imag.convert(llvm::APFloat::IEEEdouble(),
119                        llvm::RoundingMode::NearestTiesToEven, &losesInfo);
120     return builder.create<tfrt::compiler::ConstantComplexF64Op>(
121         loc, type, value.real, value.imag);
122   }
123 
124   llvm_unreachable("unsupported type");
125 }
126 
127 // Create all the Ops necessary for the GEMM operation, including the GEMM
128 // operation itself.
129 // TODO(b/175130778): element_type parameter when we move from GpuBuffers to
130 // MemRefs
CreateTfrtOps(mlir::Location loc,mlir::Value chain,mlir::Value stream,int64_t batch_size,mlir::Type element_type,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,Complex<llvm::APFloat> alpha,Complex<llvm::APFloat> beta,cublasGemmAlgo_t algorithm,mlir::OpBuilder & builder)131 FailureOr<Value> CreateTfrtOps(
132     mlir::Location loc, mlir::Value chain, mlir::Value stream,
133     int64_t batch_size, mlir::Type element_type, MatrixDescriptor lhs_matrix,
134     MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix,
135     Complex<llvm::APFloat> alpha, Complex<llvm::APFloat> beta,
136     cublasGemmAlgo_t algorithm, mlir::OpBuilder& builder) {
137   auto k_val = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
138 
139   // TODO(b/176913138): remove second argument to `rewriter.create` calls
140   auto m = builder.create<tfrt::compiler::ConstantI32Op>(
141       loc, builder.getI32Type(), output_matrix.num_rows);
142   auto n = builder.create<tfrt::compiler::ConstantI32Op>(
143       loc, builder.getI32Type(), output_matrix.num_cols);
144   auto k = builder.create<tfrt::compiler::ConstantI32Op>(
145       loc, builder.getI32Type(), k_val);
146 
147   auto const_alpha =
148       MakeScalingFactorConstant(builder, loc, element_type, alpha);
149 
150   auto lda = builder.create<tfrt::compiler::ConstantI32Op>(
151       loc, builder.getI32Type(), lhs_matrix.num_rows);
152   auto ldb = builder.create<tfrt::compiler::ConstantI32Op>(
153       loc, builder.getI32Type(), rhs_matrix.num_rows);
154 
155   auto const_beta = MakeScalingFactorConstant(builder, loc, element_type, beta);
156 
157   cudaDataType_t data_type = MlirTypeToCudaDataType(element_type);
158 
159   auto ldc = builder.create<tfrt::compiler::ConstantI32Op>(
160       loc, builder.getI32Type(), output_matrix.num_rows);
161 
162   auto compute_type = data_type;  // use the data_type for compute as well.
163 
164   auto algo = builder.create<tfrt::gpu::BlasGemmAlgoOp>(loc, algorithm);
165 
166   auto blas_handle_type = builder.getType<tfrt::gpu::BlasHandleType>();
167   auto blas_handle =
168       builder.create<tfrt::gpu::BlasCreateOp>(loc, blas_handle_type, stream);
169 
170   auto lhs_op = lhs_matrix.transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
171   auto rhs_op = rhs_matrix.transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
172 
173   if (batch_size != 1) {
174     int64_t lhs_stride_val = lhs_matrix.num_rows * lhs_matrix.num_cols;
175     int64_t rhs_stride_val = rhs_matrix.num_rows * rhs_matrix.num_cols;
176     int64_t output_stride_val = output_matrix.num_rows * output_matrix.num_cols;
177     auto lhs_stride = builder.create<tfrt::compiler::ConstantI64Op>(
178         loc, builder.getI64Type(), lhs_stride_val);
179     auto rhs_stride = builder.create<tfrt::compiler::ConstantI64Op>(
180         loc, builder.getI64Type(), rhs_stride_val);
181     auto output_stride = builder.create<tfrt::compiler::ConstantI64Op>(
182         loc, builder.getI64Type(), output_stride_val);
183     auto batch = builder.create<tfrt::compiler::ConstantI32Op>(
184         loc, builder.getI32Type(), batch_size);
185     return builder
186         .create<tfrt::gpu::BlasGemmBatchExOp>(
187             loc, chain.getType(), blas_handle, lhs_op, rhs_op, m, n, k,
188             const_alpha, lhs_matrix.converted_value, data_type, lda, lhs_stride,
189             rhs_matrix.converted_value, data_type, ldb, rhs_stride, const_beta,
190             output_matrix.converted_value, data_type, ldc, output_stride, batch,
191             compute_type, algo, chain)
192         .getResult();
193   }
194 
195   return builder
196       .create<tfrt::gpu::BlasGemmOp>(loc, chain.getType(), blas_handle, lhs_op,
197                                      rhs_op, m, n, k, const_alpha,
198                                      lhs_matrix.converted_value, data_type, lda,
199                                      rhs_matrix.converted_value, data_type, ldb,
200                                      const_beta, output_matrix.converted_value,
201                                      data_type, ldc, compute_type, algo, chain)
202       .getResult();
203 }
204 
205 template <class GemmOpType>
GemmOpConversionRewrite(GemmOpType srcOp,Value chain,Value stream,mlir::BlockAndValueMapping & mapping,mlir::OpBuilder & builder,absl::optional<llvm::APFloat> beta_arg=absl::nullopt)206 FailureOr<Value> GemmOpConversionRewrite(
207     GemmOpType srcOp, Value chain, Value stream,
208     mlir::BlockAndValueMapping& mapping, mlir::OpBuilder& builder,
209     absl::optional<llvm::APFloat> beta_arg = absl::nullopt) {
210   mlir::Type element_type = srcOp.output()
211                                 .getType()
212                                 .template cast<mlir::MemRefType>()
213                                 .getElementType();
214   // Ensure the types of all elements are the same.
215   if (element_type !=
216       srcOp.lhs().getType().template cast<mlir::MemRefType>().getElementType())
217     return mlir::failure();
218   if (element_type !=
219       srcOp.rhs().getType().template cast<mlir::MemRefType>().getElementType())
220     return mlir::failure();
221   const mlir::mhlo::DotDimensionNumbers dim_nums =
222       srcOp.dot_dimension_numbers();
223 
224   // The row and column dimensions are the last two dimensions. All the
225   // dimensions before them are batching dimensions.
226   int64_t row_dim = dim_nums.lhs_batching_dimensions().size();
227   int64_t col_dim = dim_nums.lhs_batching_dimensions().size() + 1;
228 
229   int64_t batch_size = srcOp.batch_size();
230 
231   // Check that the batch dims don't cover the last two dims.
232   for (auto batch_dim : dim_nums.lhs_batching_dimensions()) {
233     if (row_dim == batch_dim) return mlir::failure();
234     if (col_dim == batch_dim) return mlir::failure();
235   }
236 
237   // Verify that the non-batch dimensions are minor-most. This is required for
238   // efficient access.
239   const xla::Shape& lhs_shape = xla::TypeToShape(srcOp.lhs().getType());
240   const xla::Shape& rhs_shape = xla::TypeToShape(srcOp.rhs().getType());
241   const xla::Shape& output_shape = xla::TypeToShape(srcOp.output().getType());
242   for (const auto* shape : {&lhs_shape, &rhs_shape, &output_shape}) {
243     if (shape->layout().minor_to_major(row_dim) >= 2) return mlir::failure();
244     if (shape->layout().minor_to_major(col_dim) >= 2) return mlir::failure();
245   }
246 
247   // BLAS gemm expects the inputs and the output are in column-major order.
248   // Therefore, we need to convert multiplication between row-major matrices to
249   // that between column-major matrices. The key insight for the conversion is
250   // that, in linear storage, matrix M in column-major order is identical to the
251   // transpose of M in row-major order. In other words,
252   //
253   //   column-major(M) = row-major(M^T).
254   //
255   // Leveraging this insight, we can perform dot between row-major matrices as
256   // follows.
257   //
258   // row-major(C)
259   //   = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
260   //   = gemm(column-major(B^T), column-major(A^T))
261   //   = gemm(row-major(B), row-major(A))
262   //
263   // Although we do not modify the content of A and B in linear memory, we
264   // should use the dimensions of B^T and A^T when calling gemm. For example,
265   // the leading dimension of the LHS matrix of gemm is the number of rows in
266   // B^T and thus the number of columns in B.
267   auto make_descriptor = [&](const xla::Shape& shape,
268                              mlir::Value replaced_value,
269                              bool transpose) -> MatrixDescriptor {
270     bool is_row_major = xla::LayoutUtil::Minor(shape.layout(), row_dim) != 0;
271     bool layout_mismatch =
272         xla::LayoutUtil::Minor(shape.layout(), row_dim) !=
273         xla::LayoutUtil::Minor(output_shape.layout(), row_dim);
274     return MatrixDescriptor{
275         replaced_value, static_cast<bool>(transpose ^ layout_mismatch),
276         shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
277         shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
278   };
279 
280   MatrixDescriptor lhs_matrix = make_descriptor(
281       lhs_shape, mapping.lookup(srcOp.lhs()),
282       dim_nums.lhs_contracting_dimensions().getValue<int64_t>({0}) == row_dim);
283   MatrixDescriptor rhs_matrix = make_descriptor(
284       rhs_shape, mapping.lookup(srcOp.rhs()),
285       dim_nums.rhs_contracting_dimensions().getValue<int64_t>({0}) == col_dim);
286   MatrixDescriptor output_matrix = MatrixDescriptor{
287       mapping.lookup(srcOp.output()), /*transpose=*/false,
288       output_shape.dimensions(row_dim), output_shape.dimensions(col_dim)};
289 
290   Complex<llvm::APFloat> alpha{srcOp.alpha_real(), srcOp.alpha_imag()};
291   // If no beta_arg is supplied, we copy alpha and then zero it out to ensure
292   // beta has the same float semantics (IEEE single, IEEE double, ...) as alpha.
293   llvm::APFloat beta_real = beta_arg.has_value()
294                                 ? beta_arg.value()
295                                 : APFloat::getZero(alpha.real.getSemantics());
296   Complex<llvm::APFloat> beta{beta_real,
297                               APFloat::getZero(alpha.imag.getSemantics())};
298 
299   if (xla::LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
300     std::swap(lhs_matrix, rhs_matrix);
301     std::swap(output_matrix.num_cols, output_matrix.num_rows);
302   }
303 
304   auto algorithm = static_cast<cublasGemmAlgo_t>(
305       srcOp.algorithm().getValueOr(CUBLAS_GEMM_DEFAULT));
306 
307   return CreateTfrtOps(srcOp.getLoc(), chain, stream, batch_size, element_type,
308                        lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
309                        algorithm, builder);
310 }
311 
GetBeta(lmhlo_gpu::GEMMOp op)312 absl::optional<llvm::APFloat> GetBeta(lmhlo_gpu::GEMMOp op) {
313   return absl::nullopt;
314 }
315 
GetBeta(lmhlo_gpu::GEMM_BiasOp op)316 absl::optional<llvm::APFloat> GetBeta(lmhlo_gpu::GEMM_BiasOp op) {
317   return op.beta();
318 }
319 
320 template <class GemmOpType>
321 struct GemmRewritePattern : tfrt::gpu::GpuAsyncOpConversionPattern<GemmOpType> {
322   using tfrt::gpu::GpuAsyncOpConversionPattern<
323       GemmOpType>::GpuAsyncOpConversionPattern;
matchAndRewriteOptensorflow::__anond3e04f710111::GemmRewritePattern324   FailureOr<Value> matchAndRewriteOp(
325       GemmOpType op, Value chain, Value stream, ArrayRef<Value> operands,
326       ConversionPatternRewriter& rewriter) const override {
327     if (!all_of(operands, [](Value operand) {
328           return operand.getType().isa<tfrt::gpu::BufferType>();
329         }))
330       return rewriter.notifyMatchFailure(op, "expected buffer operands");
331 
332     BlockAndValueMapping mapping;
333     for (auto pair : llvm::zip_first(op->getOperands(), operands))
334       mapping.map(std::get<0>(pair), std::get<1>(pair));
335 
336     rewriter.eraseOp(op);
337 
338     return GemmOpConversionRewrite(op, chain, stream, mapping, rewriter,
339                                    GetBeta(op));
340   }
341 };
342 
343 }  // namespace
344 
GemmOpConversionRewrite(mlir::lmhlo_gpu::GEMMOp srcOp,mlir::BlockAndValueMapping & mapping,mlir::OpBuilder & builder)345 mlir::LogicalResult GemmOpConversionRewrite(mlir::lmhlo_gpu::GEMMOp srcOp,
346                                             mlir::BlockAndValueMapping& mapping,
347                                             mlir::OpBuilder& builder) {
348   auto chain_type = builder.getType<tfrt::compiler::ChainType>();
349   auto chain =
350       builder.create<tfrt::compiler::NewChainOp>(srcOp->getLoc(), chain_type);
351   auto enclosing_func =
352       llvm::cast<mlir::FuncOp>(builder.getBlock()->getParentOp());
353   // The first argument is the GpuStream.
354   auto stream = enclosing_func.getArgument(0);
355   return GemmOpConversionRewrite(srcOp, chain, stream, mapping, builder);
356 }
357 
populateGemmConversionPattern(RewritePatternSet & patterns)358 void populateGemmConversionPattern(RewritePatternSet& patterns) {
359   patterns.add<GemmRewritePattern<lmhlo_gpu::GEMMOp>,
360                GemmRewritePattern<lmhlo_gpu::GEMM_BiasOp>>(
361       patterns.getContext());
362 }
363 
364 }  // namespace tensorflow
365