1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
17
18 #include <set>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/canonicalizer.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/grappler/utils/traversal.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/hash.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/platform/types.h"
45
46 namespace tensorflow {
47 namespace grappler {
48 class Cluster;
49 } // namespace grappler
50 } // namespace tensorflow
51
52 using tensorflow::strings::StrCat;
53
54 namespace tensorflow {
55 namespace grappler {
56
57 class UniqueNodes {
58 public:
59 // Warning: This is conservative and may fail to find an identical node in
60 // some cases. This happens if the node has large attribute tensor values that
61 // have different proto encoding but identical tensor value.
FindOrAddRepresentative(NodeDef * node)62 NodeDef* FindOrAddRepresentative(NodeDef* node) {
63 uint64 sig = ComputeSignature(*node);
64 std::vector<NodeDef*>& candidates = rep_[sig];
65 for (auto& candidate : candidates) {
66 if ((candidate == node) || SameNode(*candidate, *node)) {
67 return candidate;
68 }
69 }
70 candidates.push_back(node);
71 return node;
72 }
73
RemoveRepresentative(NodeDef * node)74 void RemoveRepresentative(NodeDef* node) {
75 auto it = memoized_signatures_.find(node);
76 if (it == memoized_signatures_.end()) return;
77
78 std::vector<NodeDef*>& candidates = rep_[it->second];
79 for (int i = 0, end = candidates.size(); i < end; ++i) {
80 if (candidates[i] == node) {
81 std::swap(candidates[i], candidates[candidates.size() - 1]);
82 candidates.resize(candidates.size() - 1);
83 break;
84 }
85 }
86 memoized_signatures_.erase(node);
87 }
88
89 private:
90 uint64 ComputeSignature(const NodeDef& node);
91 bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
92
93 absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
94 absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
95 };
96
ComputeSignature(const NodeDef & node)97 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
98 auto it = memoized_signatures_.find(&node);
99 if (it != memoized_signatures_.end()) return it->second;
100
101 uint64 h = Hash64(node.op());
102 h = Hash64Combine(Hash64(node.device()), h);
103
104 for (const auto& input : node.input()) {
105 const TensorId input_tensor = ParseTensorName(input);
106 uint64 input_hash = Hash64Combine(
107 Hash64(input_tensor.node().data(), input_tensor.node().size()),
108 std::hash<int>()(input_tensor.index()));
109 h = Hash64CombineUnordered(input_hash, h);
110 }
111 for (const auto& attr : node.attr()) {
112 uint64 attr_hash =
113 Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
114 h = Hash64CombineUnordered(attr_hash, h);
115 }
116 memoized_signatures_.emplace(&node, h);
117 return h;
118 }
119
120 // PRECONDITION:
121 // Node input orders are assumed to be canonicalized, i.e. control inputs for
122 // all nodes as well as regular inputs for commutative nodes must be sorted.
SameNode(const NodeDef & node1,const NodeDef & node2) const123 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
124 if (node1.op() != node2.op()) {
125 return false;
126 }
127 if (node1.device() != node2.device()) {
128 return false;
129 }
130 if (node1.input_size() != node2.input_size()) {
131 return false;
132 }
133 if (node1.attr_size() != node2.attr_size()) {
134 return false;
135 }
136
137 // Compare inputs.
138 auto it1 = node1.input().begin();
139 auto it2 = node2.input().begin();
140 for (; it1 != node1.input().end(); ++it1, ++it2) {
141 if (*it1 != *it2) return false;
142 }
143
144 // Compare attributes.
145 for (const auto& attr1 : node1.attr()) {
146 auto it = node2.attr().find(attr1.first);
147 if (it == node2.attr().end()) return false;
148 if (!AreAttrValuesEqual(attr1.second, it->second,
149 /*allow_false_negatives=*/true)) {
150 return false;
151 }
152 }
153
154 return true;
155 }
156
CanDedup(const NodeDef & node) const157 bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
158 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
159 return false;
160 }
161 if (IsEnter(node) || IsExit(node)) {
162 return false;
163 }
164 if (node.device().find("SPU") != string::npos) {
165 return false;
166 }
167 // Workaround for Assert and Print mistakenly being labeled as stateful.
168 if (IsAssert(node) || IsPrint(node)) {
169 return true;
170 }
171 return IsFreeOfSideEffect(node);
172 }
173
DedupComputations(GraphDef * optimized_graph)174 Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
175 CanonicalizeGraph(optimized_graph);
176
177 GraphTopologyView graph_view;
178 if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
179 LOG(WARNING) << "Failed to initialize GraphTopologyView.";
180 return OkStatus();
181 }
182
183 // If either node or rep feeds an inplace op, deduping them may cause data
184 // races. For example: If we dedup nodes initializing two independent
185 // inplace accumulations, they will write to the same buffer, clobbering
186 // each other's results.
187 absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
188 for (int i = 0; i < optimized_graph->node_size(); ++i) {
189 const NodeDef& root = optimized_graph->node(i);
190 if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
191 if (ModifiesInputsInPlace(root)) {
192 const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
193 return node->op() == root.op() || !NeverForwardsInputs(*node);
194 };
195
196 DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
197 DfsPredicates::Advance(is_continue_traversal),
198 DfsCallbacks::PreOrder([&](const NodeDef* node) {
199 feeds_inplace_op.insert(node);
200 }));
201 }
202 }
203
204 std::vector<bool> can_dedup(optimized_graph->node_size());
205 for (int i = 0; i < optimized_graph->node_size(); ++i) {
206 const NodeDef& node = optimized_graph->node(i);
207 can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
208 CanDedup(node);
209 }
210
211 bool stop = true;
212 std::set<int> duplicates;
213 UniqueNodes nodes;
214 NodeMap node_map(optimized_graph);
215 do {
216 stop = true;
217 for (int i = 0; i < optimized_graph->node_size(); ++i) {
218 if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
219 continue;
220 }
221 NodeDef* node = optimized_graph->mutable_node(i);
222 NodeDef* rep = nodes.FindOrAddRepresentative(node);
223 if (rep == node) {
224 continue;
225 }
226 // Make a copy since we mutate the set below.
227 const auto fanouts = node_map.GetOutputs(node->name());
228 for (NodeDef* fanout : fanouts) {
229 // Update consumers of node.
230 bool updated_fanout = false;
231 for (int i = 0; i < fanout->input_size(); ++i) {
232 string* fanout_input = fanout->mutable_input(i);
233
234 const int position =
235 NodePositionIfSameNode(*fanout_input, node->name());
236 // Update name in-place.
237 if (position < -1) {
238 continue;
239 } else {
240 if (!updated_fanout) {
241 // The signature of the fanout node will change. Remove it from
242 // nodes.
243 nodes.RemoveRepresentative(fanout);
244 }
245 updated_fanout = true;
246 if (position > 0) {
247 *fanout_input = StrCat(rep->name(), ":", position);
248 } else if (position == 0) {
249 *fanout_input = rep->name();
250 } else {
251 *fanout_input = StrCat("^", rep->name());
252 }
253 }
254 }
255 if (updated_fanout) {
256 node_map.UpdateInput(fanout->name(), node->name(), rep->name());
257 CanonicalizeNode(fanout);
258 }
259 }
260 if (fetch_nodes_known_) {
261 node->Clear();
262 }
263 duplicates.insert(i);
264 stop = false;
265 }
266 } while (!stop);
267
268 // Delete duplicates
269 if (fetch_nodes_known_ && !duplicates.empty()) {
270 EraseNodesFromGraph(duplicates, optimized_graph);
271 }
272
273 return OkStatus();
274 }
275
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)276 Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
277 const GrapplerItem& item,
278 GraphDef* optimized_graph) {
279 // Set up helper data structures.
280 nodes_to_preserve_ = item.NodesToPreserve();
281 fetch_nodes_known_ = !item.fetch.empty();
282 *optimized_graph = item.graph;
283
284 // Perform topological sort on the graph in order to help DedupComputations
285 // optimize larger subgraphs starting from the roots with more inputs.
286 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
287 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
288
289 return DedupComputations(optimized_graph);
290 }
291
292 } // namespace grappler
293 } // namespace tensorflow
294