1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17
18 #include <climits>
19 #include <cstdint>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/OpImplementation.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/core/util/matmul_bcast.h"
41
42 namespace mlir {
43 namespace TF {
44
45 namespace {
46
47 template <typename BatchMatMulOpType>
48 class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
49 using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
50
51 static TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
52 Type element_type, Location loc,
53 PatternRewriter& rewriter);
54
55 static std::vector<Value> sliceInput(Value value, int batch_size,
56 Location loc, PatternRewriter& rewriter);
57
58 LogicalResult matchAndRewrite(BatchMatMulOpType op,
59 PatternRewriter& rewriter) const override;
60 };
61
62 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
63 // of the inputs, matmul them individually, then stack them all back together at
64 // the end.
65 struct UnrollBatchMatMulPass
66 : public UnrollBatchMatMulPassBase<UnrollBatchMatMulPass> {
67 void runOnOperation() override;
68 };
69
runOnOperation()70 void UnrollBatchMatMulPass::runOnOperation() {
71 RewritePatternSet patterns(&getContext());
72 auto func = getOperation();
73 PopulateUnrollTfBatchMatMul(&getContext(), patterns);
74 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
75 }
76
77 } // namespace
78
79 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)80 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
81 Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
82 PatternRewriter& rewriter) {
83 int64_t shape_rank = shape.size();
84 auto shape_spec_type =
85 RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
86 Type resultType = RankedTensorType::get(shape, element_type);
87 auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
88 auto shape_tensor =
89 rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
90 return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
91 /*shape=*/shape_tensor);
92 }
93
94 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)95 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
96 Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
97 RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
98 Type element_type = tensorType.getElementType();
99
100 int rank = tensorType.getShape().size();
101 int num_rows = tensorType.getShape()[rank - 2];
102 int num_cols = tensorType.getShape()[rank - 1];
103
104 std::vector<Value> sliced;
105
106 if (batch_size == 1) {
107 // Batch size is 1, no splitting is required
108 // Squeeze the batch dimension, i.e. reshape
109 // [1, num_rows, num_cols] -> [num_rows, num_cols]
110 auto reshape_op = createReshapeOp(value, {num_rows, num_cols}, element_type,
111 loc, rewriter);
112 sliced.emplace_back(reshape_op.output());
113 } else {
114 // Reshape to rank-3 tensor with first dimension as the batch size.
115 auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
116 element_type, loc, rewriter);
117
118 // Create a constant op for the split axis (=0)
119 auto split_dimension_type =
120 RankedTensorType::get({}, rewriter.getIntegerType(32));
121 auto split_dimension_attr = DenseElementsAttr::get(split_dimension_type, 0);
122 auto split_dimension_op = rewriter.create<TF::ConstOp>(
123 loc, split_dimension_type, split_dimension_attr);
124
125 // Split along each batch.
126 SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
127 Type slice_result_type = RankedTensorType::get(slice_size, element_type);
128 llvm::SmallVector<Type, 4> output_types(batch_size, slice_result_type);
129 auto split_op = rewriter.create<TF::SplitOp>(
130 loc, output_types, split_dimension_op.output(), reshape_op.output());
131
132 // Squeeze each batch, i.e. reshape
133 // [1, num_rows, num_cols] -> [num_rows, num_cols]
134 for (const auto& split_value : split_op.output()) {
135 auto reshape_op = createReshapeOp(split_value, {num_rows, num_cols},
136 element_type, loc, rewriter);
137
138 sliced.emplace_back(reshape_op.output());
139 }
140 }
141 return sliced;
142 }
143
144 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const145 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
146 BatchMatMulOpType op, PatternRewriter& rewriter) const {
147 Value input_lhs = op.x();
148 Value input_rhs = op.y();
149
150 if (!input_lhs.getType().isa<RankedTensorType>()) {
151 // LHS must be a ranked tensor type
152 return failure();
153 }
154 if (!input_rhs.getType().isa<RankedTensorType>()) {
155 // RHS must be a ranked tensor type
156 return failure();
157 }
158
159 auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
160 auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
161
162 // Skip int8 x int8 => int32.
163 if (lhs_type.getElementType().isInteger(8) &&
164 rhs_type.getElementType().isInteger(8)) {
165 return rewriter.notifyMatchFailure(op,
166 "skip unrolling for int8 BatchMatMulV3");
167 }
168
169 auto element_type = lhs_type.getElementType();
170
171 if (element_type != rhs_type.getElementType()) {
172 // The element type of LHS must be the same with element type of RHS
173 return failure();
174 }
175
176 std::vector<int64_t> lhs_shape = lhs_type.getShape();
177 std::vector<int64_t> rhs_shape = rhs_type.getShape();
178
179 Location loc = op.getLoc();
180
181 // Ensure that input ranks are at least 2.
182 const int lhs_dims = lhs_shape.size();
183 const int rhs_dims = rhs_shape.size();
184 if (lhs_dims < 2 || rhs_dims < 2) {
185 // Both inputs must have rank >= 2
186 return failure();
187 }
188
189 // Replace the last 2 dimensions of LHS and RHS if necessary.
190 // The actual transpose is done by MatMulOp.
191 if (op.adj_x()) {
192 std::swap(lhs_shape[lhs_dims - 1], lhs_shape[lhs_dims - 2]);
193 }
194 if (op.adj_y()) {
195 std::swap(rhs_shape[rhs_dims - 1], rhs_shape[rhs_dims - 2]);
196 }
197
198 const int rows = lhs_shape[lhs_dims - 2];
199 const int cols = rhs_shape[rhs_dims - 1];
200
201 if (lhs_shape[lhs_dims - 1] != rhs_shape[rhs_dims - 2]) {
202 // Input dimensions must be compatible for multiplication.
203 return failure();
204 }
205
206 const auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
207
208 if (lhs_dims == 2 && rhs_dims == 2) {
209 // When both inputs are matrices, just replace the op with a matmul op.
210 rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, matmul_type,
211 /*a=*/input_lhs,
212 /*b=*/input_rhs,
213 /*transpose_a=*/op.adj_x(),
214 /*transpose_b=*/op.adj_y());
215 return success();
216 }
217
218 // Input dimensions must be defined. MatMulBCast does not support partial
219 // shapes.
220 for (auto dim : lhs_shape) {
221 if (dim == -1) {
222 return failure();
223 }
224 }
225 for (auto dim : rhs_shape) {
226 if (dim == -1) {
227 return failure();
228 }
229 }
230 // Ensure that batch shapes are broadcastable.
231 tensorflow::MatMulBCast bcast(
232 absl::InlinedVector<int64_t, 4>(lhs_shape.begin(), lhs_shape.end()),
233 absl::InlinedVector<int64_t, 4>(rhs_shape.begin(), rhs_shape.end()));
234
235 if (!bcast.IsValid()) {
236 // Input batch dimensions must be broadcastable
237 return failure();
238 }
239
240 // Compute slices for each batch in the LHS and RHS.
241 std::vector<Value> sliced_lhs =
242 sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
243 std::vector<Value> sliced_rhs =
244 sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
245
246 // Compute (single batch) MatMul for each output batch.
247 std::vector<Value> matmuls;
248 matmuls.reserve(bcast.output_batch_size());
249 for (int batch_idx : llvm::seq<int>(0, bcast.output_batch_size())) {
250 int lhs_batch_idx, rhs_batch_idx;
251 if (bcast.IsBroadcastingRequired()) {
252 lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
253 rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
254 } else {
255 lhs_batch_idx = batch_idx;
256 rhs_batch_idx = batch_idx;
257 }
258 auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
259 /*a=*/sliced_lhs[lhs_batch_idx],
260 /*b=*/sliced_rhs[rhs_batch_idx],
261 /*transpose_a=*/op.adj_x(),
262 /*transpose_b=*/op.adj_y());
263 matmuls.emplace_back(matmul.product());
264 }
265
266 // Combine the result of each individual MatMul into a rank-3 tensor.
267 Type packed_type = RankedTensorType::get(
268 {bcast.output_batch_size(), rows, cols}, element_type);
269 const auto axis = rewriter.getI64IntegerAttr(0);
270 auto pack_op =
271 rewriter.create<TF::PackOp>(loc, packed_type, /*values=*/matmuls, axis);
272
273 // Reshape the rank-3 tensor into the correct output shape.
274 const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
275 std::vector<int64_t> result_shape(result_batch_shape.begin(),
276 result_batch_shape.end());
277 result_shape.push_back(rows);
278 result_shape.push_back(cols);
279
280 auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
281 element_type, loc, rewriter);
282 rewriter.replaceOp(op, reshape_op.output());
283 return success();
284 }
285
CreateUnrollBatchMatMulPassPass()286 std::unique_ptr<OperationPass<func::FuncOp>> CreateUnrollBatchMatMulPassPass() {
287 return std::make_unique<UnrollBatchMatMulPass>();
288 }
289
290 } // namespace TF
291 } // namespace mlir
292
PopulateUnrollTfBatchMatMul(MLIRContext * context,RewritePatternSet & patterns)293 void mlir::TF::PopulateUnrollTfBatchMatMul(MLIRContext* context,
294 RewritePatternSet& patterns) {
295 patterns.add<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
296 ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
297 ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(context);
298 }
299