1 /* Copyright 2022 Google Inc. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27
28 namespace mlir {
29 namespace TF {
30 namespace {
31
32 void wrapOpsInFunction(std::vector<Operation*>& ops, int function_id);
33
34 class GroupByDialectPass : public GroupByDialectPassBase<GroupByDialectPass> {
35 public:
runOnOperation()36 void runOnOperation() override {
37 mlir::func::FuncOp func = getOperation();
38 int function_id = 0;
39
40 for (Block& block : func.getBody().getBlocks()) {
41 StringRef current_dialect("<none>");
42 std::vector<Operation*> ops;
43 for (Operation& op : block.getOperations()) {
44 StringRef dialect = op.getName().getDialectNamespace();
45 if (dialect != current_dialect) {
46 if (!top_level_dialects_.contains(current_dialect)) {
47 wrapOpsInFunction(ops, function_id++);
48 }
49 ops.clear();
50 current_dialect = dialect;
51 }
52 ops.push_back(&op);
53 }
54 if (!top_level_dialects_.contains(current_dialect)) {
55 wrapOpsInFunction(ops, function_id++);
56 }
57 }
58 }
59
60 llvm::SmallDenseSet<StringRef> top_level_dialects_ = {"glue", "func"};
61 };
62
63 // Compute the set of all values which are inputs to `ops`, but not generated
64 // by an operation in `ops`, and all outputs which are used outside of `ops.
computeInputsOutputs(std::vector<Operation * > & ops,std::vector<Value> * inputs,std::vector<Value> * outputs)65 void computeInputsOutputs(std::vector<Operation*>& ops,
66 std::vector<Value>* inputs,
67 std::vector<Value>* outputs) {
68 // All operations.
69 llvm::DenseSet<Operation*> all_operations(ops.begin(), ops.end());
70
71 // All results of all ops.
72 llvm::DenseSet<Value> all_internal_results;
73 for (Operation* op : ops) {
74 for (Value result : op->getResults()) {
75 all_internal_results.insert(result);
76 }
77 }
78
79 // All operand values in our set not produced as result by some op in our set.
80 llvm::DenseSet<Value> inputs_seen;
81 for (Operation* op : ops) {
82 for (Value operand : op->getOperands()) {
83 if (!all_internal_results.contains(operand)) {
84 if (!inputs_seen.contains(operand)) {
85 inputs->push_back(operand);
86 inputs_seen.insert(operand);
87 }
88 }
89 }
90 }
91
92 // All results in our set that have a user outside our set.
93 llvm::DenseSet<Value> outputs_seen;
94 for (Operation* op : ops) {
95 for (Value result : op->getResults()) {
96 for (auto& use : result.getUses()) {
97 if (!all_operations.contains(use.getOwner())) {
98 if (!outputs_seen.contains(result)) {
99 outputs->push_back(result);
100 outputs_seen.insert(result);
101 }
102 break;
103 }
104 }
105 }
106 }
107 }
108
wrapOpsInFunction(std::vector<Operation * > & ops,int function_id)109 void wrapOpsInFunction(std::vector<Operation*>& ops, int function_id) {
110 if (ops.empty()) {
111 return;
112 }
113
114 std::vector<Value> inputs;
115 std::vector<Value> outputs;
116 computeInputsOutputs(ops, &inputs, &outputs);
117
118 std::vector<Type> input_types;
119 std::vector<Type> output_types;
120
121 input_types.reserve(inputs.size());
122 for (Value v : inputs) {
123 input_types.push_back(v.getType());
124 }
125 output_types.reserve(outputs.size());
126 for (Value v : outputs) {
127 output_types.push_back(v.getType());
128 }
129
130 // Create the function.
131 MLIRContext* context = ops[0]->getContext();
132 StringRef dialect = ops[0]->getName().getDialectNamespace();
133 OpBuilder builder(context);
134 builder.setInsertionPointToEnd(ops[0]->getParentOp()->getBlock());
135 auto func = builder.create<mlir::func::FuncOp>(
136 ops[0]->getLoc(), dialect.str() + std::to_string(function_id),
137 builder.getFunctionType(input_types, output_types));
138 func->setAttr("dialect", builder.getStringAttr(dialect));
139 auto block = func.addEntryBlock();
140
141 llvm::DenseSet<Operation*> all_operations(ops.begin(), ops.end());
142 for (BlockArgument& arg : block->getArguments()) {
143 inputs[arg.getArgNumber()].replaceUsesWithIf(arg, [=](OpOperand& o) {
144 // Within the operations we're moving, we need to replace uses of
145 // values generated elsewhere.
146 return all_operations.contains(o.getOwner());
147 });
148 }
149
150 // Insert function call.
151 builder.setInsertionPoint(ops[0]);
152 auto call = builder.create<mlir::func::CallOp>(
153 ops[0]->getLoc(), func.getFunctionType().getResults(), func.getSymName(),
154 inputs);
155 for (auto& v : llvm::enumerate(outputs)) {
156 v.value().replaceUsesWithIf(call.getResult(v.index()), [=](OpOperand& o) {
157 // Outside of what we're moving, results of our operations need to
158 // be replaced by results from the function call.
159 return !all_operations.contains(o.getOwner());
160 });
161 }
162
163 // Move ops inside function & add return.
164 builder.setInsertionPointToEnd(block);
165 for (Operation* op : ops) {
166 op->remove();
167 builder.insert(op);
168 }
169 builder.create<mlir::func::ReturnOp>(ops[0]->getLoc(), outputs);
170 }
171
172 } // namespace
173
CreateGroupByDialectPass()174 std::unique_ptr<Pass> CreateGroupByDialectPass() {
175 return std::make_unique<GroupByDialectPass>();
176 }
177
RegisterGroupByDialectPass()178 void RegisterGroupByDialectPass() { registerPass(CreateGroupByDialectPass); }
179
180 } // namespace TF
181 } // namespace mlir
182