1 /* Copyright 2022 Google Inc. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <queue>
19 #include <vector>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27
28 namespace mlir {
29 namespace TF {
30 namespace {
31
32 std::vector<Operation*> groupOperationsByDialect(Block& block);
33
34 // Reorder operations so that consecutive ops stay in the same dialect, as far
35 // as possible. This is to optimize the op order for the group-by-dialect pass,
36 // which factors consecutive same-dialect ops into functions.
37 // TODO(kramm): This pass needs to become aware of side-effects between ops
38 // of different dialects.
39 class OrderByDialectPass : public OrderByDialectPassBase<OrderByDialectPass> {
40 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OrderByDialectPass)41 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OrderByDialectPass)
42
43 void runOnOperation() override {
44 for (Block& block : getOperation().getBody()) {
45 auto ops = groupOperationsByDialect(block);
46 // Replace the block with the reordered block.
47 for (Operation* op : ops) {
48 op->remove();
49 block.push_back(op);
50 }
51 }
52 }
53 };
54
55 // Similar to MLIR's topological sort (lib/Transforms/TopologicalSort.cpp)
56 // but has an explicit scoring function to determine the next op to emit.
groupOperationsByDialect(Block & block)57 std::vector<Operation*> groupOperationsByDialect(Block& block) {
58 llvm::DenseMap<Operation*, int> remaining_incoming_edges;
59 llvm::DenseMap<Operation*, int> position;
60 SmallVector<Operation*> ready;
61
62 int i = 0;
63 for (Operation& op : block.getOperations()) {
64 remaining_incoming_edges[&op] = op.getNumOperands();
65 position[&op] = i++;
66 if (op.getNumOperands() == 0) {
67 ready.push_back(&op);
68 }
69 }
70
71 std::queue<Value> todo;
72 for (Value value : block.getArguments()) {
73 todo.push(value);
74 }
75
76 StringRef current_dialect = "<none>";
77
78 std::vector<Operation*> result;
79 while (!todo.empty() || !ready.empty()) {
80 while (!todo.empty()) {
81 Value value = todo.front();
82 todo.pop();
83 // All operations that have all their inputs available are good to go.
84 for (OpOperand& operand : value.getUses()) {
85 // Uses, not Users, in case getUsers ever dedups.
86 Operation* user = operand.getOwner();
87 if (--remaining_incoming_edges[user] == 0) {
88 ready.push_back(user);
89 }
90 }
91 }
92
93 // Find the "best" operation to emit. We
94 // (a) stay in the same dialect as far as possible.
95 // (b) preserve order within the ops of one dialect.
96 // (c) emit the terminator last.
97 auto better = [&](Operation* a, Operation* b) {
98 if (a->hasTrait<OpTrait::IsTerminator>() !=
99 b->hasTrait<OpTrait::IsTerminator>()) {
100 return b->hasTrait<OpTrait::IsTerminator>();
101 }
102 bool a_current = a->getName().getDialectNamespace() == current_dialect;
103 bool b_current = b->getName().getDialectNamespace() == current_dialect;
104 if (a_current != b_current) {
105 return a_current;
106 }
107 return position[a] < position[b]; // preserve order
108 };
109
110 Operation* best = nullptr;
111 for (Operation* op : ready) {
112 if (best == nullptr || better(op, best)) {
113 best = op;
114 }
115 }
116
117 // Consider this operation emitted, and make its results available.
118 ready.erase(std::find(ready.begin(), ready.end(), best));
119 current_dialect = best->getName().getDialectNamespace();
120 for (Value result : best->getResults()) {
121 todo.push(result);
122 }
123 result.push_back(best);
124 }
125 return result;
126 }
127
128 } // namespace
129
CreateOrderByDialectPass()130 std::unique_ptr<Pass> CreateOrderByDialectPass() {
131 return std::make_unique<OrderByDialectPass>();
132 }
133
RegisterOrderByDialectPass()134 void RegisterOrderByDialectPass() { registerPass(CreateOrderByDialectPass); }
135
136 } // namespace TF
137 } // namespace mlir
138