xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27 
28 namespace mlir {
29 namespace TF {
30 namespace {
31 
32 void wrapOpsInFunction(std::vector<Operation*>& ops, int function_id);
33 
34 class GroupByDialectPass : public GroupByDialectPassBase<GroupByDialectPass> {
35  public:
runOnOperation()36   void runOnOperation() override {
37     mlir::func::FuncOp func = getOperation();
38     int function_id = 0;
39 
40     for (Block& block : func.getBody().getBlocks()) {
41       StringRef current_dialect("<none>");
42       std::vector<Operation*> ops;
43       for (Operation& op : block.getOperations()) {
44         StringRef dialect = op.getName().getDialectNamespace();
45         if (dialect != current_dialect) {
46           if (!top_level_dialects_.contains(current_dialect)) {
47             wrapOpsInFunction(ops, function_id++);
48           }
49           ops.clear();
50           current_dialect = dialect;
51         }
52         ops.push_back(&op);
53       }
54       if (!top_level_dialects_.contains(current_dialect)) {
55         wrapOpsInFunction(ops, function_id++);
56       }
57     }
58   }
59 
60   llvm::SmallDenseSet<StringRef> top_level_dialects_ = {"glue", "func"};
61 };
62 
63 // Compute the set of all values which are inputs to `ops`, but not generated
64 // by an operation in `ops`, and all outputs which are used outside of `ops.
computeInputsOutputs(std::vector<Operation * > & ops,std::vector<Value> * inputs,std::vector<Value> * outputs)65 void computeInputsOutputs(std::vector<Operation*>& ops,
66                           std::vector<Value>* inputs,
67                           std::vector<Value>* outputs) {
68   // All operations.
69   llvm::DenseSet<Operation*> all_operations(ops.begin(), ops.end());
70 
71   // All results of all ops.
72   llvm::DenseSet<Value> all_internal_results;
73   for (Operation* op : ops) {
74     for (Value result : op->getResults()) {
75       all_internal_results.insert(result);
76     }
77   }
78 
79   // All operand values in our set not produced as result by some op in our set.
80   llvm::DenseSet<Value> inputs_seen;
81   for (Operation* op : ops) {
82     for (Value operand : op->getOperands()) {
83       if (!all_internal_results.contains(operand)) {
84         if (!inputs_seen.contains(operand)) {
85           inputs->push_back(operand);
86           inputs_seen.insert(operand);
87         }
88       }
89     }
90   }
91 
92   // All results in our set that have a user outside our set.
93   llvm::DenseSet<Value> outputs_seen;
94   for (Operation* op : ops) {
95     for (Value result : op->getResults()) {
96       for (auto& use : result.getUses()) {
97         if (!all_operations.contains(use.getOwner())) {
98           if (!outputs_seen.contains(result)) {
99             outputs->push_back(result);
100             outputs_seen.insert(result);
101           }
102           break;
103         }
104       }
105     }
106   }
107 }
108 
wrapOpsInFunction(std::vector<Operation * > & ops,int function_id)109 void wrapOpsInFunction(std::vector<Operation*>& ops, int function_id) {
110   if (ops.empty()) {
111     return;
112   }
113 
114   std::vector<Value> inputs;
115   std::vector<Value> outputs;
116   computeInputsOutputs(ops, &inputs, &outputs);
117 
118   std::vector<Type> input_types;
119   std::vector<Type> output_types;
120 
121   input_types.reserve(inputs.size());
122   for (Value v : inputs) {
123     input_types.push_back(v.getType());
124   }
125   output_types.reserve(outputs.size());
126   for (Value v : outputs) {
127     output_types.push_back(v.getType());
128   }
129 
130   // Create the function.
131   MLIRContext* context = ops[0]->getContext();
132   StringRef dialect = ops[0]->getName().getDialectNamespace();
133   OpBuilder builder(context);
134   builder.setInsertionPointToEnd(ops[0]->getParentOp()->getBlock());
135   auto func = builder.create<mlir::func::FuncOp>(
136       ops[0]->getLoc(), dialect.str() + std::to_string(function_id),
137       builder.getFunctionType(input_types, output_types));
138   func->setAttr("dialect", builder.getStringAttr(dialect));
139   auto block = func.addEntryBlock();
140 
141   llvm::DenseSet<Operation*> all_operations(ops.begin(), ops.end());
142   for (BlockArgument& arg : block->getArguments()) {
143     inputs[arg.getArgNumber()].replaceUsesWithIf(arg, [=](OpOperand& o) {
144       // Within the operations we're moving, we need to replace uses of
145       // values generated elsewhere.
146       return all_operations.contains(o.getOwner());
147     });
148   }
149 
150   // Insert function call.
151   builder.setInsertionPoint(ops[0]);
152   auto call = builder.create<mlir::func::CallOp>(
153       ops[0]->getLoc(), func.getFunctionType().getResults(), func.getSymName(),
154       inputs);
155   for (auto& v : llvm::enumerate(outputs)) {
156     v.value().replaceUsesWithIf(call.getResult(v.index()), [=](OpOperand& o) {
157       // Outside of what we're moving, results of our operations need to
158       // be replaced by results from the function call.
159       return !all_operations.contains(o.getOwner());
160     });
161   }
162 
163   // Move ops inside function & add return.
164   builder.setInsertionPointToEnd(block);
165   for (Operation* op : ops) {
166     op->remove();
167     builder.insert(op);
168   }
169   builder.create<mlir::func::ReturnOp>(ops[0]->getLoc(), outputs);
170 }
171 
172 }  // namespace
173 
CreateGroupByDialectPass()174 std::unique_ptr<Pass> CreateGroupByDialectPass() {
175   return std::make_unique<GroupByDialectPass>();
176 }
177 
RegisterGroupByDialectPass()178 void RegisterGroupByDialectPass() { registerPass(CreateGroupByDialectPass); }
179 
180 }  // namespace TF
181 }  // namespace mlir
182