1
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 #include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
17
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/fingerprint.h"
31 #include "tensorflow/core/util/dump_graph.h"
32
33 namespace tensorflow {
34
35 namespace {
36
37 // The name of a stateful op is semantically meaningful because ops with the
38 // same name will share the same kernel. We therefore form new op names using a
39 // deterministic function (a fingerprint) of the old names.
MergedOpFingerprint(absl::Span<Node * const> ops)40 uint64 MergedOpFingerprint(absl::Span<Node* const> ops) {
41 std::vector<string> op_names;
42 op_names.reserve(ops.size());
43 for (const Node* node : ops) {
44 op_names.push_back(node->name());
45 }
46 return Fingerprint64(absl::StrJoin(op_names, ","));
47 }
48
MergeVarHandleOps(const string & device,absl::Span<Node * const> nodes,Graph * graph)49 Status MergeVarHandleOps(const string& device, absl::Span<Node* const> nodes,
50 Graph* graph) {
51 int num_var_handles(nodes.size());
52 if (num_var_handles <= 1) return OkStatus();
53
54 std::vector<string> containers(num_var_handles);
55 std::vector<string> names(num_var_handles);
56 DataTypeVector dtypes(num_var_handles);
57 std::vector<PartialTensorShape> shapes(num_var_handles);
58 for (int i = 0; i < num_var_handles; ++i) {
59 TF_RETURN_IF_ERROR(
60 GetNodeAttr(nodes[i]->attrs(), "container", &containers[i]));
61 TF_RETURN_IF_ERROR(
62 GetNodeAttr(nodes[i]->attrs(), "shared_name", &names[i]));
63 TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "dtype", &dtypes[i]));
64 TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "shape", &shapes[i]));
65 }
66 NodeDefBuilder builder(graph->NewName(strings::StrCat(
67 "VarHandles_", MergedOpFingerprint(nodes))),
68 "_VarHandlesOp");
69 builder.Attr("N", num_var_handles);
70 builder.Attr("containers", containers);
71 builder.Attr("shared_names", names);
72 builder.Attr("dtypes", dtypes);
73 builder.Attr("shapes", shapes);
74 builder.Device(device);
75 NodeDef node_def;
76 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
77 TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(node_def));
78 node->set_assigned_device_name(device);
79
80 graph->AddControlEdge(graph->source_node(), node);
81 for (int i = 0; i < num_var_handles; ++i) {
82 std::vector<std::pair<Node*, int>> consumers;
83 for (const Edge* e : nodes[i]->out_edges()) {
84 consumers.emplace_back(e->dst(), e->dst_input());
85 }
86 graph->RemoveNode(nodes[i]);
87 for (const auto& t : consumers) {
88 graph->AddEdge(node, t.second < 0 ? -1 : i, t.first, t.second);
89 }
90 }
91 return OkStatus();
92 }
93
MergeReadVariableOps(Node * handle_op,Node * control_node,absl::Span<Node * const> nodes,Graph * graph)94 Status MergeReadVariableOps(Node* handle_op, Node* control_node,
95 absl::Span<Node* const> nodes, Graph* graph) {
96 int num_reads(nodes.size());
97 if (num_reads <= 1) return OkStatus();
98
99 DataTypeVector dtypes(num_reads);
100 for (int i = 0; i < num_reads; ++i) {
101 TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "dtype", &dtypes[i]));
102 }
103 NodeDef node_def;
104 node_def.set_name(graph->NewName(
105 strings::StrCat("ReadVariables_", MergedOpFingerprint(nodes))));
106 node_def.set_op("_ReadVariablesOp");
107 AddNodeAttr("N", num_reads, &node_def);
108 AddNodeAttr("dtypes", dtypes, &node_def);
109 node_def.set_device(handle_op->requested_device());
110 TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(node_def));
111 node->set_assigned_device_name(handle_op->assigned_device_name());
112 if (control_node) graph->AddControlEdge(control_node, node);
113 for (int i = 0; i < num_reads; ++i) {
114 const Edge* handle_edge;
115 TF_RETURN_IF_ERROR(nodes[i]->input_edge(0, &handle_edge));
116 graph->AddEdge(handle_edge->src(), handle_edge->src_output(), node, i);
117
118 std::vector<std::pair<Node*, int>> consumers;
119 for (const Edge* e : nodes[i]->out_edges()) {
120 consumers.emplace_back(e->dst(), e->dst_input());
121 }
122 graph->RemoveNode(nodes[i]);
123 for (const auto& t : consumers) {
124 graph->AddEdge(node, t.second < 0 ? -1 : i, t.first, t.second);
125 }
126 }
127 return OkStatus();
128 }
129
130 } // namespace
131
Run(const GraphOptimizationPassOptions & options)132 Status VariableMergerPass::Run(const GraphOptimizationPassOptions& options) {
133 Graph* graph = options.graph->get();
134
135 VLOG(1) << DumpGraphToFile("variable_merger_pass_before", *graph);
136
137 // Find VarHandleOps that are graph roots and group them by assigned device.
138 // Also find any ReadVariableOps that are consumers of those handles.
139 absl::flat_hash_map<string, std::vector<Node*>> var_handle_ops_by_device;
140 absl::flat_hash_set<Node*> read_variable_ops;
141
142 for (Node* m : graph->source_node()->out_nodes()) {
143 // We check that the VarHandleOp has no control edges, other than the one we
144 // followed from the source node.
145 if (m->type_string() == "VarHandleOp" && m->in_edges().size() == 1) {
146 var_handle_ops_by_device[m->assigned_device_name()].push_back(m);
147 for (Node* n : m->out_nodes()) {
148 // ReadVariableOp could have control edges, we will group them by
149 // merged VarHandleOp and control dependency.
150 if (n->type_string() == "ReadVariableOp" && n->in_edges().size() <= 2) {
151 read_variable_ops.insert(n);
152 }
153 }
154 }
155 }
156
157 auto node_name_comparator = [](Node* a, Node* b) {
158 return a->name() < b->name();
159 };
160
161 // First merge the var handle ops.
162 for (auto& vh : var_handle_ops_by_device) {
163 // Sort the handles by name for determinism.
164 std::sort(vh.second.begin(), vh.second.end(), node_name_comparator);
165 TF_RETURN_IF_ERROR(MergeVarHandleOps(vh.first, vh.second, graph));
166 }
167
168 // ReadVariableOps by a pair of <VarHandleOp, ControlDependencyNode>.
169 // ControlDependencyNode could be nullptr.
170 absl::flat_hash_map<std::pair<Node*, Node*>, std::vector<Node*>> read_var_ops;
171
172 for (Node* n : read_variable_ops) {
173 Node* control_node = nullptr;
174 Node* var_handle_op = nullptr;
175 // Each ReadVariableOp has at most one control input since we only choose
176 // ReadVariableOp with at most 2 input edges.
177 for (const Edge* e : n->in_edges()) {
178 if (e->IsControlEdge()) {
179 control_node = e->src();
180 } else {
181 var_handle_op = e->src();
182 }
183 }
184 TF_RET_CHECK(var_handle_op != nullptr);
185 read_var_ops[std::pair<Node*, Node*>(var_handle_op, control_node)]
186 .push_back(n);
187 }
188
189 for (auto& r : read_var_ops) {
190 // Sort the reads by name for determinism.
191 std::sort(r.second.begin(), r.second.end(), node_name_comparator);
192 TF_RETURN_IF_ERROR(
193 MergeReadVariableOps(r.first.first, r.first.second, r.second, graph));
194 }
195
196 VLOG(1) << DumpGraphToFile("variable_merger_pass_after", *graph);
197 return OkStatus();
198 }
199
200 } // namespace tensorflow
201