1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests for the backward const analysis.
17
18 #include "tensorflow/compiler/tf2xla/const_analysis.h"
19
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_cluster_util.h"
26 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/public/version.h"
31
32 namespace tensorflow {
33 namespace {
34
TEST(ConstAnalysisTest,Basics)35 TEST(ConstAnalysisTest, Basics) {
36 Scope root = Scope::NewRootScope();
37
38 auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
39 auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
40 auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
41 auto arg3 = ops::_Arg(root.WithOpName("Arg3"), DT_INT32, 3);
42 auto a = ops::Shape(root, arg0);
43 auto b = ops::Add(root, a, arg1);
44 auto c = ops::Reshape(root, arg2, b);
45 auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
46
47 FixupSourceAndSinkEdges(root.graph());
48
49 std::vector<bool> const_args(4, false);
50 std::vector<bool> const_nodes(root.graph()->num_node_ids(), false);
51 TF_ASSERT_OK(BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes,
52 /*flib_runtime=*/nullptr));
53
54 // Arg 0 doesn't need to be constant since the graph only uses its shape.
55 // Arg 1 must be constant because it flows to the shape argument of a Reshape.
56 // Arg 2 is used only as the value input to a Reshape and need not be const.
57 // Arg 3 is used as the reduction-indices argument to Sum and must be const.
58 EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
59
60 EXPECT_FALSE(const_nodes[arg0.node()->id()]);
61 EXPECT_TRUE(const_nodes[arg1.node()->id()]);
62 EXPECT_FALSE(const_nodes[arg2.node()->id()]);
63 EXPECT_TRUE(const_nodes[arg3.node()->id()]);
64 }
65
66 // Regression test for a case where the backward const analysis did
67 // not visit nodes in topological order.
TEST(ConstAnalysisTest,TopologicalOrder)68 TEST(ConstAnalysisTest, TopologicalOrder) {
69 for (bool order : {false, true}) {
70 Scope root = Scope::NewRootScope();
71
72 auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
73 auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
74 auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
75 auto a = ops::Reshape(root, arg0, arg1);
76 auto b = ops::Reshape(root, arg2, a);
77 if (order) {
78 // Consider both orders for arguments to the Sum so we aren't sensitive
79 // to the DFS traversal order.
80 std::swap(a, b);
81 }
82 auto c = ops::Add(root, a, b);
83
84 Graph graph(OpRegistry::Global());
85 TF_ASSERT_OK(root.ToGraph(&graph));
86
87 std::vector<bool> const_args(3, false);
88 TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
89 /*compile_time_const_nodes=*/nullptr,
90 /*flib_runtime=*/nullptr));
91
92 EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
93 }
94 }
95
TestFunctionCall(bool is_stateful_partitioned_call)96 void TestFunctionCall(bool is_stateful_partitioned_call) {
97 FunctionDef callee = FunctionDefHelper::Define(
98 "Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
99 {{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
100
101 FunctionDefLibrary flib;
102 *flib.add_function() = callee;
103 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
104
105 Scope root = Scope::NewRootScope().ExitOnError();
106
107 auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
108 auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
109
110 NameAttrList call_attrs;
111 call_attrs.set_name("Callee");
112 if (is_stateful_partitioned_call) {
113 ops::StatefulPartitionedCall b(root.WithOpName("Call"),
114 {Output(arg0), Output(arg1)}, {DT_FLOAT},
115 call_attrs);
116 } else {
117 ops::PartitionedCall b(root.WithOpName("Call"),
118 {Output(arg0), Output(arg1)}, {DT_FLOAT},
119 call_attrs);
120 }
121
122 Graph graph(&flib_def);
123 TF_ASSERT_OK(root.ToGraph(&graph));
124
125 OptimizerOptions opts;
126 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
127 new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
128 /*config=*/nullptr,
129 TF_GRAPH_DEF_VERSION, &flib_def, opts));
130 FunctionLibraryRuntime* lib_runtime =
131 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
132
133 std::vector<bool> const_args(2, false);
134 TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
135 /*compile_time_const_nodes=*/nullptr,
136 lib_runtime));
137
138 EXPECT_EQ(const_args, std::vector<bool>({false, true}));
139 }
140
TEST(ConstAnalysisTest,PartitionedCall)141 TEST(ConstAnalysisTest, PartitionedCall) {
142 TestFunctionCall(/*is_stateful_partitioned_call=*/false);
143 }
144
TEST(ConstAnalysisTest,StatefulPartitionedCall)145 TEST(ConstAnalysisTest, StatefulPartitionedCall) {
146 TestFunctionCall(/*is_stateful_partitioned_call=*/true);
147 }
148
TEST(ConstAnalysisTest,DontFollowControlDependencies)149 TEST(ConstAnalysisTest, DontFollowControlDependencies) {
150 Scope root = Scope::NewRootScope();
151
152 Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
153 Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
154 Output c1 =
155 ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
156 Output add = ops::Add(root, arg1, c1);
157 Output reshape = ops::Reshape(root, arg1, add);
158
159 Graph graph(OpRegistry::Global());
160 TF_ASSERT_OK(root.ToGraph(&graph));
161
162 std::vector<bool> const_args(2, false);
163 TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
164 /*compile_time_const_nodes=*/nullptr,
165 /*flib_runtime=*/nullptr));
166
167 EXPECT_EQ(const_args, std::vector<bool>({false, true}));
168 }
169
TEST(ConstAnalysisTest,RespectExplicitAttr_0)170 TEST(ConstAnalysisTest, RespectExplicitAttr_0) {
171 Scope root = Scope::NewRootScope();
172
173 Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
174 Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
175 Output c1 =
176 ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
177 Output add = ops::Add(root, arg1, c1);
178
179 // Force const analysis to pretend that the shape argument to `reshape` does
180 // not need to be a constant.
181 Output reshape = ops::Reshape(root, arg1, add);
182 reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr,
183 std::vector<string>());
184
185 Graph graph(OpRegistry::Global());
186 TF_ASSERT_OK(root.ToGraph(&graph));
187
188 std::vector<bool> const_args(2, false);
189 TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
190 /*compile_time_const_nodes=*/nullptr,
191 /*flib_runtime=*/nullptr));
192
193 EXPECT_EQ(const_args, std::vector<bool>({false, false}));
194 }
195
TEST(ConstAnalysisTest,RespectExplicitAttr_1)196 TEST(ConstAnalysisTest, RespectExplicitAttr_1) {
197 Scope root = Scope::NewRootScope();
198
199 Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
200 Output c1 =
201 ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
202 Output add = ops::Add(root, arg0, c1);
203
204 // Force const analysis to pretend that the first argument to `add` needs to
205 // be a constant.
206 std::vector<string> add_constant_inputs;
207 add_constant_inputs.push_back("x");
208 add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs);
209
210 Graph graph(OpRegistry::Global());
211 TF_ASSERT_OK(root.ToGraph(&graph));
212
213 std::vector<bool> const_args(1, false);
214 TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
215 /*compile_time_const_nodes=*/nullptr,
216 /*flib_runtime=*/nullptr));
217
218 EXPECT_EQ(const_args, std::vector<bool>({true}));
219 }
220
__anon997b50920202null221 static bool Initialized = [] {
222 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
223 return true;
224 }();
225
226 } // namespace
227 } // namespace tensorflow
228