1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21
22 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
23 #include "mlir/Dialect/Async/IR/Async.h"
24 #include "mlir/Dialect/Linalg/Passes.h"
25 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
31 #include "mlir/IR/MLIRContext.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/SymbolTable.h"
34 #include "mlir/Pass/PassRegistry.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/Optional.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
41 #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
50 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
51 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
53
54 namespace tensorflow {
55
56 #define GEN_PASS_CLASSES
57 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h.inc"
58
59 // -------------------------------------------------------------------------- //
60 // Helper functions used by the passes implemented below.
61 // -------------------------------------------------------------------------- //
62
63 // Returns true if the `value` type is a memref that is contiguous in memory.
IsContiguousMemref(mlir::Value value)64 static bool IsContiguousMemref(mlir::Value value) {
65 mlir::MemRefType memref_type = value.getType().dyn_cast<mlir::MemRefType>();
66 if (!memref_type) return false;
67 mlir::MemRefType canonical_type = canonicalizeStridedLayout(memref_type);
68 return canonical_type.getAffineMaps().empty();
69 }
70
71 // -------------------------------------------------------------------------- //
72 // Trivial buffer forwarding for the linalg.generic operations.
73 // -------------------------------------------------------------------------- //
74
75 namespace {
76
77 struct ForwardingCandidate {
78 mlir::Value memref;
79 mlir::AffineMap indexing_map;
80 };
81
82 struct LinalgTrivialBufferForwardingPattern
83 : public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
84 using OpRewritePattern<mlir::linalg::GenericOp>::OpRewritePattern;
85 mlir::LogicalResult matchAndRewrite(
86 mlir::linalg::GenericOp op,
87 mlir::PatternRewriter& rewriter) const override;
88 };
89
90 struct LinalgTrivialBufferForwardingPass
91 : public LinalgTrivialBufferForwardingBase<
92 LinalgTrivialBufferForwardingPass> {
93 void runOnFunction() override;
94 };
95 } // namespace
96
97 // Returns true if all linalg.generic operation iterators are "parallel".
AllIteratorsAreParallel(mlir::linalg::GenericOp op)98 static bool AllIteratorsAreParallel(mlir::linalg::GenericOp op) {
99 return llvm::all_of(op.iterator_types(), [](mlir::Attribute attr) -> bool {
100 auto str_attr = attr.dyn_cast<mlir::StringAttr>();
101 return str_attr && str_attr.getValue() == "parallel";
102 });
103 }
104
105 // Returns buffer inputs that can be safely used as buffer outputs.
FindBufferForwardingCandidates(mlir::linalg::GenericOp op)106 static llvm::SmallVector<mlir::OpOperand*> FindBufferForwardingCandidates(
107 mlir::linalg::GenericOp op) {
108 llvm::SmallVector<mlir::OpOperand*> candidates;
109
110 for (mlir::OpOperand* input_buffer : op.getInputOperands()) {
111 // Input must be a contiguous memref ...
112 if (!IsContiguousMemref(input_buffer->get())) continue;
113
114 // ... allocated in the same function.
115 auto* alloc = input_buffer->get().getDefiningOp();
116 if (!alloc || !mlir::isa<mlir::memref::AllocOp>(alloc)) continue;
117
118 // Find input users that are after linalg.generic operation in the block.
119 auto users = llvm::make_filter_range(alloc->getUsers(),
120 [&](mlir::Operation* user) -> bool {
121 return op->isBeforeInBlock(user);
122 });
123
124 // Input buffer must have exactly one user after linalg.generic.
125 llvm::SmallVector<mlir::Operation*> input_users(users.begin(), users.end());
126 if (input_users.size() > 1) continue;
127
128 // And it must be a memref.dealloc operation.
129 if (!mlir::isa<mlir::memref::DeallocOp>(input_users[0])) continue;
130
131 // This input memref can be safely reused for the output.
132 candidates.push_back(input_buffer);
133 }
134
135 return candidates;
136 }
137
matchAndRewrite(mlir::linalg::GenericOp op,mlir::PatternRewriter & rewriter) const138 mlir::LogicalResult LinalgTrivialBufferForwardingPattern::matchAndRewrite(
139 mlir::linalg::GenericOp op, mlir::PatternRewriter& rewriter) const {
140 // With parallel iterators it is guaranteed that every value used once, and it
141 // is safe to forward input to output.
142 if (!AllIteratorsAreParallel(op))
143 return rewriter.notifyMatchFailure(op, "all iterators must be parallel");
144
145 // Find memrefs that potentially could be forwarded.
146 llvm::SmallVector<mlir::OpOperand*> forwarding_candidates =
147 FindBufferForwardingCandidates(op);
148 if (forwarding_candidates.empty())
149 return rewriter.notifyMatchFailure(
150 op, "did not find any candidates for input forwarding");
151
152 // Inputs that were reused.
153 llvm::DenseSet<mlir::OpOperand*> reused_inputs;
154
155 // Try to match output buffers to forwarding candidates.
156 for (mlir::OpOperand* output_buffer : op.getOutputOperands()) {
157 // Output must be allocated in the same function.
158 auto* alloc = output_buffer->get().getDefiningOp();
159 if (!alloc || !mlir::isa<mlir::memref::AllocOp>(alloc)) continue;
160
161 // Find compatible input buffer.
162 for (mlir::OpOperand* input_buffer : forwarding_candidates) {
163 if (reused_inputs.contains(input_buffer)) continue;
164
165 // Memref types must match (dimensions and affine maps).
166 if (input_buffer->get().getType() != output_buffer->get().getType())
167 continue;
168
169 mlir::AffineMap src_map = op.getTiedIndexingMap(input_buffer);
170 mlir::AffineMap dst_map = op.getTiedIndexingMap(output_buffer);
171
172 // Only support identity maps for the output for now.
173 if (!dst_map.isIdentity()) continue;
174
175 auto is_projection = [](mlir::AffineMap map) {
176 // Allow adding/dropping dimensions but no permutations.
177 int64_t i = -1;
178 for (mlir::AffineExpr expr : map.getResults()) {
179 auto constant = expr.dyn_cast<mlir::AffineConstantExpr>();
180 if (constant && constant.getValue() == 0) continue;
181 auto dim_expr = expr.dyn_cast<mlir::AffineDimExpr>();
182 if (!dim_expr || dim_expr.getPosition() <= i) return false;
183 i = dim_expr.getPosition();
184 }
185 return true;
186 };
187
188 auto same_shape = [](mlir::Value src, mlir::Value dst) {
189 auto src_type = src.getType().cast<mlir::ShapedType>();
190 auto dst_type = dst.getType().cast<mlir::ShapedType>();
191 mlir::OperandRange src_operands =
192 src.getDefiningOp<mlir::memref::AllocOp>().getDynamicSizes();
193 mlir::OperandRange dst_operands =
194 dst.getDefiningOp<mlir::memref::AllocOp>().getDynamicSizes();
195 return src_type.getShape().equals(dst_type.getShape()) &&
196 std::equal(src_operands.begin(), src_operands.end(),
197 dst_operands.begin());
198 };
199
200 // A reuse is valid if the maps are the same or if the shape is the same
201 // and the source is a projection map (in which case the ignored
202 // dimensions must be 1 assuming that the operation reads the entire
203 // input). Note that we already know that the destination map is an
204 // identity map.
205 if (src_map != dst_map &&
206 !(is_projection(src_map) &&
207 same_shape(input_buffer->get(), output_buffer->get()))) {
208 continue;
209 }
210
211 // Find the input buffer dealloc operation.
212 mlir::Operation* input_dealloc = *llvm::find_if(
213 input_buffer->get().getUsers(), [](mlir::Operation* user) -> bool {
214 return mlir::isa<mlir::memref::DeallocOp>(user);
215 });
216
217 // Deallocate output buffer instead of the input buffer.
218 input_buffer->get().replaceUsesWithIf(
219 output_buffer->get(), [&](mlir::OpOperand& operand) -> bool {
220 return operand.getOwner() == input_dealloc;
221 });
222
223 // Forward users of output buffer to the input buffer, if they are after
224 // linalg.generic operation in the block (or linalg.generic itself).
225 output_buffer->get().replaceUsesWithIf(
226 input_buffer->get(), [&](mlir::OpOperand& operand) -> bool {
227 return operand.getOwner() != input_dealloc &&
228 !operand.getOwner()->isBeforeInBlock(op);
229 });
230
231 reused_inputs.insert(input_buffer);
232 }
233 }
234
235 return mlir::success(!reused_inputs.empty());
236 }
237
runOnFunction()238 void LinalgTrivialBufferForwardingPass::runOnFunction() {
239 mlir::FuncOp function = getFunction();
240 mlir::MLIRContext* ctx = function.getContext();
241
242 mlir::RewritePatternSet patterns(ctx);
243 patterns.insert<LinalgTrivialBufferForwardingPattern>(ctx);
244
245 (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
246 }
247
CreateLinalgTrivialBufferForwardingPass()248 std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialBufferForwardingPass() {
249 return std::make_unique<LinalgTrivialBufferForwardingPass>();
250 }
251
252 // -------------------------------------------------------------------------- //
253 // Trivial buffer forwarding for the linalg.generic operations.
254 // -------------------------------------------------------------------------- //
255
256 namespace {
257
258 struct LinalgTrivialCopyRemovalPass
259 : public LinalgTrivialCopyRemovalBase<LinalgTrivialCopyRemovalPass> {
260 void runOnFunction() override;
261 };
262
263 } // namespace
264
runOnFunction()265 void LinalgTrivialCopyRemovalPass::runOnFunction() {
266 mlir::FuncOp function = getFunction();
267
268 mlir::SmallVector<mlir::Operation*> to_erase;
269 function.walk([&to_erase](mlir::linalg::CopyOp copy) {
270 // Only match precise alloc/copy/dealloc triples.
271 auto alloc = llvm::dyn_cast<mlir::memref::AllocOp>(copy->getPrevNode());
272 auto dealloc = llvm::dyn_cast<mlir::memref::DeallocOp>(copy->getNextNode());
273
274 if (!alloc || !dealloc) return;
275
276 // Make sure the alloc and dealloc handle the operands of the copy.
277 if (alloc.getResult() != copy.getTarget() ||
278 dealloc.memref() != copy.getSource()) {
279 return;
280 }
281
282 // Remember the operations to delete.
283 to_erase.push_back(alloc);
284 to_erase.push_back(dealloc);
285 to_erase.push_back(copy);
286 copy.getTarget().replaceAllUsesWith(copy.getSource());
287 });
288
289 for (auto op : to_erase) {
290 op->erase();
291 }
292 }
293
CreateLinalgTrivialCopyRemovalPass()294 std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialCopyRemovalPass() {
295 return std::make_unique<LinalgTrivialCopyRemovalPass>();
296 }
297
298 // -------------------------------------------------------------------------- //
299 // Dispatch linalg.matmul to one of the more specialized operations at runtime.
300 // -------------------------------------------------------------------------- //
301
302 namespace {
303
304 struct LinalgMatmulSpecializationPattern
305 : public mlir::OpRewritePattern<mlir::linalg::MatmulOp> {
306 using OpRewritePattern<mlir::linalg::MatmulOp>::OpRewritePattern;
307 mlir::LogicalResult matchAndRewrite(
308 mlir::linalg::MatmulOp matmul,
309 mlir::PatternRewriter& rewriter) const override;
310 };
311
312 struct LinalgMatmulSpecializationPass
313 : public LinalgMatmulSpecializationBase<LinalgMatmulSpecializationPass> {
314 void runOnFunction() override;
315 };
316 } // namespace
317
318 // Convert 2D memref into a 1D memref (vector).
MemrefToVector(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value memref,mlir::Value size,int64_t static_size)319 static mlir::Value MemrefToVector(mlir::OpBuilder& builder, mlir::Location loc,
320 mlir::Value memref, mlir::Value size,
321 int64_t static_size) {
322 assert(static_size >= 0 || static_size == mlir::ShapedType::kDynamicSize);
323 auto memref_type = memref.getType().cast<mlir::MemRefType>();
324 auto vec_type =
325 mlir::MemRefType::get({static_size}, memref_type.getElementType());
326
327 auto static_offsets = builder.getI64ArrayAttr({0});
328 auto static_sizes = builder.getI64ArrayAttr({static_size});
329 auto static_strided = builder.getI64ArrayAttr({1});
330
331 auto empty = mlir::ValueRange();
332 auto sizes = static_size == mlir::ShapedType::kDynamicSize
333 ? mlir::ValueRange(size)
334 : mlir::ValueRange();
335
336 return builder.create<mlir::memref::ReinterpretCastOp>(
337 loc, vec_type, memref, /*offsets=*/empty,
338 /*sizes=*/sizes, /*strides=*/empty, static_offsets, static_sizes,
339 static_strided);
340 }
341
342 // Convert 2D memref into a 0D memref (scalar).
MemrefToScalar(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value memref)343 static mlir::Value MemrefToScalar(mlir::OpBuilder& builder, mlir::Location loc,
344 mlir::Value memref) {
345 auto memref_type = memref.getType().cast<mlir::MemRefType>();
346 auto scalar_type = mlir::MemRefType::get({}, memref_type.getElementType());
347
348 std::array<int64_t, 0> empty;
349 return builder.create<mlir::memref::ReinterpretCastOp>(
350 loc, scalar_type, memref, /*offset=*/0,
351 /*sizes=*/empty, /*strides=*/empty);
352 }
353
matchAndRewrite(mlir::linalg::MatmulOp matmul,mlir::PatternRewriter & rewriter) const354 mlir::LogicalResult LinalgMatmulSpecializationPattern::matchAndRewrite(
355 mlir::linalg::MatmulOp matmul, mlir::PatternRewriter& rewriter) const {
356 if (matmul->hasAttr("__tf_cpurt_specialized"))
357 return rewriter.notifyMatchFailure(matmul,
358 "operation was already specialized");
359
360 auto rhs = matmul.getInputOperand(1)->get();
361 auto lhs = matmul.getInputOperand(0)->get();
362 auto out = matmul.getOutputOperand(0)->get();
363
364 // We do not support inputs or outputs that are not contiguous in memory.
365 if (!IsContiguousMemref(lhs) || !IsContiguousMemref(rhs) ||
366 !IsContiguousMemref(out))
367 return rewriter.notifyMatchFailure(
368 matmul, "inputs and output must be contiguous memrefs");
369
370 auto loc = matmul.getLoc();
371
372 // Matmul dimensions: [m, k] x [k, n]
373 mlir::Value m = rewriter.create<mlir::memref::DimOp>(loc, lhs, 0);
374 mlir::Value k = rewriter.create<mlir::memref::DimOp>(loc, lhs, 1);
375 mlir::Value n = rewriter.create<mlir::memref::DimOp>(loc, rhs, 1);
376
377 // Matmul static dimensions if they are known (can be ShapedType::kDynamicSize
378 // if not known statically).
379 int64_t m_static = lhs.getType().cast<mlir::MemRefType>().getDimSize(0);
380 int64_t k_static = lhs.getType().cast<mlir::MemRefType>().getDimSize(1);
381 int64_t n_static = rhs.getType().cast<mlir::MemRefType>().getDimSize(1);
382
383 auto one = rewriter.create<mlir::ConstantOp>(loc, rewriter.getIndexType(),
384 rewriter.getIndexAttr(1));
385 auto m_is_one =
386 rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, m, one);
387 auto n_is_one =
388 rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, n, one);
389
390 auto m_not_one =
391 rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, m, one);
392 auto n_not_one =
393 rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, n, one);
394
395 // linalg.dot: n == 1 && m == 1
396 auto is_dot_product = rewriter.create<mlir::AndOp>(loc, m_is_one, n_is_one);
397 // linalg.vecmat m == 1 && n != 1
398 auto is_vecmat = rewriter.create<mlir::AndOp>(loc, m_is_one, n_not_one);
399 // linalg.matvec n == 1 && m != 1
400 auto is_matvec = rewriter.create<mlir::AndOp>(loc, n_is_one, m_not_one);
401
402 // Build a linalg.dot operation casting inputs to vectors.
403 auto dot = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
404 auto lhs_vec = MemrefToVector(builder, nestedLoc, lhs, k, k_static);
405 auto rhs_vec = MemrefToVector(builder, nestedLoc, rhs, k, k_static);
406 auto out_scalar = MemrefToScalar(builder, nestedLoc, out);
407
408 builder.create<mlir::linalg::DotOp>(nestedLoc,
409 mlir::ValueRange({lhs_vec, rhs_vec}),
410 mlir::ValueRange({out_scalar}));
411 builder.create<mlir::scf::YieldOp>(nestedLoc);
412 };
413
414 // Build a linalg.vecmat operation casting lhs to vector.
415 auto vecmat = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
416 auto lhs_vec = MemrefToVector(builder, nestedLoc, lhs, k, k_static);
417 auto out_vec = MemrefToVector(builder, nestedLoc, out, n, n_static);
418
419 builder.create<mlir::linalg::VecmatOp>(nestedLoc,
420 mlir::ValueRange({lhs_vec, rhs}),
421 mlir::ValueRange({out_vec}));
422 builder.create<mlir::scf::YieldOp>(nestedLoc);
423 };
424
425 // Build a linalg.matvec operation casting rhs to vector.
426 auto matvec = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
427 auto rhs_vec = MemrefToVector(builder, nestedLoc, rhs, k, k_static);
428 auto out_vec = MemrefToVector(builder, nestedLoc, out, m, m_static);
429
430 builder.create<mlir::linalg::MatvecOp>(nestedLoc,
431 mlir::ValueRange({lhs, rhs_vec}),
432 mlir::ValueRange({out_vec}));
433 builder.create<mlir::scf::YieldOp>(nestedLoc);
434 };
435
436 // Build a generic linalg.matmul operation when it can't be matched to any of
437 // the specializations.
438 auto generic = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
439 llvm::SmallVector<mlir::Value> inputs = matmul.getInputOperands();
440 llvm::SmallVector<mlir::Value> outputs = matmul.getOutputOperands();
441 auto specialized =
442 builder.create<mlir::linalg::MatmulOp>(nestedLoc, inputs, outputs);
443 specialized->setAttr("__tf_cpurt_specialized", rewriter.getUnitAttr());
444 builder.create<mlir::scf::YieldOp>(nestedLoc);
445 };
446
447 // TODO(ezhulenev): Simplify to scf.switch operation.
448 // if (is_dot_product) ===>>> linalg.dot ------------------------------- //
449 auto dispatch = rewriter.create<mlir::scf::IfOp>(
450 loc, is_dot_product, dot,
451 [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
452 // else if (is_vecmat) ===>>> linalg.vecmat --------------------- //
453 rewriter.create<mlir::scf::IfOp>(
454 nestedLoc, is_vecmat, vecmat,
455 [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
456 // else if (is_matvec) ===>>> linalg.matvec --------------- //
457 // else ===>>> linalg.matmul --------------- //
458 rewriter.create<mlir::scf::IfOp>(nestedLoc, is_matvec, matvec,
459 generic);
460 builder.create<mlir::scf::YieldOp>(nestedLoc);
461 });
462 builder.create<mlir::scf::YieldOp>(nestedLoc);
463 });
464
465 rewriter.replaceOp(matmul, dispatch.results());
466 return mlir::success();
467 }
468
runOnFunction()469 void LinalgMatmulSpecializationPass::runOnFunction() {
470 mlir::FuncOp function = getFunction();
471 mlir::MLIRContext* ctx = function.getContext();
472
473 mlir::RewritePatternSet patterns(ctx);
474 patterns.insert<LinalgMatmulSpecializationPattern>(ctx);
475
476 (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
477 }
478
CreateLinalgMatmulSpecializationPass()479 std::unique_ptr<mlir::FunctionPass> CreateLinalgMatmulSpecializationPass() {
480 return std::make_unique<LinalgMatmulSpecializationPass>();
481 }
482
483 // -------------------------------------------------------------------------- //
484 // Break Tensorflow _Fused{Op} operations into primitive ones.
485 // -------------------------------------------------------------------------- //
486
487 namespace {
488 struct FissionPass : public FissionBase<FissionPass> {
489 void runOnFunction() override;
490 };
491
492 struct FusedMatMulFission
493 : public mlir::OpRewritePattern<mlir::TF::_FusedMatMulOp> {
494 using OpRewritePattern<mlir::TF::_FusedMatMulOp>::OpRewritePattern;
495
496 mlir::LogicalResult matchAndRewrite(
497 mlir::TF::_FusedMatMulOp op,
498 mlir::PatternRewriter& rewriter) const override;
499 };
500 } // namespace
501
matchAndRewrite(mlir::TF::_FusedMatMulOp op,mlir::PatternRewriter & rewriter) const502 mlir::LogicalResult FusedMatMulFission::matchAndRewrite(
503 mlir::TF::_FusedMatMulOp op, mlir::PatternRewriter& rewriter) const {
504 auto loc = op.getLoc();
505 auto type = op.getResult().getType();
506
507 size_t n = op.fused_ops().size();
508
509 // Extract fused operations from the operation attributes.
510 mlir::StringAttr fusion0 =
511 n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
512 mlir::StringAttr fusion1 =
513 n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
514
515 // Match to supported operations
516 bool is_bias_add = fusion0 && fusion0.getValue() == "BiasAdd";
517 bool is_relu_activation = fusion1 && fusion1.getValue() == "Relu";
518
519 // Create a simple MatMul operation from the fused one.
520 auto matmul = [&]() -> mlir::TF::MatMulOp {
521 auto lhs = op.getOperand(0);
522 auto rhs = op.getOperand(1);
523 return rewriter.create<mlir::TF::MatMulOp>(
524 loc, type, lhs, rhs, op.transpose_a(), op.transpose_b());
525 };
526
527 // FusedMatMul[BiasAdd].
528 if (n == 1 && is_bias_add) {
529 rewriter.replaceOpWithNewOp<mlir::TF::BiasAddOp>(op, type, matmul(),
530 op.getOperand(2));
531 return mlir::success();
532 }
533
534 // FusedMatMul[BiasAdd, Relu].
535 if (n == 2 && is_bias_add && is_relu_activation) {
536 auto biased = rewriter.create<mlir::TF::BiasAddOp>(loc, type, matmul(),
537 op.getOperand(2));
538 rewriter.replaceOpWithNewOp<mlir::TF::ReluOp>(op, type, biased);
539 return mlir::success();
540 }
541
542 return mlir::failure();
543 }
544
runOnFunction()545 void FissionPass::runOnFunction() {
546 mlir::FuncOp function = getFunction();
547 mlir::MLIRContext* ctx = function.getContext();
548
549 mlir::RewritePatternSet patterns(ctx);
550 patterns.insert<FusedMatMulFission>(ctx);
551
552 (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
553 }
554
CreateFissionPass()555 std::unique_ptr<mlir::FunctionPass> CreateFissionPass() {
556 return std::make_unique<FissionPass>();
557 }
558
559 // -------------------------------------------------------------------------- //
560 // Custom passes that are missing upstream.
561 // -------------------------------------------------------------------------- //
562
563 namespace {
564 // TODO(herhut): Remove this once leftover tensor_to_memref are handled in core.
565 struct RemoveUnusedBufferCastOperations
566 : public mlir::PassWrapper<RemoveUnusedBufferCastOperations,
567 mlir::FunctionPass> {
568 void runOnFunction() override;
569 };
570
571 // Adds a Tensorflow producer version to the module to enable shape inference.
572 struct AddTensorflowProducerVersion
573 : public mlir::PassWrapper<AddTensorflowProducerVersion,
574 mlir::OperationPass<mlir::ModuleOp>> {
575 void runOnOperation() override;
576 };
577
578 // Use Linalg CodegenStrategy to tile linalg.matmul, linalg.matvec and
579 // linalg.vecmat operations.
580 struct CodegenStrategyForMatMulPass
581 : public mlir::PassWrapper<CodegenStrategyForMatMulPass,
582 mlir::FunctionPass> {
583 void runOnFunction() override;
getDependentDialectstensorflow::__anon8f6779061411::CodegenStrategyForMatMulPass584 void getDependentDialects(mlir::DialectRegistry& registry) const override {
585 registry.insert<mlir::vector::VectorDialect>();
586 }
587 };
588 } // namespace
589
runOnFunction()590 void RemoveUnusedBufferCastOperations::runOnFunction() {
591 getFunction().walk([](mlir::memref::BufferCastOp op) {
592 // Drop all buffer_cast that have no more users. Currently this will
593 // not happen, as tensor_to_memref has a side-effect. See
594 // https://reviews.llvm.org/D91967 for a discussion.
595 if (op.memref().getUsers().empty()) {
596 op.erase();
597 }
598 });
599 }
600
runOnFunction()601 void CodegenStrategyForMatMulPass::runOnFunction() {
602 // Promote tiles to full buffers allocated on the stack.
603 mlir::linalg::LinalgPromotionOptions full_alloca_promotion;
604 full_alloca_promotion.setUseFullTileBuffersByDefault(true).setUseAlloca(true);
605
606 // Vector contraction options.
607 mlir::vector::VectorTransformsOptions vector_transforms_ops;
608 vector_transforms_ops.setVectorTransformsOptions(
609 mlir::vector::VectorContractLowering::OuterProduct);
610
611 // Vector transfer options.
612 mlir::VectorTransferToSCFOptions vector_transfer_opts;
613 vector_transfer_opts.setUnroll(true);
614
615 // TODO(ezhulenev): Set up tiling options depending on the target machine.
616
617 // Tile and vectorize linalg.matmul operations.
618 mlir::linalg::LinalgTilingOptions matmul_tiling;
619 matmul_tiling.setTileSizes({12, 32, 16});
620
621 mlir::linalg::CodegenStrategy matmul_strategy;
622 matmul_strategy.tile<mlir::linalg::MatmulOp>(matmul_tiling)
623 .promote<mlir::linalg::MatmulOp>(full_alloca_promotion)
624 .vectorize<mlir::linalg::MatmulOp>()
625 .setVectorTransformsOptions(vector_transforms_ops)
626 .setVectorTransferToSCFOptions(vector_transfer_opts);
627 matmul_strategy.transform(getFunction());
628
629 // Tile and vectorize linalg.vecmat operations. Interchange loop order to
630 // linearly read from the matrix memref.
631 mlir::linalg::LinalgTilingOptions vecmat_tiling;
632 vecmat_tiling.setTileSizes({16, 8}).setInterchange({1, 0});
633
634 mlir::linalg::CodegenStrategy vecmat_strategy;
635 vecmat_strategy.tile<mlir::linalg::VecmatOp>(vecmat_tiling)
636 .promote<mlir::linalg::VecmatOp>(full_alloca_promotion)
637 .vectorize<mlir::linalg::VecmatOp>()
638 .setVectorTransformsOptions(vector_transforms_ops)
639 .setVectorTransferToSCFOptions(vector_transfer_opts);
640 vecmat_strategy.transform(getFunction());
641 }
642
643 static std::unique_ptr<mlir::FunctionPass>
CreateCodegenStrategyForMatMulPass()644 CreateCodegenStrategyForMatMulPass() {
645 return std::make_unique<CodegenStrategyForMatMulPass>();
646 }
647
runOnOperation()648 void AddTensorflowProducerVersion::runOnOperation() {
649 mlir::ModuleOp module = getOperation();
650
651 // Tensorflow producer version does not really impact anything during the
652 // shape inference. Set it to `0` (any random number will do the work) to
653 // bypass attribute checks.
654 mlir::Builder builder(module);
655 auto version = builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
656 module->setAttr("tf.versions", builder.getDictionaryAttr({version}));
657 }
658
659 // -------------------------------------------------------------------------- //
660 // Cluster operations based on the TF CPURT clustering policy.
661 // -------------------------------------------------------------------------- //
662
663 namespace {
664 using llvm::ArrayRef;
665
666 using mlir::TFDevice::Cluster;
667 using mlir::TFDevice::ClusteringPolicySet;
668 using mlir::TFDevice::CreateClusterOp;
669 using mlir::TFDevice::FindClustersInTheBlock;
670
671 struct ClusteringPass : public ClusteringBase<ClusteringPass> {
672 ClusteringPass() = default;
ClusteringPasstensorflow::__anon8f6779061611::ClusteringPass673 ClusteringPass(ArrayRef<std::string> cluster_oplist, int cluster_min_size) {
674 oplist = cluster_oplist;
675 min_cluster_size = cluster_min_size;
676 }
677
678 void runOnFunction() override;
679 };
680 } // anonymous namespace
681
runOnFunction()682 void ClusteringPass::runOnFunction() {
683 ClusteringPolicySet policies;
684
685 // Parse clustering tier and operations filter from the oplist.
686 llvm::DenseSet<llvm::StringRef> opset;
687 llvm::Optional<CpurtClusteringTier> tier;
688
689 for (const auto& op : oplist) {
690 if (op == "tier1") {
691 tier = CpurtClusteringTier::kTier1;
692 } else if (op == "all") {
693 tier = CpurtClusteringTier::kAll;
694 } else {
695 opset.insert(op);
696 }
697 }
698
699 // Run clustering only if the clustering tier or supported operations are
700 // explicitly defined by the oplist.
701 if (!tier.hasValue() && opset.empty()) return;
702
703 // If the clustering tier is not defined, it means that the opset will later
704 // filter supported operations, so it's ok to use `all` tier.
705 populateTfCpurtClusteringPolicies(policies,
706 tier.getValueOr(CpurtClusteringTier::kAll));
707
708 // If opset is not empty restrict operations that are enabled for clustering.
709 auto filter = [&](mlir::Operation* op) -> bool {
710 return opset.empty() || opset.contains(op->getName().getStringRef());
711 };
712
713 // Annotate all formed clusters with an attribute.
714 auto policy = mlir::StringAttr::get(&getContext(), "tfrt.auto-fusion");
715
716 getFunction().walk([&](mlir::Block* block) {
717 for (Cluster& cluster : FindClustersInTheBlock(block, policies, filter)) {
718 // Do not create too small clusters.
719 if (cluster.operations.size() < min_cluster_size) continue;
720 // Verify that JIT runtime can compile the cluster.
721 if (failed(VerifyCluster(cluster))) continue;
722
723 CreateClusterOp(cluster, policy);
724 }
725 });
726 }
727
CreateTfCpurtClusteringPass()728 std::unique_ptr<mlir::FunctionPass> CreateTfCpurtClusteringPass() {
729 return std::make_unique<ClusteringPass>();
730 }
731
CreateTfCpurtClusteringPass(llvm::ArrayRef<std::string> oplist,int min_cluster_size)732 std::unique_ptr<mlir::FunctionPass> CreateTfCpurtClusteringPass(
733 llvm::ArrayRef<std::string> oplist, int min_cluster_size) {
734 return std::make_unique<ClusteringPass>(oplist, min_cluster_size);
735 }
736
737 // -------------------------------------------------------------------------- //
738 // Assemble a TF-CPURT pipeline to lower from Tensorflow dialects to Linalg on
739 // buffers via progressive lowering to MHLO and Linalg.
740 // -------------------------------------------------------------------------- //
741
CreateTfCpuRtPipeline(mlir::OpPassManager & pm)742 void CreateTfCpuRtPipeline(mlir::OpPassManager& pm) {
743 // Break Tensorflow fused operations into primitive operations before
744 // lowering to HLO.
745 pm.addNestedPass<mlir::FuncOp>(CreateFissionPass());
746
747 // Run shape inference to propagate potentially specialized input shapes.
748 pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
749 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
750
751 // Transform TF operation to HLO.
752 pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass());
753
754 // Move up broadcasting operations to allow for more fusion opportunities.
755 pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
756 pm.addPass(mlir::createCSEPass());
757 pm.addPass(mlir::createCanonicalizerPass());
758
759 // Transform HLO operations to LinAlg and fuse them.
760 pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeHloToLinalgPass());
761
762 // Lower index cast on tensors to tensor.generate.
763 pm.addNestedPass<mlir::FuncOp>(
764 mlir::kernel_gen::transforms::CreateLowerIndexCastPass());
765
766 // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
767 // use linalg inputs instead of outputs for memref.dim operations).
768 pm.addNestedPass<mlir::FuncOp>(
769 mlir::kernel_gen::transforms::CreateShapeSimplification());
770 pm.addNestedPass<mlir::FuncOp>(mlir::createShapeToShapeLowering());
771 pm.addPass(mlir::createConvertShapeToStandardPass());
772 pm.addNestedPass<mlir::FuncOp>(mlir::createConvertShapeConstraintsPass());
773
774 // Fuse Linalg on tensors operations.
775 pm.addPass(mlir::createCSEPass());
776 pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
777 pm.addPass(mlir::createCanonicalizerPass());
778 pm.addNestedPass<mlir::FuncOp>(mlir::createLinalgElementwiseOpFusionPass());
779
780 // Bufferize Linalg on tensors program.
781 // Always run canonicalizer (which does dead code removal) before bufferizing
782 // anything.
783 pm.addPass(mlir::createCanonicalizerPass());
784 // Now bufferize all the compute operations (hlo + linalg) and func signature.
785 pm.addPass(
786 mlir::kernel_gen::transforms::CreateComputeOpAndFuncBufferizePass());
787 // Turn tensor constants into global memrefs.
788 // TODO(kramerb): Expose the patterns and add them to the bufferize passes.
789 pm.addPass(mlir::createTensorConstantBufferizePass());
790 // Run canonicalizer for dead code removal.
791 pm.addPass(mlir::createCanonicalizerPass());
792 // tensor_to_memref is not considered dead currently, fix that directly.
793 pm.addNestedPass<mlir::FuncOp>(
794 std::make_unique<RemoveUnusedBufferCastOperations>());
795 // Always run canonicalizer (which does dead code removal) before bufferizing
796 // anything.
797 pm.addPass(mlir::createCanonicalizerPass());
798 pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
799 pm.addPass(mlir::createCSEPass());
800 pm.addPass(mlir::createCanonicalizerPass());
801
802 // Deallocate all temporary buffers.
803 pm.addNestedPass<mlir::FuncOp>(mlir::createBufferDeallocationPass());
804
805 // Do trivial buffer forwarding across linalg.generic operations.
806 pm.addNestedPass<mlir::FuncOp>(CreateLinalgTrivialBufferForwardingPass());
807
808 // Remove trivial copy operations.
809 pm.addNestedPass<mlir::FuncOp>(CreateLinalgTrivialCopyRemovalPass());
810
811 // Specilize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat, and
812 // immediately canonicalize to clean up not taken branches.
813 pm.addNestedPass<mlir::FuncOp>(CreateLinalgMatmulSpecializationPass());
814 pm.addPass(mlir::createCanonicalizerPass());
815
816 // Tile and vectorize linalg operation using Linalg Codegen Strategy.
817 pm.addNestedPass<mlir::FuncOp>(CreateCodegenStrategyForMatMulPass());
818 pm.addPass(mlir::createCSEPass());
819 pm.addPass(mlir::createCanonicalizerPass());
820 }
821
822 static mlir::PassPipelineRegistration<> tf_cpurt_pipeline(
823 "tf-cpurt-pipeline",
824 "Convert Tensorflow dialect to TFRT's CPURT compatible dialects",
825 CreateTfCpuRtPipeline);
826
827 } // namespace tensorflow
828