1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/const_dedupe_hoist/pass.h"
17
18 #include <forward_list>
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Visitors.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/ops.h"
33 #include "tensorflow/core/ir/utility.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/transforms/pass_detail.h"
36
37 namespace mlir {
38 namespace tfg {
39
40 namespace {
41
42 struct DedupeAndHoistConstantPass
43 : DedupeAndHoistConstantBase<DedupeAndHoistConstantPass> {
initializemlir::tfg::__anon8f90a3de0111::DedupeAndHoistConstantPass44 LogicalResult initialize(MLIRContext* context) override {
45 dtype_id = StringAttr::get(context, "dtype");
46 name_id = StringAttr::get(context, TFGraphDialect::getNameAttrKey());
47 t_id = StringAttr::get(context, "T");
48 tfg_const = StringAttr::get(context, "tfg.Const");
49 value_id = StringAttr::get(context, "value");
50 mlir_context = context;
51 return success();
52 }
53 void runOnOperation() override;
54
55 void RunOnGraphOrFuncOp(Operation* op);
56
57 // Propagate all control deps of op to its users.
58 void PropagateEdges(Operation* op);
59
60 // Returns whether identity op is required.
61 bool RequiresIdentity(Operation* op);
62
63 // Returns an identity op with same attributes and control deps as input and
64 // value as operand.
65 Operation* BuildIdentity(Operation* input, Operation* value);
66
67 FunctionTable* function_table;
68
69 // Identifiers used for operation type & attributes checked.
70 StringAttr dtype_id;
71 StringAttr name_id;
72 StringAttr tfg_const;
73 StringAttr t_id;
74 StringAttr value_id;
75 MLIRContext* mlir_context;
76 };
77
78 } // namespace
79
80 // Checking ConstOp's for equivalence skipping names.
81 struct EquivalentConst : public llvm::DenseMapInfo<Operation*> {
getHashValuemlir::tfg::EquivalentConst82 static unsigned getHashValue(const Operation* op_c) {
83 auto* op = const_cast<Operation*>(op_c);
84 auto hash = llvm::hash_value("");
85 // We know only TFG ConstOp will be here, so can query the name attribute
86 // from it.
87 StringAttr name_id =
88 cast<TFGraphDialect>(op->getDialect())->getNameAttrIdentifier();
89 for (auto attr : op->getAttrs()) {
90 // Skip name from hash.
91 if (attr.getName() == name_id) continue;
92 hash = llvm::hash_combine(hash, attr.getValue());
93 }
94 return hash;
95 }
96
isEqualmlir::tfg::EquivalentConst97 static bool isEqual(const Operation* lhs_c, const Operation* rhs_c) {
98 auto* lhs = const_cast<Operation*>(lhs_c);
99 auto* rhs = const_cast<Operation*>(rhs_c);
100 if (lhs == rhs) return true;
101 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
102 rhs == getTombstoneKey() || rhs == getEmptyKey())
103 return false;
104 // Attributes are stored sorted by name.
105 StringAttr name_id =
106 cast<TFGraphDialect>(lhs->getDialect())->getNameAttrIdentifier();
107 for (auto it : llvm::zip(lhs->getAttrs(), rhs->getAttrs())) {
108 NamedAttribute lhs_attr = std::get<0>(it);
109 NamedAttribute rhs_attr = std::get<1>(it);
110 if (lhs_attr.getName() != rhs_attr.getName()) return false;
111 if (lhs_attr.getValue() != rhs_attr.getValue()) {
112 if (lhs_attr.getName() != name_id) return false;
113 }
114 }
115 return true;
116 }
117 };
118
PropagateEdges(Operation * op)119 void DedupeAndHoistConstantPass::PropagateEdges(Operation* op) {
120 SmallVector<Operation*> users(op->getUsers());
121 Value new_const = op->getResult(1);
122 // ConstOp's only have control operands, so any operand of the op is a
123 // control operand.
124 for (Operation* user : users) {
125 SetVector<Value> operands;
126 auto add_ctl_operands = [&](Operation* operation) {
127 // Filter out where there is a control edge already.
128 auto op_operands =
129 llvm::make_filter_range(TFOp(operation).getControlOperands(),
130 [&](Value v) { return v == new_const; });
131 operands.insert(op_operands.begin(), op_operands.end());
132 };
133 add_ctl_operands(user);
134 add_ctl_operands(op);
135 // Erase all control operands (effectively deduping control operands).
136 // TODO(jpienaar): This could be optimized by avoiding cases where we don't
137 // need to dedupe etc.
138 TFOp tf_user(user);
139 user->eraseOperands(tf_user.getNonControlOperands().size(),
140 tf_user.getControlOperands().size());
141 user->insertOperands(user->getNumOperands(), operands.takeVector());
142 }
143 }
144
RequiresIdentity(Operation * op)145 bool DedupeAndHoistConstantPass::RequiresIdentity(Operation* op) {
146 for (Operation* user : op->getUsers())
147 if (function_table->MayBeCall(user)) return true;
148 return false;
149 }
150
BuildIdentity(Operation * input,Operation * value)151 Operation* DedupeAndHoistConstantPass::BuildIdentity(Operation* input,
152 Operation* value) {
153 OperationState state(input->getLoc(), "tfg.Identity");
154 state.addTypes(input->getResultTypes());
155 state.addOperands({value->getResult(0)});
156
157 SetVector<Value> operands;
158 auto op_operands = TFOp(input).getControlOperands();
159 operands.insert(op_operands.begin(), op_operands.end());
160 state.addOperands(operands.takeVector());
161
162 // All attributes except for value, name, and dtype (which is remapped to I)
163 auto attrs = llvm::to_vector(
164 llvm::make_filter_range(input->getAttrs(), [&](NamedAttribute attr) {
165 return attr.getName() != value_id && attr.getName() != dtype_id &&
166 attr.getName() != name_id;
167 }));
168 state.addAttributes(attrs);
169
170 // Concat `const_dedupe_hoist` prefix with the const op name to avoid name
171 // collision.
172 // TODO(rdzhabarov): Improve name generation to avoid potential collisions.
173 if (auto const_name = input->getAttrOfType<StringAttr>(name_id)) {
174 state.addAttribute(
175 name_id, StringAttr::get(mlir_context, "const_dedupe_hoist/" +
176 const_name.getValue()));
177 }
178 // Map dtype to T attribute.
179 state.addAttribute(t_id, input->getAttr(dtype_id));
180 return OpBuilder(input).create(state);
181 }
182
RunOnGraphOrFuncOp(Operation * op)183 void DedupeAndHoistConstantPass::RunOnGraphOrFuncOp(Operation* op) {
184 DenseMap<Operation*, std::vector<Operation*>, EquivalentConst> constant_ops;
185
186 // Collect all small constant ops grouped by attributes.
187 op->walk([&](Operation* inner_op) {
188 if (inner_op->getName().getIdentifier() != tfg_const) return;
189
190 ElementsAttr val = inner_op->getAttr(value_id).cast<ElementsAttr>();
191 if (val.getNumElements() > max_size_) return;
192 constant_ops[inner_op].push_back(inner_op);
193 });
194
195 // Iterate over all constant ops and perform constant deduping.
196 for (const auto& it : constant_ops) {
197 if (it.second.size() > 1) {
198 Operation* top = OpBuilder(it.second.front()).clone(*it.second.front());
199 top->eraseOperands(0, top->getNumOperands());
200
201 for (auto jt : it.second) {
202 if (!assume_strict_calls_ && RequiresIdentity(jt)) {
203 // Create a new identity node with all the control deps of the node
204 // being replaced that forwards the value of top.
205 Operation* id = BuildIdentity(jt, top);
206 jt->replaceAllUsesWith(id);
207 } else {
208 // Just propagate control deps from the duplicated op to its users and
209 // then replace uses with top.
210 PropagateEdges(jt);
211 jt->replaceAllUsesWith(top);
212 }
213 jt->erase();
214 }
215 }
216 }
217 }
218
runOnOperation()219 void DedupeAndHoistConstantPass::runOnOperation() {
220 markAnalysesPreserved<FunctionTable>();
221
222 ModuleOp module = getOperation();
223 if (!assume_strict_calls_) {
224 function_table = &getAnalysis<FunctionTable>();
225 assume_strict_calls_ = function_table->empty();
226 }
227
228 for (auto& op : module.getOps())
229 // Only hoist inside Graph or GraphFunc ops.
230 if (isa<GraphFuncOp, GraphOp>(op)) RunOnGraphOrFuncOp(&op);
231 }
232
233 } // namespace tfg
234 } // namespace mlir
235
CreateDedupeAndHoistConstantPass()236 std::unique_ptr<mlir::Pass> mlir::tfg::CreateDedupeAndHoistConstantPass() {
237 return std::make_unique<mlir::tfg::DedupeAndHoistConstantPass>();
238 }
239