1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/lib/core/errors.h"
31
32 namespace tensorflow {
33
34 namespace {
35
GetFunctionBody(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_attr_name,const FunctionBody ** fbody)36 Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime,
37 const NodeDef& node, StringPiece func_attr_name,
38 const FunctionBody** fbody) {
39 NameAttrList name_attr_list;
40 TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list));
41 FunctionLibraryRuntime::Handle func_handle;
42 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
43 name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle));
44 *fbody = flib_runtime->GetFunctionBody(func_handle);
45 return OkStatus();
46 }
47
GetFunctionBodies(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_list_attr_name,std::vector<const FunctionBody * > * fbodies)48 Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime,
49 const NodeDef& node, StringPiece func_list_attr_name,
50 std::vector<const FunctionBody*>* fbodies) {
51 std::vector<NameAttrList> name_attr_lists;
52 TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists));
53 for (const NameAttrList& name_attr_list : name_attr_lists) {
54 FunctionLibraryRuntime::Handle func_handle;
55 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
56 name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
57 &func_handle));
58 fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
59 }
60 return OkStatus();
61 }
62
CondConstInputIndices(absl::Span<const FunctionBody * const> branch_bodies,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)63 Status CondConstInputIndices(
64 absl::Span<const FunctionBody* const> branch_bodies,
65 std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
66 TF_RET_CHECK(!branch_bodies.empty());
67 TF_RET_CHECK(branch_bodies[0] != nullptr);
68 int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
69 // Stores indices of the "branch function" inputs that are expected to be
70 // compile time constants.
71 std::vector<bool> compile_time_const_arg_indices(num_inputs);
72 for (auto fbody : branch_bodies) {
73 TF_RET_CHECK(fbody != nullptr);
74 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
75 *(fbody->graph), &compile_time_const_arg_indices,
76 /*compile_time_const_nodes=*/nullptr, flib_runtime));
77 }
78 for (int i = 0, end = compile_time_const_arg_indices.size(); i < end; i++) {
79 if (compile_time_const_arg_indices[i]) {
80 // The 0th input is the pred or branch index, which is not passed to the
81 // branches. So the i'th input of a branch function corresponds to the
82 // i + 1'th input of the If/Case op.
83 const_input_idxs->push_back(i + 1);
84 }
85 }
86 return OkStatus();
87 }
88
GetCompileTimeConstInputs(const NodeDef & node,const OpKernel * op_kernel,const OpDef * op_def,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)89 Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
90 const OpDef* op_def,
91 std::vector<int>* const_input_idxs,
92 FunctionLibraryRuntime* flib_runtime) {
93 DCHECK(op_def != nullptr || op_kernel != nullptr);
94 if (node.op() == "While" || node.op() == "StatelessWhile") {
95 // For While nodes, recurse into the body and cond graphs.
96 const FunctionBody* fcond = nullptr;
97 const FunctionBody* fbody = nullptr;
98 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "cond", &fcond));
99 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody));
100 TF_RET_CHECK(fcond);
101 TF_RET_CHECK(fbody);
102 int num_inputs = fbody->fdef.signature().input_arg_size();
103
104 // Stores which of the loop inputs are expected to be compile time
105 // constants.
106 std::vector<bool> compile_time_const_arg_indices(num_inputs);
107 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
108 *(fcond->graph), &compile_time_const_arg_indices,
109 /*compile_time_const_nodes=*/nullptr, flib_runtime));
110 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
111 *(fbody->graph), &compile_time_const_arg_indices,
112 /*compile_time_const_nodes=*/nullptr, flib_runtime));
113 for (int i = 0; i < num_inputs; i++) {
114 if (compile_time_const_arg_indices[i]) {
115 // Check that this input is actually a loop invariant.
116 TF_ASSIGN_OR_RETURN(
117 bool is_loop_invariant,
118 IsLoopInvariant(fbody, i,
119 flib_runtime->GetFunctionLibraryDefinition()));
120 if (is_loop_invariant) {
121 const_input_idxs->push_back(i);
122 } else {
123 // TODO(b/178546817): Verify that it's OK and raise an error if we are
124 // using this branch from jit_compile=True.
125 Node* arg_i = fbody->arg_nodes[i];
126 Node* ret_i = fbody->ret_nodes[i];
127 VLOG(1) << "Argument " << i << " to while-loop " << node.name()
128 << " has to be constant, but it's not a loop invariant, "
129 "cluster compilation likely to fail at compile time: "
130 << arg_i->DebugString() << " vs. " << ret_i->DebugString();
131 VLOG(1) << node.ShortDebugString();
132 }
133 }
134 }
135 return OkStatus();
136 } else if (node.op() == "If" || node.op() == "StatelessIf") {
137 const FunctionBody* fthen = nullptr;
138 const FunctionBody* felse = nullptr;
139 TF_RETURN_IF_ERROR(
140 GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
141 TF_RETURN_IF_ERROR(
142 GetFunctionBody(flib_runtime, node, "else_branch", &felse));
143 return CondConstInputIndices({fthen, felse}, const_input_idxs,
144 flib_runtime);
145 } else if (node.op() == "Case" || node.op() == "StatelessCase") {
146 std::vector<const FunctionBody*> branch_bodies;
147 TF_RETURN_IF_ERROR(
148 GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
149 return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
150 } else if (node.op() == "PartitionedCall" ||
151 node.op() == "StatefulPartitionedCall") {
152 const FunctionBody* fbody;
153 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
154 int num_inputs = fbody->fdef.signature().input_arg_size();
155 std::vector<bool> compile_time_const_arg_indices(num_inputs);
156 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
157 *(fbody->graph), &compile_time_const_arg_indices,
158 /*compile_time_const_nodes=*/nullptr, flib_runtime));
159 for (int i = 0; i < num_inputs; i++) {
160 if (compile_time_const_arg_indices[i]) {
161 const_input_idxs->push_back(i);
162 }
163 }
164 return OkStatus();
165 } else if (op_def != nullptr) {
166 return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
167 const_input_idxs);
168 } else {
169 return XlaOpRegistry::CompileTimeConstantInputs(*op_kernel,
170 const_input_idxs);
171 }
172 }
173
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)174 Status GetCompileTimeConstInputs(const Node* node,
175 std::vector<int>* const_input_idxs,
176 FunctionLibraryRuntime* flib_runtime) {
177 return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
178 &node->op_def(), const_input_idxs,
179 flib_runtime);
180 }
181
182 } // namespace
183
184 // Backwards dataflow analysis that finds arguments to a graph that must be
185 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter_input)186 Status BackwardsConstAnalysis(
187 const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
188 std::vector<bool>* compile_time_const_nodes,
189 FunctionLibraryRuntime* flib_runtime,
190 std::function<bool(const Edge&)> edge_filter_input) {
191 if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
192 !edge_filter_input) {
193 VLOG(5) << "Using cached argument indices on graph " << &g;
194 *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
195 return OkStatus();
196 }
197 auto edge_filter = [&](const Edge& e) {
198 return edge_filter_input ? edge_filter_input(e) : true;
199 };
200
201 std::vector<bool> compile_time_const_nodes_impl;
202 if (compile_time_const_nodes) {
203 CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
204 } else {
205 compile_time_const_nodes_impl.resize(g.num_node_ids());
206 compile_time_const_nodes = &compile_time_const_nodes_impl;
207 }
208
209 Status status;
210 auto visit = [&](Node* node) {
211 if (!status.ok()) return;
212
213 // If this is a metadata-only op, don't propagate the const requirement.
214 if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
215 VLOG(3) << "must-be-const node is metadata op: " << node->name();
216 return;
217 }
218
219 // If this node must be const, and it isn't a metadata op, then all of its
220 // parents must be const.
221 if ((*compile_time_const_nodes)[node->id()]) {
222 VLOG(3) << "marking consts for must-be-const node " << node->name();
223 if (node->type_string() == "_Arg") {
224 int index;
225 status = GetNodeAttr(node->attrs(), "index", &index);
226 if (!status.ok()) return;
227 if (compile_time_const_arg_indices) {
228 (*compile_time_const_arg_indices)[index] = true;
229 }
230 VLOG(3) << " const _Arg " << index << ": " << node->name();
231 return;
232 }
233 for (const Edge* pred : node->in_edges()) {
234 if (!pred->IsControlEdge() && edge_filter(*pred)) {
235 // If the src node of the `pred` is an IdentityN/While do not mark it
236 // as a compile-time const. Only mark the corresponding input to the
237 // IdentityN/While node as a const. XLA IdentityN op simply forwards
238 // its inputs so this is safe; loop-invariance is checked elsewhere.
239 while (edge_filter(*pred) && IsConstTraversableOpType(pred->src())) {
240 status = pred->src()->input_edge(pred->src_output(), &pred);
241 if (!status.ok()) return;
242 }
243 if (edge_filter(*pred)) {
244 VLOG(4) << " " << pred->src()->name() << " must be const (is "
245 << pred->src()->type_string() << ")";
246 (*compile_time_const_nodes)[pred->src()->id()] = true;
247 }
248 }
249 }
250 return;
251 }
252
253 // Mark any compile-time constant operator arguments as const.
254 std::vector<int> const_input_idxs;
255 status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
256
257 if (!status.ok() || const_input_idxs.empty()) {
258 return;
259 }
260
261 VLOG(3) << "marking consts for must-be-const inputs of " << node->name();
262 for (Edge const* edge : node->in_edges()) {
263 if (!edge->IsControlEdge() &&
264 absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
265 edge_filter(*edge)) {
266 // Do not mark IdentityN / While nodes as compile-time const.
267 // If the src node of the `pred` is an IdentityN do not mark it as a
268 // compile-time const. Only mark the corresponding input to the
269 // IdentityN/While node as a const. XLA IdentityN op simply forwards its
270 // inputs so this is safe; loop invariance is checked elsewhere.
271 while (edge_filter(*edge) && IsConstTraversableOpType(edge->src())) {
272 status = edge->src()->input_edge(edge->src_output(), &edge);
273 if (!status.ok()) return;
274 }
275 if (edge_filter(*edge)) {
276 VLOG(4) << " input " << edge->dst_input() << ": "
277 << edge->src()->name() << " must be const (is "
278 << edge->src()->type_string() << ")";
279 (*compile_time_const_nodes)[edge->src()->id()] = true;
280 }
281 }
282 }
283 };
284
285 // Post-order traversal visits nodes in reverse topological order for an
286 // acyclic graph.
287 DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
288 [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
289 if (compile_time_const_arg_indices && !edge_filter_input) {
290 VLOG(5) << "Setting the cache on the graph: " << &g;
291 g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
292 }
293 return status;
294 }
295
GetCompileTimeConstInputs(const OpKernel * op_kernel,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)296 Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
297 std::vector<int>* const_input_idxs,
298 FunctionLibraryRuntime* flib_runtime) {
299 return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
300 /*op_def=*/nullptr, const_input_idxs,
301 flib_runtime);
302 }
303
304 } // namespace tensorflow
305