xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/control_flow_deps_to_chains.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/dump_graph.h"
31 
32 namespace tensorflow {
33 
34 // TODO(mdan): Move this into Grappler - cleaner interface.
Run(const GraphOptimizationPassOptions & options)35 Status ControlFlowDepsToChainsPass::Run(
36     const GraphOptimizationPassOptions& options) {
37   VLOG(1) << "ControlFlowDepsToChainsPass::Run";
38 
39   if (options.graph == nullptr) {
40     VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted";
41     return OkStatus();
42   }
43 
44   Graph* g = options.graph->get();
45   DCHECK(g != nullptr);
46   FunctionLibraryDefinition* flib_def = options.flib_def;
47   DCHECK(flib_def != nullptr);
48 
49   if (VLOG_IS_ON(1)) {
50     DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def);
51   }
52 
53   for (Node* n : g->nodes()) {
54     if (n == nullptr) {
55       continue;
56     }
57     if (!n->IsWhileNode()) {
58       continue;
59     }
60 
61     // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed?
62     // TODO(mdan): Consolidate this with AddWhileInputHack.
63     NodeDef* while_node = n->mutable_def();
64     const auto& attrs = while_node->attr();
65     auto* mattrs = while_node->mutable_attr();
66 
67     string body_name = attrs.at("body").func().name();
68     auto* body_graph = flib_def->Find(body_name);
69     DCHECK(body_graph != nullptr);
70 
71     // Look for required annotations.
72 
73     if (attrs.find("_stateful_parallelism") == attrs.end()) {
74       continue;
75     }
76     if (!attrs.at("_stateful_parallelism").b()) {
77       continue;
78     }
79     if (attrs.find("parallel_iterations") != attrs.end()) {
80       if (attrs.at("parallel_iterations").i() < 2) {
81         continue;  // Loops which are already sequential are more efficient
82                    // without chains.
83       }
84     }
85     // TODO(mdan): We don't really need this attribute.
86     if (attrs.find("_num_original_outputs") == attrs.end()) {
87       continue;
88     }
89     int body_barrier_loc = -1;
90     std::map<string, int> node_index;
91     for (int i = 0, s = body_graph->node_def_size(); i < s; i++) {
92       node_index.emplace(body_graph->node_def(i).name(), i);
93       if (body_barrier_loc < 0) {
94         const auto& node_attr = body_graph->node_def(i).attr();
95         if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
96           body_barrier_loc = i;
97         }
98       }
99     }
100     if (body_barrier_loc < 0) {
101       continue;
102     }
103     bool ok_for_lowering = true;
104     for (int i = 0; i < body_graph->control_ret_size(); i++) {
105       const auto& control_node = body_graph->node_def(
106           node_index[body_graph->signature().control_output(i)]);
107       const auto& control_attr = control_node.attr();
108       if (control_attr.find("_res_first_used_by") == control_attr.end()) {
109         ok_for_lowering = false;
110         break;
111       }
112     }
113     if (!ok_for_lowering) {
114       continue;
115     }
116 
117     int num_loop_vars = body_graph->signature().input_arg_size();
118     int num_new_chains = body_graph->control_ret_size();
119     int num_node_inputs = while_node->input_size();
120 
121     if (!num_new_chains) {
122       continue;  // Nothing to do for stateless loops.
123     }
124 
125     // Add extra loop vars to the while node.
126 
127     // TODO(mdan): If the loop vars contains the resource, we should reuse it.
128     // Note that stateful ops of resource inputs cause their resources to be
129     // captured into the loop vars (through the body/cond captures). We could
130     // effectively use those as chains.
131 
132     // TODO(mdan): Is there a more efficient way to do this?
133     // Insert the new While node inputs: at the end of the loop vars, but before
134     // any non-loop var inputs (like control dependencies). Once the initial
135     // chain values are created below, they will be added to these inputs.
136     for (int i = 0; i < num_new_chains; i++) {
137       while_node->add_input();
138     }
139     for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) {
140       while_node->set_input(i + num_new_chains, while_node->input(i));
141     }
142 
143     std::vector<Node*> new_inputs;
144     std::vector<int> new_input_locations;
145     // Set their name to a gensym, type to float and shape to scalar.
146     for (int i = 0; i < num_new_chains; i++) {
147       string c_name = g->NewName("acd__chain");
148 
149       // The initial value for the i'th chain loop var.
150       NodeDef new_in;
151       new_in.set_name(c_name);
152       new_in.set_op("Const");
153       AttrValue att_dtype;
154       att_dtype.set_type(DT_FLOAT);
155       new_in.mutable_attr()->insert({"dtype", att_dtype});
156       AttrValue att_value;
157       att_value.mutable_tensor()->set_dtype(DT_FLOAT);
158       att_value.mutable_tensor()->mutable_tensor_shape();
159       att_value.mutable_tensor()->add_int_val(0);
160       new_in.mutable_attr()->insert({"value", att_value});
161       Status status;
162       new_inputs.push_back(g->AddNode(new_in, &status));
163       TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name);
164 
165       int loc = num_loop_vars + i;
166       new_input_locations.push_back(loc);
167       while_node->set_input(loc, c_name);
168       mattrs->at("T").mutable_list()->add_type(DT_FLOAT);
169       mattrs->at("output_shapes").mutable_list()->add_shape();
170     }
171 
172     // TODO(mdan): This should not be necessary to update. Delete?
173     mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains);
174     n->UpdateProperties();
175     for (int i = 0; i < num_new_chains; i++) {
176       g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]);
177     }
178 
179     // TODO(mdan): This is wasteful. Can we just mutate the original proto?
180     FunctionDef modified_body = *body_graph;
181 
182     // Disable the global end-of-body barrier from the body function.
183     // Because removing a node is too inefficient (would have to walk all the
184     // inputs of all graph nodes), we instead clear its control dependencies.
185     modified_body.mutable_node_def(body_barrier_loc)->clear_input();
186 
187     // Add extra loop vars to the body function.
188 
189     for (int i = 0; i < num_new_chains; i++) {
190       // Input loop vars.
191       // TODO(mdan): Double check that this doesn't clash with names in body.
192       string c_name = g->NewName("acd__chainv");
193       std::replace(c_name.begin(), c_name.end(), '/', '_');
194       auto* new_arg = modified_body.mutable_signature()->add_input_arg();
195       new_arg->set_name(c_name);
196       new_arg->set_type(DT_FLOAT);
197 
198       // Output ops. These are copies of the inputs conditioned on the actual
199       // control outputs.
200       string c_out_name = g->NewName("acd__outchain");
201       auto* new_out = modified_body.add_node_def();
202       new_out->set_name(c_out_name);
203       new_out->set_op("Identity");
204       new_out->add_input(c_name);
205       new_out->add_input(
206           strings::StrCat("^", body_graph->signature().control_output(i)));
207       AttrValue attr;
208       attr.set_type(DT_FLOAT);
209       new_out->mutable_attr()->insert({"T", attr});
210 
211       // Output loop var declarations.
212       string c_ret_name = c_out_name;
213       std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_');
214       auto* new_out_arg = modified_body.mutable_signature()->add_output_arg();
215       new_out_arg->set_name(c_ret_name);
216       new_out_arg->set_type(DT_FLOAT);
217 
218       // Actual output loop vars.
219       modified_body.mutable_ret()->insert(
220           {c_ret_name, strings::StrCat(c_out_name, ":output:0")});
221       AttrValue attr_val;
222       attr_val.mutable_list()->add_shape();
223       FunctionDef_ArgAttrs arg_attrs;
224       arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
225       modified_body.mutable_arg_attr()->insert(
226           {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
227     }
228 
229     // Wire chain loop vars to the ops they need to condition.
230 
231     node_index.clear();
232     for (int i = 0; i < modified_body.node_def_size(); i++) {
233       node_index.emplace(modified_body.node_def(i).name(), i);
234     }
235     auto& modified_sig = modified_body.signature();
236     for (int i = 0; i < num_new_chains; i++) {
237       const auto& control_node =
238           modified_body.node_def(node_index[modified_sig.control_output(i)]);
239       for (const auto& r :
240            control_node.attr().at("_res_first_used_by").list().s()) {
241         NodeDef* first_node = modified_body.mutable_node_def(node_index[r]);
242         // This control dependency ensures proper sequencing of stateful ops
243         // upon entry into the loop body, so that they run after the ops
244         // which affected the same resource in the previous iteration.
245         first_node->add_input(strings::StrCat(
246             "^", modified_sig.input_arg(i + num_loop_vars).name()));
247       }
248     }
249 
250     // Clear body function's control returns.
251     modified_body.mutable_control_ret()->clear();
252 
253     // Add extra loop vars to the cond function.
254 
255     // TODO(mdan): This is wasteful. Can't we just mutate the original proto?
256     string cond_name = attrs.at("cond").func().name();
257     auto* cond_graph = flib_def->Find(cond_name);
258     DCHECK(cond_graph != nullptr);
259     FunctionDef modified_cond = *cond_graph;
260 
261     int cond_barrier_loc = -1;
262     for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) {
263       if (cond_barrier_loc < 0) {
264         const auto& node_attr = cond_graph->node_def(i).attr();
265         if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
266           cond_barrier_loc = i;
267         }
268       }
269     }
270     if (cond_barrier_loc > 0) {
271       // Disable the global end-of-body barrier from the cond function.
272       // Because removing a node is too inefficient (would have to walk all the
273       // inputs of all graph nodes), we instead clear its control dependencies.
274       modified_cond.mutable_node_def(cond_barrier_loc)->clear_input();
275     }
276 
277     for (int i = 0; i < num_new_chains; i++) {
278       // Input loop vars.
279       // TODO(mdan): These should gate the stateful ops in the cond.
280       // Until ACD supplies the necessary information, these are dummies in this
281       // function.
282       string c_name = g->NewName("acd__chain");
283       auto* new_arg = modified_cond.mutable_signature()->add_input_arg();
284       new_arg->set_name(c_name);
285       new_arg->set_type(DT_FLOAT);
286 
287       // TODO(mdan): Return values on the cond function? Most likely a bug.
288       AttrValue attr_val;
289       attr_val.mutable_list()->add_shape();
290       FunctionDef_ArgAttrs arg_attrs;
291       arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
292       modified_cond.mutable_arg_attr()->insert(
293           {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
294     }
295 
296     // Wire the new cond/body functions to the While node.
297 
298     string new_cond_name = g->NewName("acd__while_cond");
299     modified_cond.mutable_signature()->set_name(new_cond_name);
300     mattrs->at("cond").mutable_func()->set_name(new_cond_name);
301 
302     string new_body_name = g->NewName("acd__while_body");
303     modified_body.mutable_signature()->set_name(new_body_name);
304     mattrs->at("body").mutable_func()->set_name(new_body_name);
305 
306     // Commit the new functions.
307 
308     TF_RETURN_WITH_CONTEXT_IF_ERROR(
309         flib_def->AddFunctionDef(modified_body,
310                                  flib_def->GetStackTraces(body_name)),
311         "while attaching ", new_body_name, " to flib_def");
312     TF_RETURN_WITH_CONTEXT_IF_ERROR(
313         flib_def->AddFunctionDef(modified_cond,
314                                  flib_def->GetStackTraces(cond_name)),
315         "while attaching ", new_cond_name, " to flib_def");
316 
317     // TODO(b/183666205): This should not be necessary.
318     // It's unclear why adding the functions here is also required.
319     // Moreover, it's unclear when graph_lib's parent is flib_def itself.
320     auto* graph_lib = g->mutable_flib_def();
321     if (graph_lib->default_registry() != flib_def) {
322       TF_RETURN_WITH_CONTEXT_IF_ERROR(
323           graph_lib->AddFunctionDef(modified_body,
324                                     graph_lib->GetStackTraces(body_name)),
325           "while attaching ", new_body_name, " to graph");
326       TF_RETURN_WITH_CONTEXT_IF_ERROR(
327           graph_lib->AddFunctionDef(modified_cond,
328                                     graph_lib->GetStackTraces(cond_name)),
329           "while attaching ", new_cond_name, " to graph");
330     }
331   }
332 
333   if (VLOG_IS_ON(1)) {
334     DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def);
335   }
336 
337   return OkStatus();
338 }
339 
340 // Note: This needs to run before functional control flow lowering, which is 10.
341 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9,
342                       ControlFlowDepsToChainsPass);
343 
344 }  // namespace tensorflow
345