xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/const_dedupe_hoist/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/const_dedupe_hoist/pass.h"
17 
18 #include <forward_list>
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Visitors.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/ops.h"
33 #include "tensorflow/core/ir/utility.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/transforms/pass_detail.h"
36 
37 namespace mlir {
38 namespace tfg {
39 
40 namespace {
41 
42 struct DedupeAndHoistConstantPass
43     : DedupeAndHoistConstantBase<DedupeAndHoistConstantPass> {
initializemlir::tfg::__anon8f90a3de0111::DedupeAndHoistConstantPass44   LogicalResult initialize(MLIRContext* context) override {
45     dtype_id = StringAttr::get(context, "dtype");
46     name_id = StringAttr::get(context, TFGraphDialect::getNameAttrKey());
47     t_id = StringAttr::get(context, "T");
48     tfg_const = StringAttr::get(context, "tfg.Const");
49     value_id = StringAttr::get(context, "value");
50     mlir_context = context;
51     return success();
52   }
53   void runOnOperation() override;
54 
55   void RunOnGraphOrFuncOp(Operation* op);
56 
57   // Propagate all control deps of op to its users.
58   void PropagateEdges(Operation* op);
59 
60   // Returns whether identity op is required.
61   bool RequiresIdentity(Operation* op);
62 
63   // Returns an identity op with same attributes and control deps as input and
64   // value as operand.
65   Operation* BuildIdentity(Operation* input, Operation* value);
66 
67   FunctionTable* function_table;
68 
69   // Identifiers used for operation type & attributes checked.
70   StringAttr dtype_id;
71   StringAttr name_id;
72   StringAttr tfg_const;
73   StringAttr t_id;
74   StringAttr value_id;
75   MLIRContext* mlir_context;
76 };
77 
78 }  // namespace
79 
80 // Checking ConstOp's for equivalence skipping names.
81 struct EquivalentConst : public llvm::DenseMapInfo<Operation*> {
getHashValuemlir::tfg::EquivalentConst82   static unsigned getHashValue(const Operation* op_c) {
83     auto* op = const_cast<Operation*>(op_c);
84     auto hash = llvm::hash_value("");
85     // We know only TFG ConstOp will be here, so can query the name attribute
86     // from it.
87     StringAttr name_id =
88         cast<TFGraphDialect>(op->getDialect())->getNameAttrIdentifier();
89     for (auto attr : op->getAttrs()) {
90       // Skip name from hash.
91       if (attr.getName() == name_id) continue;
92       hash = llvm::hash_combine(hash, attr.getValue());
93     }
94     return hash;
95   }
96 
isEqualmlir::tfg::EquivalentConst97   static bool isEqual(const Operation* lhs_c, const Operation* rhs_c) {
98     auto* lhs = const_cast<Operation*>(lhs_c);
99     auto* rhs = const_cast<Operation*>(rhs_c);
100     if (lhs == rhs) return true;
101     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
102         rhs == getTombstoneKey() || rhs == getEmptyKey())
103       return false;
104     // Attributes are stored sorted by name.
105     StringAttr name_id =
106         cast<TFGraphDialect>(lhs->getDialect())->getNameAttrIdentifier();
107     for (auto it : llvm::zip(lhs->getAttrs(), rhs->getAttrs())) {
108       NamedAttribute lhs_attr = std::get<0>(it);
109       NamedAttribute rhs_attr = std::get<1>(it);
110       if (lhs_attr.getName() != rhs_attr.getName()) return false;
111       if (lhs_attr.getValue() != rhs_attr.getValue()) {
112         if (lhs_attr.getName() != name_id) return false;
113       }
114     }
115     return true;
116   }
117 };
118 
PropagateEdges(Operation * op)119 void DedupeAndHoistConstantPass::PropagateEdges(Operation* op) {
120   SmallVector<Operation*> users(op->getUsers());
121   Value new_const = op->getResult(1);
122   // ConstOp's only have control operands, so any operand of the op is a
123   // control operand.
124   for (Operation* user : users) {
125     SetVector<Value> operands;
126     auto add_ctl_operands = [&](Operation* operation) {
127       // Filter out where there is a control edge already.
128       auto op_operands =
129           llvm::make_filter_range(TFOp(operation).getControlOperands(),
130                                   [&](Value v) { return v == new_const; });
131       operands.insert(op_operands.begin(), op_operands.end());
132     };
133     add_ctl_operands(user);
134     add_ctl_operands(op);
135     // Erase all control operands (effectively deduping control operands).
136     // TODO(jpienaar): This could be optimized by avoiding cases where we don't
137     // need to dedupe etc.
138     TFOp tf_user(user);
139     user->eraseOperands(tf_user.getNonControlOperands().size(),
140                         tf_user.getControlOperands().size());
141     user->insertOperands(user->getNumOperands(), operands.takeVector());
142   }
143 }
144 
RequiresIdentity(Operation * op)145 bool DedupeAndHoistConstantPass::RequiresIdentity(Operation* op) {
146   for (Operation* user : op->getUsers())
147     if (function_table->MayBeCall(user)) return true;
148   return false;
149 }
150 
BuildIdentity(Operation * input,Operation * value)151 Operation* DedupeAndHoistConstantPass::BuildIdentity(Operation* input,
152                                                      Operation* value) {
153   OperationState state(input->getLoc(), "tfg.Identity");
154   state.addTypes(input->getResultTypes());
155   state.addOperands({value->getResult(0)});
156 
157   SetVector<Value> operands;
158   auto op_operands = TFOp(input).getControlOperands();
159   operands.insert(op_operands.begin(), op_operands.end());
160   state.addOperands(operands.takeVector());
161 
162   // All attributes except for value, name, and dtype (which is remapped to I)
163   auto attrs = llvm::to_vector(
164       llvm::make_filter_range(input->getAttrs(), [&](NamedAttribute attr) {
165         return attr.getName() != value_id && attr.getName() != dtype_id &&
166                attr.getName() != name_id;
167       }));
168   state.addAttributes(attrs);
169 
170   // Concat `const_dedupe_hoist` prefix with the const op name to avoid name
171   // collision.
172   // TODO(rdzhabarov): Improve name generation to avoid potential collisions.
173   if (auto const_name = input->getAttrOfType<StringAttr>(name_id)) {
174     state.addAttribute(
175         name_id, StringAttr::get(mlir_context, "const_dedupe_hoist/" +
176                                                    const_name.getValue()));
177   }
178   // Map dtype to T attribute.
179   state.addAttribute(t_id, input->getAttr(dtype_id));
180   return OpBuilder(input).create(state);
181 }
182 
RunOnGraphOrFuncOp(Operation * op)183 void DedupeAndHoistConstantPass::RunOnGraphOrFuncOp(Operation* op) {
184   DenseMap<Operation*, std::vector<Operation*>, EquivalentConst> constant_ops;
185 
186   // Collect all small constant ops grouped by attributes.
187   op->walk([&](Operation* inner_op) {
188     if (inner_op->getName().getIdentifier() != tfg_const) return;
189 
190     ElementsAttr val = inner_op->getAttr(value_id).cast<ElementsAttr>();
191     if (val.getNumElements() > max_size_) return;
192     constant_ops[inner_op].push_back(inner_op);
193   });
194 
195   // Iterate over all constant ops and perform constant deduping.
196   for (const auto& it : constant_ops) {
197     if (it.second.size() > 1) {
198       Operation* top = OpBuilder(it.second.front()).clone(*it.second.front());
199       top->eraseOperands(0, top->getNumOperands());
200 
201       for (auto jt : it.second) {
202         if (!assume_strict_calls_ && RequiresIdentity(jt)) {
203           // Create a new identity node with all the control deps of the node
204           // being replaced that forwards the value of top.
205           Operation* id = BuildIdentity(jt, top);
206           jt->replaceAllUsesWith(id);
207         } else {
208           // Just propagate control deps from the duplicated op to its users and
209           // then replace uses with top.
210           PropagateEdges(jt);
211           jt->replaceAllUsesWith(top);
212         }
213         jt->erase();
214       }
215     }
216   }
217 }
218 
runOnOperation()219 void DedupeAndHoistConstantPass::runOnOperation() {
220   markAnalysesPreserved<FunctionTable>();
221 
222   ModuleOp module = getOperation();
223   if (!assume_strict_calls_) {
224     function_table = &getAnalysis<FunctionTable>();
225     assume_strict_calls_ = function_table->empty();
226   }
227 
228   for (auto& op : module.getOps())
229     // Only hoist inside Graph or GraphFunc ops.
230     if (isa<GraphFuncOp, GraphOp>(op)) RunOnGraphOrFuncOp(&op);
231 }
232 
233 }  // namespace tfg
234 }  // namespace mlir
235 
CreateDedupeAndHoistConstantPass()236 std::unique_ptr<mlir::Pass> mlir::tfg::CreateDedupeAndHoistConstantPass() {
237   return std::make_unique<mlir::tfg::DedupeAndHoistConstantPass>();
238 }
239