xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <queue>
19 #include <vector>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27 
28 namespace mlir {
29 namespace TF {
30 namespace {
31 
32 std::vector<Operation*> groupOperationsByDialect(Block& block);
33 
34 // Reorder operations so that consecutive ops stay in the same dialect, as far
35 // as possible. This is to optimize the op order for the group-by-dialect pass,
36 // which factors consecutive same-dialect ops into functions.
37 // TODO(kramm): This pass needs to become aware of side-effects between ops
38 // of different dialects.
39 class OrderByDialectPass : public OrderByDialectPassBase<OrderByDialectPass> {
40  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OrderByDialectPass)41   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OrderByDialectPass)
42 
43   void runOnOperation() override {
44     for (Block& block : getOperation().getBody()) {
45       auto ops = groupOperationsByDialect(block);
46       // Replace the block with the reordered block.
47       for (Operation* op : ops) {
48         op->remove();
49         block.push_back(op);
50       }
51     }
52   }
53 };
54 
55 // Similar to MLIR's topological sort (lib/Transforms/TopologicalSort.cpp)
56 // but has an explicit scoring function to determine the next op to emit.
groupOperationsByDialect(Block & block)57 std::vector<Operation*> groupOperationsByDialect(Block& block) {
58   llvm::DenseMap<Operation*, int> remaining_incoming_edges;
59   llvm::DenseMap<Operation*, int> position;
60   SmallVector<Operation*> ready;
61 
62   int i = 0;
63   for (Operation& op : block.getOperations()) {
64     remaining_incoming_edges[&op] = op.getNumOperands();
65     position[&op] = i++;
66     if (op.getNumOperands() == 0) {
67       ready.push_back(&op);
68     }
69   }
70 
71   std::queue<Value> todo;
72   for (Value value : block.getArguments()) {
73     todo.push(value);
74   }
75 
76   StringRef current_dialect = "<none>";
77 
78   std::vector<Operation*> result;
79   while (!todo.empty() || !ready.empty()) {
80     while (!todo.empty()) {
81       Value value = todo.front();
82       todo.pop();
83       // All operations that have all their inputs available are good to go.
84       for (OpOperand& operand : value.getUses()) {
85         // Uses, not Users, in case getUsers ever dedups.
86         Operation* user = operand.getOwner();
87         if (--remaining_incoming_edges[user] == 0) {
88           ready.push_back(user);
89         }
90       }
91     }
92 
93     // Find the "best" operation to emit. We
94     // (a) stay in the same dialect as far as possible.
95     // (b) preserve order within the ops of one dialect.
96     // (c) emit the terminator last.
97     auto better = [&](Operation* a, Operation* b) {
98       if (a->hasTrait<OpTrait::IsTerminator>() !=
99           b->hasTrait<OpTrait::IsTerminator>()) {
100         return b->hasTrait<OpTrait::IsTerminator>();
101       }
102       bool a_current = a->getName().getDialectNamespace() == current_dialect;
103       bool b_current = b->getName().getDialectNamespace() == current_dialect;
104       if (a_current != b_current) {
105         return a_current;
106       }
107       return position[a] < position[b];  // preserve order
108     };
109 
110     Operation* best = nullptr;
111     for (Operation* op : ready) {
112       if (best == nullptr || better(op, best)) {
113         best = op;
114       }
115     }
116 
117     // Consider this operation emitted, and make its results available.
118     ready.erase(std::find(ready.begin(), ready.end(), best));
119     current_dialect = best->getName().getDialectNamespace();
120     for (Value result : best->getResults()) {
121       todo.push(result);
122     }
123     result.push_back(best);
124   }
125   return result;
126 }
127 
128 }  // namespace
129 
CreateOrderByDialectPass()130 std::unique_ptr<Pass> CreateOrderByDialectPass() {
131   return std::make_unique<OrderByDialectPass>();
132 }
133 
RegisterOrderByDialectPass()134 void RegisterOrderByDialectPass() { registerPass(CreateOrderByDialectPass); }
135 
136 }  // namespace TF
137 }  // namespace mlir
138