xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17 
18 #include <climits>
19 #include <cstdint>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"  // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/core/util/matmul_bcast.h"
41 
42 namespace mlir {
43 namespace TF {
44 
45 namespace {
46 
47 template <typename BatchMatMulOpType>
48 class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
49   using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
50 
51   static TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
52                                        Type element_type, Location loc,
53                                        PatternRewriter& rewriter);
54 
55   static std::vector<Value> sliceInput(Value value, int batch_size,
56                                        Location loc, PatternRewriter& rewriter);
57 
58   LogicalResult matchAndRewrite(BatchMatMulOpType op,
59                                 PatternRewriter& rewriter) const override;
60 };
61 
62 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
63 // of the inputs, matmul them individually, then stack them all back together at
64 // the end.
65 struct UnrollBatchMatMulPass
66     : public UnrollBatchMatMulPassBase<UnrollBatchMatMulPass> {
67   void runOnOperation() override;
68 };
69 
runOnOperation()70 void UnrollBatchMatMulPass::runOnOperation() {
71   RewritePatternSet patterns(&getContext());
72   auto func = getOperation();
73   PopulateUnrollTfBatchMatMul(&getContext(), patterns);
74   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
75 }
76 
77 }  // namespace
78 
79 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)80 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
81     Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
82     PatternRewriter& rewriter) {
83   int64_t shape_rank = shape.size();
84   auto shape_spec_type =
85       RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
86   Type resultType = RankedTensorType::get(shape, element_type);
87   auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
88   auto shape_tensor =
89       rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
90   return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
91                                         /*shape=*/shape_tensor);
92 }
93 
94 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)95 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
96     Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
97   RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
98   Type element_type = tensorType.getElementType();
99 
100   int rank = tensorType.getShape().size();
101   int num_rows = tensorType.getShape()[rank - 2];
102   int num_cols = tensorType.getShape()[rank - 1];
103 
104   std::vector<Value> sliced;
105 
106   if (batch_size == 1) {
107     // Batch size is 1, no splitting is required
108     // Squeeze the batch dimension, i.e. reshape
109     // [1, num_rows, num_cols] -> [num_rows, num_cols]
110     auto reshape_op = createReshapeOp(value, {num_rows, num_cols}, element_type,
111                                       loc, rewriter);
112     sliced.emplace_back(reshape_op.output());
113   } else {
114     // Reshape to rank-3 tensor with first dimension as the batch size.
115     auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
116                                       element_type, loc, rewriter);
117 
118     // Create a constant op for the split axis (=0)
119     auto split_dimension_type =
120         RankedTensorType::get({}, rewriter.getIntegerType(32));
121     auto split_dimension_attr = DenseElementsAttr::get(split_dimension_type, 0);
122     auto split_dimension_op = rewriter.create<TF::ConstOp>(
123         loc, split_dimension_type, split_dimension_attr);
124 
125     // Split along each batch.
126     SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
127     Type slice_result_type = RankedTensorType::get(slice_size, element_type);
128     llvm::SmallVector<Type, 4> output_types(batch_size, slice_result_type);
129     auto split_op = rewriter.create<TF::SplitOp>(
130         loc, output_types, split_dimension_op.output(), reshape_op.output());
131 
132     // Squeeze each batch, i.e. reshape
133     // [1, num_rows, num_cols] -> [num_rows, num_cols]
134     for (const auto& split_value : split_op.output()) {
135       auto reshape_op = createReshapeOp(split_value, {num_rows, num_cols},
136                                         element_type, loc, rewriter);
137 
138       sliced.emplace_back(reshape_op.output());
139     }
140   }
141   return sliced;
142 }
143 
144 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const145 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
146     BatchMatMulOpType op, PatternRewriter& rewriter) const {
147   Value input_lhs = op.x();
148   Value input_rhs = op.y();
149 
150   if (!input_lhs.getType().isa<RankedTensorType>()) {
151     // LHS must be a ranked tensor type
152     return failure();
153   }
154   if (!input_rhs.getType().isa<RankedTensorType>()) {
155     // RHS must be a ranked tensor type
156     return failure();
157   }
158 
159   auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
160   auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
161 
162   // Skip int8 x int8 => int32.
163   if (lhs_type.getElementType().isInteger(8) &&
164       rhs_type.getElementType().isInteger(8)) {
165     return rewriter.notifyMatchFailure(op,
166                                        "skip unrolling for int8 BatchMatMulV3");
167   }
168 
169   auto element_type = lhs_type.getElementType();
170 
171   if (element_type != rhs_type.getElementType()) {
172     // The element type of LHS must be the same with element type of RHS
173     return failure();
174   }
175 
176   std::vector<int64_t> lhs_shape = lhs_type.getShape();
177   std::vector<int64_t> rhs_shape = rhs_type.getShape();
178 
179   Location loc = op.getLoc();
180 
181   // Ensure that input ranks are at least 2.
182   const int lhs_dims = lhs_shape.size();
183   const int rhs_dims = rhs_shape.size();
184   if (lhs_dims < 2 || rhs_dims < 2) {
185     // Both inputs must have rank >= 2
186     return failure();
187   }
188 
189   // Replace the last 2 dimensions of LHS and RHS if necessary.
190   // The actual transpose is done by MatMulOp.
191   if (op.adj_x()) {
192     std::swap(lhs_shape[lhs_dims - 1], lhs_shape[lhs_dims - 2]);
193   }
194   if (op.adj_y()) {
195     std::swap(rhs_shape[rhs_dims - 1], rhs_shape[rhs_dims - 2]);
196   }
197 
198   const int rows = lhs_shape[lhs_dims - 2];
199   const int cols = rhs_shape[rhs_dims - 1];
200 
201   if (lhs_shape[lhs_dims - 1] != rhs_shape[rhs_dims - 2]) {
202     // Input dimensions must be compatible for multiplication.
203     return failure();
204   }
205 
206   const auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
207 
208   if (lhs_dims == 2 && rhs_dims == 2) {
209     // When both inputs are matrices, just replace the op with a matmul op.
210     rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, matmul_type,
211                                               /*a=*/input_lhs,
212                                               /*b=*/input_rhs,
213                                               /*transpose_a=*/op.adj_x(),
214                                               /*transpose_b=*/op.adj_y());
215     return success();
216   }
217 
218   // Input dimensions must be defined. MatMulBCast does not support partial
219   // shapes.
220   for (auto dim : lhs_shape) {
221     if (dim == -1) {
222       return failure();
223     }
224   }
225   for (auto dim : rhs_shape) {
226     if (dim == -1) {
227       return failure();
228     }
229   }
230   // Ensure that batch shapes are broadcastable.
231   tensorflow::MatMulBCast bcast(
232       absl::InlinedVector<int64_t, 4>(lhs_shape.begin(), lhs_shape.end()),
233       absl::InlinedVector<int64_t, 4>(rhs_shape.begin(), rhs_shape.end()));
234 
235   if (!bcast.IsValid()) {
236     // Input batch dimensions must be broadcastable
237     return failure();
238   }
239 
240   // Compute slices for each batch in the LHS and RHS.
241   std::vector<Value> sliced_lhs =
242       sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
243   std::vector<Value> sliced_rhs =
244       sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
245 
246   // Compute (single batch) MatMul for each output batch.
247   std::vector<Value> matmuls;
248   matmuls.reserve(bcast.output_batch_size());
249   for (int batch_idx : llvm::seq<int>(0, bcast.output_batch_size())) {
250     int lhs_batch_idx, rhs_batch_idx;
251     if (bcast.IsBroadcastingRequired()) {
252       lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
253       rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
254     } else {
255       lhs_batch_idx = batch_idx;
256       rhs_batch_idx = batch_idx;
257     }
258     auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
259                                                 /*a=*/sliced_lhs[lhs_batch_idx],
260                                                 /*b=*/sliced_rhs[rhs_batch_idx],
261                                                 /*transpose_a=*/op.adj_x(),
262                                                 /*transpose_b=*/op.adj_y());
263     matmuls.emplace_back(matmul.product());
264   }
265 
266   // Combine the result of each individual MatMul into a rank-3 tensor.
267   Type packed_type = RankedTensorType::get(
268       {bcast.output_batch_size(), rows, cols}, element_type);
269   const auto axis = rewriter.getI64IntegerAttr(0);
270   auto pack_op =
271       rewriter.create<TF::PackOp>(loc, packed_type, /*values=*/matmuls, axis);
272 
273   // Reshape the rank-3 tensor into the correct output shape.
274   const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
275   std::vector<int64_t> result_shape(result_batch_shape.begin(),
276                                     result_batch_shape.end());
277   result_shape.push_back(rows);
278   result_shape.push_back(cols);
279 
280   auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
281                                     element_type, loc, rewriter);
282   rewriter.replaceOp(op, reshape_op.output());
283   return success();
284 }
285 
CreateUnrollBatchMatMulPassPass()286 std::unique_ptr<OperationPass<func::FuncOp>> CreateUnrollBatchMatMulPassPass() {
287   return std::make_unique<UnrollBatchMatMulPass>();
288 }
289 
290 }  // namespace TF
291 }  // namespace mlir
292 
PopulateUnrollTfBatchMatMul(MLIRContext * context,RewritePatternSet & patterns)293 void mlir::TF::PopulateUnrollTfBatchMatMul(MLIRContext* context,
294                                            RewritePatternSet& patterns) {
295   patterns.add<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
296                ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
297                ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(context);
298 }
299