xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 
22 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
23 #include "mlir/Dialect/Async/IR/Async.h"
24 #include "mlir/Dialect/Linalg/Passes.h"
25 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
31 #include "mlir/IR/MLIRContext.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/SymbolTable.h"
34 #include "mlir/Pass/PassRegistry.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/Optional.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
41 #include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
50 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
51 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
53 
54 namespace tensorflow {
55 
56 #define GEN_PASS_CLASSES
57 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h.inc"
58 
59 // -------------------------------------------------------------------------- //
60 // Helper functions used by the passes implemented below.
61 // -------------------------------------------------------------------------- //
62 
63 // Returns true if the `value` type is a memref that is contiguous in memory.
IsContiguousMemref(mlir::Value value)64 static bool IsContiguousMemref(mlir::Value value) {
65   mlir::MemRefType memref_type = value.getType().dyn_cast<mlir::MemRefType>();
66   if (!memref_type) return false;
67   mlir::MemRefType canonical_type = canonicalizeStridedLayout(memref_type);
68   return canonical_type.getAffineMaps().empty();
69 }
70 
71 // -------------------------------------------------------------------------- //
72 // Trivial buffer forwarding for the linalg.generic operations.
73 // -------------------------------------------------------------------------- //
74 
75 namespace {
76 
77 struct ForwardingCandidate {
78   mlir::Value memref;
79   mlir::AffineMap indexing_map;
80 };
81 
82 struct LinalgTrivialBufferForwardingPattern
83     : public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
84   using OpRewritePattern<mlir::linalg::GenericOp>::OpRewritePattern;
85   mlir::LogicalResult matchAndRewrite(
86       mlir::linalg::GenericOp op,
87       mlir::PatternRewriter& rewriter) const override;
88 };
89 
90 struct LinalgTrivialBufferForwardingPass
91     : public LinalgTrivialBufferForwardingBase<
92           LinalgTrivialBufferForwardingPass> {
93   void runOnFunction() override;
94 };
95 }  // namespace
96 
97 // Returns true if all linalg.generic operation iterators are "parallel".
AllIteratorsAreParallel(mlir::linalg::GenericOp op)98 static bool AllIteratorsAreParallel(mlir::linalg::GenericOp op) {
99   return llvm::all_of(op.iterator_types(), [](mlir::Attribute attr) -> bool {
100     auto str_attr = attr.dyn_cast<mlir::StringAttr>();
101     return str_attr && str_attr.getValue() == "parallel";
102   });
103 }
104 
105 // Returns buffer inputs that can be safely used as buffer outputs.
FindBufferForwardingCandidates(mlir::linalg::GenericOp op)106 static llvm::SmallVector<mlir::OpOperand*> FindBufferForwardingCandidates(
107     mlir::linalg::GenericOp op) {
108   llvm::SmallVector<mlir::OpOperand*> candidates;
109 
110   for (mlir::OpOperand* input_buffer : op.getInputOperands()) {
111     // Input must be a contiguous memref ...
112     if (!IsContiguousMemref(input_buffer->get())) continue;
113 
114     // ... allocated in the same function.
115     auto* alloc = input_buffer->get().getDefiningOp();
116     if (!alloc || !mlir::isa<mlir::memref::AllocOp>(alloc)) continue;
117 
118     // Find input users that are after linalg.generic operation in the block.
119     auto users = llvm::make_filter_range(alloc->getUsers(),
120                                          [&](mlir::Operation* user) -> bool {
121                                            return op->isBeforeInBlock(user);
122                                          });
123 
124     // Input buffer must have exactly one user after linalg.generic.
125     llvm::SmallVector<mlir::Operation*> input_users(users.begin(), users.end());
126     if (input_users.size() > 1) continue;
127 
128     // And it must be a memref.dealloc operation.
129     if (!mlir::isa<mlir::memref::DeallocOp>(input_users[0])) continue;
130 
131     // This input memref can be safely reused for the output.
132     candidates.push_back(input_buffer);
133   }
134 
135   return candidates;
136 }
137 
matchAndRewrite(mlir::linalg::GenericOp op,mlir::PatternRewriter & rewriter) const138 mlir::LogicalResult LinalgTrivialBufferForwardingPattern::matchAndRewrite(
139     mlir::linalg::GenericOp op, mlir::PatternRewriter& rewriter) const {
140   // With parallel iterators it is guaranteed that every value used once, and it
141   // is safe to forward input to output.
142   if (!AllIteratorsAreParallel(op))
143     return rewriter.notifyMatchFailure(op, "all iterators must be parallel");
144 
145   // Find memrefs that potentially could be forwarded.
146   llvm::SmallVector<mlir::OpOperand*> forwarding_candidates =
147       FindBufferForwardingCandidates(op);
148   if (forwarding_candidates.empty())
149     return rewriter.notifyMatchFailure(
150         op, "did not find any candidates for input forwarding");
151 
152   // Inputs that were reused.
153   llvm::DenseSet<mlir::OpOperand*> reused_inputs;
154 
155   // Try to match output buffers to forwarding candidates.
156   for (mlir::OpOperand* output_buffer : op.getOutputOperands()) {
157     // Output must be allocated in the same function.
158     auto* alloc = output_buffer->get().getDefiningOp();
159     if (!alloc || !mlir::isa<mlir::memref::AllocOp>(alloc)) continue;
160 
161     // Find compatible input buffer.
162     for (mlir::OpOperand* input_buffer : forwarding_candidates) {
163       if (reused_inputs.contains(input_buffer)) continue;
164 
165       // Memref types must match (dimensions and affine maps).
166       if (input_buffer->get().getType() != output_buffer->get().getType())
167         continue;
168 
169       mlir::AffineMap src_map = op.getTiedIndexingMap(input_buffer);
170       mlir::AffineMap dst_map = op.getTiedIndexingMap(output_buffer);
171 
172       // Only support identity maps for the output for now.
173       if (!dst_map.isIdentity()) continue;
174 
175       auto is_projection = [](mlir::AffineMap map) {
176         // Allow adding/dropping dimensions but no permutations.
177         int64_t i = -1;
178         for (mlir::AffineExpr expr : map.getResults()) {
179           auto constant = expr.dyn_cast<mlir::AffineConstantExpr>();
180           if (constant && constant.getValue() == 0) continue;
181           auto dim_expr = expr.dyn_cast<mlir::AffineDimExpr>();
182           if (!dim_expr || dim_expr.getPosition() <= i) return false;
183           i = dim_expr.getPosition();
184         }
185         return true;
186       };
187 
188       auto same_shape = [](mlir::Value src, mlir::Value dst) {
189         auto src_type = src.getType().cast<mlir::ShapedType>();
190         auto dst_type = dst.getType().cast<mlir::ShapedType>();
191         mlir::OperandRange src_operands =
192             src.getDefiningOp<mlir::memref::AllocOp>().getDynamicSizes();
193         mlir::OperandRange dst_operands =
194             dst.getDefiningOp<mlir::memref::AllocOp>().getDynamicSizes();
195         return src_type.getShape().equals(dst_type.getShape()) &&
196                std::equal(src_operands.begin(), src_operands.end(),
197                           dst_operands.begin());
198       };
199 
200       // A reuse is valid if the maps are the same or if the shape is the same
201       // and the source is a projection map (in which case the ignored
202       // dimensions must be 1 assuming that the operation reads the entire
203       // input). Note that we already know that the destination map is an
204       // identity map.
205       if (src_map != dst_map &&
206           !(is_projection(src_map) &&
207             same_shape(input_buffer->get(), output_buffer->get()))) {
208         continue;
209       }
210 
211       // Find the input buffer dealloc operation.
212       mlir::Operation* input_dealloc = *llvm::find_if(
213           input_buffer->get().getUsers(), [](mlir::Operation* user) -> bool {
214             return mlir::isa<mlir::memref::DeallocOp>(user);
215           });
216 
217       // Deallocate output buffer instead of the input buffer.
218       input_buffer->get().replaceUsesWithIf(
219           output_buffer->get(), [&](mlir::OpOperand& operand) -> bool {
220             return operand.getOwner() == input_dealloc;
221           });
222 
223       // Forward users of output buffer to the input buffer, if they are after
224       // linalg.generic operation in the block (or linalg.generic itself).
225       output_buffer->get().replaceUsesWithIf(
226           input_buffer->get(), [&](mlir::OpOperand& operand) -> bool {
227             return operand.getOwner() != input_dealloc &&
228                    !operand.getOwner()->isBeforeInBlock(op);
229           });
230 
231       reused_inputs.insert(input_buffer);
232     }
233   }
234 
235   return mlir::success(!reused_inputs.empty());
236 }
237 
runOnFunction()238 void LinalgTrivialBufferForwardingPass::runOnFunction() {
239   mlir::FuncOp function = getFunction();
240   mlir::MLIRContext* ctx = function.getContext();
241 
242   mlir::RewritePatternSet patterns(ctx);
243   patterns.insert<LinalgTrivialBufferForwardingPattern>(ctx);
244 
245   (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
246 }
247 
CreateLinalgTrivialBufferForwardingPass()248 std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialBufferForwardingPass() {
249   return std::make_unique<LinalgTrivialBufferForwardingPass>();
250 }
251 
252 // -------------------------------------------------------------------------- //
253 // Trivial buffer forwarding for the linalg.generic operations.
254 // -------------------------------------------------------------------------- //
255 
256 namespace {
257 
258 struct LinalgTrivialCopyRemovalPass
259     : public LinalgTrivialCopyRemovalBase<LinalgTrivialCopyRemovalPass> {
260   void runOnFunction() override;
261 };
262 
263 }  // namespace
264 
runOnFunction()265 void LinalgTrivialCopyRemovalPass::runOnFunction() {
266   mlir::FuncOp function = getFunction();
267 
268   mlir::SmallVector<mlir::Operation*> to_erase;
269   function.walk([&to_erase](mlir::linalg::CopyOp copy) {
270     // Only match precise alloc/copy/dealloc triples.
271     auto alloc = llvm::dyn_cast<mlir::memref::AllocOp>(copy->getPrevNode());
272     auto dealloc = llvm::dyn_cast<mlir::memref::DeallocOp>(copy->getNextNode());
273 
274     if (!alloc || !dealloc) return;
275 
276     // Make sure the alloc and dealloc handle the operands of the copy.
277     if (alloc.getResult() != copy.getTarget() ||
278         dealloc.memref() != copy.getSource()) {
279       return;
280     }
281 
282     // Remember the operations to delete.
283     to_erase.push_back(alloc);
284     to_erase.push_back(dealloc);
285     to_erase.push_back(copy);
286     copy.getTarget().replaceAllUsesWith(copy.getSource());
287   });
288 
289   for (auto op : to_erase) {
290     op->erase();
291   }
292 }
293 
CreateLinalgTrivialCopyRemovalPass()294 std::unique_ptr<mlir::FunctionPass> CreateLinalgTrivialCopyRemovalPass() {
295   return std::make_unique<LinalgTrivialCopyRemovalPass>();
296 }
297 
298 // -------------------------------------------------------------------------- //
299 // Dispatch linalg.matmul to one of the more specialized operations at runtime.
300 // -------------------------------------------------------------------------- //
301 
302 namespace {
303 
304 struct LinalgMatmulSpecializationPattern
305     : public mlir::OpRewritePattern<mlir::linalg::MatmulOp> {
306   using OpRewritePattern<mlir::linalg::MatmulOp>::OpRewritePattern;
307   mlir::LogicalResult matchAndRewrite(
308       mlir::linalg::MatmulOp matmul,
309       mlir::PatternRewriter& rewriter) const override;
310 };
311 
312 struct LinalgMatmulSpecializationPass
313     : public LinalgMatmulSpecializationBase<LinalgMatmulSpecializationPass> {
314   void runOnFunction() override;
315 };
316 }  // namespace
317 
318 // Convert 2D memref into a 1D memref (vector).
MemrefToVector(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value memref,mlir::Value size,int64_t static_size)319 static mlir::Value MemrefToVector(mlir::OpBuilder& builder, mlir::Location loc,
320                                   mlir::Value memref, mlir::Value size,
321                                   int64_t static_size) {
322   assert(static_size >= 0 || static_size == mlir::ShapedType::kDynamicSize);
323   auto memref_type = memref.getType().cast<mlir::MemRefType>();
324   auto vec_type =
325       mlir::MemRefType::get({static_size}, memref_type.getElementType());
326 
327   auto static_offsets = builder.getI64ArrayAttr({0});
328   auto static_sizes = builder.getI64ArrayAttr({static_size});
329   auto static_strided = builder.getI64ArrayAttr({1});
330 
331   auto empty = mlir::ValueRange();
332   auto sizes = static_size == mlir::ShapedType::kDynamicSize
333                    ? mlir::ValueRange(size)
334                    : mlir::ValueRange();
335 
336   return builder.create<mlir::memref::ReinterpretCastOp>(
337       loc, vec_type, memref, /*offsets=*/empty,
338       /*sizes=*/sizes, /*strides=*/empty, static_offsets, static_sizes,
339       static_strided);
340 }
341 
342 // Convert 2D memref into a 0D memref (scalar).
MemrefToScalar(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value memref)343 static mlir::Value MemrefToScalar(mlir::OpBuilder& builder, mlir::Location loc,
344                                   mlir::Value memref) {
345   auto memref_type = memref.getType().cast<mlir::MemRefType>();
346   auto scalar_type = mlir::MemRefType::get({}, memref_type.getElementType());
347 
348   std::array<int64_t, 0> empty;
349   return builder.create<mlir::memref::ReinterpretCastOp>(
350       loc, scalar_type, memref, /*offset=*/0,
351       /*sizes=*/empty, /*strides=*/empty);
352 }
353 
matchAndRewrite(mlir::linalg::MatmulOp matmul,mlir::PatternRewriter & rewriter) const354 mlir::LogicalResult LinalgMatmulSpecializationPattern::matchAndRewrite(
355     mlir::linalg::MatmulOp matmul, mlir::PatternRewriter& rewriter) const {
356   if (matmul->hasAttr("__tf_cpurt_specialized"))
357     return rewriter.notifyMatchFailure(matmul,
358                                        "operation was already specialized");
359 
360   auto rhs = matmul.getInputOperand(1)->get();
361   auto lhs = matmul.getInputOperand(0)->get();
362   auto out = matmul.getOutputOperand(0)->get();
363 
364   // We do not support inputs or outputs that are not contiguous in memory.
365   if (!IsContiguousMemref(lhs) || !IsContiguousMemref(rhs) ||
366       !IsContiguousMemref(out))
367     return rewriter.notifyMatchFailure(
368         matmul, "inputs and output must be contiguous memrefs");
369 
370   auto loc = matmul.getLoc();
371 
372   // Matmul dimensions: [m, k] x [k, n]
373   mlir::Value m = rewriter.create<mlir::memref::DimOp>(loc, lhs, 0);
374   mlir::Value k = rewriter.create<mlir::memref::DimOp>(loc, lhs, 1);
375   mlir::Value n = rewriter.create<mlir::memref::DimOp>(loc, rhs, 1);
376 
377   // Matmul static dimensions if they are known (can be ShapedType::kDynamicSize
378   // if not known statically).
379   int64_t m_static = lhs.getType().cast<mlir::MemRefType>().getDimSize(0);
380   int64_t k_static = lhs.getType().cast<mlir::MemRefType>().getDimSize(1);
381   int64_t n_static = rhs.getType().cast<mlir::MemRefType>().getDimSize(1);
382 
383   auto one = rewriter.create<mlir::ConstantOp>(loc, rewriter.getIndexType(),
384                                                rewriter.getIndexAttr(1));
385   auto m_is_one =
386       rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, m, one);
387   auto n_is_one =
388       rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, n, one);
389 
390   auto m_not_one =
391       rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, m, one);
392   auto n_not_one =
393       rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, n, one);
394 
395   // linalg.dot: n == 1 && m == 1
396   auto is_dot_product = rewriter.create<mlir::AndOp>(loc, m_is_one, n_is_one);
397   // linalg.vecmat m == 1 && n != 1
398   auto is_vecmat = rewriter.create<mlir::AndOp>(loc, m_is_one, n_not_one);
399   // linalg.matvec n == 1 && m != 1
400   auto is_matvec = rewriter.create<mlir::AndOp>(loc, n_is_one, m_not_one);
401 
402   // Build a linalg.dot operation casting inputs to vectors.
403   auto dot = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
404     auto lhs_vec = MemrefToVector(builder, nestedLoc, lhs, k, k_static);
405     auto rhs_vec = MemrefToVector(builder, nestedLoc, rhs, k, k_static);
406     auto out_scalar = MemrefToScalar(builder, nestedLoc, out);
407 
408     builder.create<mlir::linalg::DotOp>(nestedLoc,
409                                         mlir::ValueRange({lhs_vec, rhs_vec}),
410                                         mlir::ValueRange({out_scalar}));
411     builder.create<mlir::scf::YieldOp>(nestedLoc);
412   };
413 
414   // Build a linalg.vecmat operation casting lhs to vector.
415   auto vecmat = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
416     auto lhs_vec = MemrefToVector(builder, nestedLoc, lhs, k, k_static);
417     auto out_vec = MemrefToVector(builder, nestedLoc, out, n, n_static);
418 
419     builder.create<mlir::linalg::VecmatOp>(nestedLoc,
420                                            mlir::ValueRange({lhs_vec, rhs}),
421                                            mlir::ValueRange({out_vec}));
422     builder.create<mlir::scf::YieldOp>(nestedLoc);
423   };
424 
425   // Build a linalg.matvec operation casting rhs to vector.
426   auto matvec = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
427     auto rhs_vec = MemrefToVector(builder, nestedLoc, rhs, k, k_static);
428     auto out_vec = MemrefToVector(builder, nestedLoc, out, m, m_static);
429 
430     builder.create<mlir::linalg::MatvecOp>(nestedLoc,
431                                            mlir::ValueRange({lhs, rhs_vec}),
432                                            mlir::ValueRange({out_vec}));
433     builder.create<mlir::scf::YieldOp>(nestedLoc);
434   };
435 
436   // Build a generic linalg.matmul operation when it can't be matched to any of
437   // the specializations.
438   auto generic = [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
439     llvm::SmallVector<mlir::Value> inputs = matmul.getInputOperands();
440     llvm::SmallVector<mlir::Value> outputs = matmul.getOutputOperands();
441     auto specialized =
442         builder.create<mlir::linalg::MatmulOp>(nestedLoc, inputs, outputs);
443     specialized->setAttr("__tf_cpurt_specialized", rewriter.getUnitAttr());
444     builder.create<mlir::scf::YieldOp>(nestedLoc);
445   };
446 
447   // TODO(ezhulenev): Simplify to scf.switch operation.
448   // if (is_dot_product) ===>>> linalg.dot    ------------------------------- //
449   auto dispatch = rewriter.create<mlir::scf::IfOp>(
450       loc, is_dot_product, dot,
451       [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
452         // else if (is_vecmat)  ===>>> linalg.vecmat    --------------------- //
453         rewriter.create<mlir::scf::IfOp>(
454             nestedLoc, is_vecmat, vecmat,
455             [&](mlir::OpBuilder& builder, mlir::Location nestedLoc) {
456               // else if (is_matvec)  ===>>> linalg.matvec    --------------- //
457               // else                 ===>>> linalg.matmul    --------------- //
458               rewriter.create<mlir::scf::IfOp>(nestedLoc, is_matvec, matvec,
459                                                generic);
460               builder.create<mlir::scf::YieldOp>(nestedLoc);
461             });
462         builder.create<mlir::scf::YieldOp>(nestedLoc);
463       });
464 
465   rewriter.replaceOp(matmul, dispatch.results());
466   return mlir::success();
467 }
468 
runOnFunction()469 void LinalgMatmulSpecializationPass::runOnFunction() {
470   mlir::FuncOp function = getFunction();
471   mlir::MLIRContext* ctx = function.getContext();
472 
473   mlir::RewritePatternSet patterns(ctx);
474   patterns.insert<LinalgMatmulSpecializationPattern>(ctx);
475 
476   (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
477 }
478 
CreateLinalgMatmulSpecializationPass()479 std::unique_ptr<mlir::FunctionPass> CreateLinalgMatmulSpecializationPass() {
480   return std::make_unique<LinalgMatmulSpecializationPass>();
481 }
482 
483 // -------------------------------------------------------------------------- //
484 // Break Tensorflow _Fused{Op} operations into primitive ones.
485 // -------------------------------------------------------------------------- //
486 
487 namespace {
488 struct FissionPass : public FissionBase<FissionPass> {
489   void runOnFunction() override;
490 };
491 
492 struct FusedMatMulFission
493     : public mlir::OpRewritePattern<mlir::TF::_FusedMatMulOp> {
494   using OpRewritePattern<mlir::TF::_FusedMatMulOp>::OpRewritePattern;
495 
496   mlir::LogicalResult matchAndRewrite(
497       mlir::TF::_FusedMatMulOp op,
498       mlir::PatternRewriter& rewriter) const override;
499 };
500 }  // namespace
501 
matchAndRewrite(mlir::TF::_FusedMatMulOp op,mlir::PatternRewriter & rewriter) const502 mlir::LogicalResult FusedMatMulFission::matchAndRewrite(
503     mlir::TF::_FusedMatMulOp op, mlir::PatternRewriter& rewriter) const {
504   auto loc = op.getLoc();
505   auto type = op.getResult().getType();
506 
507   size_t n = op.fused_ops().size();
508 
509   // Extract fused operations from the operation attributes.
510   mlir::StringAttr fusion0 =
511       n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
512   mlir::StringAttr fusion1 =
513       n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
514 
515   // Match to supported operations
516   bool is_bias_add = fusion0 && fusion0.getValue() == "BiasAdd";
517   bool is_relu_activation = fusion1 && fusion1.getValue() == "Relu";
518 
519   // Create a simple MatMul operation from the fused one.
520   auto matmul = [&]() -> mlir::TF::MatMulOp {
521     auto lhs = op.getOperand(0);
522     auto rhs = op.getOperand(1);
523     return rewriter.create<mlir::TF::MatMulOp>(
524         loc, type, lhs, rhs, op.transpose_a(), op.transpose_b());
525   };
526 
527   // FusedMatMul[BiasAdd].
528   if (n == 1 && is_bias_add) {
529     rewriter.replaceOpWithNewOp<mlir::TF::BiasAddOp>(op, type, matmul(),
530                                                      op.getOperand(2));
531     return mlir::success();
532   }
533 
534   // FusedMatMul[BiasAdd, Relu].
535   if (n == 2 && is_bias_add && is_relu_activation) {
536     auto biased = rewriter.create<mlir::TF::BiasAddOp>(loc, type, matmul(),
537                                                        op.getOperand(2));
538     rewriter.replaceOpWithNewOp<mlir::TF::ReluOp>(op, type, biased);
539     return mlir::success();
540   }
541 
542   return mlir::failure();
543 }
544 
runOnFunction()545 void FissionPass::runOnFunction() {
546   mlir::FuncOp function = getFunction();
547   mlir::MLIRContext* ctx = function.getContext();
548 
549   mlir::RewritePatternSet patterns(ctx);
550   patterns.insert<FusedMatMulFission>(ctx);
551 
552   (void)mlir::applyPatternsAndFoldGreedily(function, std::move(patterns));
553 }
554 
CreateFissionPass()555 std::unique_ptr<mlir::FunctionPass> CreateFissionPass() {
556   return std::make_unique<FissionPass>();
557 }
558 
559 // -------------------------------------------------------------------------- //
560 // Custom passes that are missing upstream.
561 // -------------------------------------------------------------------------- //
562 
563 namespace {
564 // TODO(herhut): Remove this once leftover tensor_to_memref are handled in core.
565 struct RemoveUnusedBufferCastOperations
566     : public mlir::PassWrapper<RemoveUnusedBufferCastOperations,
567                                mlir::FunctionPass> {
568   void runOnFunction() override;
569 };
570 
571 // Adds a Tensorflow producer version to the module to enable shape inference.
572 struct AddTensorflowProducerVersion
573     : public mlir::PassWrapper<AddTensorflowProducerVersion,
574                                mlir::OperationPass<mlir::ModuleOp>> {
575   void runOnOperation() override;
576 };
577 
578 // Use Linalg CodegenStrategy to tile linalg.matmul, linalg.matvec and
579 // linalg.vecmat operations.
580 struct CodegenStrategyForMatMulPass
581     : public mlir::PassWrapper<CodegenStrategyForMatMulPass,
582                                mlir::FunctionPass> {
583   void runOnFunction() override;
getDependentDialectstensorflow::__anon8f6779061411::CodegenStrategyForMatMulPass584   void getDependentDialects(mlir::DialectRegistry& registry) const override {
585     registry.insert<mlir::vector::VectorDialect>();
586   }
587 };
588 }  // namespace
589 
runOnFunction()590 void RemoveUnusedBufferCastOperations::runOnFunction() {
591   getFunction().walk([](mlir::memref::BufferCastOp op) {
592     // Drop all buffer_cast that have no more users. Currently this will
593     // not happen, as tensor_to_memref has a side-effect. See
594     // https://reviews.llvm.org/D91967 for a discussion.
595     if (op.memref().getUsers().empty()) {
596       op.erase();
597     }
598   });
599 }
600 
runOnFunction()601 void CodegenStrategyForMatMulPass::runOnFunction() {
602   // Promote tiles to full buffers allocated on the stack.
603   mlir::linalg::LinalgPromotionOptions full_alloca_promotion;
604   full_alloca_promotion.setUseFullTileBuffersByDefault(true).setUseAlloca(true);
605 
606   // Vector contraction options.
607   mlir::vector::VectorTransformsOptions vector_transforms_ops;
608   vector_transforms_ops.setVectorTransformsOptions(
609       mlir::vector::VectorContractLowering::OuterProduct);
610 
611   // Vector transfer options.
612   mlir::VectorTransferToSCFOptions vector_transfer_opts;
613   vector_transfer_opts.setUnroll(true);
614 
615   // TODO(ezhulenev): Set up tiling options depending on the target machine.
616 
617   // Tile and vectorize linalg.matmul operations.
618   mlir::linalg::LinalgTilingOptions matmul_tiling;
619   matmul_tiling.setTileSizes({12, 32, 16});
620 
621   mlir::linalg::CodegenStrategy matmul_strategy;
622   matmul_strategy.tile<mlir::linalg::MatmulOp>(matmul_tiling)
623       .promote<mlir::linalg::MatmulOp>(full_alloca_promotion)
624       .vectorize<mlir::linalg::MatmulOp>()
625       .setVectorTransformsOptions(vector_transforms_ops)
626       .setVectorTransferToSCFOptions(vector_transfer_opts);
627   matmul_strategy.transform(getFunction());
628 
629   // Tile and vectorize linalg.vecmat operations. Interchange loop order to
630   // linearly read from the matrix memref.
631   mlir::linalg::LinalgTilingOptions vecmat_tiling;
632   vecmat_tiling.setTileSizes({16, 8}).setInterchange({1, 0});
633 
634   mlir::linalg::CodegenStrategy vecmat_strategy;
635   vecmat_strategy.tile<mlir::linalg::VecmatOp>(vecmat_tiling)
636       .promote<mlir::linalg::VecmatOp>(full_alloca_promotion)
637       .vectorize<mlir::linalg::VecmatOp>()
638       .setVectorTransformsOptions(vector_transforms_ops)
639       .setVectorTransferToSCFOptions(vector_transfer_opts);
640   vecmat_strategy.transform(getFunction());
641 }
642 
643 static std::unique_ptr<mlir::FunctionPass>
CreateCodegenStrategyForMatMulPass()644 CreateCodegenStrategyForMatMulPass() {
645   return std::make_unique<CodegenStrategyForMatMulPass>();
646 }
647 
runOnOperation()648 void AddTensorflowProducerVersion::runOnOperation() {
649   mlir::ModuleOp module = getOperation();
650 
651   // Tensorflow producer version does not really impact anything during the
652   // shape inference. Set it to `0` (any random number will do the work) to
653   // bypass attribute checks.
654   mlir::Builder builder(module);
655   auto version = builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
656   module->setAttr("tf.versions", builder.getDictionaryAttr({version}));
657 }
658 
659 // -------------------------------------------------------------------------- //
660 // Cluster operations based on the TF CPURT clustering policy.
661 // -------------------------------------------------------------------------- //
662 
663 namespace {
664 using llvm::ArrayRef;
665 
666 using mlir::TFDevice::Cluster;
667 using mlir::TFDevice::ClusteringPolicySet;
668 using mlir::TFDevice::CreateClusterOp;
669 using mlir::TFDevice::FindClustersInTheBlock;
670 
671 struct ClusteringPass : public ClusteringBase<ClusteringPass> {
672   ClusteringPass() = default;
ClusteringPasstensorflow::__anon8f6779061611::ClusteringPass673   ClusteringPass(ArrayRef<std::string> cluster_oplist, int cluster_min_size) {
674     oplist = cluster_oplist;
675     min_cluster_size = cluster_min_size;
676   }
677 
678   void runOnFunction() override;
679 };
680 }  // anonymous namespace
681 
runOnFunction()682 void ClusteringPass::runOnFunction() {
683   ClusteringPolicySet policies;
684 
685   // Parse clustering tier and operations filter from the oplist.
686   llvm::DenseSet<llvm::StringRef> opset;
687   llvm::Optional<CpurtClusteringTier> tier;
688 
689   for (const auto& op : oplist) {
690     if (op == "tier1") {
691       tier = CpurtClusteringTier::kTier1;
692     } else if (op == "all") {
693       tier = CpurtClusteringTier::kAll;
694     } else {
695       opset.insert(op);
696     }
697   }
698 
699   // Run clustering only if the clustering tier or supported operations are
700   // explicitly defined by the oplist.
701   if (!tier.hasValue() && opset.empty()) return;
702 
703   // If the clustering tier is not defined, it means that the opset will later
704   // filter supported operations, so it's ok to use `all` tier.
705   populateTfCpurtClusteringPolicies(policies,
706                                     tier.getValueOr(CpurtClusteringTier::kAll));
707 
708   // If opset is not empty restrict operations that are enabled for clustering.
709   auto filter = [&](mlir::Operation* op) -> bool {
710     return opset.empty() || opset.contains(op->getName().getStringRef());
711   };
712 
713   // Annotate all formed clusters with an attribute.
714   auto policy = mlir::StringAttr::get(&getContext(), "tfrt.auto-fusion");
715 
716   getFunction().walk([&](mlir::Block* block) {
717     for (Cluster& cluster : FindClustersInTheBlock(block, policies, filter)) {
718       // Do not create too small clusters.
719       if (cluster.operations.size() < min_cluster_size) continue;
720       // Verify that JIT runtime can compile the cluster.
721       if (failed(VerifyCluster(cluster))) continue;
722 
723       CreateClusterOp(cluster, policy);
724     }
725   });
726 }
727 
CreateTfCpurtClusteringPass()728 std::unique_ptr<mlir::FunctionPass> CreateTfCpurtClusteringPass() {
729   return std::make_unique<ClusteringPass>();
730 }
731 
CreateTfCpurtClusteringPass(llvm::ArrayRef<std::string> oplist,int min_cluster_size)732 std::unique_ptr<mlir::FunctionPass> CreateTfCpurtClusteringPass(
733     llvm::ArrayRef<std::string> oplist, int min_cluster_size) {
734   return std::make_unique<ClusteringPass>(oplist, min_cluster_size);
735 }
736 
737 // -------------------------------------------------------------------------- //
738 // Assemble a TF-CPURT pipeline to lower from Tensorflow dialects to Linalg on
739 // buffers via progressive lowering to MHLO and Linalg.
740 // -------------------------------------------------------------------------- //
741 
CreateTfCpuRtPipeline(mlir::OpPassManager & pm)742 void CreateTfCpuRtPipeline(mlir::OpPassManager& pm) {
743   // Break Tensorflow fused operations into primitive operations before
744   // lowering to HLO.
745   pm.addNestedPass<mlir::FuncOp>(CreateFissionPass());
746 
747   // Run shape inference to propagate potentially specialized input shapes.
748   pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
749   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
750 
751   // Transform TF operation to HLO.
752   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass());
753 
754   // Move up broadcasting operations to allow for more fusion opportunities.
755   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
756   pm.addPass(mlir::createCSEPass());
757   pm.addPass(mlir::createCanonicalizerPass());
758 
759   // Transform HLO operations to LinAlg and fuse them.
760   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeHloToLinalgPass());
761 
762   // Lower index cast on tensors to tensor.generate.
763   pm.addNestedPass<mlir::FuncOp>(
764       mlir::kernel_gen::transforms::CreateLowerIndexCastPass());
765 
766   // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
767   // use linalg inputs instead of outputs for memref.dim operations).
768   pm.addNestedPass<mlir::FuncOp>(
769       mlir::kernel_gen::transforms::CreateShapeSimplification());
770   pm.addNestedPass<mlir::FuncOp>(mlir::createShapeToShapeLowering());
771   pm.addPass(mlir::createConvertShapeToStandardPass());
772   pm.addNestedPass<mlir::FuncOp>(mlir::createConvertShapeConstraintsPass());
773 
774   // Fuse Linalg on tensors operations.
775   pm.addPass(mlir::createCSEPass());
776   pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
777   pm.addPass(mlir::createCanonicalizerPass());
778   pm.addNestedPass<mlir::FuncOp>(mlir::createLinalgElementwiseOpFusionPass());
779 
780   // Bufferize Linalg on tensors program.
781   // Always run canonicalizer (which does dead code removal) before bufferizing
782   // anything.
783   pm.addPass(mlir::createCanonicalizerPass());
784   // Now bufferize all the compute operations (hlo + linalg) and func signature.
785   pm.addPass(
786       mlir::kernel_gen::transforms::CreateComputeOpAndFuncBufferizePass());
787   // Turn tensor constants into global memrefs.
788   // TODO(kramerb): Expose the patterns and add them to the bufferize passes.
789   pm.addPass(mlir::createTensorConstantBufferizePass());
790   // Run canonicalizer for dead code removal.
791   pm.addPass(mlir::createCanonicalizerPass());
792   // tensor_to_memref is not considered dead currently, fix that directly.
793   pm.addNestedPass<mlir::FuncOp>(
794       std::make_unique<RemoveUnusedBufferCastOperations>());
795   // Always run canonicalizer (which does dead code removal) before bufferizing
796   // anything.
797   pm.addPass(mlir::createCanonicalizerPass());
798   pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
799   pm.addPass(mlir::createCSEPass());
800   pm.addPass(mlir::createCanonicalizerPass());
801 
802   // Deallocate all temporary buffers.
803   pm.addNestedPass<mlir::FuncOp>(mlir::createBufferDeallocationPass());
804 
805   // Do trivial buffer forwarding across linalg.generic operations.
806   pm.addNestedPass<mlir::FuncOp>(CreateLinalgTrivialBufferForwardingPass());
807 
808   // Remove trivial copy operations.
809   pm.addNestedPass<mlir::FuncOp>(CreateLinalgTrivialCopyRemovalPass());
810 
811   // Specilize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat, and
812   // immediately canonicalize to clean up not taken branches.
813   pm.addNestedPass<mlir::FuncOp>(CreateLinalgMatmulSpecializationPass());
814   pm.addPass(mlir::createCanonicalizerPass());
815 
816   // Tile and vectorize linalg operation using Linalg Codegen Strategy.
817   pm.addNestedPass<mlir::FuncOp>(CreateCodegenStrategyForMatMulPass());
818   pm.addPass(mlir::createCSEPass());
819   pm.addPass(mlir::createCanonicalizerPass());
820 }
821 
822 static mlir::PassPipelineRegistration<> tf_cpurt_pipeline(
823     "tf-cpurt-pipeline",
824     "Convert Tensorflow dialect to TFRT's CPURT compatible dialects",
825     CreateTfCpuRtPipeline);
826 
827 }  // namespace tensorflow
828