1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/tensor_coding.h"
32
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36
37 class ConstantFoldingTest : public GrapplerTest {
38 protected:
39 template <DataType DTYPE>
SimpleNeutralElementTest()40 void SimpleNeutralElementTest() {
41 for (bool use_snapshot : {false, true}) {
42 typedef typename EnumToDataType<DTYPE>::Type T;
43 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44 Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
45 ops::Placeholder::Shape(TensorShape({2, 2})));
46 Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
47 Tensor zeros_t(DTYPE, TensorShape({2, 2}));
48 Tensor ones_t(DTYPE, TensorShape({2, 2}));
49 Tensor x_t(DTYPE, TensorShape({2, 2}));
50 for (int i = 0; i < 4; ++i) {
51 zeros_t.flat<T>()(i) = T(0);
52 ones_t.flat<T>()(i) = T(1);
53 x_t.flat<T>()(i) = T(i + 1);
54 }
55 Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
56 Output ones = ops::Const(s.WithOpName("ones"), ones_t);
57 Output mul1;
58 Output mul2;
59 Output add1;
60 Output add2;
61 if (DTYPE == DT_BOOL) {
62 mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
63 mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
64 add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
65 add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
66 } else {
67 if (DTYPE == DT_FLOAT) {
68 mul1 = ops::MulNoNan(s.WithOpName("mul1"), x, zeros);
69 } else {
70 mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
71 }
72 mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
73 add1 = ops::Add(s.WithOpName("add1"), x, zeros);
74 add1 = ops::Add(s.WithOpName("add2"), x, ones);
75 }
76 if (use_snapshot) {
77 // Add an op with ref input to prevent Snapshot from being
78 // turned into Identity.
79 ops::Assign(s.WithOpName("assign"), v, ones);
80 }
81 GrapplerItem item;
82 TF_CHECK_OK(s.ToGraphDef(&item.graph));
83 item.fetch = {"mul1", "mul2", "add1", "add2"};
84 ConstantFolding optimizer(/*cpu_device=*/nullptr);
85 GraphDef output;
86 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
87 TF_EXPECT_OK(status);
88
89 EXPECT_EQ(7, output.node_size());
90 const string snapshot_or_identity =
91 use_snapshot ? "Snapshot" : "Identity";
92 for (int i = 0; i < output.node_size(); ++i) {
93 const NodeDef& node = output.node(i);
94 const string& name = node.name();
95 if (name == "mul1") {
96 EXPECT_EQ("Const", node.op());
97 EXPECT_EQ("^x", node.input(0));
98 EXPECT_EQ("^zeros", node.input(1));
99 } else if (name == "mul2") {
100 EXPECT_EQ(snapshot_or_identity, node.op());
101 EXPECT_EQ("x", node.input(0));
102 EXPECT_EQ("^ones", node.input(1));
103 } else if (name == "add1") {
104 EXPECT_EQ(snapshot_or_identity, node.op());
105 EXPECT_EQ("x", node.input(0));
106 EXPECT_EQ("^zeros", node.input(1));
107 } else if (name == "add2") {
108 if (DTYPE == DT_BOOL) {
109 EXPECT_EQ("Const", node.op());
110 EXPECT_EQ("^x", node.input(0));
111 EXPECT_EQ("^ones", node.input(1));
112 } else {
113 EXPECT_EQ("Add", node.op());
114 EXPECT_EQ("x", node.input(0));
115 EXPECT_EQ("ones", node.input(1));
116 }
117 }
118 }
119 auto tensors_expected =
120 EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
121 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
122 EXPECT_EQ(4, tensors_expected.size());
123 EXPECT_EQ(4, tensors.size());
124 for (int i = 0; i < item.fetch.size(); ++i) {
125 test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
126 }
127 }
128 }
129
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)130 void MulConvPushDownTest(const TensorShape& input_shape,
131 const TensorShape& filter_shape,
132 const TensorShape& mul_const_input_shape,
133 const bool use_3d_conv, const char* padding,
134 const char* data_format, const bool expect_folded) {
135 // Tests if the following rewrite is performed:
136 //
137 // * Conv2D
138 // / \ / \
139 // c Conv2D --> x (c * filter)
140 // / \
141 // x filter
142 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143
144 Tensor filter_values(DT_FLOAT, filter_shape);
145 for (int i = 0; i < filter_values.NumElements(); ++i) {
146 filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
147 }
148 Output filter =
149 ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
150
151 Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
152 ops::Placeholder::Shape(input_shape));
153
154 Output conv;
155 if (use_3d_conv) {
156 conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
157 padding, ops::Conv3D::DataFormat(data_format));
158 } else {
159 conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
160 padding, ops::Conv2D::DataFormat(data_format));
161 }
162 Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
163 for (int i = 0; i < mul_const_input.NumElements(); ++i) {
164 mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
165 }
166 Output c =
167 ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
168 Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
169
170 GrapplerItem item;
171 TF_CHECK_OK(s.ToGraphDef(&item.graph));
172
173 ConstantFolding optimizer(/*cpu_device=*/nullptr);
174 GraphDef output;
175 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
176 TF_EXPECT_OK(status);
177
178 EXPECT_EQ(5, output.node_size());
179 int found = 0;
180 if (expect_folded) {
181 for (const auto& node : output.node()) {
182 if (node.name() == "mul") {
183 found++;
184 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
185 EXPECT_EQ(2, node.input_size());
186 EXPECT_EQ("x", node.input(0));
187 EXPECT_EQ("conv/merged_input", node.input(1));
188 } else if (node.name() == "conv/merged_input") {
189 found++;
190 EXPECT_EQ("Const", node.op());
191 EXPECT_EQ(0, node.input_size());
192 }
193 }
194 } else {
195 for (const auto& node : output.node()) {
196 if (node.name() == "mul") {
197 found++;
198 EXPECT_EQ("Mul", node.op());
199 EXPECT_EQ(2, node.input_size());
200 EXPECT_EQ("c", node.input(0));
201 EXPECT_EQ("conv", node.input(1));
202 } else if (node.name() == "conv") {
203 found++;
204 EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
205 EXPECT_EQ(2, node.input_size());
206 EXPECT_EQ("x", node.input(0));
207 EXPECT_EQ("filter", node.input(1));
208 }
209 }
210 }
211 EXPECT_EQ(2, found);
212
213 // Check that const folded multiplication node has the expected value.
214 std::vector<string> fetch = {"mul"};
215 Tensor value(DT_FLOAT, input_shape);
216 for (int i = 0; i < value.NumElements(); ++i) {
217 value.flat<float>()(i) = i;
218 }
219 auto actual = EvaluateNodes(output, fetch, {{"x", value}});
220 auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
221 test::ExpectTensorEqual<float>(expected[0], actual[0]);
222 }
223
224 template <typename T>
PaddingWithZeroSize()225 void PaddingWithZeroSize() {
226 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
227
228 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
229 auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
230 auto paddings1 =
231 ops::Const<T>(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
232 auto paddings2 =
233 ops::Const<T>(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
234 auto c1 = ops::Const(scope.WithOpName("c1"), 1);
235 auto c2 = ops::Const(scope.WithOpName("c2"), 1);
236
237 ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
238 ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
239
240 ops::Add out(scope.WithOpName("out"), p1, p2);
241
242 GrapplerItem item;
243 item.fetch = {"out"};
244 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
245
246 ConstantFolding optimizer(/*cpu_device=*/nullptr);
247 GraphDef got;
248 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
249 TF_EXPECT_OK(status);
250
251 GraphDef want;
252 AddNode("in1", "VariableV2", {}, {}, &want);
253 AddNode("in2", "VariableV2", {}, {}, &want);
254 AddNode("paddings1", "Const", {}, {}, &want);
255 AddNode("paddings2", "Const", {}, {}, &want);
256 AddNode("c1", "Const", {}, {}, &want);
257 AddNode("c2", "Const", {}, {}, &want);
258 AddNode(
259 "p1", "Identity",
260 {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
261 {}, &want);
262 AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
263 AddNode("out", "Add", {"p1", "p2"}, {}, &want);
264
265 CompareGraphs(want, got);
266
267 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
268 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
269 auto tensors_expected =
270 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
271 EXPECT_EQ(1, tensors_expected.size());
272 auto tensors =
273 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
274 EXPECT_EQ(1, tensors.size());
275 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
276 }
277 };
278
TEST_F(ConstantFoldingTest,SimpleFolding)279 TEST_F(ConstantFoldingTest, SimpleFolding) {
280 // Build a simple graph with a few trivially prunable ops.
281 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
282
283 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
284 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
285 Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
286 Output d = ops::AddN(s.WithOpName("d"), {b, c});
287
288 GrapplerItem item;
289 item.fetch.push_back("d");
290 TF_CHECK_OK(s.ToGraphDef(&item.graph));
291
292 ConstantFolding optimizer(/*cpu_device=*/nullptr);
293 GraphDef output;
294 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
295 TF_EXPECT_OK(status);
296
297 EXPECT_EQ(1, output.node_size());
298
299 const NodeDef& node_d = output.node(0);
300 EXPECT_EQ("d", node_d.name());
301 EXPECT_EQ("Const", node_d.op());
302
303 std::vector<string> fetch = {"d"};
304 auto tensors_expected = EvaluateNodes(item.graph, fetch);
305 auto tensors = EvaluateNodes(output, fetch);
306 EXPECT_EQ(1, tensors_expected.size());
307 EXPECT_EQ(1, tensors.size());
308 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
309 }
310
TEST_F(ConstantFoldingTest,AddTree)311 TEST_F(ConstantFoldingTest, AddTree) {
312 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
313
314 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
315 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
316 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
317 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
318 ops::Placeholder::Shape(TensorShape({2, 2})));
319 Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
320 Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
321
322 Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
323 Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
324 Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
325 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
326 ops::Placeholder::Shape(TensorShape({2, 2})));
327 Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
328 Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
329 Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
330 Output addmul_parent =
331 ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
332
333 GrapplerItem item;
334 item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
335 TF_CHECK_OK(s.ToGraphDef(&item.graph));
336
337 ConstantFolding optimizer(/*cpu_device=*/nullptr);
338 GraphDef output;
339 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
340 TF_EXPECT_OK(status);
341
342 // We expect the following rewrite(s) to occur:
343 //
344 // + + +
345 // / \ / \ / \
346 // 1.0 + --> x + --> x 3.0
347 // / \ / \
348 // 2.0 x 1.0 2.0
349 //
350 // * * *
351 // / \ / \ / \
352 // 4.0 * --> y * --> y 20.0
353 // / \ / \
354 // 5.0 y 4.0 5.0
355
356 EXPECT_EQ(10, output.node_size());
357 for (const auto& node : output.node()) {
358 if (node.name() == "add_child") {
359 EXPECT_EQ("Const", node.op());
360 TensorProto t = node.attr().at("value").tensor();
361 ASSERT_EQ(1, t.tensor_shape().dim_size());
362 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
363 } else if (node.name() == "add_parent") {
364 EXPECT_EQ("Add", node.op());
365 ASSERT_EQ(2, node.input_size());
366 EXPECT_EQ("x", node.input(0));
367 EXPECT_EQ("add_child", node.input(1));
368 } else if (node.name() == "mul_child") {
369 EXPECT_EQ("Const", node.op());
370 TensorProto t = node.attr().at("value").tensor();
371 EXPECT_EQ(1, t.tensor_shape().dim_size());
372 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
373 } else if (node.name() == "mul_parent") {
374 EXPECT_EQ("Mul", node.op());
375 ASSERT_EQ(2, node.input_size());
376 EXPECT_EQ("y", node.input(0));
377 EXPECT_EQ("mul_child", node.input(1));
378 } else if (node.name() == "addmul_child") {
379 // Unchanged.
380 EXPECT_EQ("Add", node.op());
381 ASSERT_EQ(2, node.input_size());
382 EXPECT_EQ("c4", node.input(0));
383 EXPECT_EQ("x", node.input(1));
384 }
385 }
386
387 // Check that the result nodes have the expected value.
388 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
389 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
390
391 std::vector<string> fetch = {"add_parent", "mul_parent"};
392 auto tensor_expected =
393 EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
394 ASSERT_EQ(fetch.size(), tensor_expected.size());
395 fetch = {"add_parent", "mul_parent"};
396 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
397 ASSERT_EQ(fetch.size(), tensors.size());
398 for (int i = 0; i < fetch.size(); i++) {
399 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
400 }
401 }
402
TEST_F(ConstantFoldingTest,AddSubtactTree)403 TEST_F(ConstantFoldingTest, AddSubtactTree) {
404 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
405
406 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
407 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
408 ops::Placeholder::Shape(TensorShape({2, 2})));
409 Output sub_child = ops::Sub(s.WithOpName("sub_child"), x, x);
410 Output add_parent = ops::Add(s.WithOpName("add_parent"), sub_child, c1);
411
412 GrapplerItem item;
413 item.fetch = {"add_parent"};
414 TF_CHECK_OK(s.ToGraphDef(&item.graph));
415
416 ConstantFolding optimizer(/*cpu_device=*/nullptr);
417 GraphDef output;
418 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
419 TF_EXPECT_OK(status);
420
421 // We expect the following rewrite(s) to occur:
422 //
423 // + +
424 // / \ / \
425 // - 1 --> - x
426 // / \ / \
427 // x x 1 x
428
429 EXPECT_EQ(4, output.node_size());
430 for (const auto& node : output.node()) {
431 if (node.name() == "sub_child") {
432 EXPECT_EQ("Sub", node.op());
433 ASSERT_EQ(2, node.input_size());
434 EXPECT_EQ("c1", node.input(0));
435 EXPECT_EQ("x", node.input(1));
436 } else if (node.name() == "add_parent") {
437 EXPECT_EQ("Add", node.op());
438 ASSERT_EQ(2, node.input_size());
439 EXPECT_EQ("x", node.input(0));
440 EXPECT_EQ("sub_child", node.input(1));
441 }
442 }
443
444 // Check that the result nodes have the expected value.
445 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
446
447 std::vector<string> fetch = {"add_parent"};
448 auto tensor_expected = EvaluateNodes(item.graph, fetch, {{"x", x_t}});
449 ASSERT_EQ(fetch.size(), tensor_expected.size());
450 fetch = {"add_parent"};
451 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
452 ASSERT_EQ(fetch.size(), tensors.size());
453 for (int i = 0; i < fetch.size(); i++) {
454 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
455 }
456 }
457
TEST_F(ConstantFoldingTest,ConstantPushDown)458 TEST_F(ConstantFoldingTest, ConstantPushDown) {
459 for (int is_add : {true, false}) {
460 for (int is_parent_commutative : {true, false}) {
461 for (int is_child_commutative : {true, false}) {
462 for (int is_left_child_const : {true, false}) {
463 for (int is_left_leaf_const : {true, false}) {
464 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
465 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
466 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
467 Output x =
468 ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
469 ops::Placeholder::Shape(TensorShape({2, 2})));
470
471 auto get_op = [&](bool is_commutative, bool is_left_arg_const,
472 const string& name, const Output& const_arg,
473 const Output non_const_arg) -> Output {
474 if (is_add) {
475 if (is_commutative) {
476 return ops::Add(
477 s.WithOpName(name),
478 is_left_arg_const ? const_arg : non_const_arg,
479 is_left_arg_const ? non_const_arg : const_arg);
480 } else {
481 return ops::Sub(
482 s.WithOpName(name),
483 is_left_arg_const ? const_arg : non_const_arg,
484 is_left_arg_const ? non_const_arg : const_arg);
485 }
486 } else {
487 if (is_commutative) {
488 return ops::Mul(
489 s.WithOpName(name),
490 is_left_arg_const ? const_arg : non_const_arg,
491 is_left_arg_const ? non_const_arg : const_arg);
492 } else {
493 return ops::Div(
494 s.WithOpName(name),
495 is_left_arg_const ? const_arg : non_const_arg,
496 is_left_arg_const ? non_const_arg : const_arg);
497 }
498 }
499 };
500
501 Output child = get_op(is_child_commutative, is_left_leaf_const,
502 "child", c2, x);
503 Output parent = get_op(is_parent_commutative, is_left_child_const,
504 "parent", c3, child);
505 GrapplerItem item;
506 item.fetch = {"parent"};
507 TF_CHECK_OK(s.ToGraphDef(&item.graph));
508
509 ConstantFolding optimizer(/*cpu_device=*/nullptr);
510 GraphDef output;
511 Status status =
512 optimizer.Optimize(/*cluster=*/nullptr, item, &output);
513 TF_EXPECT_OK(status);
514
515 // Check that the result nodes have the expected value.
516 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
517 std::vector<string> fetch = {"parent"};
518 auto tensor_expected =
519 EvaluateNodes(item.graph, fetch, {{"x", x_t}});
520 ASSERT_EQ(fetch.size(), tensor_expected.size());
521 fetch = {"parent"};
522 auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
523 ASSERT_EQ(fetch.size(), tensors.size());
524 for (int i = 0; i < fetch.size(); i++) {
525 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
526 }
527 }
528 }
529 }
530 }
531 }
532 }
533
TEST_F(ConstantFoldingTest,ConstantPushDownBiasAdd)534 TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
535 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
536 Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
537 Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
538 Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
539 ops::Placeholder::Shape(TensorShape({2, 2})));
540 Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
541 ops::Placeholder::Shape(TensorShape({2})));
542 // Rewrite expected for cases 1 through 3 and their symmetric equivalents,
543 // and case 4.
544 Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
545 Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
546 Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
547 Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
548
549 Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
550 Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
551 Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
552 Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
553
554 Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
555 Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
556 Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
557 Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
558
559 Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
560 Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
561
562 // No rewrite expected.
563 Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
564 Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
565 Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
566 Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
567 Output child7 = ops::Add(s.WithOpName("child7"), x_mat, c_vec);
568 Output parent7 = ops::BiasAdd(s.WithOpName("parent7"), child7, c_vec);
569
570 GrapplerItem item;
571 item.fetch = {"parent1", "parent2", "parent3", "parent1a", "parent2a",
572 "parent3a", "parent4", "parent5", "parent6", "parent7"};
573 TF_CHECK_OK(s.ToGraphDef(&item.graph));
574
575 ConstantFolding optimizer(/*cpu_device=*/nullptr);
576 GraphDef output;
577 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
578 TF_EXPECT_OK(status);
579
580 EXPECT_EQ(24, output.node_size());
581 for (const auto& node : output.node()) {
582 if (node.name() == "child1" || node.name() == "child1a" ||
583 node.name() == "child2" || node.name() == "child2a" ||
584 node.name() == "child3" || node.name() == "child3a" ||
585 node.name() == "child4") {
586 EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
587 } else if (node.name() != "c_mat" && node.name() != "c_vec") {
588 EXPECT_NE(node.op(), "Const") << " node: " << node.name();
589 }
590 }
591 // Check that the result nodes have the expected value.
592 auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
593 auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
594 std::vector<string> fetch = item.fetch;
595 auto tensor_expected = EvaluateNodes(
596 item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
597 ASSERT_EQ(fetch.size(), tensor_expected.size());
598 auto tensors =
599 EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
600 ASSERT_EQ(fetch.size(), tensors.size());
601 for (int i = 0; i < fetch.size(); i++) {
602 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
603 }
604 }
605
606 // This test fails on ROCm platform (see commit message for details)
607 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)608 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
609 for (string data_format : {
610 "NHWC",
611 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
612 "NCHW"
613 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
614 }) {
615 MulConvPushDownTest(
616 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
617 : TensorShape{4, 3, 10, 10},
618 /*filter_shape=*/{2, 2, 3, 5},
619 /*mul_const_input_shape=*/{},
620 /*use_3d_conv=*/false,
621 /*padding=*/"VALID", data_format.c_str(),
622 /*expect_folded=*/true);
623 }
624 }
625 #endif
626
627 // This test fails on ROCm platform (see commit message for details)
628 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)629 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
630 for (string data_format : {
631 "NHWC",
632 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633 "NCHW"
634 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635 }) {
636 for (auto mul_const_input_shape :
637 {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
638 MulConvPushDownTest(
639 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
640 : TensorShape{4, 3, 10, 10},
641 /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
642 /*use_3d_conv=*/false,
643 /*padding=*/"VALID", data_format.c_str(),
644 /*expect_folded=*/true);
645 }
646 }
647 }
648 #endif
649
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)650 TEST_F(ConstantFoldingTest,
651 MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
652 for (string data_format : {
653 "NHWC",
654 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
655 "NCHW"
656 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
657 }) {
658 MulConvPushDownTest(
659 /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
660 : TensorShape{4, 3, 10, 10},
661 /*filter_shape=*/{2, 2, 3, 5},
662 /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
663 /*use_3d_conv=*/false,
664 /*padding=*/"VALID", data_format.c_str(),
665 /*expect_folded=*/false);
666 }
667 }
668
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)669 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
670 for (auto data_format : {
671 "NHWC",
672 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
673 "NCHW"
674 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
675 }) {
676 MulConvPushDownTest(
677 /*input_shape=*/{3, 3, 3, 3},
678 /*filter_shape=*/{3, 3, 3, 3},
679 /*mul_const_input_shape=*/{3, 1, 3},
680 /*use_3d_conv=*/false,
681 /*padding=*/"SAME", data_format,
682 /*expect_folded=*/false);
683 }
684 }
685
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)686 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
687 for (auto mul_const_input_shape :
688 {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
689 MulConvPushDownTest(
690 /*input_shape=*/{3, 3, 3, 3},
691 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
692 /*use_3d_conv=*/false,
693 /*padding=*/"SAME",
694 /*data_format=*/"NHWC",
695 /*expect_folded=*/true);
696 }
697 }
698
699 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)700 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
701 for (auto mul_const_input_shape :
702 {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
703 MulConvPushDownTest(
704 /*input_shape=*/{3, 3, 3, 3},
705 /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
706 /*use_3d_conv=*/false,
707 /*padding=*/"SAME",
708 /*data_format=*/"NCHW",
709 // TODO(laigd): optimization should happen in this case.
710 /*expect_folded=*/false);
711 }
712 }
713 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
714
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)715 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
716 for (auto data_format : {
717 "NHWC",
718 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
719 "NCHW"
720 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
721 }) {
722 MulConvPushDownTest(
723 /*input_shape=*/{3, 3, 3, 3},
724 /*filter_shape=*/{3, 3, 3, 3},
725 /*mul_const_input_shape=*/{3, 1},
726 /*use_3d_conv=*/false,
727 /*padding=*/"SAME", data_format,
728 /*expect_folded=*/false);
729 }
730 }
731
732 // This test fails on ROCm platform (see commit message for details)
733 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)734 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
735 MulConvPushDownTest(
736 /*input_shape=*/{3, 3, 3, 3, 3},
737 /*filter_shape=*/{3, 3, 3, 3, 3},
738 /*mul_const_input_shape=*/{1, 1, 3},
739 /*use_3d_conv=*/true,
740 /*padding=*/"SAME",
741 /*data_format=*/"NDHWC",
742 /*expect_folded=*/true);
743 }
744 #endif
745
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)746 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
747 MulConvPushDownTest(
748 /*input_shape=*/{3, 3, 3, 3, 3},
749 /*filter_shape=*/{3, 3, 3, 3, 3},
750 /*mul_const_input_shape=*/{3, 1, 1, 1},
751 /*use_3d_conv=*/true,
752 /*padding=*/"SAME",
753 /*data_format=*/"NDHWC",
754 // TODO(laigd): optimization should happen in this case.
755 /*expect_folded=*/false);
756 }
757
TEST_F(ConstantFoldingTest,NeutralElement)758 TEST_F(ConstantFoldingTest, NeutralElement) {
759 int kConst = 0;
760 int kLike = 1;
761 int kFill = 2;
762 for (int const_type : {kConst, kLike, kFill}) {
763 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
764 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
765 ops::Placeholder::Shape(TensorShape({2, 2})));
766 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
767 ops::Placeholder::Shape(TensorShape({2, 2})));
768 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
769 ops::Placeholder::Shape(TensorShape({3, 2})));
770 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
771 ops::Placeholder::Shape(TensorShape({2, 3})));
772 Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
773 ops::Placeholder::Shape(TensorShape({2})));
774 Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
775 Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
776 Output zeros_const_bcast =
777 ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
778 Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
779 Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
780 Output zeros = const_type == kConst
781 ? zeros_const
782 : (const_type == kLike ? zeros_like : zeros_fill);
783 Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
784 Output ones_const_bcast =
785 ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
786 Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
787 Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
788 Output ones = const_type == kConst
789 ? ones_const
790 : (const_type == kLike ? ones_like : ones_fill);
791 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
792 Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
793 Output mul1_bcast =
794 ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
795 Output mul2_bcast =
796 ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
797 Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
798 Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
799 Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
800 Output mul6 = ops::MulNoNan(s.WithOpName("mul6"), zeros_1d, y);
801 Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
802 Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
803 Output floordiv = ops::FloorDiv(s.WithOpName("floordiv"), x, ones);
804 Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
805 Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
806 Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
807 Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
808 Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
809 Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
810 Output add1_bcast =
811 ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
812 Output add2_bcast =
813 ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
814 Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
815 Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
816 Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
817 Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
818 Output concat = ops::Stack(
819 s.WithOpName("stack"),
820 {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, floordiv, matmul1,
821 matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
822 GrapplerItem item;
823 TF_CHECK_OK(s.ToGraphDef(&item.graph));
824 item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
825 "mul2_bcast", "add1_bcast", "add2_bcast"};
826
827 ConstantFolding optimizer(/*cpu_device=*/nullptr);
828 GraphDef output;
829 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
830 TF_EXPECT_OK(status);
831
832 const string suffix =
833 (const_type == kConst ? "_const"
834 : (const_type == kLike ? "_like" : "_fill"));
835 const string zeros_name = strings::StrCat("zeros", suffix);
836 const string ones_name = strings::StrCat("ones", suffix);
837 const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
838 const string ctrl_ones_name = strings::StrCat("^ones", suffix);
839
840 EXPECT_EQ(const_type == kFill ? 43 : 39, output.node_size());
841 for (int i = 0; i < output.node_size(); ++i) {
842 const NodeDef& node = output.node(i);
843 const string& name = node.name();
844 if (name == "mul1") {
845 EXPECT_EQ("Const", node.op());
846 EXPECT_EQ("^x", node.input(0));
847 EXPECT_EQ(ctrl_zeros_name, node.input(1));
848 } else if (name == "mul2") {
849 EXPECT_EQ("Const", node.op());
850 EXPECT_EQ(ctrl_zeros_name, node.input(0));
851 EXPECT_EQ("^y", node.input(1));
852 } else if (name == "mul1_bcast") {
853 EXPECT_EQ("BroadcastTo", node.op());
854 EXPECT_EQ("x", node.input(0));
855 EXPECT_EQ("^ones_const_bcast", node.input(2));
856 } else if (name == "mul2_bcast") {
857 EXPECT_EQ("BroadcastTo", node.op());
858 EXPECT_EQ("y", node.input(0));
859 EXPECT_EQ("^ones_const_bcast", node.input(2));
860 } else if (name == "mul3") {
861 EXPECT_EQ("Identity", node.op());
862 EXPECT_EQ("x", node.input(0));
863 EXPECT_EQ(ctrl_ones_name, node.input(1));
864 } else if (name == "mul4") {
865 EXPECT_EQ("Identity", node.op());
866 EXPECT_EQ("y", node.input(0));
867 EXPECT_EQ(ctrl_ones_name, node.input(1));
868 } else if (name == "mul5") {
869 EXPECT_EQ("Const", node.op());
870 EXPECT_EQ("^x", node.input(0));
871 EXPECT_EQ("^zeros_1d", node.input(1));
872 } else if (name == "mul6") {
873 EXPECT_EQ("Const", node.op());
874 EXPECT_EQ("^zeros_1d", node.input(0));
875 EXPECT_EQ("^y", node.input(1));
876 } else if (name == "div1") {
877 EXPECT_EQ("Identity", node.op());
878 EXPECT_EQ("x", node.input(0));
879 EXPECT_EQ(ctrl_ones_name, node.input(1));
880 } else if (name == "div2") {
881 EXPECT_EQ("Reciprocal", node.op());
882 EXPECT_EQ("y", node.input(0));
883 EXPECT_EQ(ctrl_ones_name, node.input(1));
884 } else if (name == "floordiv") {
885 EXPECT_EQ("FloorDiv", node.op());
886 EXPECT_EQ("x", node.input(0));
887 EXPECT_EQ(ones_name, node.input(1));
888 } else if (name == "matmul1") {
889 EXPECT_EQ("Const", node.op());
890 EXPECT_EQ("^x", node.input(0));
891 EXPECT_EQ(ctrl_zeros_name, node.input(1));
892 } else if (name == "matmul2") {
893 EXPECT_EQ("Const", node.op());
894 EXPECT_EQ(ctrl_zeros_name, node.input(0));
895 EXPECT_EQ("^y", node.input(1));
896 } else if (name == "matmul3") {
897 EXPECT_EQ("Const", node.op());
898 EXPECT_EQ("^a", node.input(0));
899 EXPECT_EQ(ctrl_zeros_name, node.input(1));
900 TensorProto t = node.attr().at("value").tensor();
901 EXPECT_EQ(1, t.float_val_size());
902 EXPECT_EQ(0, t.float_val(0));
903 EXPECT_EQ(2, t.tensor_shape().dim_size());
904 EXPECT_EQ(3, t.tensor_shape().dim(0).size());
905 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
906 } else if (name == "matmul4") {
907 EXPECT_EQ("Const", node.op());
908 EXPECT_EQ(ctrl_zeros_name, node.input(0));
909 EXPECT_EQ("^b", node.input(1));
910 TensorProto t = node.attr().at("value").tensor();
911 EXPECT_EQ(1, t.float_val_size());
912 EXPECT_EQ(0, t.float_val(0));
913 EXPECT_EQ(2, t.tensor_shape().dim_size());
914 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
915 EXPECT_EQ(3, t.tensor_shape().dim(1).size());
916 } else if (name == "add1") {
917 EXPECT_EQ("Identity", node.op());
918 EXPECT_EQ("x", node.input(0));
919 EXPECT_EQ(ctrl_zeros_name, node.input(1));
920 } else if (name == "add2") {
921 EXPECT_EQ("Identity", node.op());
922 EXPECT_EQ("y", node.input(0));
923 EXPECT_EQ(ctrl_zeros_name, node.input(1));
924 } else if (name == "add1_bcast") {
925 EXPECT_EQ("BroadcastTo", node.op());
926 EXPECT_EQ("x", node.input(0));
927 EXPECT_EQ("^zeros_const_bcast", node.input(2));
928 } else if (name == "add2_bcast") {
929 EXPECT_EQ("BroadcastTo", node.op());
930 EXPECT_EQ("y", node.input(0));
931 EXPECT_EQ("^zeros_const_bcast", node.input(2));
932 } else if (name == "bias_add1") {
933 EXPECT_EQ("Identity", node.op());
934 EXPECT_EQ("x", node.input(0));
935 EXPECT_EQ("^zeros_1d", node.input(1));
936 } else if (name == "bias_add2") {
937 EXPECT_EQ("BroadcastTo", node.op());
938 EXPECT_EQ("bias", node.input(0));
939 EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
940 node.input(1));
941 EXPECT_EQ(ctrl_zeros_name, node.input(2));
942 } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
943 EXPECT_EQ("Const", node.op());
944 EXPECT_EQ(ctrl_zeros_name, node.input(0));
945 EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
946 TensorProto t = node.attr().at("value").tensor();
947 EXPECT_EQ(DT_INT32, t.dtype());
948 EXPECT_EQ(1, t.tensor_shape().dim_size());
949 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
950 } else if (name == "sub1") {
951 EXPECT_EQ("Identity", node.op());
952 EXPECT_EQ("x", node.input(0));
953 EXPECT_EQ(ctrl_zeros_name, node.input(1));
954 } else if (name == "sub2") {
955 EXPECT_EQ("Neg", node.op());
956 EXPECT_EQ("y", node.input(0));
957 EXPECT_EQ(ctrl_zeros_name, node.input(1));
958 }
959 const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
960 "mul6", "matmul1", "matmul2"};
961 if (square_zero_const.count(name) > 0) {
962 TensorProto t = node.attr().at("value").tensor();
963 EXPECT_EQ(1, t.float_val_size());
964 EXPECT_EQ(0, t.float_val(0));
965 EXPECT_EQ(2, t.tensor_shape().dim_size());
966 EXPECT_EQ(2, t.tensor_shape().dim(0).size());
967 EXPECT_EQ(2, t.tensor_shape().dim(1).size());
968 }
969 }
970 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
971 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
972 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
973 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
974 auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
975
976 auto tensors_expected = EvaluateNodes(
977 item.graph, item.fetch,
978 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
979 EXPECT_EQ(item.fetch.size(), tensors_expected.size());
980 auto tensors = EvaluateNodes(
981 output, item.fetch,
982 {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
983 EXPECT_EQ(item.fetch.size(), tensors.size());
984 for (int i = 0; i < item.fetch.size(); ++i) {
985 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
986 }
987 }
988 }
989
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)990 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
991 SimpleNeutralElementTest<DT_BOOL>();
992 SimpleNeutralElementTest<DT_HALF>();
993 SimpleNeutralElementTest<DT_BFLOAT16>();
994 }
995
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)996 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
997 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
998 Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
999 Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
1000 ops::Placeholder::Shape(TensorShape({2, 2})));
1001 Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
1002 ops::Placeholder::Shape(TensorShape({2, 2})));
1003 Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
1004 Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
1005 Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
1006 Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
1007 Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
1008
1009 GrapplerItem item;
1010 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1011 item.fetch = {"div_f", "div_i", "realdiv"};
1012 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1013 GraphDef output;
1014 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1015 TF_EXPECT_OK(status);
1016
1017 EXPECT_EQ(8, output.node_size());
1018 for (int i = 0; i < output.node_size(); ++i) {
1019 const NodeDef& node = output.node(i);
1020 const string& name = node.name();
1021 if (name == "div_i") {
1022 // Integer division is unchanged.
1023 EXPECT_EQ("Div", node.op());
1024 EXPECT_EQ("xi", node.input(0));
1025 EXPECT_EQ("ci", node.input(1));
1026 } else if (name == "div_f") {
1027 EXPECT_EQ("Mul", node.op());
1028 EXPECT_EQ("xf", node.input(0));
1029 EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
1030 } else if (name == "realdiv") {
1031 EXPECT_EQ("Mul", node.op());
1032 EXPECT_EQ("xf", node.input(0));
1033 EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
1034 } else if (name == "ConstantFolding/div_f_recip") {
1035 EXPECT_EQ("Const", node.op());
1036 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1037 TensorProto t = node.attr().at("value").tensor();
1038 EXPECT_EQ(DT_FLOAT, t.dtype());
1039 EXPECT_EQ(1, t.tensor_shape().dim_size());
1040 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1041 } else if (name == "ConstantFolding/realdiv_recip") {
1042 EXPECT_EQ("Const", node.op());
1043 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1044 TensorProto t = node.attr().at("value").tensor();
1045 EXPECT_EQ(DT_FLOAT, t.dtype());
1046 EXPECT_EQ(1, t.tensor_shape().dim_size());
1047 EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1048 }
1049 }
1050
1051 // Check that the reciprocals have the expected value.
1052 std::vector<string> fetch = {"cf_half"};
1053 auto tensor_expected = EvaluateNodes(item.graph, fetch);
1054 EXPECT_EQ(fetch.size(), tensor_expected.size());
1055 fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
1056 auto tensors = EvaluateNodes(output, fetch);
1057 EXPECT_EQ(fetch.size(), tensors.size());
1058 for (int i = 0; i < fetch.size(); i++) {
1059 test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
1060 }
1061 }
1062
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)1063 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
1064 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065 Output x_known =
1066 ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
1067 ops::Placeholder::Shape(TensorShape({2, 2})));
1068 Output x_partially_known =
1069 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1070 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1071 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1072 Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
1073 Output zeros_partially_known =
1074 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1075 Output zeros_unknown =
1076 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1077
1078 // Multiplies without any additional ops to supply the output shape.
1079 int count = 0;
1080 std::vector<Output> muls;
1081 std::unordered_set<string> not_converted;
1082 std::unordered_set<string> to_const;
1083 std::unordered_set<string> to_identity;
1084 for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
1085 for (const auto* zeros :
1086 {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
1087 const string name = strings::StrCat("mul_", count++);
1088 muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
1089 if (x == &x_partially_known && zeros == &zeros_partially_known) {
1090 to_identity.insert(name);
1091 } else if (x == &x_unknown || zeros == &zeros_unknown) {
1092 not_converted.insert(name);
1093 } else {
1094 to_const.insert(name);
1095 }
1096 }
1097 }
1098
1099 GrapplerItem item;
1100 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1101
1102 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1103 GraphDef output;
1104 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1105 TF_EXPECT_OK(status);
1106
1107 EXPECT_EQ(15, output.node_size());
1108 for (int i = 0; i < output.node_size(); ++i) {
1109 const NodeDef& node = output.node(i);
1110 const string& name = node.name();
1111 if (to_const.count(name) > 0) {
1112 EXPECT_EQ("Const", node.op()) << node.name();
1113 } else if (to_identity.count(name) > 0) {
1114 EXPECT_EQ("Identity", node.op()) << node.name();
1115 } else if (not_converted.count(name) > 0) {
1116 EXPECT_EQ("Mul", node.op()) << node.name();
1117 }
1118 }
1119
1120 const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
1121 auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1122 auto x_partially_unknown_t =
1123 GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1124 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1125 auto expected_tensors =
1126 EvaluateNodes(item.graph, fetch,
1127 {{"x_known", x_known_t},
1128 {"x_partially_unknown", x_partially_unknown_t},
1129 {"x_unknown", x_unknown_t}});
1130 EXPECT_EQ(fetch.size(), expected_tensors.size());
1131 auto tensors = EvaluateNodes(output, fetch,
1132 {{"x_known", x_known_t},
1133 {"x_partially_unknown", x_partially_unknown_t},
1134 {"x_unknown", x_unknown_t}});
1135 EXPECT_EQ(fetch.size(), tensors.size());
1136 for (int i = 0; i < tensors.size(); i++)
1137 test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
1138 }
1139
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)1140 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
1141 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1142 Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
1143 Output x_partially_known =
1144 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1145 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1146 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1147 Output zeros_partially_known =
1148 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1149 Output zeros_unknown =
1150 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1151
1152 // If at least one of the inputs to AddN has a known shape, shape inference
1153 // will propagate the shape back to the inputs of AddN, making the
1154 // output shapes of all its inputs known
1155 std::vector<Output> muls_deduced_output_shape;
1156 std::unordered_set<string> to_const;
1157 int count = 0;
1158 for (const auto& x : {x_partially_known, x_unknown}) {
1159 for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
1160 const string name = strings::StrCat("mul_", count++);
1161 muls_deduced_output_shape.push_back(
1162 ops::Mul(s.WithOpName(name), x, zeros));
1163 to_const.insert(name);
1164 }
1165 }
1166 // We add a known shape as input to AddN to propagate it back to the
1167 // multiplies above, which means they can all be turned into Const nodes.
1168 muls_deduced_output_shape.push_back(known_shape);
1169 Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
1170
1171 GrapplerItem item;
1172 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1173
1174 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1175 GraphDef output;
1176 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1177 TF_EXPECT_OK(status);
1178
1179 EXPECT_EQ(10, output.node_size());
1180 for (int i = 0; i < output.node_size(); ++i) {
1181 const NodeDef& node = output.node(i);
1182 const string& name = node.name();
1183 if (to_const.count(name) > 0) {
1184 EXPECT_EQ("Const", node.op()) << node.name();
1185 EXPECT_EQ(2, node.input_size());
1186 EXPECT_TRUE(IsControlInput(node.input(0)));
1187 EXPECT_TRUE(IsControlInput(node.input(1)));
1188 }
1189 }
1190 const std::vector<string> fetch = {"addn1"};
1191 auto x_partially_unknown_t =
1192 GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1193 auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1194 auto expected_tensors =
1195 EvaluateNodes(item.graph, fetch,
1196 {{"x_partially_unknown", x_partially_unknown_t},
1197 {"x_unknown", x_unknown_t}});
1198 EXPECT_EQ(1, expected_tensors.size());
1199 auto tensors = EvaluateNodes(output, fetch,
1200 {{"x_partially_unknown", x_partially_unknown_t},
1201 {"x_unknown", x_unknown_t}});
1202 EXPECT_EQ(1, tensors.size());
1203 test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
1204 }
1205
TEST_F(ConstantFoldingTest,CreateConstNodes)1206 TEST_F(ConstantFoldingTest, CreateConstNodes) {
1207 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1208
1209 #define MAKE_TEST_GRAPH(TYPE) \
1210 Output TYPE##_const = \
1211 ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
1212 Output TYPE##_mul = \
1213 ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
1214 Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
1215
1216 MAKE_TEST_GRAPH(float);
1217 MAKE_TEST_GRAPH(double);
1218 MAKE_TEST_GRAPH(int64_t);
1219 MAKE_TEST_GRAPH(int32);
1220 MAKE_TEST_GRAPH(int16);
1221 MAKE_TEST_GRAPH(int8);
1222 MAKE_TEST_GRAPH(uint8);
1223 #undef MAKE_TEST_GRAPH
1224
1225 Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
1226 Output bool_and =
1227 ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
1228 Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
1229
1230 GrapplerItem item;
1231 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1232 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1233 GraphDef output;
1234 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1235 TF_EXPECT_OK(status);
1236
1237 EXPECT_EQ(24, output.node_size());
1238 for (const NodeDef& node : output.node()) {
1239 #define CHECK_RESULT(TYPE, FIELD) \
1240 if (node.name() == #TYPE "_mul") { \
1241 EXPECT_EQ(5, \
1242 node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
1243 EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
1244 EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
1245 }
1246
1247 CHECK_RESULT(float, float);
1248 CHECK_RESULT(double, double);
1249 CHECK_RESULT(int64, int64);
1250 CHECK_RESULT(int32, int);
1251 CHECK_RESULT(int16, int);
1252 CHECK_RESULT(int8, int);
1253 CHECK_RESULT(uint8, int);
1254 #undef CHECK_RESULT
1255
1256 if (node.name() == "bool_and") {
1257 EXPECT_EQ(5,
1258 node.attr().at("value").tensor().tensor_shape().dim(0).size());
1259 EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
1260 EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
1261 }
1262 }
1263 }
1264
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)1265 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
1266 // Build a simple graph with a few trivially prunable ops.
1267 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1268
1269 Output a = ops::Const(s.WithOpName("a"), 10, {5});
1270 auto b = ops::Unique(s.WithOpName("b"), {a});
1271 Output c = ops::Identity(s.WithOpName("c"), {b.y});
1272 Output d = ops::Identity(s.WithOpName("d"), {b.idx});
1273 Output e = ops::Identity(s.WithOpName("e"), {c});
1274 Output f = ops::Identity(s.WithOpName("f"), {d});
1275
1276 GrapplerItem item;
1277 item.fetch.push_back("e");
1278 item.fetch.push_back("f");
1279 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1280
1281 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1282 GraphDef output;
1283 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1284 TF_EXPECT_OK(status);
1285
1286 EXPECT_EQ(2, output.node_size());
1287
1288 const NodeDef& new_c = output.node(0);
1289 EXPECT_EQ("e", new_c.name());
1290 EXPECT_EQ("Const", new_c.op());
1291
1292 const NodeDef& new_d = output.node(1);
1293 EXPECT_EQ("f", new_d.name());
1294 EXPECT_EQ("Const", new_d.op());
1295
1296 std::vector<string> fetch = {"e", "f"};
1297 auto tensors_expected = EvaluateNodes(item.graph, fetch);
1298 auto tensors = EvaluateNodes(output, fetch);
1299 EXPECT_EQ(fetch.size(), tensors_expected.size());
1300 EXPECT_EQ(fetch.size(), tensors.size());
1301 for (int i = 0; i < fetch.size(); i++) {
1302 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1303 }
1304 }
1305
TEST_F(ConstantFoldingTest,ControlDependencies)1306 TEST_F(ConstantFoldingTest, ControlDependencies) {
1307 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1308 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1309 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1310 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1311 Output c =
1312 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1313 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1314 Output i2 =
1315 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1316 Output i3 = ops::Identity(scope.WithOpName("i3"), {i2});
1317
1318 GrapplerItem item;
1319 item.fetch.push_back("i3");
1320 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1321
1322 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1323 GraphDef output;
1324 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1325 TF_EXPECT_OK(status);
1326
1327 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i3"};
1328 EXPECT_EQ(output.node_size(), expected_nodes.size());
1329 int i = 0;
1330 int found = 0;
1331 for (const auto& node : output.node()) {
1332 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1333 i++;
1334 if (node.name() == "i3") {
1335 EXPECT_EQ("Const", node.op());
1336 ++found;
1337 auto folded = EvaluateNodes(output, {"i3"});
1338 auto expected = EvaluateNodes(item.graph, {"i3"});
1339 EXPECT_EQ(1, expected.size());
1340 EXPECT_EQ(1, folded.size());
1341 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1342 EXPECT_EQ(2, node.input_size());
1343 EXPECT_EQ("^p1", node.input(0));
1344 EXPECT_EQ("^p2", node.input(1));
1345 }
1346 }
1347 EXPECT_EQ(1, found);
1348 }
1349
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1350 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1351 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1352 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1353 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1354 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1355 Output c =
1356 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1357 Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1358 Output i2 =
1359 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1360 Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1361
1362 GrapplerItem item;
1363 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1364
1365 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1366 GraphDef output;
1367 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1368 TF_EXPECT_OK(status);
1369
1370 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1371 "i1", "i2", "e"};
1372 EXPECT_EQ(output.node_size(), expected_nodes.size());
1373 int i = 0;
1374 int found = 0;
1375 for (const auto& node : output.node()) {
1376 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1377 i++;
1378 if (node.name() == "i1") {
1379 EXPECT_EQ("Const", node.op());
1380 ++found;
1381 auto folded = EvaluateNodes(output, {"i1"});
1382 auto expected = EvaluateNodes(item.graph, {"i1"});
1383 EXPECT_EQ(1, expected.size());
1384 EXPECT_EQ(1, folded.size());
1385 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1386 EXPECT_EQ(1, node.input_size());
1387 EXPECT_EQ("^p1", node.input(0));
1388 }
1389 if (node.name() == "i2") {
1390 EXPECT_EQ("Const", node.op());
1391 ++found;
1392 auto folded = EvaluateNodes(output, {"i2"});
1393 auto expected = EvaluateNodes(item.graph, {"i2"});
1394 EXPECT_EQ(1, expected.size());
1395 EXPECT_EQ(1, folded.size());
1396 test::ExpectTensorEqual<int>(folded[0], expected[0]);
1397 EXPECT_EQ(2, node.input_size());
1398 EXPECT_EQ("^p1", node.input(0));
1399 EXPECT_EQ("^p2", node.input(1));
1400 }
1401 }
1402 EXPECT_EQ(2, found);
1403 }
1404
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1405 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1406 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1407 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1408 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1409 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1410 Output c =
1411 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1412 Output i1 = ops::Identity(scope.WithOpName("i1")
1413 .WithControlDependencies(p2)
1414 .WithControlDependencies(p1),
1415 {c});
1416 Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1417
1418 GrapplerItem item;
1419 item.fetch.push_back("i2");
1420 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1421 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1422 EXPECT_EQ(1, tensors_expected.size());
1423 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1424 GraphDef output;
1425 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1426 TF_EXPECT_OK(status);
1427
1428 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1429 EXPECT_EQ(output.node_size(), expected_nodes.size());
1430 int i = 0;
1431 for (const auto& node : output.node()) {
1432 EXPECT_EQ(expected_nodes[i], output.node(i).name());
1433 i++;
1434 if (node.name() == "i2") {
1435 EXPECT_EQ("Const", node.op());
1436 EXPECT_EQ(2, node.input_size());
1437 EXPECT_EQ("^p1", node.input(0));
1438 EXPECT_EQ("^p2", node.input(1));
1439 }
1440 }
1441 auto tensors = EvaluateNodes(output, item.fetch);
1442 EXPECT_EQ(1, tensors.size());
1443 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1444 }
1445
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1446 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1447 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1448 // Add a DynamicPartition node to the graph
1449 Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1450 Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1451 int num_partitions = 4;
1452 ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1453 num_partitions);
1454
1455 std::vector<string> outputs;
1456 for (int i = 0; i < num_partitions; ++i) {
1457 string part_out_name = strings::StrCat("part_out", i);
1458 ops::Identity partition_out(scope.WithOpName(part_out_name),
1459 {part.outputs[i]});
1460 outputs.push_back(part_out_name);
1461 }
1462
1463 GrapplerItem item;
1464 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1465
1466 // Add a ConcatOffset node to the graph
1467 Tensor initial_val(DT_INT32, TensorShape({3}));
1468 test::FillIota<int>(&initial_val, 7);
1469 for (int i = 1; i < 5; ++i) {
1470 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1471 .Attr("dtype", DT_INT32)
1472 .Attr("value", initial_val)
1473 .Finalize(item.graph.add_node()));
1474 }
1475 Tensor concat_dim(DT_INT32, TensorShape({}));
1476 test::FillIota<int>(&concat_dim, 0);
1477 TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1478 .Attr("dtype", DT_INT32)
1479 .Attr("value", concat_dim)
1480 .Finalize(item.graph.add_node()));
1481
1482 TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1483 .Input("concat_dim", 0, DT_INT32)
1484 .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1485 NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1486 NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1487 NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1488 .Finalize(item.graph.add_node()));
1489
1490 for (int i = 0; i < 4; ++i) {
1491 string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1492 TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1493 .Attr("T", DT_INT32)
1494 .Input("concat_offsets", i, DT_INT32)
1495 .Finalize(item.graph.add_node()));
1496 outputs.push_back(concat_offset_out_name);
1497 }
1498
1499 item.fetch = outputs;
1500 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1501 GraphDef output;
1502 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1503 TF_EXPECT_OK(status);
1504
1505 int constant_folded = 0;
1506 for (const auto& node : output.node()) {
1507 if (node.name().find("part_out") != string::npos ||
1508 node.name().find("concat_offset_out") != string::npos) {
1509 ++constant_folded;
1510 EXPECT_EQ("Const", node.op());
1511 }
1512 }
1513 EXPECT_EQ(8, constant_folded);
1514
1515 auto expected = EvaluateNodes(item.graph, outputs);
1516 auto optimized = EvaluateNodes(output, outputs);
1517 ASSERT_EQ(expected.size(), optimized.size());
1518 for (int i = 0; i < expected.size(); ++i) {
1519 test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1520 }
1521 }
1522
TEST_F(ConstantFoldingTest,ShapeMaterialization)1523 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1524 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1525 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1526 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1527 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1528 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1529 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1530 Output size = ops::Size(scope.WithOpName("size"), v3);
1531 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1532 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1533
1534 GrapplerItem item;
1535 item.fetch.push_back("p2");
1536 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1537
1538 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1539 GraphDef output;
1540 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1541 TF_EXPECT_OK(status);
1542
1543 int found = 0;
1544 for (const auto& node : output.node()) {
1545 if (node.name() == "p2") {
1546 ++found;
1547 EXPECT_EQ("Const", node.op());
1548 EXPECT_EQ(3, node.input_size());
1549 EXPECT_EQ("^v3", node.input(0));
1550 EXPECT_EQ("^v1", node.input(1));
1551 EXPECT_EQ("^v2", node.input(2));
1552 Tensor value;
1553 CHECK(value.FromProto(node.attr().at("value").tensor()));
1554 // rank = 1, shape = (5, 7), size = 143 = 11*13
1555 // p2 = (715, 1001) = (5*143, 7*143)
1556 EXPECT_EQ(715, value.flat<int>()(0));
1557 EXPECT_EQ(1001, value.flat<int>()(1));
1558 }
1559 }
1560 EXPECT_EQ(1, found);
1561 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1562 auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1563 auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1564
1565 auto tensors_expected = EvaluateNodes(
1566 item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1567 EXPECT_EQ(1, item.fetch.size());
1568 auto tensors = EvaluateNodes(output, item.fetch,
1569 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1570 EXPECT_EQ(1, item.fetch.size());
1571 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1572 }
1573
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1574 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1575 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1576 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1577 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1578 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1579 Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1580 Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1581 Output size = ops::Size(scope.WithOpName("size"), v3);
1582 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1583 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1584
1585 GrapplerItem item;
1586 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1587
1588 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1589 GraphDef output;
1590 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1591 TF_EXPECT_OK(status);
1592
1593 int found = 0;
1594 for (const auto& node : output.node()) {
1595 if (node.name() == "size") {
1596 ++found;
1597 EXPECT_EQ("Const", node.op());
1598 EXPECT_EQ(1, node.input_size());
1599 EXPECT_EQ("^v3", node.input(0));
1600 Tensor value;
1601 CHECK(value.FromProto(node.attr().at("value").tensor()));
1602 EXPECT_EQ(11 * 13, value.flat<int>()(0));
1603 } else if (node.name() == "rank") {
1604 ++found;
1605 EXPECT_EQ("Const", node.op());
1606 EXPECT_EQ(1, node.input_size());
1607 EXPECT_EQ("^v1", node.input(0));
1608 Tensor value;
1609 CHECK(value.FromProto(node.attr().at("value").tensor()));
1610 EXPECT_EQ(1, value.flat<int>()(0));
1611 } else if (node.name() == "shape") {
1612 ++found;
1613 EXPECT_EQ("Const", node.op());
1614 EXPECT_EQ(1, node.input_size());
1615 EXPECT_EQ("^v2", node.input(0));
1616 Tensor value;
1617 CHECK(value.FromProto(node.attr().at("value").tensor()));
1618 EXPECT_EQ(5, value.flat<int>()(0));
1619 EXPECT_EQ(7, value.flat<int>()(1));
1620 }
1621 }
1622 EXPECT_EQ(3, found);
1623
1624 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1625 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1626 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1627 std::vector<string> fetch_nodes = {"p2"};
1628 auto tensors_expected = EvaluateNodes(
1629 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1630 EXPECT_EQ(1, tensors_expected.size());
1631 auto tensors = EvaluateNodes(output, fetch_nodes,
1632 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1633 EXPECT_EQ(1, tensors.size());
1634 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1635 }
1636
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1637 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1638 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1639 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1640 Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1641 Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1642 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1643 Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1644 Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1645 Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1646 Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1647 Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1648 Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1649 Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1650
1651 GrapplerItem item;
1652 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1653
1654 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1655 GraphDef output;
1656 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1657 TF_EXPECT_OK(status);
1658 int found = 0;
1659 for (const auto& node : output.node()) {
1660 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1661 node.name());
1662 EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1663 node.name());
1664 if (node.name() == "i1a" || node.name() == "i1b") {
1665 ++found;
1666 EXPECT_EQ("s", node.input(0));
1667 }
1668 if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1669 ++found;
1670 EXPECT_EQ("s:1", node.input(0));
1671 }
1672 if (node.name() == "i3a" || node.name() == "i3b") {
1673 ++found;
1674 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1675 node.input(0));
1676 }
1677 if (node.name() == "s") {
1678 ++found;
1679 EXPECT_EQ("ShapeN", node.op());
1680 EXPECT_EQ("v1", node.input(0));
1681 EXPECT_EQ("v2", node.input(1));
1682 EXPECT_EQ("v3", node.input(2));
1683 }
1684 if (node.name() ==
1685 AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1686 ++found;
1687 EXPECT_EQ("Const", node.op());
1688 EXPECT_EQ("^s", node.input(0));
1689 Tensor value;
1690 CHECK(value.FromProto(node.attr().at("value").tensor()));
1691 EXPECT_EQ(4, value.flat<int>()(0));
1692 EXPECT_EQ(6, value.flat<int>()(1));
1693 }
1694 }
1695 EXPECT_EQ(9, found);
1696
1697 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1698 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1699 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1700 const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1701 "i2c", "i3a", "i3b"};
1702 auto tensors_expected = EvaluateNodes(
1703 item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1704 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1705 auto tensors = EvaluateNodes(output, fetch_nodes,
1706 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1707 EXPECT_EQ(fetch_nodes.size(), tensors.size());
1708 for (int i = 0; i < fetch_nodes.size(); i++)
1709 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1710 }
1711
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1712 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1713 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1714 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1715 Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1716 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1717 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1718 Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1719 Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1720
1721 GrapplerItem item;
1722 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1723 item.fetch.push_back("ia");
1724 item.fetch.push_back("ib");
1725
1726 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1727 GraphDef output;
1728 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1729 TF_EXPECT_OK(status);
1730
1731 int found = 0;
1732 for (const auto& node : output.node()) {
1733 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1734 node.name());
1735 if (node.name() == "s") {
1736 ++found;
1737 EXPECT_EQ("ShapeN", node.op());
1738 EXPECT_EQ("v1", node.input(0));
1739 EXPECT_EQ("v2", node.input(1));
1740 }
1741 if (node.name() == "id_n") {
1742 ++found;
1743 EXPECT_EQ("IdentityN", node.op());
1744 EXPECT_EQ("s", node.input(0));
1745 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1746 node.input(1));
1747 }
1748 if (node.name() == "ia") {
1749 ++found;
1750 EXPECT_EQ("id_n", node.input(0));
1751 }
1752 if (node.name() == "ib") {
1753 ++found;
1754 EXPECT_EQ("Const", node.op());
1755 EXPECT_EQ("^s", node.input(0));
1756 EXPECT_EQ("^id_n", node.input(1));
1757 }
1758 }
1759 EXPECT_EQ(4, found);
1760
1761 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1762 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1763 auto tensors_expected =
1764 EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1765 EXPECT_EQ(2, tensors_expected.size());
1766 auto tensors =
1767 EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1768 EXPECT_EQ(2, tensors.size());
1769 for (int i = 0; i < tensors.size(); i++)
1770 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1771 }
1772
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1773 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1774 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1775 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1776 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1777 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1778 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1779 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1780 ops::Size size(scope.WithOpName("size"), i);
1781 ops::Square p1(scope.WithOpName("p1"), rank);
1782 ops::Square p2(scope.WithOpName("p2"), size);
1783 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1784
1785 Output predicate =
1786 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1787 Output constant =
1788 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1789 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1790 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1791 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1792 ops::Merge m2(scope.WithOpName("m2"),
1793 {statically_known.output, never_generated.output});
1794
1795 GrapplerItem item;
1796 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1797
1798 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1799 GraphDef output;
1800 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1801 TF_EXPECT_OK(status);
1802
1803 std::set<string> present_nodes = {"v_in", "v_ctrl",
1804 "switch", "i",
1805 "p1", "p2",
1806 "m", "false",
1807 "constant", "switch2",
1808 "i2", "i3",
1809 "m2", "ConstantFoldingCtrl/switch_0",
1810 "rank", "size"};
1811 std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1812 EXPECT_EQ(present_nodes.size(), output.node_size());
1813 int found = 0;
1814 for (const auto& node : output.node()) {
1815 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1816 << node.name();
1817 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1818 << node.name();
1819 present_nodes.erase(node.name());
1820 not_present_nodes.erase(node.name());
1821 if (node.name() == "rank") {
1822 ++found;
1823 EXPECT_EQ("Const", node.op());
1824 EXPECT_EQ(1, node.input_size());
1825 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1826 }
1827 if (node.name() == "size") {
1828 ++found;
1829 EXPECT_EQ("Const", node.op());
1830 EXPECT_EQ(1, node.input_size());
1831 EXPECT_EQ("^i", node.input(0));
1832 }
1833 if (node.name() == "i2") {
1834 ++found;
1835 EXPECT_EQ("Const", node.op());
1836 EXPECT_EQ(0, node.input_size());
1837 }
1838 if (node.name() == "i3") {
1839 ++found;
1840 EXPECT_EQ("Identity", node.op());
1841 EXPECT_EQ(1, node.input_size());
1842 EXPECT_EQ("switch2:1", node.input(0));
1843 }
1844 }
1845 EXPECT_EQ(4, found);
1846
1847 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1848 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1849
1850 v_ctrl_t.flat<bool>()(0) = true;
1851 std::vector<string> fetch_nodes = {"m", "m2"};
1852 auto tensors_expected = EvaluateNodes(
1853 item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1854 EXPECT_EQ(2, tensors_expected.size());
1855 auto tensors = EvaluateNodes(output, fetch_nodes,
1856 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1857 EXPECT_EQ(2, tensors.size());
1858 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1859 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1860
1861 v_ctrl_t.flat<bool>()(0) = false;
1862 tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1863 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1864 EXPECT_EQ(2, tensors_expected.size());
1865 tensors = EvaluateNodes(output, fetch_nodes,
1866 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1867 EXPECT_EQ(2, tensors.size());
1868 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1869 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1870 }
1871
TEST_F(ConstantFoldingTest,SwitchNodes)1872 TEST_F(ConstantFoldingTest, SwitchNodes) {
1873 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1874 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1875 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1876 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1877 ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1878 ops::Identity i(scope.WithOpName("i"), s1.output_true);
1879 ops::Size size(scope.WithOpName("size"), i);
1880 ops::Square p1(scope.WithOpName("p1"), rank);
1881 ops::Square p2(scope.WithOpName("p2"), size);
1882 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1883
1884 Output predicate =
1885 ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1886 Output constant =
1887 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1888 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1889 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1890 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1891 ops::Merge m2(scope.WithOpName("m2"),
1892 {statically_known.output, never_generated.output});
1893
1894 GrapplerItem item;
1895 item.fetch.push_back("m");
1896 item.fetch.push_back("m2");
1897
1898 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1899
1900 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1901 GraphDef output;
1902 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1903 TF_EXPECT_OK(status);
1904 std::set<string> present_nodes = {"v_in", "v_ctrl",
1905 "switch", "i",
1906 "p1", "p2",
1907 "m", "false",
1908 "constant", "switch2",
1909 "i2", "i3",
1910 "m2", "ConstantFoldingCtrl/switch_0"};
1911 std::set<string> not_present_nodes = {"rank", "size",
1912 "ConstantFolding/switch2-0"};
1913 EXPECT_EQ(present_nodes.size(), output.node_size());
1914
1915 int found = 0;
1916 for (const auto& node : output.node()) {
1917 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1918 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1919 present_nodes.erase(node.name());
1920 not_present_nodes.erase(node.name());
1921 if (node.name() == "i2") {
1922 ++found;
1923 EXPECT_EQ("Const", node.op());
1924 EXPECT_EQ(0, node.input_size());
1925 }
1926 if (node.name() == "i3") {
1927 ++found;
1928 EXPECT_EQ("Identity", node.op());
1929 EXPECT_EQ(1, node.input_size());
1930 EXPECT_EQ("switch2:1", node.input(0));
1931 }
1932 }
1933 EXPECT_EQ(2, found);
1934
1935 auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1936 Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1937 v_ctrl_t.flat<bool>()(0) = true;
1938 auto tensors_expected = EvaluateNodes(
1939 item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1940 EXPECT_EQ(2, tensors_expected.size());
1941 auto tensors = EvaluateNodes(output, item.fetch,
1942 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1943 EXPECT_EQ(2, tensors.size());
1944 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1945 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1946
1947 v_ctrl_t.flat<bool>()(0) = false;
1948 tensors_expected = EvaluateNodes(item.graph, item.fetch,
1949 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1950 EXPECT_EQ(2, tensors_expected.size());
1951 tensors = EvaluateNodes(output, item.fetch,
1952 {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1953 EXPECT_EQ(2, tensors.size());
1954 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1955 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1956 }
1957
TEST_F(ConstantFoldingTest,MergeNodes)1958 TEST_F(ConstantFoldingTest, MergeNodes) {
1959 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1960
1961 Output x =
1962 ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1963 Output y =
1964 ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1965 Output const1 =
1966 ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1967 TensorShape({3, 5}));
1968 Output const2 =
1969 ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1970 Output const3 =
1971 ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1972 TensorShape({3, 5}));
1973
1974 // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1975 ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1976 ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1977 ops::Merge m3(scope.WithOpName("m3"), {x, y});
1978 // m4 is not foldable because the only constant input
1979 // has a control input, so we cannot know if it will be
1980 // triggered.
1981 ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1982
1983 ops::Identity out1(scope.WithOpName("out1"), m1.output);
1984 ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1985 ops::Identity out2(scope.WithOpName("out2"), m2.output);
1986 ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1987 ops::Identity out3(scope.WithOpName("out3"), m3.output);
1988 ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1989 ops::Identity out4(scope.WithOpName("out4"), m4.output);
1990 ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1991
1992 GrapplerItem item;
1993 item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1994 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1995
1996 ConstantFolding optimizer(/*cpu_device=*/nullptr);
1997 GraphDef output;
1998 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1999 TF_EXPECT_OK(status);
2000
2001 EXPECT_EQ(19, output.node_size());
2002 int found_nodes = 0;
2003 for (const auto& node : output.node()) {
2004 if (node.name() == "out1") {
2005 EXPECT_EQ(1, node.input_size());
2006 EXPECT_EQ("^m1", node.input(0));
2007 ++found_nodes;
2008 } else if (node.name() == "idx1") {
2009 EXPECT_EQ(1, node.input_size());
2010 EXPECT_EQ("^m1", node.input(0));
2011 ++found_nodes;
2012 } else if (node.name() == "ConstantFolding/m1") {
2013 EXPECT_EQ("Const", node.op());
2014 EXPECT_EQ(1, node.input_size());
2015 EXPECT_EQ("^m1", node.input(0));
2016 ++found_nodes;
2017 } else if (node.name() == "ConstantFolding/m1_index") {
2018 EXPECT_EQ("Const", node.op());
2019 EXPECT_EQ(1, node.input_size());
2020 EXPECT_EQ("^m1", node.input(0));
2021 ++found_nodes;
2022 } else if (node.name() == "out2") {
2023 EXPECT_EQ(1, node.input_size());
2024 EXPECT_EQ("m2", node.input(0));
2025 ++found_nodes;
2026 } else if (node.name() == "idx2") {
2027 EXPECT_EQ(1, node.input_size());
2028 EXPECT_EQ("m2:1", node.input(0));
2029 ++found_nodes;
2030 } else if (node.name() == "out3") {
2031 EXPECT_EQ(1, node.input_size());
2032 EXPECT_EQ("m3", node.input(0));
2033 ++found_nodes;
2034 } else if (node.name() == "idx3") {
2035 EXPECT_EQ(1, node.input_size());
2036 EXPECT_EQ("m3:1", node.input(0));
2037 ++found_nodes;
2038 } else if (node.name() == "out4") {
2039 EXPECT_EQ(1, node.input_size());
2040 EXPECT_EQ("m4", node.input(0));
2041 ++found_nodes;
2042 } else if (node.name() == "idx4") {
2043 EXPECT_EQ(1, node.input_size());
2044 EXPECT_EQ("m4:1", node.input(0));
2045 ++found_nodes;
2046 }
2047 }
2048 // Make sure the graph contains all the nodes we're expecting.
2049 EXPECT_EQ(8, found_nodes);
2050
2051 std::vector<string> fetch = {"out1", "idx1"};
2052 auto tensors = EvaluateNodes(output, fetch);
2053 EXPECT_EQ(2, tensors.size());
2054 const Tensor& out_value = tensors[0];
2055 EXPECT_EQ(3 * 5, out_value.NumElements());
2056 for (int i = 0; i < 3 * 5; ++i) {
2057 EXPECT_EQ(3.14f, out_value.flat<float>()(i));
2058 }
2059 const Tensor& out_idx = tensors[1];
2060 EXPECT_EQ(1, out_idx.NumElements());
2061 EXPECT_EQ(2, out_idx.flat<int32>()(0));
2062 }
2063
TEST_F(ConstantFoldingTest,SplitRemoval)2064 TEST_F(ConstantFoldingTest, SplitRemoval) {
2065 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2066
2067 Output in1 =
2068 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2069 Output in2 =
2070 ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
2071 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2072 ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
2073 ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
2074
2075 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2076
2077 GrapplerItem item;
2078 item.fetch = {"out"};
2079 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2080
2081 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2082 GraphDef got;
2083 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2084 TF_EXPECT_OK(status);
2085
2086 GraphDef want;
2087 AddNode("in1", "VariableV2", {}, {}, &want);
2088 AddNode("in2", "VariableV2", {}, {}, &want);
2089 AddNode("split_dim", "Const", {}, {}, &want);
2090 AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
2091 &want);
2092 AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
2093 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2094
2095 CompareGraphs(want, got);
2096
2097 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2098 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
2099 auto tensors_expected =
2100 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2101 EXPECT_EQ(1, tensors_expected.size());
2102 auto tensors =
2103 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2104 EXPECT_EQ(1, tensors.size());
2105 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2106 }
2107
TEST_F(ConstantFoldingTest,SplitVRemoval)2108 TEST_F(ConstantFoldingTest, SplitVRemoval) {
2109 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2110
2111 Output in1 =
2112 ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2113 Output in2 =
2114 ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
2115 auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2116 auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
2117 auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
2118 ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
2119 ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
2120
2121 ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2122
2123 GrapplerItem item;
2124 item.fetch = {"out"};
2125 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2126
2127 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2128 GraphDef got;
2129 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2130 TF_EXPECT_OK(status);
2131
2132 GraphDef want;
2133 AddNode("in1", "VariableV2", {}, {}, &want);
2134 AddNode("in2", "VariableV2", {}, {}, &want);
2135 AddNode("split_dim", "Const", {}, {}, &want);
2136 AddNode("size_splits1", "Const", {}, {}, &want);
2137 AddNode("size_splits2", "Const", {}, {}, &want);
2138 AddNode("s1", "Identity",
2139 {"in1", AsControlDependency("size_splits1"),
2140 AsControlDependency("split_dim")},
2141 {}, &want);
2142 AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
2143 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2144
2145 CompareGraphs(want, got);
2146
2147 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2148 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
2149 auto tensors_expected =
2150 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2151 EXPECT_EQ(1, tensors_expected.size());
2152 auto tensors =
2153 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2154 EXPECT_EQ(1, tensors.size());
2155 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2156 }
2157
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)2158 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
2159 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2160
2161 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2162 DT_FLOAT);
2163 Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
2164 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
2165 DT_FLOAT);
2166 Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
2167 ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
2168 ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
2169 p2);
2170
2171 ops::Add out1(scope.WithOpName("out1"), t1, t2);
2172
2173 GrapplerItem item;
2174 item.fetch = {"out1"};
2175 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2176
2177 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2178 GraphDef got;
2179 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2180 TF_EXPECT_OK(status);
2181
2182 GraphDef want;
2183 AddNode("in1", "VariableV2", {}, {}, &want);
2184 AddNode("in2", "VariableV2", {}, {}, &want);
2185 AddNode("p1", "Const", {}, {}, &want);
2186 AddNode("p2", "Const", {}, {}, &want);
2187 AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
2188 AddNode("t2", "Identity",
2189 {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
2190 &want);
2191 AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
2192
2193 CompareGraphs(want, got);
2194 }
2195
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)2196 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
2197 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2198
2199 Output in1 =
2200 ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
2201 Output in2 =
2202 ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
2203 ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
2204 ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
2205 in2);
2206
2207 ops::Add out1(scope.WithOpName("out1"), s1, s2);
2208 ops::Identity out2(scope.WithOpName("out2"), s2);
2209
2210 GrapplerItem item;
2211 item.fetch = {"out1", "out2"};
2212 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2213
2214 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2215 GraphDef got;
2216 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2217 TF_EXPECT_OK(status);
2218
2219 GraphDef want;
2220 AddNode("in1", "VariableV2", {}, {}, &want);
2221 AddNode("in2", "VariableV2", {}, {}, &want);
2222 AddNode("s1", "Identity", {"in1"}, {}, &want);
2223 AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
2224 AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
2225 AddNode("out2", "Identity", {"s2"}, {}, &want);
2226
2227 CompareGraphs(want, got);
2228
2229 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2230 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2231 auto tensors_expected =
2232 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2233 EXPECT_EQ(2, tensors_expected.size());
2234 auto tensors =
2235 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2236 EXPECT_EQ(2, tensors.size());
2237 for (int i = 0; i < tensors.size(); i++)
2238 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2239 }
2240
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)2241 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
2242 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2243
2244 Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2245 DT_FLOAT);
2246 Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
2247 Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
2248 DT_FLOAT);
2249 Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
2250 ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
2251 ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
2252 a2);
2253
2254 ops::Add out1(scope.WithOpName("out1"), r1, r2);
2255
2256 GrapplerItem item;
2257 item.fetch = {"out1"};
2258 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2259
2260 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2261 GraphDef got;
2262 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2263 TF_EXPECT_OK(status);
2264
2265 GraphDef want;
2266 AddNode("in1", "VariableV2", {}, {}, &want);
2267 AddNode("in2", "VariableV2", {}, {}, &want);
2268 AddNode("a1", "Const", {}, {}, &want);
2269 AddNode("a2", "Const", {}, {}, &want);
2270 AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
2271 AddNode("r2", "Identity",
2272 {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
2273 &want);
2274 AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
2275
2276 CompareGraphs(want, got);
2277 }
2278
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)2279 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
2280 { // size = {3, 5}
2281 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2282
2283 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
2284 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2285 auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
2286 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2287 ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
2288 ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
2289
2290 ops::Add out(scope.WithOpName("out"), s1, s2);
2291
2292 GrapplerItem item;
2293 item.fetch = {"out"};
2294 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2295
2296 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2297 GraphDef got;
2298 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2299 TF_EXPECT_OK(status);
2300
2301 GraphDef want;
2302 AddNode("in1", "VariableV2", {}, {}, &want);
2303 AddNode("in2", "VariableV2", {}, {}, &want);
2304 AddNode("begin", "Const", {}, {}, &want);
2305 AddNode("size", "Const", {}, {}, &want);
2306 AddNode("s1", "Identity",
2307 {"in1", AsControlDependency("begin"), AsControlDependency("size")},
2308 {}, &want);
2309 AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
2310 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2311
2312 CompareGraphs(want, got);
2313
2314 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2315 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2316 auto tensors_expected =
2317 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2318 EXPECT_EQ(1, tensors_expected.size());
2319 auto tensors =
2320 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2321 EXPECT_EQ(1, tensors.size());
2322 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2323 }
2324 { // size = {-1, -1}
2325 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2326
2327 auto in1 =
2328 ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2329 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2330 auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2331 auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2332 Output in2 =
2333 ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2334 ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2335 ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2336
2337 ops::Add out(scope.WithOpName("out"), s1, s2);
2338
2339 GrapplerItem item;
2340 item.fetch = {"out"};
2341 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2342
2343 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2344 GraphDef got;
2345 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2346 TF_EXPECT_OK(status);
2347
2348 GraphDef want;
2349 AddNode("in1", "VariableV2", {}, {}, &want);
2350 AddNode("in2", "VariableV2", {}, {}, &want);
2351 AddNode("begin1", "Const", {}, {}, &want);
2352 AddNode("begin2", "Const", {}, {}, &want);
2353 AddNode("size", "Const", {}, {}, &want);
2354 AddNode("s1", "Identity",
2355 {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2356 {}, &want);
2357 AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2358 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2359
2360 CompareGraphs(want, got);
2361
2362 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2363 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2364 auto tensors_expected =
2365 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2366 EXPECT_EQ(1, tensors_expected.size());
2367 auto tensors =
2368 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2369 EXPECT_EQ(1, tensors.size());
2370 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2371 }
2372 }
2373
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2374 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2375 { // no mask
2376 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2377
2378 auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2379 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2380 auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2381 auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2382 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2383 ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2384 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2385
2386 ops::Add out(scope.WithOpName("out"), s1, s2);
2387
2388 GrapplerItem item;
2389 item.fetch = {"out"};
2390 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2391
2392 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2393 GraphDef got;
2394 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2395 TF_EXPECT_OK(status);
2396
2397 GraphDef want;
2398 AddNode("in1", "VariableV2", {}, {}, &want);
2399 AddNode("in2", "VariableV2", {}, {}, &want);
2400 AddNode("begin", "Const", {}, {}, &want);
2401 AddNode("end", "Const", {}, {}, &want);
2402 AddNode("strides", "Const", {}, {}, &want);
2403 AddNode("s1", "Identity",
2404 {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2405 AsControlDependency("strides")},
2406 {}, &want);
2407 AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2408 &want);
2409 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2410
2411 CompareGraphs(want, got);
2412
2413 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2414 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2415 auto tensors_expected =
2416 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2417 EXPECT_EQ(1, tensors_expected.size());
2418 auto tensors =
2419 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2420 EXPECT_EQ(1, tensors.size());
2421 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2422 }
2423 { // with begin/end/ellipsis mask
2424 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2425
2426 // s1 = in1[:, ..., 0:5, 0:6]
2427 auto in1 =
2428 ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2429 auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2430 auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2431 auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2432 ops::StridedSlice s1(
2433 scope.WithOpName("s1"), in1, begin1, end1, strides1,
2434 ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2435
2436 Output in2 =
2437 ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2438 auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2439 auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2440 auto strides2 =
2441 ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2442 ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2443
2444 ops::Add out(scope.WithOpName("out"), s1, s2);
2445
2446 GrapplerItem item;
2447 item.fetch = {"out"};
2448 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2449
2450 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2451 GraphDef got;
2452 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2453 TF_EXPECT_OK(status);
2454
2455 GraphDef want;
2456 AddNode("in1", "VariableV2", {}, {}, &want);
2457 AddNode("in2", "VariableV2", {}, {}, &want);
2458 AddNode("begin1", "Const", {}, {}, &want);
2459 AddNode("end1", "Const", {}, {}, &want);
2460 AddNode("strides1", "Const", {}, {}, &want);
2461 AddNode("s1", "Identity",
2462 {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2463 AsControlDependency("strides1")},
2464 {}, &want);
2465 AddNode("begin2", "Const", {}, {}, &want);
2466 AddNode("end2", "Const", {}, {}, &want);
2467 AddNode("strides2", "Const", {}, {}, &want);
2468 AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2469 &want);
2470 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2471
2472 CompareGraphs(want, got);
2473
2474 auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2475 auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2476 auto tensors_expected =
2477 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2478 EXPECT_EQ(1, tensors_expected.size());
2479 auto tensors =
2480 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2481 EXPECT_EQ(1, tensors.size());
2482 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2483 }
2484 }
2485
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2486 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2487 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2488
2489 auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2490 auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2491 auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2492 auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2493
2494 ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2495 ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2496
2497 ops::Add out(scope.WithOpName("out"), t1, t2);
2498
2499 GrapplerItem item;
2500 item.fetch = {"out"};
2501 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2502
2503 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2504 GraphDef got;
2505 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2506 TF_EXPECT_OK(status);
2507
2508 GraphDef want;
2509 AddNode("in1", "VariableV2", {}, {}, &want);
2510 AddNode("in2", "VariableV2", {}, {}, &want);
2511 AddNode("multiplies1", "Const", {}, {}, &want);
2512 AddNode("multiplies2", "Const", {}, {}, &want);
2513 AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2514 &want);
2515 AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2516 AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2517
2518 CompareGraphs(want, got);
2519 }
2520
TEST_F(ConstantFoldingTest,MergeConcat)2521 TEST_F(ConstantFoldingTest, MergeConcat) {
2522 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2523
2524 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2525 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2526 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2527 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2528
2529 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2530 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2531
2532 GrapplerItem item;
2533 item.fetch = {"c2"};
2534 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2535
2536 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2537 GraphDef got;
2538 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2539 TF_EXPECT_OK(status);
2540
2541 GraphDef want;
2542 AddNode("in1", "VariableV2", {}, {}, &want);
2543 AddNode("in2", "VariableV2", {}, {}, &want);
2544 AddNode("in3", "VariableV2", {}, {}, &want);
2545 AddNode("axis", "Const", {}, {}, &want);
2546 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2547
2548 CompareGraphs(want, got);
2549 }
2550
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2551 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2552 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2553
2554 Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2555 Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2556 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2557 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2558
2559 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2560 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2561
2562 GrapplerItem item;
2563 item.fetch = {"c2"};
2564 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2565
2566 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2567 GraphDef got;
2568 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2569 TF_EXPECT_OK(status);
2570
2571 GraphDef want;
2572 AddNode("in1", "VariableV2", {}, {}, &want);
2573 AddNode("in2", "VariableV2", {}, {}, &want);
2574 AddNode("in3", "VariableV2", {}, {}, &want);
2575 AddNode("axis", "Const", {}, {}, &want);
2576 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2577 &want);
2578
2579 CompareGraphs(want, got);
2580 }
2581
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2582 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2583 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2584
2585 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2586 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2587 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2588 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2589
2590 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2591 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2592
2593 GrapplerItem item;
2594 item.fetch = {"c2"};
2595 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2596
2597 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2598 GraphDef got;
2599 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2600 TF_EXPECT_OK(status);
2601
2602 GraphDef want;
2603 AddNode("in1", "VariableV2", {}, {}, &want);
2604 AddNode("in2", "VariableV2", {}, {}, &want);
2605 AddNode("in3", "VariableV2", {}, {}, &want);
2606 AddNode("axis", "Const", {}, {}, &want);
2607 AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2608
2609 CompareGraphs(want, got);
2610 }
2611
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2612 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2613 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2614
2615 Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2616 Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2617 Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2618 Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2619 Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2620
2621 ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2622 ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2623
2624 GrapplerItem item;
2625 item.fetch = {"c2"};
2626 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2627
2628 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2629 GraphDef got;
2630 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2631 TF_EXPECT_OK(status);
2632
2633 GraphDef want;
2634 AddNode("in1", "VariableV2", {}, {}, &want);
2635 AddNode("in2", "VariableV2", {}, {}, &want);
2636 AddNode("in3", "VariableV2", {}, {}, &want);
2637 AddNode("axis1", "Const", {}, {}, &want);
2638 AddNode("axis2", "Const", {}, {}, &want);
2639 AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2640 AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2641
2642 CompareGraphs(want, got);
2643 }
2644
TEST_F(ConstantFoldingTest,MergeConcat_PartialFolding)2645 TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
2646 Scope scope = Scope::NewRootScope();
2647 Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
2648 Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
2649 Output c3 = ops::Const(scope.WithOpName("c3"), 3.0f, {2, 2});
2650 Output c4 = ops::Const(scope.WithOpName("c4"), 4.0f, {2, 2});
2651 Output ph = ops::Placeholder(scope.WithOpName("ph"), DT_FLOAT,
2652 ops::Placeholder::Shape(TensorShape({2, 2})));
2653 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2654
2655 ops::Concat concat1(scope.WithOpName("concat1"), {c1, c2, ph}, axis);
2656 ops::Concat concat2(scope.WithOpName("concat2"), {c3, c4, Output(concat1)},
2657 axis);
2658
2659 GrapplerItem item;
2660 item.fetch = {"concat2"};
2661 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2662
2663 ConstantFolding optimizer(nullptr);
2664 GraphDef got;
2665 Status status = optimizer.Optimize(nullptr, item, &got);
2666 TF_EXPECT_OK(status);
2667
2668 GraphDef want;
2669 AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
2670 AddNode("axis", "Const", {}, {}, &want);
2671 AddNode("ph", "Placeholder", {}, {}, &want);
2672 AddNode("concat2", "ConcatV2",
2673 {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
2674
2675 CompareGraphs(want, got);
2676 }
2677
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2678 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2679 PaddingWithZeroSize<int32>();
2680 PaddingWithZeroSize<int64_t>();
2681 }
2682
TEST_F(ConstantFoldingTest,SqueezeWithAllDimensionsGreaterThanOne)2683 TEST_F(ConstantFoldingTest, SqueezeWithAllDimensionsGreaterThanOne) {
2684 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2685
2686 auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2687 auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2688
2689 ops::Squeeze s1(scope.WithOpName("s1"), in1);
2690 ops::Squeeze s2(scope.WithOpName("s2"), in2);
2691
2692 ops::Add out(scope.WithOpName("out"), s1, s2);
2693
2694 GrapplerItem item;
2695 item.fetch = {"out"};
2696 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2697
2698 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2699 GraphDef got;
2700 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2701 TF_EXPECT_OK(status);
2702
2703 GraphDef want;
2704 AddNode("in1", "VariableV2", {}, {}, &want);
2705 AddNode("in2", "VariableV2", {}, {}, &want);
2706 AddNode("s1", "Identity", {"in1"}, {}, &want);
2707 AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2708 AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2709
2710 CompareGraphs(want, got);
2711
2712 auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2713 auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2714 auto tensors_expected =
2715 EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2716 EXPECT_EQ(1, tensors_expected.size());
2717 auto tensors =
2718 EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2719 EXPECT_EQ(1, tensors.size());
2720 test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2721 }
2722
TEST_F(ConstantFoldingTest,NoOpReduction)2723 TEST_F(ConstantFoldingTest, NoOpReduction) {
2724 // Build a simple graph with reductions that can be reduced to the
2725 // identity.
2726 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2727
2728 Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2729 Output c =
2730 ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2731 Output i = ops::Identity(scope.WithOpName("i"), c);
2732 Output p = ops::Prod(scope.WithOpName("p"), v, i);
2733 Output s = ops::Square(scope.WithOpName("s"), p);
2734
2735 Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2736 Output c2 =
2737 ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2738 ops::Prod::Attrs attr;
2739 attr = attr.KeepDims(true);
2740 Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2741
2742 // Test with unknown input shape.
2743 Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
2744 Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
2745
2746 GrapplerItem item;
2747 item.fetch = {"s", "p2", "p3"};
2748 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2749
2750 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2751 GraphDef output;
2752 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2753 TF_EXPECT_OK(status);
2754
2755 int found = 0;
2756 for (const auto& node : output.node()) {
2757 if (node.name() == "p") {
2758 found++;
2759 EXPECT_EQ("Identity", node.op());
2760 EXPECT_EQ(2, node.input_size());
2761 EXPECT_EQ("v", node.input(0));
2762 EXPECT_EQ("^i", node.input(1));
2763 } else if (node.name() == "p2") {
2764 found++;
2765 EXPECT_EQ("Identity", node.op());
2766 EXPECT_EQ(2, node.input_size());
2767 EXPECT_EQ("v2", node.input(0));
2768 EXPECT_EQ("^c2", node.input(1));
2769 } else if (node.name() == "p3") {
2770 found++;
2771 EXPECT_EQ("Identity", node.op());
2772 EXPECT_EQ(2, node.input_size());
2773 EXPECT_EQ("a", node.input(0));
2774 EXPECT_EQ("^i", node.input(1));
2775 }
2776 }
2777 EXPECT_EQ(3, found);
2778
2779 auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2780 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2781 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2782 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
2783 {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2784 EXPECT_EQ(3, tensors_expected.size());
2785 auto tensors =
2786 EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2787 EXPECT_EQ(3, tensors.size());
2788 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2789 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2790 test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
2791 }
2792
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2793 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2794 // Build a simple graph with reductions that involve single-element input and
2795 // no axes to reduce along.
2796 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2797
2798 Output input_var_three_dim = ops::Variable(
2799 scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2800 Output input_var_one_dim =
2801 ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2802 Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2803 Output multiple_axes =
2804 ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2805 Output variable_axis =
2806 ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2807 ops::Mean::Attrs attr;
2808 attr = attr.KeepDims(false);
2809 // Should be optimized to Reshape.
2810 Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2811 one_axis, attr.KeepDims(false));
2812 Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2813 multiple_axes, attr.KeepDims(false));
2814 // Should remain as-is, since OutputProperties will not be known this node.
2815 Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2816 one_axis, attr.KeepDims(false));
2817 // Should remain as-is.
2818 Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2819 variable_axis, attr.KeepDims(false));
2820 // Should be optimized to Identity, since KeepDims=true.
2821 Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2822 multiple_axes, attr.KeepDims(true));
2823
2824 GrapplerItem item;
2825 item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2826 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2827
2828 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2829 GraphDef output;
2830 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2831 TF_EXPECT_OK(status);
2832
2833 // Ensure Mean node is optimized to Reshape.
2834 int found = 0;
2835 for (const auto& node : output.node()) {
2836 if (node.name() == "mean_1" || node.name() == "mean_2") {
2837 found++;
2838 EXPECT_EQ("Reshape", node.op());
2839 EXPECT_EQ(2, node.input_size());
2840 EXPECT_EQ("input_var_three_dim", node.input(0));
2841 } else if (node.name() == "mean_3") {
2842 found++;
2843 EXPECT_EQ("Mean", node.op());
2844 EXPECT_EQ(2, node.input_size());
2845 EXPECT_EQ("input_var_one_dim", node.input(0));
2846 } else if (node.name() == "mean_4") {
2847 found++;
2848 EXPECT_EQ("Mean", node.op());
2849 EXPECT_EQ(2, node.input_size());
2850 EXPECT_EQ("input_var_three_dim", node.input(0));
2851 } else if (node.name() == "mean_5") {
2852 found++;
2853 EXPECT_EQ("Identity", node.op());
2854 EXPECT_EQ(2, node.input_size());
2855 EXPECT_EQ("^multiple_axes", node.input(1));
2856 }
2857 }
2858 EXPECT_EQ(5, found);
2859
2860 // Ensure resultant values from Mean and Reshape are the same.
2861 auto input_var_three_dim_t =
2862 GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2863 auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2864 Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2865 input_var_axis_t.flat<int32>()(0) = 0;
2866 auto tensors_expected =
2867 EvaluateNodes(item.graph, item.fetch,
2868 {{"input_var_three_dim", input_var_three_dim_t},
2869 {"input_var_one_dim", input_var_one_dim_t},
2870 {"input_var_axis", input_var_axis_t}});
2871 EXPECT_EQ(5, tensors_expected.size());
2872 auto tensors = EvaluateNodes(output, item.fetch,
2873 {{"input_var_three_dim", input_var_three_dim_t},
2874 {"input_var_one_dim", input_var_one_dim_t},
2875 {"input_var_axis", input_var_axis_t}});
2876 EXPECT_EQ(5, tensors.size());
2877 for (int i = 0; i < 5; ++i) {
2878 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2879 }
2880 }
2881
TEST_F(ConstantFoldingTest,NoOpReshape)2882 TEST_F(ConstantFoldingTest, NoOpReshape) {
2883 // Build a simple graph with a reshape that can be reduced to the identity.
2884 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2885
2886 // A reshape than can be optimized
2887 Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2888 Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2889 Output c1 =
2890 ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2891 Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2892 Output r1 =
2893 ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2894 Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2895
2896 // A multi dimensional reshape than can be optimized
2897 Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2898 Output c3 =
2899 ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2900 Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2901 Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2902 Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2903
2904 // A multi dimensional partially defined reshape than can be optimized
2905 Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2906 Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2907 {5, -1, 5}, {3});
2908 Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2909 Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2910 Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2911
2912 // A reshape that can't be optimized
2913 Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2914 Output c2 =
2915 ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2916 Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2917 Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2918
2919 GrapplerItem item;
2920 item.fetch = {"s1", "s2", "s3", "s4"};
2921 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2922
2923 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2924 GraphDef output;
2925 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2926 TF_EXPECT_OK(status);
2927
2928 int found = 0;
2929 for (const auto& node : output.node()) {
2930 if (node.name() == "r1") {
2931 ++found;
2932 EXPECT_EQ("Identity", node.op());
2933 ASSERT_EQ(3, node.input_size());
2934 EXPECT_EQ("v1", node.input(0));
2935 EXPECT_EQ("^i1", node.input(1));
2936 EXPECT_EQ("^d1", node.input(2));
2937 } else if (node.name() == "r3") {
2938 ++found;
2939 EXPECT_EQ("Identity", node.op());
2940 ASSERT_EQ(2, node.input_size());
2941 EXPECT_EQ("v3", node.input(0));
2942 EXPECT_EQ("^i3", node.input(1));
2943 } else if (node.name() == "r4") {
2944 ++found;
2945 EXPECT_EQ("Identity", node.op());
2946 ASSERT_EQ(2, node.input_size());
2947 EXPECT_EQ("v4", node.input(0));
2948 EXPECT_EQ("^i4", node.input(1));
2949 } else if (node.name() == "r2") {
2950 ++found;
2951 EXPECT_EQ("Reshape", node.op());
2952 ASSERT_EQ(2, node.input_size());
2953 EXPECT_EQ("v2", node.input(0));
2954 EXPECT_EQ("c2", node.input(1));
2955 }
2956 }
2957 EXPECT_EQ(4, found);
2958
2959 auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2960 auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2961 auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2962 auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2963 auto tensors_expected =
2964 EvaluateNodes(item.graph, item.fetch,
2965 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2966 EXPECT_EQ(4, tensors_expected.size());
2967 auto tensors =
2968 EvaluateNodes(output, item.fetch,
2969 {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2970 EXPECT_EQ(4, tensors.size());
2971 for (int i = 0; i < tensors.size(); i++)
2972 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2973 }
2974
TEST_F(ConstantFoldingTest,Packing)2975 TEST_F(ConstantFoldingTest, Packing) {
2976 // Build a simple graph with a large constant that can be folded.
2977 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2978 Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2979 Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2980 Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2981
2982 GrapplerItem item;
2983 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2984
2985 ConstantFolding optimizer(/*cpu_device=*/nullptr);
2986 GraphDef output;
2987 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2988 TF_EXPECT_OK(status);
2989
2990 const std::vector<string> fetch_nodes = {"i1", "i2"};
2991 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2992 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2993 auto tensors = EvaluateNodes(output, fetch_nodes);
2994 EXPECT_EQ(fetch_nodes.size(), tensors.size());
2995 for (int i = 0; i < fetch_nodes.size(); i++)
2996 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2997
2998 // Make sure that the representation of the folded constant is space
2999 // efficient: in particular, the whole message should be smaller than 8k
3000 // (the size needed to naively encode 1000 floats folded twice).
3001 EXPECT_GT(8000, output.ByteSizeLong());
3002 }
3003
TEST_F(ConstantFoldingTest,LargeConstantNoSizeIncrease)3004 TEST_F(ConstantFoldingTest, LargeConstantNoSizeIncrease) {
3005 // Build a simple graph with a large constant with size greater than
3006 // kMaxConstantSize that can be folded because the resulting size does not
3007 // increase.
3008 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3009 const int64_t large_constant_size = kMaxConstantSize + 1;
3010 Output a = ops::Variable(scope.WithOpName("a"), {1, 1}, DT_FLOAT);
3011 Output b_const =
3012 ops::Const(scope.WithOpName("b_const"), 3.14f, {1, large_constant_size});
3013 Output b = ops::Identity(scope.WithOpName("b"), b_const);
3014 Output matmul = ops::MatMul(scope.WithOpName("matmul"), a, b);
3015
3016 GrapplerItem item;
3017 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3018
3019 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3020 GraphDef output;
3021 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3022 TF_EXPECT_OK(status);
3023
3024 item.graph.Swap(&output);
3025 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3026 TF_EXPECT_OK(status);
3027
3028 for (const auto& node : output.node()) {
3029 if (node.name() == "b") {
3030 EXPECT_EQ("Const", node.op());
3031 }
3032 }
3033 EXPECT_EQ(4, output.node_size());
3034 EXPECT_LT(output.ByteSizeLong(), sizeof(float) * large_constant_size + 500);
3035 }
3036
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)3037 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
3038 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3039 Output a =
3040 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3041 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3042 Output b = ops::Square(s.WithOpName("b"), a);
3043 Output c = ops::Mul(s.WithOpName("c"), a, b);
3044 Output d = ops::Shape(s.WithOpName("d"), a);
3045 Output e = ops::Shape(s.WithOpName("e"), b);
3046
3047 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3048 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3049 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3050
3051 Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
3052 ops::Placeholder::Shape(PartialTensorShape({1})));
3053 Output h = ops::Shape(s.WithOpName("h"), g);
3054 auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
3055 Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
3056 Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
3057
3058 GrapplerItem item;
3059 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3060
3061 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3062 GraphDef output;
3063 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3064 TF_EXPECT_OK(status);
3065
3066 std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
3067 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
3068 auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
3069 auto tensors_expected =
3070 EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3071 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3072
3073 // Run a second time to make sure the optimization is idempotent.
3074 item.graph.Swap(&output);
3075 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3076 TF_EXPECT_OK(status);
3077
3078 int found = 0;
3079 for (const auto& node : output.node()) {
3080 if (node.name() == "o1") {
3081 ++found;
3082 EXPECT_EQ(1, node.input_size());
3083 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3084 } else if (node.name() == "o2") {
3085 ++found;
3086 EXPECT_EQ(1, node.input_size());
3087 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3088 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3089 ++found;
3090 EXPECT_EQ("Const", node.op());
3091 EXPECT_EQ(1, node.input_size());
3092 EXPECT_EQ("^f", node.input(0));
3093 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3094 .num_elements());
3095 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3096 ++found;
3097 EXPECT_EQ("Const", node.op());
3098 EXPECT_EQ(1, node.input_size());
3099 EXPECT_EQ("^f", node.input(0));
3100 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3101 .num_elements());
3102 } else if (node.name() == "p1") {
3103 ++found;
3104 EXPECT_EQ(1, node.input_size());
3105 EXPECT_EQ("i", node.input(0));
3106 } else if (node.name() == "p2") {
3107 ++found;
3108 EXPECT_EQ(1, node.input_size());
3109 EXPECT_EQ("i:1", node.input(0));
3110 }
3111 }
3112 EXPECT_EQ(6, found);
3113
3114 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3115 EXPECT_EQ(fetch_nodes.size(), tensors.size());
3116 for (int i = 0; i < fetch_nodes.size(); i++)
3117 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3118 }
3119
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)3120 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
3121 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3122 Output a =
3123 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3124 ops::Placeholder::Shape(PartialTensorShape({2, 2})));
3125 Output b = ops::Square(s.WithOpName("b"), a);
3126 Output c = ops::Mul(s.WithOpName("c"), a, b);
3127 Output d = ops::Shape(s.WithOpName("d"), a);
3128 Output e = ops::Shape(s.WithOpName("e"), b);
3129
3130 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3131 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3132 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3133
3134 GrapplerItem item;
3135 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3136
3137 std::vector<string> fetch_nodes = {"o1", "o2"};
3138 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
3139 auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
3140 EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3141
3142 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3143 GraphDef output;
3144 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3145 TF_EXPECT_OK(status);
3146
3147 // Run a second time to make sure the optimization is idempotent.
3148 item.graph.Swap(&output);
3149 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3150 TF_EXPECT_OK(status);
3151
3152 EXPECT_EQ(11, output.node_size());
3153 int found = 0;
3154 for (const auto& node : output.node()) {
3155 if (node.name() == "ConstantFolding/f-folded-1") {
3156 ++found;
3157 EXPECT_EQ("Const", node.op());
3158 EXPECT_EQ(2, node.input_size());
3159 EXPECT_EQ("^a", node.input(0));
3160 EXPECT_EQ("^b", node.input(1));
3161 } else if (node.name() == "d") {
3162 ++found;
3163 EXPECT_EQ("Const", node.op());
3164 EXPECT_EQ(1, node.input_size());
3165 EXPECT_EQ("^a", node.input(0));
3166 } else if (node.name() == "e") {
3167 ++found;
3168 EXPECT_EQ("Const", node.op());
3169 EXPECT_EQ(1, node.input_size());
3170 EXPECT_EQ("^b", node.input(0));
3171 } else if (node.name() == "o1") {
3172 ++found;
3173 EXPECT_EQ(1, node.input_size());
3174 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3175 } else if (node.name() == "o2") {
3176 ++found;
3177 EXPECT_EQ(1, node.input_size());
3178 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3179 } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3180 ++found;
3181 EXPECT_EQ("Const", node.op());
3182 EXPECT_EQ(1, node.input_size());
3183 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3184 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3185 .num_elements());
3186 } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3187 ++found;
3188 EXPECT_EQ("Const", node.op());
3189 EXPECT_EQ(1, node.input_size());
3190 EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3191 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3192 .num_elements());
3193 }
3194 }
3195 EXPECT_EQ(7, found);
3196 auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
3197 EXPECT_EQ(fetch_nodes.size(), tensors.size());
3198 for (int i = 0; i < fetch_nodes.size(); i++)
3199 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3200 }
3201
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)3202 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
3203 for (bool use_reshape : {true, false}) {
3204 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3205 Output input =
3206 ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3207 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3208 // If use_reshape is false, we need to now the number of indices to apply
3209 // the rewrite.
3210 Output indices = ops::Placeholder(
3211 s.WithOpName("indices"), DT_INT32,
3212 ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
3213 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3214 if (use_reshape) {
3215 Output size = ops::Const(s.WithOpName("size"), 1, {1});
3216 Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
3217 }
3218
3219 GrapplerItem item;
3220 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3221 item.fetch.push_back(use_reshape ? "reshape" : "sum");
3222
3223 auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
3224 Tensor indices_t(DT_INT32, TensorShape({2}));
3225 indices_t.flat<int>()(0) = 0;
3226 indices_t.flat<int>()(1) = 1;
3227 auto tensors_expected = EvaluateNodes(
3228 item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
3229 EXPECT_EQ(1, tensors_expected.size());
3230
3231 // Use aggressive mode to force the shape inference to propagate placeholder
3232 // shapes.
3233 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3234 /*cpu_device=*/nullptr);
3235 GraphDef output;
3236 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3237 TF_EXPECT_OK(status);
3238
3239 // Run a second time to make sure the optimization is idempotent.
3240 item.graph.Swap(&output);
3241 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3242 TF_EXPECT_OK(status);
3243
3244 int found = 0;
3245 for (const auto& node : output.node()) {
3246 if (node.name() == "ConstantFolding/sum-reduction_indices") {
3247 ++found;
3248 EXPECT_EQ("Const", node.op());
3249 EXPECT_EQ("^indices", node.input(0));
3250 EXPECT_EQ(2,
3251 TensorShape(node.attr().at("value").tensor().tensor_shape())
3252 .num_elements());
3253 } else if (node.name() == "sum") {
3254 ++found;
3255 EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
3256 } else if (node.name() == "indices") {
3257 ++found;
3258 }
3259 }
3260 EXPECT_EQ(3, found);
3261
3262 auto tensors = EvaluateNodes(output, item.fetch,
3263 {{"input", input_t}, {"indices", indices_t}});
3264 EXPECT_EQ(1, tensors.size());
3265 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
3266 }
3267 }
3268
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)3269 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
3270 for (bool input_rank_known : {true, false}) {
3271 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3272 Output input =
3273 (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3274 ops::Placeholder::Shape(
3275 PartialTensorShape({-1, -1})))
3276 : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
3277 Output indices =
3278 ops::Placeholder(s.WithOpName("indices"), DT_INT32,
3279 ops::Placeholder::Shape(
3280 PartialTensorShape({input_rank_known ? 1 : 2})));
3281 Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3282
3283 GrapplerItem item;
3284 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3285 item.fetch.push_back("sum");
3286
3287 // Use aggressive mode to force the shape inference to propagate placeholder
3288 // shapes.
3289 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3290 /*cpu_device=*/nullptr);
3291 GraphDef output;
3292 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3293 TF_EXPECT_OK(status);
3294
3295 CompareGraphs(item.graph, output);
3296 }
3297 }
3298
TEST_F(ConstantFoldingTest,LargeConstant)3299 TEST_F(ConstantFoldingTest, LargeConstant) {
3300 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3301 // Generate a 4k by 4k constant, non-compressible matrix.
3302 Output mat_diag =
3303 ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
3304 Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
3305 Output out = ops::Identity(scope.WithOpName("out"), mat);
3306
3307 GrapplerItem item;
3308 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3309 item.fetch.push_back("out");
3310
3311 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3312 GraphDef output;
3313 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3314 TF_EXPECT_OK(status);
3315
3316 // Make sure the diag node hasn't been folded, since it would use too much
3317 // memory to encode the corresponding constant.
3318 int found = 0;
3319 for (const NodeDef& node : output.node()) {
3320 if (node.name() == "out") {
3321 EXPECT_EQ(node.op(), "Identity");
3322 ASSERT_EQ(node.input_size(), 1);
3323 EXPECT_EQ(node.input(0), "mat");
3324 ++found;
3325 } else if (node.name() == "mat") {
3326 EXPECT_EQ(node.op(), "Diag");
3327 ASSERT_EQ(node.input_size(), 1);
3328 EXPECT_EQ(node.input(0), "mat_diag");
3329 ++found;
3330 }
3331 }
3332 EXPECT_EQ(found, 2);
3333 // output should be no longer than the size of the constant "mat_diag"
3334 // plus a small constant amount for the remaining nodes.
3335 EXPECT_LT(output.ByteSizeLong(), sizeof(int) * 4 * 1024 + 500);
3336
3337 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3338 ASSERT_EQ(tensors_expected.size(), 1);
3339 auto tensors = EvaluateNodes(output, item.fetch);
3340 ASSERT_EQ(tensors.size(), 1);
3341 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3342 }
3343
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)3344 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
3345 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3346 Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
3347 ops::Placeholder::Shape(TensorShape({})));
3348 ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
3349 Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
3350 Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
3351
3352 GrapplerItem item;
3353 item.fetch.push_back("id_false");
3354 item.fetch.push_back("id_true");
3355 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3356
3357 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3358 GraphDef output;
3359 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3360 TF_EXPECT_OK(status);
3361
3362 EXPECT_EQ(6, output.node_size());
3363 int found = 0;
3364 for (const auto& node : output.node()) {
3365 if (node.name() == "switch" || node.name() == "x") {
3366 ++found;
3367 }
3368 if (node.name() == "id_false") {
3369 EXPECT_EQ("Const", node.op());
3370 EXPECT_EQ(1, node.input_size());
3371 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3372 ++found;
3373 }
3374 if (node.name() == "id_true") {
3375 EXPECT_EQ("Const", node.op());
3376 EXPECT_EQ(1, node.input_size());
3377 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3378 ++found;
3379 }
3380 if (node.name() == "ConstantFoldingCtrl/switch_0") {
3381 EXPECT_EQ("Identity", node.op());
3382 EXPECT_EQ(1, node.input_size());
3383 EXPECT_EQ("switch", node.input(0));
3384 ++found;
3385 }
3386 if (node.name() == "ConstantFoldingCtrl/switch_1") {
3387 EXPECT_EQ("Identity", node.op());
3388 EXPECT_EQ(1, node.input_size());
3389 EXPECT_EQ("switch:1", node.input(0));
3390 ++found;
3391 }
3392 }
3393 EXPECT_EQ(6, found);
3394
3395 // Evaluate id_true when input tensor x is true.
3396 Tensor x_t(DT_BOOL, TensorShape({}));
3397 x_t.flat<bool>()(0) = true;
3398 auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3399 EXPECT_EQ(1, tensors_expected.size());
3400 auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3401 EXPECT_EQ(1, tensors.size());
3402 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3403
3404 // Evaluate id_false when input tensor is false.
3405 x_t.flat<bool>()(0) = false;
3406 tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3407 EXPECT_EQ(1, tensors_expected.size());
3408 tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3409 EXPECT_EQ(1, tensors.size());
3410 test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3411 }
3412
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3413 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3414 std::function<Output(const Scope&, InputList)> addn_fun =
3415 [](const Scope& scope, InputList inputs) {
3416 return ops::AddN(scope, inputs);
3417 };
3418 std::function<Output(const Scope&, InputList)> accumulate_fun =
3419 [](const Scope& scope, InputList inputs) {
3420 return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3421 };
3422 for (bool use_add_n : {true, false}) {
3423 auto fun = use_add_n ? addn_fun : accumulate_fun;
3424 const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3425 Scope s = Scope::NewRootScope();
3426 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3427 ops::Placeholder::Shape(TensorShape({2, 2})));
3428 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3429 ops::Placeholder::Shape(TensorShape({2, 2})));
3430 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3431 ops::Placeholder::Shape(TensorShape({2, 2})));
3432 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3433 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3434 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3435 Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3436 Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3437 Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3438 Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3439 Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3440 Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3441 Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3442 Output stack = ops::Stack(s.WithOpName("stack"),
3443 {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3444
3445 GrapplerItem item;
3446 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3447 item.fetch = {"stack"};
3448
3449 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3450 GraphDef output;
3451 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3452 TF_EXPECT_OK(status);
3453
3454 EXPECT_EQ(16, output.node_size());
3455 for (const NodeDef& node : output.node()) {
3456 if (node.name() == "acc0") {
3457 EXPECT_EQ("Const", node.op());
3458 }
3459 if (node.name() == "acc1") {
3460 EXPECT_EQ(op_name, node.op());
3461 EXPECT_EQ(3, node.input_size());
3462 EXPECT_EQ("x", node.input(0));
3463 EXPECT_EQ("y", node.input(1));
3464 EXPECT_EQ("z", node.input(2));
3465 }
3466 if (node.name() == "acc2") {
3467 EXPECT_EQ(op_name, node.op());
3468 EXPECT_EQ(3, node.input_size());
3469 EXPECT_EQ("c1", node.input(0));
3470 EXPECT_EQ("x", node.input(1));
3471 EXPECT_EQ("y", node.input(2));
3472 }
3473 if (node.name() == "acc3") {
3474 EXPECT_EQ(op_name, node.op());
3475 EXPECT_EQ(2, node.input_size());
3476 EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3477 EXPECT_EQ("z", node.input(1));
3478 }
3479 if (node.name() == "acc4") {
3480 EXPECT_EQ(op_name, node.op());
3481 EXPECT_EQ(2, node.input_size());
3482 EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3483 EXPECT_EQ("y", node.input(1));
3484 }
3485 if (node.name() == "acc5") {
3486 EXPECT_EQ(op_name, node.op());
3487 EXPECT_EQ(2, node.input_size());
3488 EXPECT_EQ("x", node.input(0));
3489 EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3490 }
3491 if (node.name() == "acc6") {
3492 EXPECT_EQ(op_name, node.op());
3493 EXPECT_EQ(3, node.input_size());
3494 EXPECT_EQ("x", node.input(0));
3495 EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3496 EXPECT_EQ("y", node.input(2));
3497 }
3498 if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3499 EXPECT_EQ("Const", node.op());
3500 }
3501 }
3502
3503 std::vector<string> fetch = {"acc0"};
3504 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3505 auto tensors = EvaluateNodes(output, fetch);
3506 EXPECT_EQ(1, tensors_expected.size());
3507 EXPECT_EQ(1, tensors.size());
3508 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3509 }
3510 }
3511
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3512 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3513 Scope s = Scope::NewRootScope();
3514 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3515 ops::Placeholder::Shape(TensorShape({2, 2})));
3516 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3517 ops::Placeholder::Shape(TensorShape({2, 2})));
3518 Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3519 ops::Placeholder::Shape(TensorShape({2, 2})));
3520 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3521 Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3522 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3523 Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3524 Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3525 Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3526 Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3527 Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3528 Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3529 Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3530 Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3531 Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3532 Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3533
3534 GrapplerItem item;
3535 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3536 item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3537 "concat5", "concat6", "concat7", "concat8", "concat9"};
3538
3539 auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3540 EXPECT_EQ(1, tensors_expected.size());
3541 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3542 GraphDef output;
3543 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3544 TF_EXPECT_OK(status);
3545 // Run the optimizer twice to make sure the rewrite is idempotent.
3546 item.graph.Swap(&output);
3547 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3548 TF_EXPECT_OK(status);
3549
3550 EXPECT_EQ(21, output.node_size());
3551 for (int i = 0; i < output.node_size(); ++i) {
3552 const NodeDef& node = output.node(i);
3553 if (node.name() == "concat0") {
3554 EXPECT_EQ("Const", node.op());
3555 } else if (node.name() == "concat3") {
3556 EXPECT_EQ(3, node.input_size());
3557 EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3558 EXPECT_EQ("z", node.input(1));
3559 EXPECT_EQ("axis", node.input(2));
3560 } else if (node.name() == "concat5") {
3561 EXPECT_EQ(3, node.input_size());
3562 EXPECT_EQ("x", node.input(0));
3563 EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3564 EXPECT_EQ("axis", node.input(2));
3565 } else if (node.name() == "concat7") {
3566 EXPECT_EQ(4, node.input_size());
3567 EXPECT_EQ("x", node.input(0));
3568 EXPECT_EQ("y", node.input(1));
3569 EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3570 EXPECT_EQ("axis", node.input(3));
3571 } else if (node.name() == "concat8") {
3572 EXPECT_EQ(4, node.input_size());
3573 EXPECT_EQ("x", node.input(0));
3574 EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3575 EXPECT_EQ("y", node.input(2));
3576 EXPECT_EQ("axis", node.input(3));
3577 } else if (node.name() == "concat9") {
3578 EXPECT_EQ(4, node.input_size());
3579 EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3580 EXPECT_EQ("x", node.input(1));
3581 EXPECT_EQ("y", node.input(2));
3582 EXPECT_EQ("axis", node.input(3));
3583 } else if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3584 EXPECT_EQ("Const", node.op());
3585 } else {
3586 EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3587 }
3588 }
3589
3590 auto tensors = EvaluateNodes(output, {"concat0"});
3591 EXPECT_EQ(1, tensors.size());
3592 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3593 }
3594
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3595 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3596 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3597 Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3598 ops::Placeholder::Shape(TensorShape({})));
3599 Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3600 Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3601 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3602 auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3603 auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3604 auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3605 auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3606
3607 GrapplerItem item;
3608 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3609 item.fetch.push_back("id0");
3610 item.fetch.push_back("id1");
3611 item.fetch.push_back("add0");
3612 item.fetch.push_back("add1");
3613
3614 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3615 GraphDef output;
3616 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3617 TF_EXPECT_OK(status);
3618 EXPECT_EQ(8, output.node_size());
3619 for (const auto& node : output.node()) {
3620 // id_n should remain unchanged.
3621 if (node.name() == "id_n") {
3622 EXPECT_EQ(3, node.input_size());
3623 EXPECT_EQ("c1", node.input(0));
3624 EXPECT_EQ("x", node.input(1));
3625 EXPECT_EQ("c2", node.input(2));
3626 }
3627 // id0 should be constant folded, and a control dependency from id_n.
3628 if (node.name() == "id0") {
3629 EXPECT_EQ("Const", node.op());
3630 EXPECT_EQ(1, node.input_size());
3631 EXPECT_EQ("^id_n", node.input(0));
3632 }
3633 // id1 is unchanged.
3634 if ("id1" == node.name()) {
3635 EXPECT_EQ(1, node.input_size());
3636 EXPECT_EQ("id_n:1", node.input(0));
3637 }
3638
3639 if ("add0" == node.name()) {
3640 EXPECT_EQ(2, node.input_size());
3641 EXPECT_EQ("c1", node.input(0));
3642 EXPECT_EQ("id_n:1", node.input(1));
3643 }
3644 // add1 should bo constant folded and have a control dependency from id_n.
3645 if ("add1" == node.name()) {
3646 EXPECT_EQ("Const", node.op());
3647 EXPECT_EQ(1, node.input_size());
3648 EXPECT_EQ("^id_n", node.input(0));
3649 }
3650 }
3651
3652 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3653 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3654 EXPECT_EQ(4, tensors_expected.size());
3655 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3656 EXPECT_EQ(4, tensors.size());
3657 for (int i = 0; i < tensors.size(); i++) {
3658 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3659 }
3660 }
3661
TEST_F(ConstantFoldingTest,TrivialPack)3662 TEST_F(ConstantFoldingTest, TrivialPack) {
3663 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3664 Output x =
3665 ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3666 Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3667 auto stack =
3668 ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3669 ops::Stack::Axis(1));
3670 auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3671
3672 GrapplerItem item;
3673 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3674 item.fetch = {"stack", "stack_no_axis"};
3675
3676 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3677 GraphDef output;
3678 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3679 TF_EXPECT_OK(status);
3680 EXPECT_EQ(7, output.node_size());
3681 int found = 0;
3682 for (const auto& node : output.node()) {
3683 if (node.name() == "stack") {
3684 EXPECT_EQ("ExpandDims", node.op());
3685 EXPECT_EQ(3, node.input_size());
3686 EXPECT_EQ("x", node.input(0));
3687 EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3688 EXPECT_EQ("^y", node.input(2));
3689 ++found;
3690 } else if (node.name() == "stack_no_axis") {
3691 EXPECT_EQ("ExpandDims", node.op());
3692 EXPECT_EQ(2, node.input_size());
3693 EXPECT_EQ("x", node.input(0));
3694 EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3695 ++found;
3696 } else if (node.name() == "ConstantFolding/stack_const_axis") {
3697 EXPECT_EQ("Const", node.op());
3698 EXPECT_EQ(1, node.input_size());
3699 EXPECT_EQ("^x", node.input(0));
3700 ++found;
3701 }
3702 }
3703 EXPECT_EQ(found, 3);
3704
3705 std::vector<string> fetch = {"stack", "stack_no_axis"};
3706 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3707 auto tensors = EvaluateNodes(output, fetch);
3708 EXPECT_EQ(2, tensors_expected.size());
3709 EXPECT_EQ(2, tensors.size());
3710 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3711 EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3712 }
3713
3714 // The test does not evalaute the optimized and original graphs to check if
3715 // their outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3716 TEST_F(ConstantFoldingTest, Enter) {
3717 GrapplerItem item;
3718 AttrValue frame_name;
3719 frame_name.set_s("foo");
3720 AttrValue is_constant_true;
3721 is_constant_true.set_b(true);
3722 AttrValue is_constant_false;
3723 is_constant_false.set_b(false);
3724 AttrValue type;
3725 type.set_type(DT_FLOAT);
3726 AttrValue value;
3727 Tensor value_tensor(DT_FLOAT, TensorShape({}));
3728 value_tensor.flat<float>()(0) = 1;
3729 value_tensor.AsProtoTensorContent(value.mutable_tensor());
3730
3731 GraphDef& graph = item.graph;
3732 AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3733 AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3734 AddNode("enter1", "Enter", {"x"},
3735 {{"T", type},
3736 {"frame_name", frame_name},
3737 {"is_constant", is_constant_true}},
3738 &graph);
3739 AddNode("enter2", "Enter", {"c1"},
3740 {{"T", type},
3741 {"frame_name", frame_name},
3742 {"is_constant", is_constant_true}},
3743 &graph);
3744 AddNode("enter3", "Enter", {"c1"},
3745 {{"T", type},
3746 {"frame_name", frame_name},
3747 {"is_constant", is_constant_false}},
3748 &graph);
3749 AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3750 AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3751 AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3752 AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3753 item.fetch.push_back("id1");
3754 item.fetch.push_back("id2");
3755 item.fetch.push_back("id3");
3756 item.fetch.push_back("id4");
3757
3758 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3759 GraphDef output;
3760 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3761 TF_EXPECT_OK(status);
3762 // Run the optimizer twice to make sure the rewrite is idempotent.
3763 item.graph.Swap(&output);
3764 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3765 TF_EXPECT_OK(status);
3766
3767 EXPECT_EQ(9, output.node_size());
3768 for (const NodeDef& node : output.node()) {
3769 if (node.name() == "id1") {
3770 EXPECT_EQ("Identity", node.op());
3771 EXPECT_EQ(1, node.input_size());
3772 EXPECT_EQ("enter1", node.input(0));
3773 }
3774 if (node.name() == "id2" || node.name() == "id3") {
3775 EXPECT_EQ("Const", node.op());
3776 EXPECT_EQ(1, node.input_size());
3777 EXPECT_EQ("^enter2", node.input(0));
3778 }
3779 if (node.name() == "id4") {
3780 EXPECT_EQ("Identity", node.op());
3781 EXPECT_EQ(1, node.input_size());
3782 EXPECT_EQ("enter3", node.input(0));
3783 }
3784 }
3785 }
3786
TEST_F(ConstantFoldingTest,TensorArraySize)3787 TEST_F(ConstantFoldingTest, TensorArraySize) {
3788 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3789 Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3790 Output placeholder =
3791 ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3792 ops::Placeholder::Shape(TensorShape({2})));
3793 Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3794 auto dynamic_array =
3795 ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3796 ops::TensorArray::DynamicSize(true));
3797 auto static_array =
3798 ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3799 ops::TensorArray::DynamicSize(false));
3800 auto dynamic_sz = ops::TensorArraySize(
3801 scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3802 auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3803 static_array.handle, static_array.flow);
3804 auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3805 placeholder, foo);
3806
3807 GrapplerItem item;
3808 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3809
3810 auto tensors_expected =
3811 EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3812
3813 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3814 GraphDef output;
3815 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3816 TF_EXPECT_OK(status);
3817 // Run the optimizer twice to make sure the rewrite is idempotent.
3818 item.graph.Swap(&output);
3819 status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3820 TF_EXPECT_OK(status);
3821
3822 EXPECT_EQ(8, output.node_size());
3823 EXPECT_EQ("dynamic_sz", output.node(5).name());
3824 EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3825 EXPECT_EQ("static_sz", output.node(6).name());
3826 EXPECT_EQ("Const", output.node(6).op());
3827 EXPECT_EQ("placeholder_sz", output.node(7).name());
3828 EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3829
3830 auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3831 EXPECT_EQ(2, tensors_expected.size());
3832 EXPECT_EQ(2, tensors_actual.size());
3833 test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3834 test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3835 }
3836
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3837 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3838 // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3839 // Make sure constant folding behaves the same way as TensorFlow.
3840 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3841
3842 Output a =
3843 ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3844 Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3845 Output c = ops::Mul(s.WithOpName("c"), a, b);
3846
3847 GrapplerItem item;
3848 item.fetch.push_back("c");
3849 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3850
3851 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3852 GraphDef output;
3853 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3854 TF_EXPECT_OK(status);
3855
3856 EXPECT_EQ(1, output.node_size());
3857
3858 const NodeDef& node_d = output.node(0);
3859 EXPECT_EQ("c", node_d.name());
3860 EXPECT_EQ("Const", node_d.op());
3861
3862 std::vector<string> fetch = {"c"};
3863 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3864 auto tensors = EvaluateNodes(output, fetch);
3865 EXPECT_EQ(1, tensors_expected.size());
3866 EXPECT_EQ(1, tensors.size());
3867 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3868 }
3869
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3870 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3871 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3872
3873 int size = 10 * 1024 * 1024 / 4 / 2;
3874 Output nonconst =
3875 ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3876 Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3877 Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3878 Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3879 Output concat1 =
3880 ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3881 Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3882
3883 GrapplerItem item;
3884 item.fetch.push_back("result");
3885 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3886
3887 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3888 GraphDef output;
3889 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3890 TF_EXPECT_OK(status);
3891
3892 std::vector<string> fetch = {"result"};
3893 auto tensors_expected = EvaluateNodes(item.graph, fetch);
3894 auto tensors = EvaluateNodes(output, fetch);
3895 EXPECT_EQ(1, tensors_expected.size());
3896 EXPECT_EQ(1, tensors.size());
3897 EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3898 }
3899
3900 class ConstantFoldingCastConstTest : public GrapplerTest {
3901 protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3902 void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3903 bool fetch_const_child, bool fetch_cast_child) {
3904 if (!fetch_const && !fetch_cast && !fetch_const_child &&
3905 !fetch_cast_child) {
3906 return;
3907 }
3908
3909 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3910 CreateCastConstGraph(s);
3911 GrapplerItem item;
3912 int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3913 fetch_const_child, fetch_cast_child);
3914 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3915
3916 GraphDef output = ConstantFoldingOptimize(item);
3917 EXPECT_EQ(expected_output_size, output.node_size());
3918
3919 EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3920 }
3921
3922 private:
CreateCastConstGraph(const tensorflow::Scope & s)3923 void CreateCastConstGraph(const tensorflow::Scope& s) {
3924 Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3925 Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3926 Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3927 Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3928 }
3929
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3930 int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3931 bool fetch_const_child, bool fetch_cast_child) {
3932 int expected_output_size = 0;
3933 if (fetch_const) {
3934 item->fetch.push_back("const1");
3935 expected_output_size++;
3936 }
3937 if (fetch_cast) {
3938 item->fetch.push_back("cast");
3939 expected_output_size++;
3940 }
3941 if (fetch_const_child) {
3942 item->fetch.push_back("const1_child");
3943 expected_output_size++;
3944 }
3945 if (fetch_cast_child) {
3946 item->fetch.push_back("cast_child");
3947 expected_output_size++;
3948 }
3949 return expected_output_size;
3950 }
3951
ConstantFoldingOptimize(const GrapplerItem & item)3952 GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3953 ConstantFolding optimizer(/*cpu_device=*/nullptr);
3954 GraphDef output;
3955 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3956 TF_EXPECT_OK(status);
3957 return output;
3958 }
3959
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3960 void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3961 const GraphDef& optimized_graph,
3962 const std::vector<string>& fetch_nodes) {
3963 auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3964 auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3965 ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3966 ASSERT_EQ(fetch_nodes.size(), tensors.size());
3967 for (int i = 0; i < fetch_nodes.size(); i++) {
3968 if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3969 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3970 } else {
3971 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3972 }
3973 }
3974 }
3975 };
3976
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3977 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3978 for (bool fetch_const : {false, true}) {
3979 for (bool fetch_cast : {false, true}) {
3980 for (bool fetch_const_child : {false, true}) {
3981 for (bool fetch_cast_child : {false, true}) {
3982 ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3983 fetch_cast_child);
3984 }
3985 }
3986 }
3987 }
3988 }
3989
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3990 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3991 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3992
3993 Output x =
3994 ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3995 ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3996 Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3997 Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3998 Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3999
4000 GrapplerItem item;
4001 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4002 item.fetch = {"ones_like", "zeros_like", "fill"};
4003 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4004 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4005
4006 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4007 GraphDef output;
4008 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4009 TF_EXPECT_OK(status);
4010
4011 EXPECT_EQ(output.node_size(), 6);
4012 for (const auto& node : output.node()) {
4013 if (node.name() != "x") {
4014 EXPECT_EQ(node.op(), "Const");
4015 }
4016 if (node.name() == "ones_like" || node.name() == "zeros_like") {
4017 ASSERT_EQ(node.input_size(), 1);
4018 EXPECT_EQ(node.input(0), "^x");
4019 }
4020 if (node.name() == "fill") {
4021 ASSERT_EQ(node.input_size(), 2);
4022 EXPECT_EQ(node.input(0)[0], '^');
4023 EXPECT_EQ(node.input(1)[0], '^');
4024 }
4025 }
4026 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4027 ASSERT_EQ(item.fetch.size(), tensors.size());
4028 ASSERT_EQ(tensors_expected.size(), tensors.size());
4029 for (int i = 0; i < tensors.size(); i++) {
4030 if (item.fetch[i] == "fill") {
4031 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4032 } else {
4033 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4034 }
4035 }
4036 }
4037
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeDisableCompression)4038 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeDisableCompression) {
4039 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4040
4041 Output x =
4042 ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
4043 ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
4044 Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
4045 Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
4046 Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
4047
4048 GrapplerItem item;
4049 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4050 item.fetch = {"ones_like", "zeros_like", "fill"};
4051 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4052 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4053
4054 ConstantFolding optimizer(/*cpu_device=*/nullptr, true);
4055 GraphDef output;
4056 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4057 TF_EXPECT_OK(status);
4058
4059 EXPECT_EQ(output.node_size(), 6);
4060 for (const auto& node : output.node()) {
4061 if (node.name() == "ones_like") {
4062 EXPECT_EQ(node.op(), "OnesLike");
4063 ASSERT_EQ(node.input_size(), 1);
4064 EXPECT_EQ(node.input(0), "x");
4065 }
4066 if (node.name() == "zeros_like") {
4067 EXPECT_EQ(node.op(), "ZerosLike");
4068 ASSERT_EQ(node.input_size(), 1);
4069 EXPECT_EQ(node.input(0), "x");
4070 }
4071 if (node.name() == "fill") {
4072 EXPECT_EQ(node.op(), "Fill");
4073 ASSERT_EQ(node.input_size(), 2);
4074 EXPECT_EQ(node.input(0), "Const/Const");
4075 EXPECT_EQ(node.input(1), "Const_1/Const");
4076 }
4077 }
4078 auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4079 ASSERT_EQ(item.fetch.size(), tensors.size());
4080 ASSERT_EQ(tensors_expected.size(), tensors.size());
4081 for (int i = 0; i < tensors.size(); i++) {
4082 if (item.fetch[i] == "fill") {
4083 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4084 } else {
4085 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4086 }
4087 }
4088 }
4089
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)4090 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
4091 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4092 Output value = ops::Const(scope.WithOpName("value"), 42, {});
4093 Output shape_const = ops::Const(scope.WithOpName("shape"),
4094 {1024, 1024, 1024, 1024, 1024}, {5});
4095 Output fill_huge =
4096 ops::Fill(scope.WithOpName("fill_huge"), shape_const, value);
4097
4098 GrapplerItem item;
4099 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4100 // Manually convert the input value format to tensor_content to test this
4101 // case.
4102 NodeDef* node = item.graph.mutable_node(0);
4103 ASSERT_EQ(node->name(), "value");
4104 TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
4105 t->clear_int_val();
4106 int val = 42;
4107 port::CopyFromArray(t->mutable_tensor_content(),
4108 reinterpret_cast<const char*>(&val), sizeof(int));
4109 item.fetch = {"fill_huge"};
4110 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4111 GraphDef output;
4112 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4113 TF_EXPECT_OK(status);
4114
4115 EXPECT_EQ(output.node_size(), 3);
4116 for (const auto& node : output.node()) {
4117 EXPECT_EQ(node.op(), "Const");
4118 if (node.name() == "fill_huge") {
4119 ASSERT_EQ(node.input_size(), 2);
4120 EXPECT_EQ(node.input(0), "^shape");
4121 EXPECT_EQ(node.input(1), "^value");
4122 }
4123 }
4124 }
4125
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)4126 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
4127 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4128
4129 Tensor x_t(DT_INT64, TensorShape({2, 2}));
4130 x_t.flat<int64_t>()(0) = 9223372036854775807L;
4131 x_t.flat<int64_t>()(1) = 1L;
4132 x_t.flat<int64_t>()(2) = 9223372036854775807L;
4133 x_t.flat<int64_t>()(3) = 1L;
4134 Output x = ops::Const(scope.WithOpName("x"), x_t);
4135 Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
4136 Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
4137
4138 GrapplerItem item;
4139 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4140 item.fetch = {"z"};
4141 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
4142
4143 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4144 GraphDef output;
4145 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4146 TF_EXPECT_OK(status);
4147
4148 ASSERT_EQ(output.node_size(), 1);
4149 const NodeDef& node = output.node(0);
4150 EXPECT_EQ(node.name(), "z");
4151 EXPECT_EQ(node.op(), "Const");
4152
4153 auto tensors = EvaluateNodes(output, item.fetch, {});
4154 ASSERT_EQ(tensors.size(), 1);
4155 ASSERT_EQ(tensors_expected.size(), 1);
4156 test::ExpectTensorEqual<int64_t>(tensors[0], tensors_expected[0]);
4157 }
4158
TEST_F(ConstantFoldingTest,SimplifyCase)4159 TEST_F(ConstantFoldingTest, SimplifyCase) {
4160 using test::function::NDef;
4161
4162 for (int index = 0; index < 2; ++index) {
4163 // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
4164 GrapplerItem item;
4165 constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
4166 AttrValue branches;
4167 auto* f = branches.mutable_list()->add_func();
4168 f->set_name("XTimesTwo");
4169 (*f->mutable_attr())["T"].set_type(DT_FLOAT);
4170 auto* g = branches.mutable_list()->add_func();
4171 *g = *f;
4172 g->set_name("NonZero");
4173
4174 // Add a pair of somewhat arbitrary output shapes to
4175 // test that they are correctly propagates to the _output_shapes
4176 // attribute.
4177 AttrValue output_shapes;
4178 // The first shape is a scalar.
4179 output_shapes.mutable_list()->add_shape();
4180 // The second shape is unknown.
4181 TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
4182 g_shape->set_unknown_rank(true);
4183
4184 const Tensor kZero = test::AsScalar<int32>(0);
4185 const Tensor kOne = test::AsScalar<int32>(1);
4186 item.graph = test::function::GDef(
4187 {NDef("one", "Const", {},
4188 {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
4189 kDevice),
4190 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
4191 NDef("case", "Case", {"one", "x"},
4192 {{"Tin", DataTypeSlice{DT_FLOAT}},
4193 {"Tout", DataTypeSlice{DT_FLOAT}},
4194 {"branches", branches},
4195 {"output_shapes", output_shapes}},
4196 kDevice),
4197 NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
4198 // FunctionLib
4199 {
4200 test::function::XTimesTwo(),
4201 test::function::NonZero(),
4202 });
4203 VLOG(1) << "Before: " << item.graph.DebugString();
4204
4205 item.fetch = {"y"};
4206 const Tensor kTwo = test::AsScalar<float>(2.0f);
4207 auto tensors_expected =
4208 EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
4209
4210 ConstantFolding optimizer(/*cpu_device=*/nullptr);
4211 GraphDef optimized_graph;
4212 TF_ASSERT_OK(
4213 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4214 VLOG(1) << "After: " << optimized_graph.DebugString();
4215
4216 int pco_count = 0;
4217 for (const auto& node : optimized_graph.node()) {
4218 EXPECT_NE(node.op(), "Case");
4219 if (node.op() == "PartitionedCall") {
4220 ++pco_count;
4221 const auto& shape_list = node.attr().at("_output_shapes").list();
4222 ASSERT_EQ(shape_list.shape_size(), 1);
4223 EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
4224 if (index == 0) {
4225 EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
4226 EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
4227 } else {
4228 EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
4229 EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
4230 }
4231 }
4232 }
4233 EXPECT_EQ(pco_count, 1);
4234
4235 auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
4236 ASSERT_EQ(tensors.size(), tensors_expected.size());
4237 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4238 }
4239 }
4240
TEST_F(ConstantFoldingTest,SimplifySelect)4241 TEST_F(ConstantFoldingTest, SimplifySelect) {
4242 for (bool scalar_pred : {true, false}) {
4243 for (bool pred_val : {true, false}) {
4244 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4245 std::unique_ptr<Tensor> if_t;
4246 if (scalar_pred) {
4247 if_t.reset(new Tensor(DT_BOOL, TensorShape()));
4248 } else {
4249 if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
4250 }
4251 for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
4252 if_t->flat<bool>()(i) = pred_val;
4253 }
4254 Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4255 Output then_ =
4256 ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4257 ops::Placeholder::Shape(TensorShape({2, 2})));
4258 Output else_ =
4259 ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4260 ops::Placeholder::Shape(TensorShape({2, 2})));
4261 Output select =
4262 ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4263 Output id = ops::Identity(scope.WithOpName("id"), select);
4264
4265 GrapplerItem item;
4266 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4267 item.fetch = {"id"};
4268
4269 const Tensor kOne =
4270 test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
4271 const Tensor kTwo =
4272 test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
4273 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4274 {{"then", kOne}, {"else", kTwo}});
4275
4276 // Use aggressive mode to force the shape inference to propagate
4277 // placeholder shapes.
4278 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4279 /*cpu_device=*/nullptr);
4280 GraphDef optimized_graph;
4281 TF_EXPECT_OK(
4282 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4283
4284 ASSERT_EQ(optimized_graph.node_size(), 5);
4285 bool found = false;
4286 for (const auto& node : optimized_graph.node()) {
4287 if (node.name() == "select") {
4288 found = true;
4289 EXPECT_EQ(node.op(), "Identity");
4290 ASSERT_EQ(node.input_size(), 3);
4291 EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4292 EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
4293 EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4294 }
4295 }
4296 EXPECT_TRUE(found);
4297
4298 auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4299 {{"then", kOne}, {"else", kTwo}});
4300 ASSERT_EQ(tensors.size(), 1);
4301 ASSERT_EQ(tensors_expected.size(), 1);
4302 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4303 }
4304 }
4305 }
4306
TEST_F(ConstantFoldingTest,SimplifySelect_BroadcastTo)4307 TEST_F(ConstantFoldingTest, SimplifySelect_BroadcastTo) {
4308 for (TensorShape pred_shape : {TensorShape{2, 1}, TensorShape{2, 2, 1}}) {
4309 for (bool pred_val : {true, false}) {
4310 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4311 std::unique_ptr<Tensor> if_t;
4312 if_t.reset(new Tensor(DT_BOOL, pred_shape));
4313 for (int i = 0; i < pred_shape.num_elements(); ++i) {
4314 if_t->flat<bool>()(i) = pred_val;
4315 }
4316 Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4317 Output then_ =
4318 ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4319 ops::Placeholder::Shape(TensorShape({2, 1})));
4320 Output else_ =
4321 ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4322 ops::Placeholder::Shape(TensorShape({2, 4})));
4323 Output select =
4324 ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4325 Output id = ops::Identity(scope.WithOpName("id"), select);
4326
4327 GrapplerItem item;
4328 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4329 item.fetch = {"id"};
4330
4331 const Tensor kOne =
4332 test::AsTensor<float>({1.0f, 1.0f}, TensorShape({2, 1}));
4333 const Tensor kTwo = test::AsTensor<float>(
4334 {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f},
4335 TensorShape({2, 4}));
4336 auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4337 {{"then", kOne}, {"else", kTwo}});
4338
4339 // Use aggressive mode to force the shape inference to propagate
4340 // placeholder shapes.
4341 ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4342 /*cpu_device=*/nullptr);
4343 GraphDef optimized_graph;
4344 TF_EXPECT_OK(
4345 optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4346
4347 ASSERT_EQ(optimized_graph.node_size(), 6);
4348 bool found = false;
4349 for (const auto& node : optimized_graph.node()) {
4350 if (node.name() == "select") {
4351 found = true;
4352 EXPECT_EQ(node.op(), "BroadcastTo");
4353 ASSERT_EQ(node.input_size(), 4);
4354 EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4355 EXPECT_EQ(node.input(1),
4356 strings::StrCat("ConstantFolding/select-broadcastto_shape-",
4357 pred_val ? 1 : 2));
4358 EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4359 EXPECT_EQ(node.input(3), pred_val ? "^if" : "^then");
4360 }
4361 }
4362 EXPECT_TRUE(found);
4363
4364 auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4365 {{"then", kOne}, {"else", kTwo}});
4366 ASSERT_EQ(tensors.size(), 1);
4367 ASSERT_EQ(tensors_expected.size(), 1);
4368 ASSERT_EQ(tensors[0].shape(), pred_shape.num_elements() == 2
4369 ? TensorShape({2, 4})
4370 : TensorShape({2, 2, 4}));
4371 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4372 }
4373 }
4374 }
4375
TEST_F(ConstantFoldingTest,QuantizationEmulation)4376 TEST_F(ConstantFoldingTest, QuantizationEmulation) {
4377 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4378 Output x = ops::Const(scope.WithOpName("x"), {0.0f, 1.0f, 2.0f, 3.0f}, {4});
4379 Output min_range = ops::Const(scope.WithOpName("min_range"), 0.0f, {});
4380 Output max_range = ops::Const(scope.WithOpName("max_range"), 3.0f, {});
4381 Output y = ops::QuantizeAndDequantizeV2(scope.WithOpName("y"), x, min_range,
4382 max_range);
4383 Output id = ops::Identity(scope.WithOpName("id"), y);
4384
4385 GrapplerItem item;
4386 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4387 item.fetch = {"id"};
4388
4389 std::vector<Tensor> expected_tensors = EvaluateNodes(item.graph, item.fetch);
4390
4391 for (const bool fold_quantization_emulation : {false, true}) {
4392 ConstantFolding optimizer(/*cpu_device=*/nullptr,
4393 /*disable_compressed_tensor_optimization=*/false,
4394 fold_quantization_emulation);
4395 GraphDef output;
4396 Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4397 int num_quantization_emulation_ops = 0;
4398 for (const NodeDef& node : output.node()) {
4399 if (node.op() == "QuantizeAndDequantizeV2") {
4400 num_quantization_emulation_ops++;
4401 }
4402 }
4403 EXPECT_EQ(fold_quantization_emulation ? 0 : 1,
4404 num_quantization_emulation_ops);
4405
4406 std::vector<Tensor> actual_tensors = EvaluateNodes(output, item.fetch);
4407 for (int i = 0; i < item.fetch.size(); ++i) {
4408 test::ExpectTensorEqual<float>(expected_tensors[i], actual_tensors[i]);
4409 }
4410 }
4411 }
4412
4413 } // namespace
4414 } // namespace grappler
4415 } // namespace tensorflow
4416