xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/const_analysis_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests for the backward const analysis.
17 
18 #include "tensorflow/compiler/tf2xla/const_analysis.h"
19 
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_cluster_util.h"
26 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/public/version.h"
31 
32 namespace tensorflow {
33 namespace {
34 
TEST(ConstAnalysisTest,Basics)35 TEST(ConstAnalysisTest, Basics) {
36   Scope root = Scope::NewRootScope();
37 
38   auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
39   auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
40   auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
41   auto arg3 = ops::_Arg(root.WithOpName("Arg3"), DT_INT32, 3);
42   auto a = ops::Shape(root, arg0);
43   auto b = ops::Add(root, a, arg1);
44   auto c = ops::Reshape(root, arg2, b);
45   auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
46 
47   FixupSourceAndSinkEdges(root.graph());
48 
49   std::vector<bool> const_args(4, false);
50   std::vector<bool> const_nodes(root.graph()->num_node_ids(), false);
51   TF_ASSERT_OK(BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes,
52                                       /*flib_runtime=*/nullptr));
53 
54   // Arg 0 doesn't need to be constant since the graph only uses its shape.
55   // Arg 1 must be constant because it flows to the shape argument of a Reshape.
56   // Arg 2 is used only as the value input to a Reshape and need not be const.
57   // Arg 3 is used as the reduction-indices argument to Sum and must be const.
58   EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
59 
60   EXPECT_FALSE(const_nodes[arg0.node()->id()]);
61   EXPECT_TRUE(const_nodes[arg1.node()->id()]);
62   EXPECT_FALSE(const_nodes[arg2.node()->id()]);
63   EXPECT_TRUE(const_nodes[arg3.node()->id()]);
64 }
65 
66 // Regression test for a case where the backward const analysis did
67 // not visit nodes in topological order.
TEST(ConstAnalysisTest,TopologicalOrder)68 TEST(ConstAnalysisTest, TopologicalOrder) {
69   for (bool order : {false, true}) {
70     Scope root = Scope::NewRootScope();
71 
72     auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
73     auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
74     auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
75     auto a = ops::Reshape(root, arg0, arg1);
76     auto b = ops::Reshape(root, arg2, a);
77     if (order) {
78       // Consider both orders for arguments to the Sum so we aren't sensitive
79       // to the DFS traversal order.
80       std::swap(a, b);
81     }
82     auto c = ops::Add(root, a, b);
83 
84     Graph graph(OpRegistry::Global());
85     TF_ASSERT_OK(root.ToGraph(&graph));
86 
87     std::vector<bool> const_args(3, false);
88     TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
89                                         /*compile_time_const_nodes=*/nullptr,
90                                         /*flib_runtime=*/nullptr));
91 
92     EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
93   }
94 }
95 
TestFunctionCall(bool is_stateful_partitioned_call)96 void TestFunctionCall(bool is_stateful_partitioned_call) {
97   FunctionDef callee = FunctionDefHelper::Define(
98       "Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
99       {{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
100 
101   FunctionDefLibrary flib;
102   *flib.add_function() = callee;
103   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
104 
105   Scope root = Scope::NewRootScope().ExitOnError();
106 
107   auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
108   auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
109 
110   NameAttrList call_attrs;
111   call_attrs.set_name("Callee");
112   if (is_stateful_partitioned_call) {
113     ops::StatefulPartitionedCall b(root.WithOpName("Call"),
114                                    {Output(arg0), Output(arg1)}, {DT_FLOAT},
115                                    call_attrs);
116   } else {
117     ops::PartitionedCall b(root.WithOpName("Call"),
118                            {Output(arg0), Output(arg1)}, {DT_FLOAT},
119                            call_attrs);
120   }
121 
122   Graph graph(&flib_def);
123   TF_ASSERT_OK(root.ToGraph(&graph));
124 
125   OptimizerOptions opts;
126   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
127       new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
128                                         /*config=*/nullptr,
129                                         TF_GRAPH_DEF_VERSION, &flib_def, opts));
130   FunctionLibraryRuntime* lib_runtime =
131       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
132 
133   std::vector<bool> const_args(2, false);
134   TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
135                                       /*compile_time_const_nodes=*/nullptr,
136                                       lib_runtime));
137 
138   EXPECT_EQ(const_args, std::vector<bool>({false, true}));
139 }
140 
TEST(ConstAnalysisTest,PartitionedCall)141 TEST(ConstAnalysisTest, PartitionedCall) {
142   TestFunctionCall(/*is_stateful_partitioned_call=*/false);
143 }
144 
TEST(ConstAnalysisTest,StatefulPartitionedCall)145 TEST(ConstAnalysisTest, StatefulPartitionedCall) {
146   TestFunctionCall(/*is_stateful_partitioned_call=*/true);
147 }
148 
TEST(ConstAnalysisTest,DontFollowControlDependencies)149 TEST(ConstAnalysisTest, DontFollowControlDependencies) {
150   Scope root = Scope::NewRootScope();
151 
152   Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
153   Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
154   Output c1 =
155       ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
156   Output add = ops::Add(root, arg1, c1);
157   Output reshape = ops::Reshape(root, arg1, add);
158 
159   Graph graph(OpRegistry::Global());
160   TF_ASSERT_OK(root.ToGraph(&graph));
161 
162   std::vector<bool> const_args(2, false);
163   TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
164                                       /*compile_time_const_nodes=*/nullptr,
165                                       /*flib_runtime=*/nullptr));
166 
167   EXPECT_EQ(const_args, std::vector<bool>({false, true}));
168 }
169 
TEST(ConstAnalysisTest,RespectExplicitAttr_0)170 TEST(ConstAnalysisTest, RespectExplicitAttr_0) {
171   Scope root = Scope::NewRootScope();
172 
173   Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
174   Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
175   Output c1 =
176       ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
177   Output add = ops::Add(root, arg1, c1);
178 
179   // Force const analysis to pretend that the shape argument to `reshape` does
180   // not need to be a constant.
181   Output reshape = ops::Reshape(root, arg1, add);
182   reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr,
183                           std::vector<string>());
184 
185   Graph graph(OpRegistry::Global());
186   TF_ASSERT_OK(root.ToGraph(&graph));
187 
188   std::vector<bool> const_args(2, false);
189   TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
190                                       /*compile_time_const_nodes=*/nullptr,
191                                       /*flib_runtime=*/nullptr));
192 
193   EXPECT_EQ(const_args, std::vector<bool>({false, false}));
194 }
195 
TEST(ConstAnalysisTest,RespectExplicitAttr_1)196 TEST(ConstAnalysisTest, RespectExplicitAttr_1) {
197   Scope root = Scope::NewRootScope();
198 
199   Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
200   Output c1 =
201       ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
202   Output add = ops::Add(root, arg0, c1);
203 
204   // Force const analysis to pretend that the first argument to `add` needs to
205   // be a constant.
206   std::vector<string> add_constant_inputs;
207   add_constant_inputs.push_back("x");
208   add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs);
209 
210   Graph graph(OpRegistry::Global());
211   TF_ASSERT_OK(root.ToGraph(&graph));
212 
213   std::vector<bool> const_args(1, false);
214   TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
215                                       /*compile_time_const_nodes=*/nullptr,
216                                       /*flib_runtime=*/nullptr));
217 
218   EXPECT_EQ(const_args, std::vector<bool>({true}));
219 }
220 
__anon997b50920202null221 static bool Initialized = [] {
222   tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
223   return true;
224 }();
225 
226 }  // namespace
227 }  // namespace tensorflow
228