xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/const_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 
GetFunctionBody(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_attr_name,const FunctionBody ** fbody)36 Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime,
37                        const NodeDef& node, StringPiece func_attr_name,
38                        const FunctionBody** fbody) {
39   NameAttrList name_attr_list;
40   TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list));
41   FunctionLibraryRuntime::Handle func_handle;
42   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
43       name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle));
44   *fbody = flib_runtime->GetFunctionBody(func_handle);
45   return OkStatus();
46 }
47 
GetFunctionBodies(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_list_attr_name,std::vector<const FunctionBody * > * fbodies)48 Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime,
49                          const NodeDef& node, StringPiece func_list_attr_name,
50                          std::vector<const FunctionBody*>* fbodies) {
51   std::vector<NameAttrList> name_attr_lists;
52   TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists));
53   for (const NameAttrList& name_attr_list : name_attr_lists) {
54     FunctionLibraryRuntime::Handle func_handle;
55     TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
56         name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
57         &func_handle));
58     fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
59   }
60   return OkStatus();
61 }
62 
CondConstInputIndices(absl::Span<const FunctionBody * const> branch_bodies,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)63 Status CondConstInputIndices(
64     absl::Span<const FunctionBody* const> branch_bodies,
65     std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
66   TF_RET_CHECK(!branch_bodies.empty());
67   TF_RET_CHECK(branch_bodies[0] != nullptr);
68   int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
69   // Stores indices of the "branch function" inputs that are expected to be
70   // compile time constants.
71   std::vector<bool> compile_time_const_arg_indices(num_inputs);
72   for (auto fbody : branch_bodies) {
73     TF_RET_CHECK(fbody != nullptr);
74     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
75         *(fbody->graph), &compile_time_const_arg_indices,
76         /*compile_time_const_nodes=*/nullptr, flib_runtime));
77   }
78   for (int i = 0, end = compile_time_const_arg_indices.size(); i < end; i++) {
79     if (compile_time_const_arg_indices[i]) {
80       // The 0th input is the pred or branch index, which is not passed to the
81       // branches. So the i'th input of a branch function corresponds to the
82       // i + 1'th input of the If/Case op.
83       const_input_idxs->push_back(i + 1);
84     }
85   }
86   return OkStatus();
87 }
88 
GetCompileTimeConstInputs(const NodeDef & node,const OpKernel * op_kernel,const OpDef * op_def,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)89 Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
90                                  const OpDef* op_def,
91                                  std::vector<int>* const_input_idxs,
92                                  FunctionLibraryRuntime* flib_runtime) {
93   DCHECK(op_def != nullptr || op_kernel != nullptr);
94   if (node.op() == "While" || node.op() == "StatelessWhile") {
95     // For While nodes, recurse into the body and cond graphs.
96     const FunctionBody* fcond = nullptr;
97     const FunctionBody* fbody = nullptr;
98     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "cond", &fcond));
99     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody));
100     TF_RET_CHECK(fcond);
101     TF_RET_CHECK(fbody);
102     int num_inputs = fbody->fdef.signature().input_arg_size();
103 
104     // Stores which of the loop inputs are expected to be compile time
105     // constants.
106     std::vector<bool> compile_time_const_arg_indices(num_inputs);
107     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
108         *(fcond->graph), &compile_time_const_arg_indices,
109         /*compile_time_const_nodes=*/nullptr, flib_runtime));
110     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
111         *(fbody->graph), &compile_time_const_arg_indices,
112         /*compile_time_const_nodes=*/nullptr, flib_runtime));
113     for (int i = 0; i < num_inputs; i++) {
114       if (compile_time_const_arg_indices[i]) {
115         // Check that this input is actually a loop invariant.
116         TF_ASSIGN_OR_RETURN(
117             bool is_loop_invariant,
118             IsLoopInvariant(fbody, i,
119                             flib_runtime->GetFunctionLibraryDefinition()));
120         if (is_loop_invariant) {
121           const_input_idxs->push_back(i);
122         } else {
123           // TODO(b/178546817): Verify that it's OK and raise an error if we are
124           // using this branch from jit_compile=True.
125           Node* arg_i = fbody->arg_nodes[i];
126           Node* ret_i = fbody->ret_nodes[i];
127           VLOG(1) << "Argument " << i << " to while-loop " << node.name()
128                   << " has to be constant, but it's not a loop invariant, "
129                      "cluster compilation likely to fail at compile time: "
130                   << arg_i->DebugString() << " vs. " << ret_i->DebugString();
131           VLOG(1) << node.ShortDebugString();
132         }
133       }
134     }
135     return OkStatus();
136   } else if (node.op() == "If" || node.op() == "StatelessIf") {
137     const FunctionBody* fthen = nullptr;
138     const FunctionBody* felse = nullptr;
139     TF_RETURN_IF_ERROR(
140         GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
141     TF_RETURN_IF_ERROR(
142         GetFunctionBody(flib_runtime, node, "else_branch", &felse));
143     return CondConstInputIndices({fthen, felse}, const_input_idxs,
144                                  flib_runtime);
145   } else if (node.op() == "Case" || node.op() == "StatelessCase") {
146     std::vector<const FunctionBody*> branch_bodies;
147     TF_RETURN_IF_ERROR(
148         GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
149     return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
150   } else if (node.op() == "PartitionedCall" ||
151              node.op() == "StatefulPartitionedCall") {
152     const FunctionBody* fbody;
153     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
154     int num_inputs = fbody->fdef.signature().input_arg_size();
155     std::vector<bool> compile_time_const_arg_indices(num_inputs);
156     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
157         *(fbody->graph), &compile_time_const_arg_indices,
158         /*compile_time_const_nodes=*/nullptr, flib_runtime));
159     for (int i = 0; i < num_inputs; i++) {
160       if (compile_time_const_arg_indices[i]) {
161         const_input_idxs->push_back(i);
162       }
163     }
164     return OkStatus();
165   } else if (op_def != nullptr) {
166     return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
167                                                     const_input_idxs);
168   } else {
169     return XlaOpRegistry::CompileTimeConstantInputs(*op_kernel,
170                                                     const_input_idxs);
171   }
172 }
173 
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)174 Status GetCompileTimeConstInputs(const Node* node,
175                                  std::vector<int>* const_input_idxs,
176                                  FunctionLibraryRuntime* flib_runtime) {
177   return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
178                                    &node->op_def(), const_input_idxs,
179                                    flib_runtime);
180 }
181 
182 }  // namespace
183 
184 // Backwards dataflow analysis that finds arguments to a graph that must be
185 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter_input)186 Status BackwardsConstAnalysis(
187     const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
188     std::vector<bool>* compile_time_const_nodes,
189     FunctionLibraryRuntime* flib_runtime,
190     std::function<bool(const Edge&)> edge_filter_input) {
191   if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
192       !edge_filter_input) {
193     VLOG(5) << "Using cached argument indices on graph " << &g;
194     *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
195     return OkStatus();
196   }
197   auto edge_filter = [&](const Edge& e) {
198     return edge_filter_input ? edge_filter_input(e) : true;
199   };
200 
201   std::vector<bool> compile_time_const_nodes_impl;
202   if (compile_time_const_nodes) {
203     CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
204   } else {
205     compile_time_const_nodes_impl.resize(g.num_node_ids());
206     compile_time_const_nodes = &compile_time_const_nodes_impl;
207   }
208 
209   Status status;
210   auto visit = [&](Node* node) {
211     if (!status.ok()) return;
212 
213     // If this is a metadata-only op, don't propagate the const requirement.
214     if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
215       VLOG(3) << "must-be-const node is metadata op: " << node->name();
216       return;
217     }
218 
219     // If this node must be const, and it isn't a metadata op, then all of its
220     // parents must be const.
221     if ((*compile_time_const_nodes)[node->id()]) {
222       VLOG(3) << "marking consts for must-be-const node " << node->name();
223       if (node->type_string() == "_Arg") {
224         int index;
225         status = GetNodeAttr(node->attrs(), "index", &index);
226         if (!status.ok()) return;
227         if (compile_time_const_arg_indices) {
228           (*compile_time_const_arg_indices)[index] = true;
229         }
230         VLOG(3) << "  const _Arg " << index << ": " << node->name();
231         return;
232       }
233       for (const Edge* pred : node->in_edges()) {
234         if (!pred->IsControlEdge() && edge_filter(*pred)) {
235           // If the src node of the `pred` is an IdentityN/While do not mark it
236           // as a compile-time const. Only mark the corresponding input to the
237           // IdentityN/While node as a const. XLA IdentityN op simply forwards
238           // its inputs so this is safe; loop-invariance is checked elsewhere.
239           while (edge_filter(*pred) && IsConstTraversableOpType(pred->src())) {
240             status = pred->src()->input_edge(pred->src_output(), &pred);
241             if (!status.ok()) return;
242           }
243           if (edge_filter(*pred)) {
244             VLOG(4) << "  " << pred->src()->name() << " must be const (is "
245                     << pred->src()->type_string() << ")";
246             (*compile_time_const_nodes)[pred->src()->id()] = true;
247           }
248         }
249       }
250       return;
251     }
252 
253     // Mark any compile-time constant operator arguments as const.
254     std::vector<int> const_input_idxs;
255     status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
256 
257     if (!status.ok() || const_input_idxs.empty()) {
258       return;
259     }
260 
261     VLOG(3) << "marking consts for must-be-const inputs of " << node->name();
262     for (Edge const* edge : node->in_edges()) {
263       if (!edge->IsControlEdge() &&
264           absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
265           edge_filter(*edge)) {
266         // Do not mark IdentityN / While nodes as compile-time const.
267         // If the src node of the `pred` is an IdentityN do not mark it as a
268         // compile-time const. Only mark the corresponding input to the
269         // IdentityN/While node as a const. XLA IdentityN op simply forwards its
270         // inputs so this is safe; loop invariance is checked elsewhere.
271         while (edge_filter(*edge) && IsConstTraversableOpType(edge->src())) {
272           status = edge->src()->input_edge(edge->src_output(), &edge);
273           if (!status.ok()) return;
274         }
275         if (edge_filter(*edge)) {
276           VLOG(4) << "  input " << edge->dst_input() << ": "
277                   << edge->src()->name() << " must be const (is "
278                   << edge->src()->type_string() << ")";
279           (*compile_time_const_nodes)[edge->src()->id()] = true;
280         }
281       }
282     }
283   };
284 
285   // Post-order traversal visits nodes in reverse topological order for an
286   // acyclic graph.
287   DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
288       [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
289   if (compile_time_const_arg_indices && !edge_filter_input) {
290     VLOG(5) << "Setting the cache on the graph: " << &g;
291     g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
292   }
293   return status;
294 }
295 
GetCompileTimeConstInputs(const OpKernel * op_kernel,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)296 Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
297                                  std::vector<int>* const_input_idxs,
298                                  FunctionLibraryRuntime* flib_runtime) {
299   return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
300                                    /*op_def=*/nullptr, const_input_idxs,
301                                    flib_runtime);
302 }
303 
304 }  // namespace tensorflow
305