xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Block.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFDevice {
33 
34 namespace {
35 
36 constexpr char kFuncAttr[] = "func";
37 
38 struct ClusterOutliningPass
39     : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40   void runOnOperation() override;
41 };
42 
43 struct LaunchOutliningPass
44     : public TF::LaunchOutliningPassBase<LaunchOutliningPass> {
45   void runOnOperation() override;
46 };
47 
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)48 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
49                                     OpBuilder* builder) {
50   builder->create<func::ReturnOp>(cluster_return_op.getLoc(),
51                                   cluster_return_op.getOperands());
52   cluster_return_op.erase();
53 }
54 
55 // Builds a function that outlines region attached to cluster_op or launch_op,
56 // and inserts built function into given module.
57 template <typename ClusterOrLaunchOp>
BuildFunction(llvm::ArrayRef<Value> live_ins,ClusterOrLaunchOp op,SymbolTable * symbol_table,OpBuilder * builder)58 func::FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins, ClusterOrLaunchOp op,
59                            SymbolTable* symbol_table, OpBuilder* builder) {
60   llvm::SmallVector<Type, 4> operand_types;
61   operand_types.reserve(live_ins.size());
62   for (Value v : live_ins) operand_types.emplace_back(v.getType());
63 
64   auto func_type = builder->getFunctionType(operand_types, op.getResultTypes());
65 
66   // TODO(lyandy): Define better name for outlined function. Potentially some
67   // name can be added during cluster formation.
68   func::FuncOp outlined_func =
69       func::FuncOp::create(op.getLoc(), "_func", func_type);
70 
71   // This function is not externally visible and marking it private would allow
72   // symbol-dce pass to remove it when it is not referenced anymore.
73   outlined_func.setPrivate();
74 
75   // Create function body.
76   Block* outlined_func_block = outlined_func.addEntryBlock();
77 
78   // Replace uses of live-in values within cluster_op region with function
79   // arguments.
80   Region& op_region = op.body();
81   for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
82     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op_region);
83   }
84 
85   // Move all instructions in cluster_op into outlined_function's only block.
86   auto& op_body = op.GetBody().getOperations();
87   outlined_func_block->getOperations().splice(
88       outlined_func_block->end(), op_body, op_body.begin(), op_body.end());
89 
90   // Replace `tf_device.return` terminator with `std.return` in function
91   // body.
92   auto return_op =
93       cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
94   builder->setInsertionPoint(return_op);
95   ReplaceClusterReturnWithReturn(return_op, builder);
96 
97   symbol_table->insert(outlined_func);
98   return outlined_func;
99 }
100 
101 // Outlines body of `tf_device.cluster` into a function and create a
102 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
103 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)104 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
105                     OpBuilder* builder) {
106   llvm::SetVector<Value> live_ins;
107   getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
108 
109   func::FuncOp outlined_func =
110       BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
111   cluster_op->setAttr(
112       builder->getStringAttr(kFuncAttr),
113       mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName()));
114 
115   builder->setInsertionPoint(cluster_op);
116   auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
117       cluster_op.getLoc(), outlined_func.getFunctionType().getResults(),
118       live_ins.getArrayRef(), cluster_op->getAttrs());
119 
120   cluster_op.replaceAllUsesWith(cluster_func_op);
121   cluster_op.erase();
122 }
123 
124 // Outlines body of `tf_device.launch` into a function and create a
125 // `tf_device.launch_func` to invoke that function. `tf_device.launch` is
126 // removed afterwards.`
OutlineLaunch(tf_device::LaunchOp launch_op,SymbolTable * symbol_table,OpBuilder * builder)127 void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
128                    OpBuilder* builder) {
129   llvm::SetVector<Value> live_ins;
130   getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
131 
132   func::FuncOp outlined_func =
133       BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder);
134   launch_op->setAttr(
135       builder->getStringAttr(kFuncAttr),
136       mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName()));
137 
138   builder->setInsertionPoint(launch_op);
139   auto cluster_func_op = builder->create<tf_device::LaunchFuncOp>(
140       launch_op.getLoc(), outlined_func.getFunctionType().getResults(),
141       live_ins.getArrayRef(), launch_op->getAttrs());
142 
143   launch_op.replaceAllUsesWith(cluster_func_op);
144   launch_op.erase();
145 }
146 
runOnOperation()147 void ClusterOutliningPass::runOnOperation() {
148   ModuleOp module = getOperation();
149   SymbolTable symbol_table(module);
150   OpBuilder builder(module.getContext());
151   module.walk([&](tf_device::ClusterOp cluster) {
152     OutlineCluster(cluster, &symbol_table, &builder);
153   });
154 }
155 
runOnOperation()156 void LaunchOutliningPass::runOnOperation() {
157   ModuleOp module = getOperation();
158   SymbolTable symbol_table(module);
159   OpBuilder builder(module.getContext());
160   module.walk([&](tf_device::LaunchOp launch) {
161     OutlineLaunch(launch, &symbol_table, &builder);
162   });
163 }
164 
165 }  // namespace
166 
CreateClusterOutliningPass()167 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
168   return std::make_unique<ClusterOutliningPass>();
169 }
170 
CreateLaunchOutliningPass()171 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass() {
172   return std::make_unique<LaunchOutliningPass>();
173 }
174 
175 }  // namespace TFDevice
176 }  // namespace mlir
177