xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/variable_merger_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 #include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
17 
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/fingerprint.h"
31 #include "tensorflow/core/util/dump_graph.h"
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 // The name of a stateful op is semantically meaningful because ops with the
38 // same name will share the same kernel. We therefore form new op names using a
39 // deterministic function (a fingerprint) of the old names.
MergedOpFingerprint(absl::Span<Node * const> ops)40 uint64 MergedOpFingerprint(absl::Span<Node* const> ops) {
41   std::vector<string> op_names;
42   op_names.reserve(ops.size());
43   for (const Node* node : ops) {
44     op_names.push_back(node->name());
45   }
46   return Fingerprint64(absl::StrJoin(op_names, ","));
47 }
48 
MergeVarHandleOps(const string & device,absl::Span<Node * const> nodes,Graph * graph)49 Status MergeVarHandleOps(const string& device, absl::Span<Node* const> nodes,
50                          Graph* graph) {
51   int num_var_handles(nodes.size());
52   if (num_var_handles <= 1) return OkStatus();
53 
54   std::vector<string> containers(num_var_handles);
55   std::vector<string> names(num_var_handles);
56   DataTypeVector dtypes(num_var_handles);
57   std::vector<PartialTensorShape> shapes(num_var_handles);
58   for (int i = 0; i < num_var_handles; ++i) {
59     TF_RETURN_IF_ERROR(
60         GetNodeAttr(nodes[i]->attrs(), "container", &containers[i]));
61     TF_RETURN_IF_ERROR(
62         GetNodeAttr(nodes[i]->attrs(), "shared_name", &names[i]));
63     TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "dtype", &dtypes[i]));
64     TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "shape", &shapes[i]));
65   }
66   NodeDefBuilder builder(graph->NewName(strings::StrCat(
67                              "VarHandles_", MergedOpFingerprint(nodes))),
68                          "_VarHandlesOp");
69   builder.Attr("N", num_var_handles);
70   builder.Attr("containers", containers);
71   builder.Attr("shared_names", names);
72   builder.Attr("dtypes", dtypes);
73   builder.Attr("shapes", shapes);
74   builder.Device(device);
75   NodeDef node_def;
76   TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
77   TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(node_def));
78   node->set_assigned_device_name(device);
79 
80   graph->AddControlEdge(graph->source_node(), node);
81   for (int i = 0; i < num_var_handles; ++i) {
82     std::vector<std::pair<Node*, int>> consumers;
83     for (const Edge* e : nodes[i]->out_edges()) {
84       consumers.emplace_back(e->dst(), e->dst_input());
85     }
86     graph->RemoveNode(nodes[i]);
87     for (const auto& t : consumers) {
88       graph->AddEdge(node, t.second < 0 ? -1 : i, t.first, t.second);
89     }
90   }
91   return OkStatus();
92 }
93 
MergeReadVariableOps(Node * handle_op,Node * control_node,absl::Span<Node * const> nodes,Graph * graph)94 Status MergeReadVariableOps(Node* handle_op, Node* control_node,
95                             absl::Span<Node* const> nodes, Graph* graph) {
96   int num_reads(nodes.size());
97   if (num_reads <= 1) return OkStatus();
98 
99   DataTypeVector dtypes(num_reads);
100   for (int i = 0; i < num_reads; ++i) {
101     TF_RETURN_IF_ERROR(GetNodeAttr(nodes[i]->attrs(), "dtype", &dtypes[i]));
102   }
103   NodeDef node_def;
104   node_def.set_name(graph->NewName(
105       strings::StrCat("ReadVariables_", MergedOpFingerprint(nodes))));
106   node_def.set_op("_ReadVariablesOp");
107   AddNodeAttr("N", num_reads, &node_def);
108   AddNodeAttr("dtypes", dtypes, &node_def);
109   node_def.set_device(handle_op->requested_device());
110   TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(node_def));
111   node->set_assigned_device_name(handle_op->assigned_device_name());
112   if (control_node) graph->AddControlEdge(control_node, node);
113   for (int i = 0; i < num_reads; ++i) {
114     const Edge* handle_edge;
115     TF_RETURN_IF_ERROR(nodes[i]->input_edge(0, &handle_edge));
116     graph->AddEdge(handle_edge->src(), handle_edge->src_output(), node, i);
117 
118     std::vector<std::pair<Node*, int>> consumers;
119     for (const Edge* e : nodes[i]->out_edges()) {
120       consumers.emplace_back(e->dst(), e->dst_input());
121     }
122     graph->RemoveNode(nodes[i]);
123     for (const auto& t : consumers) {
124       graph->AddEdge(node, t.second < 0 ? -1 : i, t.first, t.second);
125     }
126   }
127   return OkStatus();
128 }
129 
130 }  // namespace
131 
Run(const GraphOptimizationPassOptions & options)132 Status VariableMergerPass::Run(const GraphOptimizationPassOptions& options) {
133   Graph* graph = options.graph->get();
134 
135   VLOG(1) << DumpGraphToFile("variable_merger_pass_before", *graph);
136 
137   // Find VarHandleOps that are graph roots and group them by assigned device.
138   // Also find any ReadVariableOps that are consumers of those handles.
139   absl::flat_hash_map<string, std::vector<Node*>> var_handle_ops_by_device;
140   absl::flat_hash_set<Node*> read_variable_ops;
141 
142   for (Node* m : graph->source_node()->out_nodes()) {
143     // We check that the VarHandleOp has no control edges, other than the one we
144     // followed from the source node.
145     if (m->type_string() == "VarHandleOp" && m->in_edges().size() == 1) {
146       var_handle_ops_by_device[m->assigned_device_name()].push_back(m);
147       for (Node* n : m->out_nodes()) {
148         // ReadVariableOp could have control edges, we will group them by
149         // merged VarHandleOp and control dependency.
150         if (n->type_string() == "ReadVariableOp" && n->in_edges().size() <= 2) {
151           read_variable_ops.insert(n);
152         }
153       }
154     }
155   }
156 
157   auto node_name_comparator = [](Node* a, Node* b) {
158     return a->name() < b->name();
159   };
160 
161   // First merge the var handle ops.
162   for (auto& vh : var_handle_ops_by_device) {
163     // Sort the handles by name for determinism.
164     std::sort(vh.second.begin(), vh.second.end(), node_name_comparator);
165     TF_RETURN_IF_ERROR(MergeVarHandleOps(vh.first, vh.second, graph));
166   }
167 
168   // ReadVariableOps by a pair of <VarHandleOp, ControlDependencyNode>.
169   // ControlDependencyNode could be nullptr.
170   absl::flat_hash_map<std::pair<Node*, Node*>, std::vector<Node*>> read_var_ops;
171 
172   for (Node* n : read_variable_ops) {
173     Node* control_node = nullptr;
174     Node* var_handle_op = nullptr;
175     // Each ReadVariableOp has at most one control input since we only choose
176     // ReadVariableOp with at most 2 input edges.
177     for (const Edge* e : n->in_edges()) {
178       if (e->IsControlEdge()) {
179         control_node = e->src();
180       } else {
181         var_handle_op = e->src();
182       }
183     }
184     TF_RET_CHECK(var_handle_op != nullptr);
185     read_var_ops[std::pair<Node*, Node*>(var_handle_op, control_node)]
186         .push_back(n);
187   }
188 
189   for (auto& r : read_var_ops) {
190     // Sort the reads by name for determinism.
191     std::sort(r.second.begin(), r.second.end(), node_name_comparator);
192     TF_RETURN_IF_ERROR(
193         MergeReadVariableOps(r.first.first, r.first.second, r.second, graph));
194   }
195 
196   VLOG(1) << DumpGraphToFile("variable_merger_pass_after", *graph);
197   return OkStatus();
198 }
199 
200 }  // namespace tensorflow
201