xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
17 
18 #include <set>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/canonicalizer.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/grappler/utils/traversal.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/hash.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace tensorflow {
47 namespace grappler {
48 class Cluster;
49 }  // namespace grappler
50 }  // namespace tensorflow
51 
52 using tensorflow::strings::StrCat;
53 
54 namespace tensorflow {
55 namespace grappler {
56 
57 class UniqueNodes {
58  public:
59   // Warning: This is conservative and may fail to find an identical node in
60   // some cases. This happens if the node has large attribute tensor values that
61   // have different proto encoding but identical tensor value.
FindOrAddRepresentative(NodeDef * node)62   NodeDef* FindOrAddRepresentative(NodeDef* node) {
63     uint64 sig = ComputeSignature(*node);
64     std::vector<NodeDef*>& candidates = rep_[sig];
65     for (auto& candidate : candidates) {
66       if ((candidate == node) || SameNode(*candidate, *node)) {
67         return candidate;
68       }
69     }
70     candidates.push_back(node);
71     return node;
72   }
73 
RemoveRepresentative(NodeDef * node)74   void RemoveRepresentative(NodeDef* node) {
75     auto it = memoized_signatures_.find(node);
76     if (it == memoized_signatures_.end()) return;
77 
78     std::vector<NodeDef*>& candidates = rep_[it->second];
79     for (int i = 0, end = candidates.size(); i < end; ++i) {
80       if (candidates[i] == node) {
81         std::swap(candidates[i], candidates[candidates.size() - 1]);
82         candidates.resize(candidates.size() - 1);
83         break;
84       }
85     }
86     memoized_signatures_.erase(node);
87   }
88 
89  private:
90   uint64 ComputeSignature(const NodeDef& node);
91   bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
92 
93   absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
94   absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
95 };
96 
ComputeSignature(const NodeDef & node)97 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
98   auto it = memoized_signatures_.find(&node);
99   if (it != memoized_signatures_.end()) return it->second;
100 
101   uint64 h = Hash64(node.op());
102   h = Hash64Combine(Hash64(node.device()), h);
103 
104   for (const auto& input : node.input()) {
105     const TensorId input_tensor = ParseTensorName(input);
106     uint64 input_hash = Hash64Combine(
107         Hash64(input_tensor.node().data(), input_tensor.node().size()),
108         std::hash<int>()(input_tensor.index()));
109     h = Hash64CombineUnordered(input_hash, h);
110   }
111   for (const auto& attr : node.attr()) {
112     uint64 attr_hash =
113         Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
114     h = Hash64CombineUnordered(attr_hash, h);
115   }
116   memoized_signatures_.emplace(&node, h);
117   return h;
118 }
119 
120 // PRECONDITION:
121 //  Node input orders are assumed to be canonicalized, i.e. control inputs for
122 //  all nodes as well as regular inputs for commutative nodes must be sorted.
SameNode(const NodeDef & node1,const NodeDef & node2) const123 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
124   if (node1.op() != node2.op()) {
125     return false;
126   }
127   if (node1.device() != node2.device()) {
128     return false;
129   }
130   if (node1.input_size() != node2.input_size()) {
131     return false;
132   }
133   if (node1.attr_size() != node2.attr_size()) {
134     return false;
135   }
136 
137   // Compare inputs.
138   auto it1 = node1.input().begin();
139   auto it2 = node2.input().begin();
140   for (; it1 != node1.input().end(); ++it1, ++it2) {
141     if (*it1 != *it2) return false;
142   }
143 
144   // Compare attributes.
145   for (const auto& attr1 : node1.attr()) {
146     auto it = node2.attr().find(attr1.first);
147     if (it == node2.attr().end()) return false;
148     if (!AreAttrValuesEqual(attr1.second, it->second,
149                             /*allow_false_negatives=*/true)) {
150       return false;
151     }
152   }
153 
154   return true;
155 }
156 
CanDedup(const NodeDef & node) const157 bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
158   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
159     return false;
160   }
161   if (IsEnter(node) || IsExit(node)) {
162     return false;
163   }
164   if (node.device().find("SPU") != string::npos) {
165     return false;
166   }
167   // Workaround for Assert and Print mistakenly being labeled as stateful.
168   if (IsAssert(node) || IsPrint(node)) {
169     return true;
170   }
171   return IsFreeOfSideEffect(node);
172 }
173 
DedupComputations(GraphDef * optimized_graph)174 Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
175   CanonicalizeGraph(optimized_graph);
176 
177   GraphTopologyView graph_view;
178   if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
179     LOG(WARNING) << "Failed to initialize GraphTopologyView.";
180     return OkStatus();
181   }
182 
183   // If either node or rep feeds an inplace op, deduping them may cause data
184   // races. For example: If we dedup nodes initializing two independent
185   // inplace accumulations, they will write to the same buffer, clobbering
186   // each other's results.
187   absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
188   for (int i = 0; i < optimized_graph->node_size(); ++i) {
189     const NodeDef& root = optimized_graph->node(i);
190     if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
191     if (ModifiesInputsInPlace(root)) {
192       const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
193         return node->op() == root.op() || !NeverForwardsInputs(*node);
194       };
195 
196       DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
197                    DfsPredicates::Advance(is_continue_traversal),
198                    DfsCallbacks::PreOrder([&](const NodeDef* node) {
199                      feeds_inplace_op.insert(node);
200                    }));
201     }
202   }
203 
204   std::vector<bool> can_dedup(optimized_graph->node_size());
205   for (int i = 0; i < optimized_graph->node_size(); ++i) {
206     const NodeDef& node = optimized_graph->node(i);
207     can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
208                    CanDedup(node);
209   }
210 
211   bool stop = true;
212   std::set<int> duplicates;
213   UniqueNodes nodes;
214   NodeMap node_map(optimized_graph);
215   do {
216     stop = true;
217     for (int i = 0; i < optimized_graph->node_size(); ++i) {
218       if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
219         continue;
220       }
221       NodeDef* node = optimized_graph->mutable_node(i);
222       NodeDef* rep = nodes.FindOrAddRepresentative(node);
223       if (rep == node) {
224         continue;
225       }
226       // Make a copy since we mutate the set below.
227       const auto fanouts = node_map.GetOutputs(node->name());
228       for (NodeDef* fanout : fanouts) {
229         // Update consumers of node.
230         bool updated_fanout = false;
231         for (int i = 0; i < fanout->input_size(); ++i) {
232           string* fanout_input = fanout->mutable_input(i);
233 
234           const int position =
235               NodePositionIfSameNode(*fanout_input, node->name());
236           // Update name in-place.
237           if (position < -1) {
238             continue;
239           } else {
240             if (!updated_fanout) {
241               // The signature of the fanout node will change. Remove it from
242               // nodes.
243               nodes.RemoveRepresentative(fanout);
244             }
245             updated_fanout = true;
246             if (position > 0) {
247               *fanout_input = StrCat(rep->name(), ":", position);
248             } else if (position == 0) {
249               *fanout_input = rep->name();
250             } else {
251               *fanout_input = StrCat("^", rep->name());
252             }
253           }
254         }
255         if (updated_fanout) {
256           node_map.UpdateInput(fanout->name(), node->name(), rep->name());
257           CanonicalizeNode(fanout);
258         }
259       }
260       if (fetch_nodes_known_) {
261         node->Clear();
262       }
263       duplicates.insert(i);
264       stop = false;
265     }
266   } while (!stop);
267 
268   // Delete duplicates
269   if (fetch_nodes_known_ && !duplicates.empty()) {
270     EraseNodesFromGraph(duplicates, optimized_graph);
271   }
272 
273   return OkStatus();
274 }
275 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)276 Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
277                                            const GrapplerItem& item,
278                                            GraphDef* optimized_graph) {
279   // Set up helper data structures.
280   nodes_to_preserve_ = item.NodesToPreserve();
281   fetch_nodes_known_ = !item.fetch.empty();
282   *optimized_graph = item.graph;
283 
284   // Perform topological sort on the graph in order to help DedupComputations
285   // optimize larger subgraphs starting from the roots with more inputs.
286   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
287   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
288 
289   return DedupComputations(optimized_graph);
290 }
291 
292 }  // namespace grappler
293 }  // namespace tensorflow
294