xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/constant_folding_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17 
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/tensor_coding.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36 
37 class ConstantFoldingTest : public GrapplerTest {
38  protected:
39   template <DataType DTYPE>
SimpleNeutralElementTest()40   void SimpleNeutralElementTest() {
41     for (bool use_snapshot : {false, true}) {
42       typedef typename EnumToDataType<DTYPE>::Type T;
43       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44       Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
45                                   ops::Placeholder::Shape(TensorShape({2, 2})));
46       Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
47       Tensor zeros_t(DTYPE, TensorShape({2, 2}));
48       Tensor ones_t(DTYPE, TensorShape({2, 2}));
49       Tensor x_t(DTYPE, TensorShape({2, 2}));
50       for (int i = 0; i < 4; ++i) {
51         zeros_t.flat<T>()(i) = T(0);
52         ones_t.flat<T>()(i) = T(1);
53         x_t.flat<T>()(i) = T(i + 1);
54       }
55       Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
56       Output ones = ops::Const(s.WithOpName("ones"), ones_t);
57       Output mul1;
58       Output mul2;
59       Output add1;
60       Output add2;
61       if (DTYPE == DT_BOOL) {
62         mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
63         mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
64         add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
65         add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
66       } else {
67         if (DTYPE == DT_FLOAT) {
68           mul1 = ops::MulNoNan(s.WithOpName("mul1"), x, zeros);
69         } else {
70           mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
71         }
72         mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
73         add1 = ops::Add(s.WithOpName("add1"), x, zeros);
74         add1 = ops::Add(s.WithOpName("add2"), x, ones);
75       }
76       if (use_snapshot) {
77         // Add an op with ref input to prevent Snapshot from being
78         // turned into Identity.
79         ops::Assign(s.WithOpName("assign"), v, ones);
80       }
81       GrapplerItem item;
82       TF_CHECK_OK(s.ToGraphDef(&item.graph));
83       item.fetch = {"mul1", "mul2", "add1", "add2"};
84       ConstantFolding optimizer(/*cpu_device=*/nullptr);
85       GraphDef output;
86       Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
87       TF_EXPECT_OK(status);
88 
89       EXPECT_EQ(7, output.node_size());
90       const string snapshot_or_identity =
91           use_snapshot ? "Snapshot" : "Identity";
92       for (int i = 0; i < output.node_size(); ++i) {
93         const NodeDef& node = output.node(i);
94         const string& name = node.name();
95         if (name == "mul1") {
96           EXPECT_EQ("Const", node.op());
97           EXPECT_EQ("^x", node.input(0));
98           EXPECT_EQ("^zeros", node.input(1));
99         } else if (name == "mul2") {
100           EXPECT_EQ(snapshot_or_identity, node.op());
101           EXPECT_EQ("x", node.input(0));
102           EXPECT_EQ("^ones", node.input(1));
103         } else if (name == "add1") {
104           EXPECT_EQ(snapshot_or_identity, node.op());
105           EXPECT_EQ("x", node.input(0));
106           EXPECT_EQ("^zeros", node.input(1));
107         } else if (name == "add2") {
108           if (DTYPE == DT_BOOL) {
109             EXPECT_EQ("Const", node.op());
110             EXPECT_EQ("^x", node.input(0));
111             EXPECT_EQ("^ones", node.input(1));
112           } else {
113             EXPECT_EQ("Add", node.op());
114             EXPECT_EQ("x", node.input(0));
115             EXPECT_EQ("ones", node.input(1));
116           }
117         }
118       }
119       auto tensors_expected =
120           EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
121       auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
122       EXPECT_EQ(4, tensors_expected.size());
123       EXPECT_EQ(4, tensors.size());
124       for (int i = 0; i < item.fetch.size(); ++i) {
125         test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
126       }
127     }
128   }
129 
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)130   void MulConvPushDownTest(const TensorShape& input_shape,
131                            const TensorShape& filter_shape,
132                            const TensorShape& mul_const_input_shape,
133                            const bool use_3d_conv, const char* padding,
134                            const char* data_format, const bool expect_folded) {
135     // Tests if the following rewrite is performed:
136     //
137     //         *                       Conv2D
138     //        / \                       / \
139     //       c  Conv2D        -->      x  (c * filter)
140     //           / \
141     //          x  filter
142     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143 
144     Tensor filter_values(DT_FLOAT, filter_shape);
145     for (int i = 0; i < filter_values.NumElements(); ++i) {
146       filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
147     }
148     Output filter =
149         ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
150 
151     Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
152                                     ops::Placeholder::Shape(input_shape));
153 
154     Output conv;
155     if (use_3d_conv) {
156       conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
157                          padding, ops::Conv3D::DataFormat(data_format));
158     } else {
159       conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
160                          padding, ops::Conv2D::DataFormat(data_format));
161     }
162     Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
163     for (int i = 0; i < mul_const_input.NumElements(); ++i) {
164       mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
165     }
166     Output c =
167         ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
168     Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
169 
170     GrapplerItem item;
171     TF_CHECK_OK(s.ToGraphDef(&item.graph));
172 
173     ConstantFolding optimizer(/*cpu_device=*/nullptr);
174     GraphDef output;
175     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
176     TF_EXPECT_OK(status);
177 
178     EXPECT_EQ(5, output.node_size());
179     int found = 0;
180     if (expect_folded) {
181       for (const auto& node : output.node()) {
182         if (node.name() == "mul") {
183           found++;
184           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
185           EXPECT_EQ(2, node.input_size());
186           EXPECT_EQ("x", node.input(0));
187           EXPECT_EQ("conv/merged_input", node.input(1));
188         } else if (node.name() == "conv/merged_input") {
189           found++;
190           EXPECT_EQ("Const", node.op());
191           EXPECT_EQ(0, node.input_size());
192         }
193       }
194     } else {
195       for (const auto& node : output.node()) {
196         if (node.name() == "mul") {
197           found++;
198           EXPECT_EQ("Mul", node.op());
199           EXPECT_EQ(2, node.input_size());
200           EXPECT_EQ("c", node.input(0));
201           EXPECT_EQ("conv", node.input(1));
202         } else if (node.name() == "conv") {
203           found++;
204           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
205           EXPECT_EQ(2, node.input_size());
206           EXPECT_EQ("x", node.input(0));
207           EXPECT_EQ("filter", node.input(1));
208         }
209       }
210     }
211     EXPECT_EQ(2, found);
212 
213     // Check that const folded multiplication node has the expected value.
214     std::vector<string> fetch = {"mul"};
215     Tensor value(DT_FLOAT, input_shape);
216     for (int i = 0; i < value.NumElements(); ++i) {
217       value.flat<float>()(i) = i;
218     }
219     auto actual = EvaluateNodes(output, fetch, {{"x", value}});
220     auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
221     test::ExpectTensorEqual<float>(expected[0], actual[0]);
222   }
223 
224   template <typename T>
PaddingWithZeroSize()225   void PaddingWithZeroSize() {
226     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
227 
228     auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
229     auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
230     auto paddings1 =
231         ops::Const<T>(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
232     auto paddings2 =
233         ops::Const<T>(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
234     auto c1 = ops::Const(scope.WithOpName("c1"), 1);
235     auto c2 = ops::Const(scope.WithOpName("c2"), 1);
236 
237     ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
238     ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
239 
240     ops::Add out(scope.WithOpName("out"), p1, p2);
241 
242     GrapplerItem item;
243     item.fetch = {"out"};
244     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
245 
246     ConstantFolding optimizer(/*cpu_device=*/nullptr);
247     GraphDef got;
248     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
249     TF_EXPECT_OK(status);
250 
251     GraphDef want;
252     AddNode("in1", "VariableV2", {}, {}, &want);
253     AddNode("in2", "VariableV2", {}, {}, &want);
254     AddNode("paddings1", "Const", {}, {}, &want);
255     AddNode("paddings2", "Const", {}, {}, &want);
256     AddNode("c1", "Const", {}, {}, &want);
257     AddNode("c2", "Const", {}, {}, &want);
258     AddNode(
259         "p1", "Identity",
260         {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
261         {}, &want);
262     AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
263     AddNode("out", "Add", {"p1", "p2"}, {}, &want);
264 
265     CompareGraphs(want, got);
266 
267     auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
268     auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
269     auto tensors_expected =
270         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
271     EXPECT_EQ(1, tensors_expected.size());
272     auto tensors =
273         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
274     EXPECT_EQ(1, tensors.size());
275     test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
276   }
277 };
278 
TEST_F(ConstantFoldingTest,SimpleFolding)279 TEST_F(ConstantFoldingTest, SimpleFolding) {
280   // Build a simple graph with a few trivially prunable ops.
281   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
282 
283   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
284   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
285   Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
286   Output d = ops::AddN(s.WithOpName("d"), {b, c});
287 
288   GrapplerItem item;
289   item.fetch.push_back("d");
290   TF_CHECK_OK(s.ToGraphDef(&item.graph));
291 
292   ConstantFolding optimizer(/*cpu_device=*/nullptr);
293   GraphDef output;
294   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
295   TF_EXPECT_OK(status);
296 
297   EXPECT_EQ(1, output.node_size());
298 
299   const NodeDef& node_d = output.node(0);
300   EXPECT_EQ("d", node_d.name());
301   EXPECT_EQ("Const", node_d.op());
302 
303   std::vector<string> fetch = {"d"};
304   auto tensors_expected = EvaluateNodes(item.graph, fetch);
305   auto tensors = EvaluateNodes(output, fetch);
306   EXPECT_EQ(1, tensors_expected.size());
307   EXPECT_EQ(1, tensors.size());
308   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
309 }
310 
TEST_F(ConstantFoldingTest,AddTree)311 TEST_F(ConstantFoldingTest, AddTree) {
312   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
313 
314   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
315   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
316   Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
317   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
318                               ops::Placeholder::Shape(TensorShape({2, 2})));
319   Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
320   Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
321 
322   Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
323   Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
324   Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
325   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
326                               ops::Placeholder::Shape(TensorShape({2, 2})));
327   Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
328   Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
329   Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
330   Output addmul_parent =
331       ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
332 
333   GrapplerItem item;
334   item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
335   TF_CHECK_OK(s.ToGraphDef(&item.graph));
336 
337   ConstantFolding optimizer(/*cpu_device=*/nullptr);
338   GraphDef output;
339   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
340   TF_EXPECT_OK(status);
341 
342   // We expect the following rewrite(s) to occur:
343   //
344   //    +                +             +
345   //   / \              / \           / \
346   // 1.0  +     -->    x   +    -->  x  3.0
347   //     / \              / \
348   //   2.0  x           1.0 2.0
349   //
350   //    *                *             *
351   //   / \              / \           / \
352   // 4.0  *     -->    y   *    -->  y  20.0
353   //     / \              / \
354   //   5.0  y           4.0 5.0
355 
356   EXPECT_EQ(10, output.node_size());
357   for (const auto& node : output.node()) {
358     if (node.name() == "add_child") {
359       EXPECT_EQ("Const", node.op());
360       TensorProto t = node.attr().at("value").tensor();
361       ASSERT_EQ(1, t.tensor_shape().dim_size());
362       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
363     } else if (node.name() == "add_parent") {
364       EXPECT_EQ("Add", node.op());
365       ASSERT_EQ(2, node.input_size());
366       EXPECT_EQ("x", node.input(0));
367       EXPECT_EQ("add_child", node.input(1));
368     } else if (node.name() == "mul_child") {
369       EXPECT_EQ("Const", node.op());
370       TensorProto t = node.attr().at("value").tensor();
371       EXPECT_EQ(1, t.tensor_shape().dim_size());
372       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
373     } else if (node.name() == "mul_parent") {
374       EXPECT_EQ("Mul", node.op());
375       ASSERT_EQ(2, node.input_size());
376       EXPECT_EQ("y", node.input(0));
377       EXPECT_EQ("mul_child", node.input(1));
378     } else if (node.name() == "addmul_child") {
379       // Unchanged.
380       EXPECT_EQ("Add", node.op());
381       ASSERT_EQ(2, node.input_size());
382       EXPECT_EQ("c4", node.input(0));
383       EXPECT_EQ("x", node.input(1));
384     }
385   }
386 
387   // Check that the result nodes have the expected value.
388   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
389   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
390 
391   std::vector<string> fetch = {"add_parent", "mul_parent"};
392   auto tensor_expected =
393       EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
394   ASSERT_EQ(fetch.size(), tensor_expected.size());
395   fetch = {"add_parent", "mul_parent"};
396   auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
397   ASSERT_EQ(fetch.size(), tensors.size());
398   for (int i = 0; i < fetch.size(); i++) {
399     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
400   }
401 }
402 
TEST_F(ConstantFoldingTest,AddSubtactTree)403 TEST_F(ConstantFoldingTest, AddSubtactTree) {
404   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
405 
406   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
407   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
408                               ops::Placeholder::Shape(TensorShape({2, 2})));
409   Output sub_child = ops::Sub(s.WithOpName("sub_child"), x, x);
410   Output add_parent = ops::Add(s.WithOpName("add_parent"), sub_child, c1);
411 
412   GrapplerItem item;
413   item.fetch = {"add_parent"};
414   TF_CHECK_OK(s.ToGraphDef(&item.graph));
415 
416   ConstantFolding optimizer(/*cpu_device=*/nullptr);
417   GraphDef output;
418   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
419   TF_EXPECT_OK(status);
420 
421   // We expect the following rewrite(s) to occur:
422   //
423   //     +                +
424   //    / \              / \
425   //   -   1     -->    -   x
426   //  / \              / \
427   // x   x            1   x
428 
429   EXPECT_EQ(4, output.node_size());
430   for (const auto& node : output.node()) {
431     if (node.name() == "sub_child") {
432       EXPECT_EQ("Sub", node.op());
433       ASSERT_EQ(2, node.input_size());
434       EXPECT_EQ("c1", node.input(0));
435       EXPECT_EQ("x", node.input(1));
436     } else if (node.name() == "add_parent") {
437       EXPECT_EQ("Add", node.op());
438       ASSERT_EQ(2, node.input_size());
439       EXPECT_EQ("x", node.input(0));
440       EXPECT_EQ("sub_child", node.input(1));
441     }
442   }
443 
444   // Check that the result nodes have the expected value.
445   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
446 
447   std::vector<string> fetch = {"add_parent"};
448   auto tensor_expected = EvaluateNodes(item.graph, fetch, {{"x", x_t}});
449   ASSERT_EQ(fetch.size(), tensor_expected.size());
450   fetch = {"add_parent"};
451   auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
452   ASSERT_EQ(fetch.size(), tensors.size());
453   for (int i = 0; i < fetch.size(); i++) {
454     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
455   }
456 }
457 
TEST_F(ConstantFoldingTest,ConstantPushDown)458 TEST_F(ConstantFoldingTest, ConstantPushDown) {
459   for (int is_add : {true, false}) {
460     for (int is_parent_commutative : {true, false}) {
461       for (int is_child_commutative : {true, false}) {
462         for (int is_left_child_const : {true, false}) {
463           for (int is_left_leaf_const : {true, false}) {
464             tensorflow::Scope s = tensorflow::Scope::NewRootScope();
465             Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
466             Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
467             Output x =
468                 ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
469                                  ops::Placeholder::Shape(TensorShape({2, 2})));
470 
471             auto get_op = [&](bool is_commutative, bool is_left_arg_const,
472                               const string& name, const Output& const_arg,
473                               const Output non_const_arg) -> Output {
474               if (is_add) {
475                 if (is_commutative) {
476                   return ops::Add(
477                       s.WithOpName(name),
478                       is_left_arg_const ? const_arg : non_const_arg,
479                       is_left_arg_const ? non_const_arg : const_arg);
480                 } else {
481                   return ops::Sub(
482                       s.WithOpName(name),
483                       is_left_arg_const ? const_arg : non_const_arg,
484                       is_left_arg_const ? non_const_arg : const_arg);
485                 }
486               } else {
487                 if (is_commutative) {
488                   return ops::Mul(
489                       s.WithOpName(name),
490                       is_left_arg_const ? const_arg : non_const_arg,
491                       is_left_arg_const ? non_const_arg : const_arg);
492                 } else {
493                   return ops::Div(
494                       s.WithOpName(name),
495                       is_left_arg_const ? const_arg : non_const_arg,
496                       is_left_arg_const ? non_const_arg : const_arg);
497                 }
498               }
499             };
500 
501             Output child = get_op(is_child_commutative, is_left_leaf_const,
502                                   "child", c2, x);
503             Output parent = get_op(is_parent_commutative, is_left_child_const,
504                                    "parent", c3, child);
505             GrapplerItem item;
506             item.fetch = {"parent"};
507             TF_CHECK_OK(s.ToGraphDef(&item.graph));
508 
509             ConstantFolding optimizer(/*cpu_device=*/nullptr);
510             GraphDef output;
511             Status status =
512                 optimizer.Optimize(/*cluster=*/nullptr, item, &output);
513             TF_EXPECT_OK(status);
514 
515             // Check that the result nodes have the expected value.
516             auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
517             std::vector<string> fetch = {"parent"};
518             auto tensor_expected =
519                 EvaluateNodes(item.graph, fetch, {{"x", x_t}});
520             ASSERT_EQ(fetch.size(), tensor_expected.size());
521             fetch = {"parent"};
522             auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
523             ASSERT_EQ(fetch.size(), tensors.size());
524             for (int i = 0; i < fetch.size(); i++) {
525               test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
526             }
527           }
528         }
529       }
530     }
531   }
532 }
533 
TEST_F(ConstantFoldingTest,ConstantPushDownBiasAdd)534 TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
535   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
536   Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
537   Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
538   Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
539                                   ops::Placeholder::Shape(TensorShape({2, 2})));
540   Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
541                                   ops::Placeholder::Shape(TensorShape({2})));
542   // Rewrite expected for cases 1 through 3 and their symmetric equivalents,
543   // and case 4.
544   Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
545   Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
546   Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
547   Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
548 
549   Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
550   Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
551   Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
552   Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
553 
554   Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
555   Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
556   Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
557   Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
558 
559   Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
560   Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
561 
562   // No rewrite expected.
563   Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
564   Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
565   Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
566   Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
567   Output child7 = ops::Add(s.WithOpName("child7"), x_mat, c_vec);
568   Output parent7 = ops::BiasAdd(s.WithOpName("parent7"), child7, c_vec);
569 
570   GrapplerItem item;
571   item.fetch = {"parent1",  "parent2", "parent3", "parent1a", "parent2a",
572                 "parent3a", "parent4", "parent5", "parent6",  "parent7"};
573   TF_CHECK_OK(s.ToGraphDef(&item.graph));
574 
575   ConstantFolding optimizer(/*cpu_device=*/nullptr);
576   GraphDef output;
577   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
578   TF_EXPECT_OK(status);
579 
580   EXPECT_EQ(24, output.node_size());
581   for (const auto& node : output.node()) {
582     if (node.name() == "child1" || node.name() == "child1a" ||
583         node.name() == "child2" || node.name() == "child2a" ||
584         node.name() == "child3" || node.name() == "child3a" ||
585         node.name() == "child4") {
586       EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
587     } else if (node.name() != "c_mat" && node.name() != "c_vec") {
588       EXPECT_NE(node.op(), "Const") << " node: " << node.name();
589     }
590   }
591   // Check that the result nodes have the expected value.
592   auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
593   auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
594   std::vector<string> fetch = item.fetch;
595   auto tensor_expected = EvaluateNodes(
596       item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
597   ASSERT_EQ(fetch.size(), tensor_expected.size());
598   auto tensors =
599       EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
600   ASSERT_EQ(fetch.size(), tensors.size());
601   for (int i = 0; i < fetch.size(); i++) {
602     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
603   }
604 }
605 
606 // This test fails on ROCm platform (see commit message for details)
607 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)608 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
609   for (string data_format : {
610          "NHWC",
611 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
612              "NCHW"
613 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
614        }) {
615     MulConvPushDownTest(
616         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
617                                               : TensorShape{4, 3, 10, 10},
618         /*filter_shape=*/{2, 2, 3, 5},
619         /*mul_const_input_shape=*/{},
620         /*use_3d_conv=*/false,
621         /*padding=*/"VALID", data_format.c_str(),
622         /*expect_folded=*/true);
623   }
624 }
625 #endif
626 
627 // This test fails on ROCm platform (see commit message for details)
628 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)629 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
630   for (string data_format : {
631          "NHWC",
632 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633              "NCHW"
634 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635        }) {
636     for (auto mul_const_input_shape :
637          {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
638       MulConvPushDownTest(
639           /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
640                                                 : TensorShape{4, 3, 10, 10},
641           /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
642           /*use_3d_conv=*/false,
643           /*padding=*/"VALID", data_format.c_str(),
644           /*expect_folded=*/true);
645     }
646   }
647 }
648 #endif
649 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)650 TEST_F(ConstantFoldingTest,
651        MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
652   for (string data_format : {
653          "NHWC",
654 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
655              "NCHW"
656 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
657        }) {
658     MulConvPushDownTest(
659         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
660                                               : TensorShape{4, 3, 10, 10},
661         /*filter_shape=*/{2, 2, 3, 5},
662         /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
663         /*use_3d_conv=*/false,
664         /*padding=*/"VALID", data_format.c_str(),
665         /*expect_folded=*/false);
666   }
667 }
668 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)669 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
670   for (auto data_format : {
671          "NHWC",
672 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
673              "NCHW"
674 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
675        }) {
676     MulConvPushDownTest(
677         /*input_shape=*/{3, 3, 3, 3},
678         /*filter_shape=*/{3, 3, 3, 3},
679         /*mul_const_input_shape=*/{3, 1, 3},
680         /*use_3d_conv=*/false,
681         /*padding=*/"SAME", data_format,
682         /*expect_folded=*/false);
683   }
684 }
685 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)686 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
687   for (auto mul_const_input_shape :
688        {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
689     MulConvPushDownTest(
690         /*input_shape=*/{3, 3, 3, 3},
691         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
692         /*use_3d_conv=*/false,
693         /*padding=*/"SAME",
694         /*data_format=*/"NHWC",
695         /*expect_folded=*/true);
696   }
697 }
698 
699 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)700 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
701   for (auto mul_const_input_shape :
702        {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
703     MulConvPushDownTest(
704         /*input_shape=*/{3, 3, 3, 3},
705         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
706         /*use_3d_conv=*/false,
707         /*padding=*/"SAME",
708         /*data_format=*/"NCHW",
709         // TODO(laigd): optimization should happen in this case.
710         /*expect_folded=*/false);
711   }
712 }
713 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
714 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)715 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
716   for (auto data_format : {
717          "NHWC",
718 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
719              "NCHW"
720 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
721        }) {
722     MulConvPushDownTest(
723         /*input_shape=*/{3, 3, 3, 3},
724         /*filter_shape=*/{3, 3, 3, 3},
725         /*mul_const_input_shape=*/{3, 1},
726         /*use_3d_conv=*/false,
727         /*padding=*/"SAME", data_format,
728         /*expect_folded=*/false);
729   }
730 }
731 
732 // This test fails on ROCm platform (see commit message for details)
733 #ifndef TENSORFLOW_USE_ROCM
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)734 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
735   MulConvPushDownTest(
736       /*input_shape=*/{3, 3, 3, 3, 3},
737       /*filter_shape=*/{3, 3, 3, 3, 3},
738       /*mul_const_input_shape=*/{1, 1, 3},
739       /*use_3d_conv=*/true,
740       /*padding=*/"SAME",
741       /*data_format=*/"NDHWC",
742       /*expect_folded=*/true);
743 }
744 #endif
745 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)746 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
747   MulConvPushDownTest(
748       /*input_shape=*/{3, 3, 3, 3, 3},
749       /*filter_shape=*/{3, 3, 3, 3, 3},
750       /*mul_const_input_shape=*/{3, 1, 1, 1},
751       /*use_3d_conv=*/true,
752       /*padding=*/"SAME",
753       /*data_format=*/"NDHWC",
754       // TODO(laigd): optimization should happen in this case.
755       /*expect_folded=*/false);
756 }
757 
TEST_F(ConstantFoldingTest,NeutralElement)758 TEST_F(ConstantFoldingTest, NeutralElement) {
759   int kConst = 0;
760   int kLike = 1;
761   int kFill = 2;
762   for (int const_type : {kConst, kLike, kFill}) {
763     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
764     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
765                                 ops::Placeholder::Shape(TensorShape({2, 2})));
766     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
767                                 ops::Placeholder::Shape(TensorShape({2, 2})));
768     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
769                                 ops::Placeholder::Shape(TensorShape({3, 2})));
770     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
771                                 ops::Placeholder::Shape(TensorShape({2, 3})));
772     Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
773                                    ops::Placeholder::Shape(TensorShape({2})));
774     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
775     Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
776     Output zeros_const_bcast =
777         ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
778     Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
779     Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
780     Output zeros = const_type == kConst
781                        ? zeros_const
782                        : (const_type == kLike ? zeros_like : zeros_fill);
783     Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
784     Output ones_const_bcast =
785         ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
786     Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
787     Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
788     Output ones = const_type == kConst
789                       ? ones_const
790                       : (const_type == kLike ? ones_like : ones_fill);
791     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
792     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
793     Output mul1_bcast =
794         ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
795     Output mul2_bcast =
796         ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
797     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
798     Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
799     Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
800     Output mul6 = ops::MulNoNan(s.WithOpName("mul6"), zeros_1d, y);
801     Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
802     Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
803     Output floordiv = ops::FloorDiv(s.WithOpName("floordiv"), x, ones);
804     Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
805     Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
806     Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
807     Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
808     Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
809     Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
810     Output add1_bcast =
811         ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
812     Output add2_bcast =
813         ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
814     Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
815     Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
816     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
817     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
818     Output concat = ops::Stack(
819         s.WithOpName("stack"),
820         {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, floordiv, matmul1,
821          matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
822     GrapplerItem item;
823     TF_CHECK_OK(s.ToGraphDef(&item.graph));
824     item.fetch = {"stack",      "matmul3",    "matmul4",   "mul1_bcast",
825                   "mul2_bcast", "add1_bcast", "add2_bcast"};
826 
827     ConstantFolding optimizer(/*cpu_device=*/nullptr);
828     GraphDef output;
829     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
830     TF_EXPECT_OK(status);
831 
832     const string suffix =
833         (const_type == kConst ? "_const"
834                               : (const_type == kLike ? "_like" : "_fill"));
835     const string zeros_name = strings::StrCat("zeros", suffix);
836     const string ones_name = strings::StrCat("ones", suffix);
837     const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
838     const string ctrl_ones_name = strings::StrCat("^ones", suffix);
839 
840     EXPECT_EQ(const_type == kFill ? 43 : 39, output.node_size());
841     for (int i = 0; i < output.node_size(); ++i) {
842       const NodeDef& node = output.node(i);
843       const string& name = node.name();
844       if (name == "mul1") {
845         EXPECT_EQ("Const", node.op());
846         EXPECT_EQ("^x", node.input(0));
847         EXPECT_EQ(ctrl_zeros_name, node.input(1));
848       } else if (name == "mul2") {
849         EXPECT_EQ("Const", node.op());
850         EXPECT_EQ(ctrl_zeros_name, node.input(0));
851         EXPECT_EQ("^y", node.input(1));
852       } else if (name == "mul1_bcast") {
853         EXPECT_EQ("BroadcastTo", node.op());
854         EXPECT_EQ("x", node.input(0));
855         EXPECT_EQ("^ones_const_bcast", node.input(2));
856       } else if (name == "mul2_bcast") {
857         EXPECT_EQ("BroadcastTo", node.op());
858         EXPECT_EQ("y", node.input(0));
859         EXPECT_EQ("^ones_const_bcast", node.input(2));
860       } else if (name == "mul3") {
861         EXPECT_EQ("Identity", node.op());
862         EXPECT_EQ("x", node.input(0));
863         EXPECT_EQ(ctrl_ones_name, node.input(1));
864       } else if (name == "mul4") {
865         EXPECT_EQ("Identity", node.op());
866         EXPECT_EQ("y", node.input(0));
867         EXPECT_EQ(ctrl_ones_name, node.input(1));
868       } else if (name == "mul5") {
869         EXPECT_EQ("Const", node.op());
870         EXPECT_EQ("^x", node.input(0));
871         EXPECT_EQ("^zeros_1d", node.input(1));
872       } else if (name == "mul6") {
873         EXPECT_EQ("Const", node.op());
874         EXPECT_EQ("^zeros_1d", node.input(0));
875         EXPECT_EQ("^y", node.input(1));
876       } else if (name == "div1") {
877         EXPECT_EQ("Identity", node.op());
878         EXPECT_EQ("x", node.input(0));
879         EXPECT_EQ(ctrl_ones_name, node.input(1));
880       } else if (name == "div2") {
881         EXPECT_EQ("Reciprocal", node.op());
882         EXPECT_EQ("y", node.input(0));
883         EXPECT_EQ(ctrl_ones_name, node.input(1));
884       } else if (name == "floordiv") {
885         EXPECT_EQ("FloorDiv", node.op());
886         EXPECT_EQ("x", node.input(0));
887         EXPECT_EQ(ones_name, node.input(1));
888       } else if (name == "matmul1") {
889         EXPECT_EQ("Const", node.op());
890         EXPECT_EQ("^x", node.input(0));
891         EXPECT_EQ(ctrl_zeros_name, node.input(1));
892       } else if (name == "matmul2") {
893         EXPECT_EQ("Const", node.op());
894         EXPECT_EQ(ctrl_zeros_name, node.input(0));
895         EXPECT_EQ("^y", node.input(1));
896       } else if (name == "matmul3") {
897         EXPECT_EQ("Const", node.op());
898         EXPECT_EQ("^a", node.input(0));
899         EXPECT_EQ(ctrl_zeros_name, node.input(1));
900         TensorProto t = node.attr().at("value").tensor();
901         EXPECT_EQ(1, t.float_val_size());
902         EXPECT_EQ(0, t.float_val(0));
903         EXPECT_EQ(2, t.tensor_shape().dim_size());
904         EXPECT_EQ(3, t.tensor_shape().dim(0).size());
905         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
906       } else if (name == "matmul4") {
907         EXPECT_EQ("Const", node.op());
908         EXPECT_EQ(ctrl_zeros_name, node.input(0));
909         EXPECT_EQ("^b", node.input(1));
910         TensorProto t = node.attr().at("value").tensor();
911         EXPECT_EQ(1, t.float_val_size());
912         EXPECT_EQ(0, t.float_val(0));
913         EXPECT_EQ(2, t.tensor_shape().dim_size());
914         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
915         EXPECT_EQ(3, t.tensor_shape().dim(1).size());
916       } else if (name == "add1") {
917         EXPECT_EQ("Identity", node.op());
918         EXPECT_EQ("x", node.input(0));
919         EXPECT_EQ(ctrl_zeros_name, node.input(1));
920       } else if (name == "add2") {
921         EXPECT_EQ("Identity", node.op());
922         EXPECT_EQ("y", node.input(0));
923         EXPECT_EQ(ctrl_zeros_name, node.input(1));
924       } else if (name == "add1_bcast") {
925         EXPECT_EQ("BroadcastTo", node.op());
926         EXPECT_EQ("x", node.input(0));
927         EXPECT_EQ("^zeros_const_bcast", node.input(2));
928       } else if (name == "add2_bcast") {
929         EXPECT_EQ("BroadcastTo", node.op());
930         EXPECT_EQ("y", node.input(0));
931         EXPECT_EQ("^zeros_const_bcast", node.input(2));
932       } else if (name == "bias_add1") {
933         EXPECT_EQ("Identity", node.op());
934         EXPECT_EQ("x", node.input(0));
935         EXPECT_EQ("^zeros_1d", node.input(1));
936       } else if (name == "bias_add2") {
937         EXPECT_EQ("BroadcastTo", node.op());
938         EXPECT_EQ("bias", node.input(0));
939         EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
940                   node.input(1));
941         EXPECT_EQ(ctrl_zeros_name, node.input(2));
942       } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
943         EXPECT_EQ("Const", node.op());
944         EXPECT_EQ(ctrl_zeros_name, node.input(0));
945         EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
946         TensorProto t = node.attr().at("value").tensor();
947         EXPECT_EQ(DT_INT32, t.dtype());
948         EXPECT_EQ(1, t.tensor_shape().dim_size());
949         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
950       } else if (name == "sub1") {
951         EXPECT_EQ("Identity", node.op());
952         EXPECT_EQ("x", node.input(0));
953         EXPECT_EQ(ctrl_zeros_name, node.input(1));
954       } else if (name == "sub2") {
955         EXPECT_EQ("Neg", node.op());
956         EXPECT_EQ("y", node.input(0));
957         EXPECT_EQ(ctrl_zeros_name, node.input(1));
958       }
959       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
960                                                "mul6", "matmul1", "matmul2"};
961       if (square_zero_const.count(name) > 0) {
962         TensorProto t = node.attr().at("value").tensor();
963         EXPECT_EQ(1, t.float_val_size());
964         EXPECT_EQ(0, t.float_val(0));
965         EXPECT_EQ(2, t.tensor_shape().dim_size());
966         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
967         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
968       }
969     }
970     auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
971     auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
972     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
973     auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
974     auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
975 
976     auto tensors_expected = EvaluateNodes(
977         item.graph, item.fetch,
978         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
979     EXPECT_EQ(item.fetch.size(), tensors_expected.size());
980     auto tensors = EvaluateNodes(
981         output, item.fetch,
982         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
983     EXPECT_EQ(item.fetch.size(), tensors.size());
984     for (int i = 0; i < item.fetch.size(); ++i) {
985       test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
986     }
987   }
988 }
989 
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)990 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
991   SimpleNeutralElementTest<DT_BOOL>();
992   SimpleNeutralElementTest<DT_HALF>();
993   SimpleNeutralElementTest<DT_BFLOAT16>();
994 }
995 
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)996 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
997   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
998   Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
999   Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
1000                                ops::Placeholder::Shape(TensorShape({2, 2})));
1001   Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
1002                                ops::Placeholder::Shape(TensorShape({2, 2})));
1003   Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
1004   Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
1005   Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
1006   Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
1007   Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
1008 
1009   GrapplerItem item;
1010   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1011   item.fetch = {"div_f", "div_i", "realdiv"};
1012   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1013   GraphDef output;
1014   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1015   TF_EXPECT_OK(status);
1016 
1017   EXPECT_EQ(8, output.node_size());
1018   for (int i = 0; i < output.node_size(); ++i) {
1019     const NodeDef& node = output.node(i);
1020     const string& name = node.name();
1021     if (name == "div_i") {
1022       // Integer division is unchanged.
1023       EXPECT_EQ("Div", node.op());
1024       EXPECT_EQ("xi", node.input(0));
1025       EXPECT_EQ("ci", node.input(1));
1026     } else if (name == "div_f") {
1027       EXPECT_EQ("Mul", node.op());
1028       EXPECT_EQ("xf", node.input(0));
1029       EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
1030     } else if (name == "realdiv") {
1031       EXPECT_EQ("Mul", node.op());
1032       EXPECT_EQ("xf", node.input(0));
1033       EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
1034     } else if (name == "ConstantFolding/div_f_recip") {
1035       EXPECT_EQ("Const", node.op());
1036       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1037       TensorProto t = node.attr().at("value").tensor();
1038       EXPECT_EQ(DT_FLOAT, t.dtype());
1039       EXPECT_EQ(1, t.tensor_shape().dim_size());
1040       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1041     } else if (name == "ConstantFolding/realdiv_recip") {
1042       EXPECT_EQ("Const", node.op());
1043       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
1044       TensorProto t = node.attr().at("value").tensor();
1045       EXPECT_EQ(DT_FLOAT, t.dtype());
1046       EXPECT_EQ(1, t.tensor_shape().dim_size());
1047       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
1048     }
1049   }
1050 
1051   // Check that the reciprocals have the expected value.
1052   std::vector<string> fetch = {"cf_half"};
1053   auto tensor_expected = EvaluateNodes(item.graph, fetch);
1054   EXPECT_EQ(fetch.size(), tensor_expected.size());
1055   fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
1056   auto tensors = EvaluateNodes(output, fetch);
1057   EXPECT_EQ(fetch.size(), tensors.size());
1058   for (int i = 0; i < fetch.size(); i++) {
1059     test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
1060   }
1061 }
1062 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)1063 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
1064   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065   Output x_known =
1066       ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
1067                        ops::Placeholder::Shape(TensorShape({2, 2})));
1068   Output x_partially_known =
1069       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1070                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1071   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1072   Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
1073   Output zeros_partially_known =
1074       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1075   Output zeros_unknown =
1076       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1077 
1078   // Multiplies without any additional ops to supply the output shape.
1079   int count = 0;
1080   std::vector<Output> muls;
1081   std::unordered_set<string> not_converted;
1082   std::unordered_set<string> to_const;
1083   std::unordered_set<string> to_identity;
1084   for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
1085     for (const auto* zeros :
1086          {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
1087       const string name = strings::StrCat("mul_", count++);
1088       muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
1089       if (x == &x_partially_known && zeros == &zeros_partially_known) {
1090         to_identity.insert(name);
1091       } else if (x == &x_unknown || zeros == &zeros_unknown) {
1092         not_converted.insert(name);
1093       } else {
1094         to_const.insert(name);
1095       }
1096     }
1097   }
1098 
1099   GrapplerItem item;
1100   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1101 
1102   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1103   GraphDef output;
1104   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1105   TF_EXPECT_OK(status);
1106 
1107   EXPECT_EQ(15, output.node_size());
1108   for (int i = 0; i < output.node_size(); ++i) {
1109     const NodeDef& node = output.node(i);
1110     const string& name = node.name();
1111     if (to_const.count(name) > 0) {
1112       EXPECT_EQ("Const", node.op()) << node.name();
1113     } else if (to_identity.count(name) > 0) {
1114       EXPECT_EQ("Identity", node.op()) << node.name();
1115     } else if (not_converted.count(name) > 0) {
1116       EXPECT_EQ("Mul", node.op()) << node.name();
1117     }
1118   }
1119 
1120   const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
1121   auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1122   auto x_partially_unknown_t =
1123       GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1124   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1125   auto expected_tensors =
1126       EvaluateNodes(item.graph, fetch,
1127                     {{"x_known", x_known_t},
1128                      {"x_partially_unknown", x_partially_unknown_t},
1129                      {"x_unknown", x_unknown_t}});
1130   EXPECT_EQ(fetch.size(), expected_tensors.size());
1131   auto tensors = EvaluateNodes(output, fetch,
1132                                {{"x_known", x_known_t},
1133                                 {"x_partially_unknown", x_partially_unknown_t},
1134                                 {"x_unknown", x_unknown_t}});
1135   EXPECT_EQ(fetch.size(), tensors.size());
1136   for (int i = 0; i < tensors.size(); i++)
1137     test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
1138 }
1139 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)1140 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
1141   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1142   Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
1143   Output x_partially_known =
1144       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
1145                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1146   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
1147   Output zeros_partially_known =
1148       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
1149   Output zeros_unknown =
1150       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
1151 
1152   // If at least one of the inputs to AddN has a known shape, shape inference
1153   // will propagate the shape back to the inputs of AddN, making the
1154   // output shapes of all its inputs known
1155   std::vector<Output> muls_deduced_output_shape;
1156   std::unordered_set<string> to_const;
1157   int count = 0;
1158   for (const auto& x : {x_partially_known, x_unknown}) {
1159     for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
1160       const string name = strings::StrCat("mul_", count++);
1161       muls_deduced_output_shape.push_back(
1162           ops::Mul(s.WithOpName(name), x, zeros));
1163       to_const.insert(name);
1164     }
1165   }
1166   // We add a known shape as input to AddN to propagate it back to the
1167   // multiplies above, which means they can all be turned into Const nodes.
1168   muls_deduced_output_shape.push_back(known_shape);
1169   Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
1170 
1171   GrapplerItem item;
1172   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1173 
1174   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1175   GraphDef output;
1176   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1177   TF_EXPECT_OK(status);
1178 
1179   EXPECT_EQ(10, output.node_size());
1180   for (int i = 0; i < output.node_size(); ++i) {
1181     const NodeDef& node = output.node(i);
1182     const string& name = node.name();
1183     if (to_const.count(name) > 0) {
1184       EXPECT_EQ("Const", node.op()) << node.name();
1185       EXPECT_EQ(2, node.input_size());
1186       EXPECT_TRUE(IsControlInput(node.input(0)));
1187       EXPECT_TRUE(IsControlInput(node.input(1)));
1188     }
1189   }
1190   const std::vector<string> fetch = {"addn1"};
1191   auto x_partially_unknown_t =
1192       GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1193   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1194   auto expected_tensors =
1195       EvaluateNodes(item.graph, fetch,
1196                     {{"x_partially_unknown", x_partially_unknown_t},
1197                      {"x_unknown", x_unknown_t}});
1198   EXPECT_EQ(1, expected_tensors.size());
1199   auto tensors = EvaluateNodes(output, fetch,
1200                                {{"x_partially_unknown", x_partially_unknown_t},
1201                                 {"x_unknown", x_unknown_t}});
1202   EXPECT_EQ(1, tensors.size());
1203   test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
1204 }
1205 
TEST_F(ConstantFoldingTest,CreateConstNodes)1206 TEST_F(ConstantFoldingTest, CreateConstNodes) {
1207   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1208 
1209 #define MAKE_TEST_GRAPH(TYPE)                                               \
1210   Output TYPE##_const =                                                     \
1211       ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
1212   Output TYPE##_mul =                                                       \
1213       ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const);     \
1214   Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
1215 
1216   MAKE_TEST_GRAPH(float);
1217   MAKE_TEST_GRAPH(double);
1218   MAKE_TEST_GRAPH(int64_t);
1219   MAKE_TEST_GRAPH(int32);
1220   MAKE_TEST_GRAPH(int16);
1221   MAKE_TEST_GRAPH(int8);
1222   MAKE_TEST_GRAPH(uint8);
1223 #undef MAKE_TEST_GRAPH
1224 
1225   Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
1226   Output bool_and =
1227       ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
1228   Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
1229 
1230   GrapplerItem item;
1231   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1232   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1233   GraphDef output;
1234   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1235   TF_EXPECT_OK(status);
1236 
1237   EXPECT_EQ(24, output.node_size());
1238   for (const NodeDef& node : output.node()) {
1239 #define CHECK_RESULT(TYPE, FIELD)                                             \
1240   if (node.name() == #TYPE "_mul") {                                          \
1241     EXPECT_EQ(5,                                                              \
1242               node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
1243     EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size());        \
1244     EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0));      \
1245   }
1246 
1247     CHECK_RESULT(float, float);
1248     CHECK_RESULT(double, double);
1249     CHECK_RESULT(int64, int64);
1250     CHECK_RESULT(int32, int);
1251     CHECK_RESULT(int16, int);
1252     CHECK_RESULT(int8, int);
1253     CHECK_RESULT(uint8, int);
1254 #undef CHECK_RESULT
1255 
1256     if (node.name() == "bool_and") {
1257       EXPECT_EQ(5,
1258                 node.attr().at("value").tensor().tensor_shape().dim(0).size());
1259       EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
1260       EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
1261     }
1262   }
1263 }
1264 
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)1265 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
1266   // Build a simple graph with a few trivially prunable ops.
1267   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1268 
1269   Output a = ops::Const(s.WithOpName("a"), 10, {5});
1270   auto b = ops::Unique(s.WithOpName("b"), {a});
1271   Output c = ops::Identity(s.WithOpName("c"), {b.y});
1272   Output d = ops::Identity(s.WithOpName("d"), {b.idx});
1273   Output e = ops::Identity(s.WithOpName("e"), {c});
1274   Output f = ops::Identity(s.WithOpName("f"), {d});
1275 
1276   GrapplerItem item;
1277   item.fetch.push_back("e");
1278   item.fetch.push_back("f");
1279   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1280 
1281   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1282   GraphDef output;
1283   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1284   TF_EXPECT_OK(status);
1285 
1286   EXPECT_EQ(2, output.node_size());
1287 
1288   const NodeDef& new_c = output.node(0);
1289   EXPECT_EQ("e", new_c.name());
1290   EXPECT_EQ("Const", new_c.op());
1291 
1292   const NodeDef& new_d = output.node(1);
1293   EXPECT_EQ("f", new_d.name());
1294   EXPECT_EQ("Const", new_d.op());
1295 
1296   std::vector<string> fetch = {"e", "f"};
1297   auto tensors_expected = EvaluateNodes(item.graph, fetch);
1298   auto tensors = EvaluateNodes(output, fetch);
1299   EXPECT_EQ(fetch.size(), tensors_expected.size());
1300   EXPECT_EQ(fetch.size(), tensors.size());
1301   for (int i = 0; i < fetch.size(); i++) {
1302     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1303   }
1304 }
1305 
TEST_F(ConstantFoldingTest,ControlDependencies)1306 TEST_F(ConstantFoldingTest, ControlDependencies) {
1307   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1308   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1309   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1310   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1311   Output c =
1312       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1313   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1314   Output i2 =
1315       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1316   Output i3 = ops::Identity(scope.WithOpName("i3"), {i2});
1317 
1318   GrapplerItem item;
1319   item.fetch.push_back("i3");
1320   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1321 
1322   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1323   GraphDef output;
1324   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1325   TF_EXPECT_OK(status);
1326 
1327   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i3"};
1328   EXPECT_EQ(output.node_size(), expected_nodes.size());
1329   int i = 0;
1330   int found = 0;
1331   for (const auto& node : output.node()) {
1332     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1333     i++;
1334     if (node.name() == "i3") {
1335       EXPECT_EQ("Const", node.op());
1336       ++found;
1337       auto folded = EvaluateNodes(output, {"i3"});
1338       auto expected = EvaluateNodes(item.graph, {"i3"});
1339       EXPECT_EQ(1, expected.size());
1340       EXPECT_EQ(1, folded.size());
1341       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1342       EXPECT_EQ(2, node.input_size());
1343       EXPECT_EQ("^p1", node.input(0));
1344       EXPECT_EQ("^p2", node.input(1));
1345     }
1346   }
1347   EXPECT_EQ(1, found);
1348 }
1349 
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1350 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1351   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1352   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1353   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1354   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1355   Output c =
1356       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1357   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1358   Output i2 =
1359       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1360   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1361 
1362   GrapplerItem item;
1363   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1364 
1365   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1366   GraphDef output;
1367   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1368   TF_EXPECT_OK(status);
1369 
1370   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1371                                         "i1",   "i2", "e"};
1372   EXPECT_EQ(output.node_size(), expected_nodes.size());
1373   int i = 0;
1374   int found = 0;
1375   for (const auto& node : output.node()) {
1376     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1377     i++;
1378     if (node.name() == "i1") {
1379       EXPECT_EQ("Const", node.op());
1380       ++found;
1381       auto folded = EvaluateNodes(output, {"i1"});
1382       auto expected = EvaluateNodes(item.graph, {"i1"});
1383       EXPECT_EQ(1, expected.size());
1384       EXPECT_EQ(1, folded.size());
1385       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1386       EXPECT_EQ(1, node.input_size());
1387       EXPECT_EQ("^p1", node.input(0));
1388     }
1389     if (node.name() == "i2") {
1390       EXPECT_EQ("Const", node.op());
1391       ++found;
1392       auto folded = EvaluateNodes(output, {"i2"});
1393       auto expected = EvaluateNodes(item.graph, {"i2"});
1394       EXPECT_EQ(1, expected.size());
1395       EXPECT_EQ(1, folded.size());
1396       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1397       EXPECT_EQ(2, node.input_size());
1398       EXPECT_EQ("^p1", node.input(0));
1399       EXPECT_EQ("^p2", node.input(1));
1400     }
1401   }
1402   EXPECT_EQ(2, found);
1403 }
1404 
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1405 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1406   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1407   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1408   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1409   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1410   Output c =
1411       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1412   Output i1 = ops::Identity(scope.WithOpName("i1")
1413                                 .WithControlDependencies(p2)
1414                                 .WithControlDependencies(p1),
1415                             {c});
1416   Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1417 
1418   GrapplerItem item;
1419   item.fetch.push_back("i2");
1420   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1421   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1422   EXPECT_EQ(1, tensors_expected.size());
1423   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1424   GraphDef output;
1425   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1426   TF_EXPECT_OK(status);
1427 
1428   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1429   EXPECT_EQ(output.node_size(), expected_nodes.size());
1430   int i = 0;
1431   for (const auto& node : output.node()) {
1432     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1433     i++;
1434     if (node.name() == "i2") {
1435       EXPECT_EQ("Const", node.op());
1436       EXPECT_EQ(2, node.input_size());
1437       EXPECT_EQ("^p1", node.input(0));
1438       EXPECT_EQ("^p2", node.input(1));
1439     }
1440   }
1441   auto tensors = EvaluateNodes(output, item.fetch);
1442   EXPECT_EQ(1, tensors.size());
1443   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1444 }
1445 
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1446 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1447   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1448   // Add a DynamicPartition node to the graph
1449   Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1450   Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1451   int num_partitions = 4;
1452   ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1453                              num_partitions);
1454 
1455   std::vector<string> outputs;
1456   for (int i = 0; i < num_partitions; ++i) {
1457     string part_out_name = strings::StrCat("part_out", i);
1458     ops::Identity partition_out(scope.WithOpName(part_out_name),
1459                                 {part.outputs[i]});
1460     outputs.push_back(part_out_name);
1461   }
1462 
1463   GrapplerItem item;
1464   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1465 
1466   // Add a ConcatOffset node to the graph
1467   Tensor initial_val(DT_INT32, TensorShape({3}));
1468   test::FillIota<int>(&initial_val, 7);
1469   for (int i = 1; i < 5; ++i) {
1470     TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1471                     .Attr("dtype", DT_INT32)
1472                     .Attr("value", initial_val)
1473                     .Finalize(item.graph.add_node()));
1474   }
1475   Tensor concat_dim(DT_INT32, TensorShape({}));
1476   test::FillIota<int>(&concat_dim, 0);
1477   TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1478                   .Attr("dtype", DT_INT32)
1479                   .Attr("value", concat_dim)
1480                   .Finalize(item.graph.add_node()));
1481 
1482   TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1483                   .Input("concat_dim", 0, DT_INT32)
1484                   .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1485                           NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1486                           NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1487                           NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1488                   .Finalize(item.graph.add_node()));
1489 
1490   for (int i = 0; i < 4; ++i) {
1491     string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1492     TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1493                     .Attr("T", DT_INT32)
1494                     .Input("concat_offsets", i, DT_INT32)
1495                     .Finalize(item.graph.add_node()));
1496     outputs.push_back(concat_offset_out_name);
1497   }
1498 
1499   item.fetch = outputs;
1500   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1501   GraphDef output;
1502   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1503   TF_EXPECT_OK(status);
1504 
1505   int constant_folded = 0;
1506   for (const auto& node : output.node()) {
1507     if (node.name().find("part_out") != string::npos ||
1508         node.name().find("concat_offset_out") != string::npos) {
1509       ++constant_folded;
1510       EXPECT_EQ("Const", node.op());
1511     }
1512   }
1513   EXPECT_EQ(8, constant_folded);
1514 
1515   auto expected = EvaluateNodes(item.graph, outputs);
1516   auto optimized = EvaluateNodes(output, outputs);
1517   ASSERT_EQ(expected.size(), optimized.size());
1518   for (int i = 0; i < expected.size(); ++i) {
1519     test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1520   }
1521 }
1522 
TEST_F(ConstantFoldingTest,ShapeMaterialization)1523 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1524   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1525   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1526   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1527   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1528   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1529   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1530   Output size = ops::Size(scope.WithOpName("size"), v3);
1531   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1532   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1533 
1534   GrapplerItem item;
1535   item.fetch.push_back("p2");
1536   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1537 
1538   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1539   GraphDef output;
1540   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1541   TF_EXPECT_OK(status);
1542 
1543   int found = 0;
1544   for (const auto& node : output.node()) {
1545     if (node.name() == "p2") {
1546       ++found;
1547       EXPECT_EQ("Const", node.op());
1548       EXPECT_EQ(3, node.input_size());
1549       EXPECT_EQ("^v3", node.input(0));
1550       EXPECT_EQ("^v1", node.input(1));
1551       EXPECT_EQ("^v2", node.input(2));
1552       Tensor value;
1553       CHECK(value.FromProto(node.attr().at("value").tensor()));
1554       // rank = 1, shape = (5, 7), size = 143 = 11*13
1555       // p2 = (715, 1001) = (5*143, 7*143)
1556       EXPECT_EQ(715, value.flat<int>()(0));
1557       EXPECT_EQ(1001, value.flat<int>()(1));
1558     }
1559   }
1560   EXPECT_EQ(1, found);
1561   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1562   auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1563   auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1564 
1565   auto tensors_expected = EvaluateNodes(
1566       item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1567   EXPECT_EQ(1, item.fetch.size());
1568   auto tensors = EvaluateNodes(output, item.fetch,
1569                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1570   EXPECT_EQ(1, item.fetch.size());
1571   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1572 }
1573 
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1574 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1575   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1576   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1577   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1578   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1579   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1580   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1581   Output size = ops::Size(scope.WithOpName("size"), v3);
1582   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1583   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1584 
1585   GrapplerItem item;
1586   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1587 
1588   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1589   GraphDef output;
1590   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1591   TF_EXPECT_OK(status);
1592 
1593   int found = 0;
1594   for (const auto& node : output.node()) {
1595     if (node.name() == "size") {
1596       ++found;
1597       EXPECT_EQ("Const", node.op());
1598       EXPECT_EQ(1, node.input_size());
1599       EXPECT_EQ("^v3", node.input(0));
1600       Tensor value;
1601       CHECK(value.FromProto(node.attr().at("value").tensor()));
1602       EXPECT_EQ(11 * 13, value.flat<int>()(0));
1603     } else if (node.name() == "rank") {
1604       ++found;
1605       EXPECT_EQ("Const", node.op());
1606       EXPECT_EQ(1, node.input_size());
1607       EXPECT_EQ("^v1", node.input(0));
1608       Tensor value;
1609       CHECK(value.FromProto(node.attr().at("value").tensor()));
1610       EXPECT_EQ(1, value.flat<int>()(0));
1611     } else if (node.name() == "shape") {
1612       ++found;
1613       EXPECT_EQ("Const", node.op());
1614       EXPECT_EQ(1, node.input_size());
1615       EXPECT_EQ("^v2", node.input(0));
1616       Tensor value;
1617       CHECK(value.FromProto(node.attr().at("value").tensor()));
1618       EXPECT_EQ(5, value.flat<int>()(0));
1619       EXPECT_EQ(7, value.flat<int>()(1));
1620     }
1621   }
1622   EXPECT_EQ(3, found);
1623 
1624   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1625   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1626   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1627   std::vector<string> fetch_nodes = {"p2"};
1628   auto tensors_expected = EvaluateNodes(
1629       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1630   EXPECT_EQ(1, tensors_expected.size());
1631   auto tensors = EvaluateNodes(output, fetch_nodes,
1632                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1633   EXPECT_EQ(1, tensors.size());
1634   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1635 }
1636 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1637 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1638   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1639   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1640   Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1641   Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1642   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1643   Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1644   Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1645   Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1646   Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1647   Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1648   Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1649   Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1650 
1651   GrapplerItem item;
1652   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1653 
1654   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1655   GraphDef output;
1656   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1657   TF_EXPECT_OK(status);
1658   int found = 0;
1659   for (const auto& node : output.node()) {
1660     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1661               node.name());
1662     EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1663               node.name());
1664     if (node.name() == "i1a" || node.name() == "i1b") {
1665       ++found;
1666       EXPECT_EQ("s", node.input(0));
1667     }
1668     if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1669       ++found;
1670       EXPECT_EQ("s:1", node.input(0));
1671     }
1672     if (node.name() == "i3a" || node.name() == "i3b") {
1673       ++found;
1674       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1675                 node.input(0));
1676     }
1677     if (node.name() == "s") {
1678       ++found;
1679       EXPECT_EQ("ShapeN", node.op());
1680       EXPECT_EQ("v1", node.input(0));
1681       EXPECT_EQ("v2", node.input(1));
1682       EXPECT_EQ("v3", node.input(2));
1683     }
1684     if (node.name() ==
1685         AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1686       ++found;
1687       EXPECT_EQ("Const", node.op());
1688       EXPECT_EQ("^s", node.input(0));
1689       Tensor value;
1690       CHECK(value.FromProto(node.attr().at("value").tensor()));
1691       EXPECT_EQ(4, value.flat<int>()(0));
1692       EXPECT_EQ(6, value.flat<int>()(1));
1693     }
1694   }
1695   EXPECT_EQ(9, found);
1696 
1697   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1698   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1699   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1700   const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1701                                            "i2c", "i3a", "i3b"};
1702   auto tensors_expected = EvaluateNodes(
1703       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1704   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1705   auto tensors = EvaluateNodes(output, fetch_nodes,
1706                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1707   EXPECT_EQ(fetch_nodes.size(), tensors.size());
1708   for (int i = 0; i < fetch_nodes.size(); i++)
1709     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1710 }
1711 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1712 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1713   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1714   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1715   Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1716   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1717   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1718   Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1719   Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1720 
1721   GrapplerItem item;
1722   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1723   item.fetch.push_back("ia");
1724   item.fetch.push_back("ib");
1725 
1726   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1727   GraphDef output;
1728   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1729   TF_EXPECT_OK(status);
1730 
1731   int found = 0;
1732   for (const auto& node : output.node()) {
1733     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1734               node.name());
1735     if (node.name() == "s") {
1736       ++found;
1737       EXPECT_EQ("ShapeN", node.op());
1738       EXPECT_EQ("v1", node.input(0));
1739       EXPECT_EQ("v2", node.input(1));
1740     }
1741     if (node.name() == "id_n") {
1742       ++found;
1743       EXPECT_EQ("IdentityN", node.op());
1744       EXPECT_EQ("s", node.input(0));
1745       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1746                 node.input(1));
1747     }
1748     if (node.name() == "ia") {
1749       ++found;
1750       EXPECT_EQ("id_n", node.input(0));
1751     }
1752     if (node.name() == "ib") {
1753       ++found;
1754       EXPECT_EQ("Const", node.op());
1755       EXPECT_EQ("^s", node.input(0));
1756       EXPECT_EQ("^id_n", node.input(1));
1757     }
1758   }
1759   EXPECT_EQ(4, found);
1760 
1761   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1762   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1763   auto tensors_expected =
1764       EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1765   EXPECT_EQ(2, tensors_expected.size());
1766   auto tensors =
1767       EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1768   EXPECT_EQ(2, tensors.size());
1769   for (int i = 0; i < tensors.size(); i++)
1770     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1771 }
1772 
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1773 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1774   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1775   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1776   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1777   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1778   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1779   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1780   ops::Size size(scope.WithOpName("size"), i);
1781   ops::Square p1(scope.WithOpName("p1"), rank);
1782   ops::Square p2(scope.WithOpName("p2"), size);
1783   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1784 
1785   Output predicate =
1786       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1787   Output constant =
1788       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1789   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1790   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1791   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1792   ops::Merge m2(scope.WithOpName("m2"),
1793                 {statically_known.output, never_generated.output});
1794 
1795   GrapplerItem item;
1796   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1797 
1798   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1799   GraphDef output;
1800   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1801   TF_EXPECT_OK(status);
1802 
1803   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1804                                     "switch",   "i",
1805                                     "p1",       "p2",
1806                                     "m",        "false",
1807                                     "constant", "switch2",
1808                                     "i2",       "i3",
1809                                     "m2",       "ConstantFoldingCtrl/switch_0",
1810                                     "rank",     "size"};
1811   std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1812   EXPECT_EQ(present_nodes.size(), output.node_size());
1813   int found = 0;
1814   for (const auto& node : output.node()) {
1815     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1816         << node.name();
1817     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1818         << node.name();
1819     present_nodes.erase(node.name());
1820     not_present_nodes.erase(node.name());
1821     if (node.name() == "rank") {
1822       ++found;
1823       EXPECT_EQ("Const", node.op());
1824       EXPECT_EQ(1, node.input_size());
1825       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1826     }
1827     if (node.name() == "size") {
1828       ++found;
1829       EXPECT_EQ("Const", node.op());
1830       EXPECT_EQ(1, node.input_size());
1831       EXPECT_EQ("^i", node.input(0));
1832     }
1833     if (node.name() == "i2") {
1834       ++found;
1835       EXPECT_EQ("Const", node.op());
1836       EXPECT_EQ(0, node.input_size());
1837     }
1838     if (node.name() == "i3") {
1839       ++found;
1840       EXPECT_EQ("Identity", node.op());
1841       EXPECT_EQ(1, node.input_size());
1842       EXPECT_EQ("switch2:1", node.input(0));
1843     }
1844   }
1845   EXPECT_EQ(4, found);
1846 
1847   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1848   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1849 
1850   v_ctrl_t.flat<bool>()(0) = true;
1851   std::vector<string> fetch_nodes = {"m", "m2"};
1852   auto tensors_expected = EvaluateNodes(
1853       item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1854   EXPECT_EQ(2, tensors_expected.size());
1855   auto tensors = EvaluateNodes(output, fetch_nodes,
1856                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1857   EXPECT_EQ(2, tensors.size());
1858   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1859   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1860 
1861   v_ctrl_t.flat<bool>()(0) = false;
1862   tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1863                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1864   EXPECT_EQ(2, tensors_expected.size());
1865   tensors = EvaluateNodes(output, fetch_nodes,
1866                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1867   EXPECT_EQ(2, tensors.size());
1868   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1869   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1870 }
1871 
TEST_F(ConstantFoldingTest,SwitchNodes)1872 TEST_F(ConstantFoldingTest, SwitchNodes) {
1873   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1874   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1875   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1876   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1877   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1878   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1879   ops::Size size(scope.WithOpName("size"), i);
1880   ops::Square p1(scope.WithOpName("p1"), rank);
1881   ops::Square p2(scope.WithOpName("p2"), size);
1882   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1883 
1884   Output predicate =
1885       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1886   Output constant =
1887       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1888   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1889   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1890   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1891   ops::Merge m2(scope.WithOpName("m2"),
1892                 {statically_known.output, never_generated.output});
1893 
1894   GrapplerItem item;
1895   item.fetch.push_back("m");
1896   item.fetch.push_back("m2");
1897 
1898   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1899 
1900   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1901   GraphDef output;
1902   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1903   TF_EXPECT_OK(status);
1904   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1905                                     "switch",   "i",
1906                                     "p1",       "p2",
1907                                     "m",        "false",
1908                                     "constant", "switch2",
1909                                     "i2",       "i3",
1910                                     "m2",       "ConstantFoldingCtrl/switch_0"};
1911   std::set<string> not_present_nodes = {"rank", "size",
1912                                         "ConstantFolding/switch2-0"};
1913   EXPECT_EQ(present_nodes.size(), output.node_size());
1914 
1915   int found = 0;
1916   for (const auto& node : output.node()) {
1917     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1918     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1919     present_nodes.erase(node.name());
1920     not_present_nodes.erase(node.name());
1921     if (node.name() == "i2") {
1922       ++found;
1923       EXPECT_EQ("Const", node.op());
1924       EXPECT_EQ(0, node.input_size());
1925     }
1926     if (node.name() == "i3") {
1927       ++found;
1928       EXPECT_EQ("Identity", node.op());
1929       EXPECT_EQ(1, node.input_size());
1930       EXPECT_EQ("switch2:1", node.input(0));
1931     }
1932   }
1933   EXPECT_EQ(2, found);
1934 
1935   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1936   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1937   v_ctrl_t.flat<bool>()(0) = true;
1938   auto tensors_expected = EvaluateNodes(
1939       item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1940   EXPECT_EQ(2, tensors_expected.size());
1941   auto tensors = EvaluateNodes(output, item.fetch,
1942                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1943   EXPECT_EQ(2, tensors.size());
1944   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1945   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1946 
1947   v_ctrl_t.flat<bool>()(0) = false;
1948   tensors_expected = EvaluateNodes(item.graph, item.fetch,
1949                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1950   EXPECT_EQ(2, tensors_expected.size());
1951   tensors = EvaluateNodes(output, item.fetch,
1952                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1953   EXPECT_EQ(2, tensors.size());
1954   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1955   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1956 }
1957 
TEST_F(ConstantFoldingTest,MergeNodes)1958 TEST_F(ConstantFoldingTest, MergeNodes) {
1959   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1960 
1961   Output x =
1962       ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1963   Output y =
1964       ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1965   Output const1 =
1966       ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1967                  TensorShape({3, 5}));
1968   Output const2 =
1969       ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1970   Output const3 =
1971       ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1972                  TensorShape({3, 5}));
1973 
1974   // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1975   ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1976   ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1977   ops::Merge m3(scope.WithOpName("m3"), {x, y});
1978   // m4 is not foldable because the only constant input
1979   // has a control input, so we cannot know if it will be
1980   // triggered.
1981   ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1982 
1983   ops::Identity out1(scope.WithOpName("out1"), m1.output);
1984   ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1985   ops::Identity out2(scope.WithOpName("out2"), m2.output);
1986   ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1987   ops::Identity out3(scope.WithOpName("out3"), m3.output);
1988   ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1989   ops::Identity out4(scope.WithOpName("out4"), m4.output);
1990   ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1991 
1992   GrapplerItem item;
1993   item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1994   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1995 
1996   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1997   GraphDef output;
1998   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1999   TF_EXPECT_OK(status);
2000 
2001   EXPECT_EQ(19, output.node_size());
2002   int found_nodes = 0;
2003   for (const auto& node : output.node()) {
2004     if (node.name() == "out1") {
2005       EXPECT_EQ(1, node.input_size());
2006       EXPECT_EQ("^m1", node.input(0));
2007       ++found_nodes;
2008     } else if (node.name() == "idx1") {
2009       EXPECT_EQ(1, node.input_size());
2010       EXPECT_EQ("^m1", node.input(0));
2011       ++found_nodes;
2012     } else if (node.name() == "ConstantFolding/m1") {
2013       EXPECT_EQ("Const", node.op());
2014       EXPECT_EQ(1, node.input_size());
2015       EXPECT_EQ("^m1", node.input(0));
2016       ++found_nodes;
2017     } else if (node.name() == "ConstantFolding/m1_index") {
2018       EXPECT_EQ("Const", node.op());
2019       EXPECT_EQ(1, node.input_size());
2020       EXPECT_EQ("^m1", node.input(0));
2021       ++found_nodes;
2022     } else if (node.name() == "out2") {
2023       EXPECT_EQ(1, node.input_size());
2024       EXPECT_EQ("m2", node.input(0));
2025       ++found_nodes;
2026     } else if (node.name() == "idx2") {
2027       EXPECT_EQ(1, node.input_size());
2028       EXPECT_EQ("m2:1", node.input(0));
2029       ++found_nodes;
2030     } else if (node.name() == "out3") {
2031       EXPECT_EQ(1, node.input_size());
2032       EXPECT_EQ("m3", node.input(0));
2033       ++found_nodes;
2034     } else if (node.name() == "idx3") {
2035       EXPECT_EQ(1, node.input_size());
2036       EXPECT_EQ("m3:1", node.input(0));
2037       ++found_nodes;
2038     } else if (node.name() == "out4") {
2039       EXPECT_EQ(1, node.input_size());
2040       EXPECT_EQ("m4", node.input(0));
2041       ++found_nodes;
2042     } else if (node.name() == "idx4") {
2043       EXPECT_EQ(1, node.input_size());
2044       EXPECT_EQ("m4:1", node.input(0));
2045       ++found_nodes;
2046     }
2047   }
2048   // Make sure the graph contains all the nodes we're expecting.
2049   EXPECT_EQ(8, found_nodes);
2050 
2051   std::vector<string> fetch = {"out1", "idx1"};
2052   auto tensors = EvaluateNodes(output, fetch);
2053   EXPECT_EQ(2, tensors.size());
2054   const Tensor& out_value = tensors[0];
2055   EXPECT_EQ(3 * 5, out_value.NumElements());
2056   for (int i = 0; i < 3 * 5; ++i) {
2057     EXPECT_EQ(3.14f, out_value.flat<float>()(i));
2058   }
2059   const Tensor& out_idx = tensors[1];
2060   EXPECT_EQ(1, out_idx.NumElements());
2061   EXPECT_EQ(2, out_idx.flat<int32>()(0));
2062 }
2063 
TEST_F(ConstantFoldingTest,SplitRemoval)2064 TEST_F(ConstantFoldingTest, SplitRemoval) {
2065   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2066 
2067   Output in1 =
2068       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2069   Output in2 =
2070       ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
2071   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2072   ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
2073   ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
2074 
2075   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2076 
2077   GrapplerItem item;
2078   item.fetch = {"out"};
2079   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2080 
2081   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2082   GraphDef got;
2083   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2084   TF_EXPECT_OK(status);
2085 
2086   GraphDef want;
2087   AddNode("in1", "VariableV2", {}, {}, &want);
2088   AddNode("in2", "VariableV2", {}, {}, &want);
2089   AddNode("split_dim", "Const", {}, {}, &want);
2090   AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
2091           &want);
2092   AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
2093   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2094 
2095   CompareGraphs(want, got);
2096 
2097   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2098   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
2099   auto tensors_expected =
2100       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2101   EXPECT_EQ(1, tensors_expected.size());
2102   auto tensors =
2103       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2104   EXPECT_EQ(1, tensors.size());
2105   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2106 }
2107 
TEST_F(ConstantFoldingTest,SplitVRemoval)2108 TEST_F(ConstantFoldingTest, SplitVRemoval) {
2109   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2110 
2111   Output in1 =
2112       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
2113   Output in2 =
2114       ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
2115   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
2116   auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
2117   auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
2118   ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
2119   ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
2120 
2121   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
2122 
2123   GrapplerItem item;
2124   item.fetch = {"out"};
2125   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2126 
2127   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2128   GraphDef got;
2129   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2130   TF_EXPECT_OK(status);
2131 
2132   GraphDef want;
2133   AddNode("in1", "VariableV2", {}, {}, &want);
2134   AddNode("in2", "VariableV2", {}, {}, &want);
2135   AddNode("split_dim", "Const", {}, {}, &want);
2136   AddNode("size_splits1", "Const", {}, {}, &want);
2137   AddNode("size_splits2", "Const", {}, {}, &want);
2138   AddNode("s1", "Identity",
2139           {"in1", AsControlDependency("size_splits1"),
2140            AsControlDependency("split_dim")},
2141           {}, &want);
2142   AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
2143   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2144 
2145   CompareGraphs(want, got);
2146 
2147   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
2148   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
2149   auto tensors_expected =
2150       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2151   EXPECT_EQ(1, tensors_expected.size());
2152   auto tensors =
2153       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2154   EXPECT_EQ(1, tensors.size());
2155   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2156 }
2157 
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)2158 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
2159   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2160 
2161   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2162                              DT_FLOAT);
2163   Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
2164   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
2165                              DT_FLOAT);
2166   Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
2167   ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
2168   ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
2169                     p2);
2170 
2171   ops::Add out1(scope.WithOpName("out1"), t1, t2);
2172 
2173   GrapplerItem item;
2174   item.fetch = {"out1"};
2175   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2176 
2177   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2178   GraphDef got;
2179   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2180   TF_EXPECT_OK(status);
2181 
2182   GraphDef want;
2183   AddNode("in1", "VariableV2", {}, {}, &want);
2184   AddNode("in2", "VariableV2", {}, {}, &want);
2185   AddNode("p1", "Const", {}, {}, &want);
2186   AddNode("p2", "Const", {}, {}, &want);
2187   AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
2188   AddNode("t2", "Identity",
2189           {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
2190           &want);
2191   AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
2192 
2193   CompareGraphs(want, got);
2194 }
2195 
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)2196 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
2197   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2198 
2199   Output in1 =
2200       ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
2201   Output in2 =
2202       ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
2203   ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
2204   ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
2205                         in2);
2206 
2207   ops::Add out1(scope.WithOpName("out1"), s1, s2);
2208   ops::Identity out2(scope.WithOpName("out2"), s2);
2209 
2210   GrapplerItem item;
2211   item.fetch = {"out1", "out2"};
2212   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2213 
2214   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2215   GraphDef got;
2216   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2217   TF_EXPECT_OK(status);
2218 
2219   GraphDef want;
2220   AddNode("in1", "VariableV2", {}, {}, &want);
2221   AddNode("in2", "VariableV2", {}, {}, &want);
2222   AddNode("s1", "Identity", {"in1"}, {}, &want);
2223   AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
2224   AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
2225   AddNode("out2", "Identity", {"s2"}, {}, &want);
2226 
2227   CompareGraphs(want, got);
2228 
2229   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2230   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
2231   auto tensors_expected =
2232       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2233   EXPECT_EQ(2, tensors_expected.size());
2234   auto tensors =
2235       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2236   EXPECT_EQ(2, tensors.size());
2237   for (int i = 0; i < tensors.size(); i++)
2238     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2239 }
2240 
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)2241 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
2242   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2243 
2244   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
2245                              DT_FLOAT);
2246   Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
2247   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
2248                              DT_FLOAT);
2249   Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
2250   ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
2251   ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
2252                   a2);
2253 
2254   ops::Add out1(scope.WithOpName("out1"), r1, r2);
2255 
2256   GrapplerItem item;
2257   item.fetch = {"out1"};
2258   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2259 
2260   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2261   GraphDef got;
2262   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2263   TF_EXPECT_OK(status);
2264 
2265   GraphDef want;
2266   AddNode("in1", "VariableV2", {}, {}, &want);
2267   AddNode("in2", "VariableV2", {}, {}, &want);
2268   AddNode("a1", "Const", {}, {}, &want);
2269   AddNode("a2", "Const", {}, {}, &want);
2270   AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
2271   AddNode("r2", "Identity",
2272           {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
2273           &want);
2274   AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
2275 
2276   CompareGraphs(want, got);
2277 }
2278 
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)2279 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
2280   {  // size = {3, 5}
2281     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2282 
2283     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
2284     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2285     auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
2286     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2287     ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
2288     ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
2289 
2290     ops::Add out(scope.WithOpName("out"), s1, s2);
2291 
2292     GrapplerItem item;
2293     item.fetch = {"out"};
2294     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2295 
2296     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2297     GraphDef got;
2298     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2299     TF_EXPECT_OK(status);
2300 
2301     GraphDef want;
2302     AddNode("in1", "VariableV2", {}, {}, &want);
2303     AddNode("in2", "VariableV2", {}, {}, &want);
2304     AddNode("begin", "Const", {}, {}, &want);
2305     AddNode("size", "Const", {}, {}, &want);
2306     AddNode("s1", "Identity",
2307             {"in1", AsControlDependency("begin"), AsControlDependency("size")},
2308             {}, &want);
2309     AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
2310     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2311 
2312     CompareGraphs(want, got);
2313 
2314     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2315     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2316     auto tensors_expected =
2317         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2318     EXPECT_EQ(1, tensors_expected.size());
2319     auto tensors =
2320         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2321     EXPECT_EQ(1, tensors.size());
2322     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2323   }
2324   {  // size = {-1, -1}
2325     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2326 
2327     auto in1 =
2328         ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2329     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2330     auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2331     auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2332     Output in2 =
2333         ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2334     ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2335     ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2336 
2337     ops::Add out(scope.WithOpName("out"), s1, s2);
2338 
2339     GrapplerItem item;
2340     item.fetch = {"out"};
2341     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2342 
2343     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2344     GraphDef got;
2345     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2346     TF_EXPECT_OK(status);
2347 
2348     GraphDef want;
2349     AddNode("in1", "VariableV2", {}, {}, &want);
2350     AddNode("in2", "VariableV2", {}, {}, &want);
2351     AddNode("begin1", "Const", {}, {}, &want);
2352     AddNode("begin2", "Const", {}, {}, &want);
2353     AddNode("size", "Const", {}, {}, &want);
2354     AddNode("s1", "Identity",
2355             {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2356             {}, &want);
2357     AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2358     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2359 
2360     CompareGraphs(want, got);
2361 
2362     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2363     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2364     auto tensors_expected =
2365         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2366     EXPECT_EQ(1, tensors_expected.size());
2367     auto tensors =
2368         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2369     EXPECT_EQ(1, tensors.size());
2370     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2371   }
2372 }
2373 
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2374 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2375   {  // no mask
2376     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2377 
2378     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2379     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2380     auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2381     auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2382     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2383     ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2384     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2385 
2386     ops::Add out(scope.WithOpName("out"), s1, s2);
2387 
2388     GrapplerItem item;
2389     item.fetch = {"out"};
2390     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2391 
2392     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2393     GraphDef got;
2394     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2395     TF_EXPECT_OK(status);
2396 
2397     GraphDef want;
2398     AddNode("in1", "VariableV2", {}, {}, &want);
2399     AddNode("in2", "VariableV2", {}, {}, &want);
2400     AddNode("begin", "Const", {}, {}, &want);
2401     AddNode("end", "Const", {}, {}, &want);
2402     AddNode("strides", "Const", {}, {}, &want);
2403     AddNode("s1", "Identity",
2404             {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2405              AsControlDependency("strides")},
2406             {}, &want);
2407     AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2408             &want);
2409     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2410 
2411     CompareGraphs(want, got);
2412 
2413     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2414     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2415     auto tensors_expected =
2416         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2417     EXPECT_EQ(1, tensors_expected.size());
2418     auto tensors =
2419         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2420     EXPECT_EQ(1, tensors.size());
2421     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2422   }
2423   {  // with begin/end/ellipsis mask
2424     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2425 
2426     // s1 = in1[:, ..., 0:5, 0:6]
2427     auto in1 =
2428         ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2429     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2430     auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2431     auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2432     ops::StridedSlice s1(
2433         scope.WithOpName("s1"), in1, begin1, end1, strides1,
2434         ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2435 
2436     Output in2 =
2437         ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2438     auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2439     auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2440     auto strides2 =
2441         ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2442     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2443 
2444     ops::Add out(scope.WithOpName("out"), s1, s2);
2445 
2446     GrapplerItem item;
2447     item.fetch = {"out"};
2448     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2449 
2450     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2451     GraphDef got;
2452     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2453     TF_EXPECT_OK(status);
2454 
2455     GraphDef want;
2456     AddNode("in1", "VariableV2", {}, {}, &want);
2457     AddNode("in2", "VariableV2", {}, {}, &want);
2458     AddNode("begin1", "Const", {}, {}, &want);
2459     AddNode("end1", "Const", {}, {}, &want);
2460     AddNode("strides1", "Const", {}, {}, &want);
2461     AddNode("s1", "Identity",
2462             {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2463              AsControlDependency("strides1")},
2464             {}, &want);
2465     AddNode("begin2", "Const", {}, {}, &want);
2466     AddNode("end2", "Const", {}, {}, &want);
2467     AddNode("strides2", "Const", {}, {}, &want);
2468     AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2469             &want);
2470     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2471 
2472     CompareGraphs(want, got);
2473 
2474     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2475     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2476     auto tensors_expected =
2477         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2478     EXPECT_EQ(1, tensors_expected.size());
2479     auto tensors =
2480         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2481     EXPECT_EQ(1, tensors.size());
2482     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2483   }
2484 }
2485 
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2486 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2487   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2488 
2489   auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2490   auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2491   auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2492   auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2493 
2494   ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2495   ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2496 
2497   ops::Add out(scope.WithOpName("out"), t1, t2);
2498 
2499   GrapplerItem item;
2500   item.fetch = {"out"};
2501   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2502 
2503   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2504   GraphDef got;
2505   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2506   TF_EXPECT_OK(status);
2507 
2508   GraphDef want;
2509   AddNode("in1", "VariableV2", {}, {}, &want);
2510   AddNode("in2", "VariableV2", {}, {}, &want);
2511   AddNode("multiplies1", "Const", {}, {}, &want);
2512   AddNode("multiplies2", "Const", {}, {}, &want);
2513   AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2514           &want);
2515   AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2516   AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2517 
2518   CompareGraphs(want, got);
2519 }
2520 
TEST_F(ConstantFoldingTest,MergeConcat)2521 TEST_F(ConstantFoldingTest, MergeConcat) {
2522   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2523 
2524   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2525   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2526   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2527   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2528 
2529   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2530   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2531 
2532   GrapplerItem item;
2533   item.fetch = {"c2"};
2534   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2535 
2536   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2537   GraphDef got;
2538   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2539   TF_EXPECT_OK(status);
2540 
2541   GraphDef want;
2542   AddNode("in1", "VariableV2", {}, {}, &want);
2543   AddNode("in2", "VariableV2", {}, {}, &want);
2544   AddNode("in3", "VariableV2", {}, {}, &want);
2545   AddNode("axis", "Const", {}, {}, &want);
2546   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2547 
2548   CompareGraphs(want, got);
2549 }
2550 
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2551 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2552   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2553 
2554   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2555   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2556   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2557   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2558 
2559   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2560   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2561 
2562   GrapplerItem item;
2563   item.fetch = {"c2"};
2564   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2565 
2566   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2567   GraphDef got;
2568   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2569   TF_EXPECT_OK(status);
2570 
2571   GraphDef want;
2572   AddNode("in1", "VariableV2", {}, {}, &want);
2573   AddNode("in2", "VariableV2", {}, {}, &want);
2574   AddNode("in3", "VariableV2", {}, {}, &want);
2575   AddNode("axis", "Const", {}, {}, &want);
2576   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2577           &want);
2578 
2579   CompareGraphs(want, got);
2580 }
2581 
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2582 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2583   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2584 
2585   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2586   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2587   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2588   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2589 
2590   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2591   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2592 
2593   GrapplerItem item;
2594   item.fetch = {"c2"};
2595   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2596 
2597   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2598   GraphDef got;
2599   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2600   TF_EXPECT_OK(status);
2601 
2602   GraphDef want;
2603   AddNode("in1", "VariableV2", {}, {}, &want);
2604   AddNode("in2", "VariableV2", {}, {}, &want);
2605   AddNode("in3", "VariableV2", {}, {}, &want);
2606   AddNode("axis", "Const", {}, {}, &want);
2607   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2608 
2609   CompareGraphs(want, got);
2610 }
2611 
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2612 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2613   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2614 
2615   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2616   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2617   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2618   Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2619   Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2620 
2621   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2622   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2623 
2624   GrapplerItem item;
2625   item.fetch = {"c2"};
2626   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2627 
2628   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2629   GraphDef got;
2630   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2631   TF_EXPECT_OK(status);
2632 
2633   GraphDef want;
2634   AddNode("in1", "VariableV2", {}, {}, &want);
2635   AddNode("in2", "VariableV2", {}, {}, &want);
2636   AddNode("in3", "VariableV2", {}, {}, &want);
2637   AddNode("axis1", "Const", {}, {}, &want);
2638   AddNode("axis2", "Const", {}, {}, &want);
2639   AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2640   AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2641 
2642   CompareGraphs(want, got);
2643 }
2644 
TEST_F(ConstantFoldingTest,MergeConcat_PartialFolding)2645 TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
2646   Scope scope = Scope::NewRootScope();
2647   Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
2648   Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
2649   Output c3 = ops::Const(scope.WithOpName("c3"), 3.0f, {2, 2});
2650   Output c4 = ops::Const(scope.WithOpName("c4"), 4.0f, {2, 2});
2651   Output ph = ops::Placeholder(scope.WithOpName("ph"), DT_FLOAT,
2652                                ops::Placeholder::Shape(TensorShape({2, 2})));
2653   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2654 
2655   ops::Concat concat1(scope.WithOpName("concat1"), {c1, c2, ph}, axis);
2656   ops::Concat concat2(scope.WithOpName("concat2"), {c3, c4, Output(concat1)},
2657                       axis);
2658 
2659   GrapplerItem item;
2660   item.fetch = {"concat2"};
2661   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2662 
2663   ConstantFolding optimizer(nullptr);
2664   GraphDef got;
2665   Status status = optimizer.Optimize(nullptr, item, &got);
2666   TF_EXPECT_OK(status);
2667 
2668   GraphDef want;
2669   AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
2670   AddNode("axis", "Const", {}, {}, &want);
2671   AddNode("ph", "Placeholder", {}, {}, &want);
2672   AddNode("concat2", "ConcatV2",
2673           {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
2674 
2675   CompareGraphs(want, got);
2676 }
2677 
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2678 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2679   PaddingWithZeroSize<int32>();
2680   PaddingWithZeroSize<int64_t>();
2681 }
2682 
TEST_F(ConstantFoldingTest,SqueezeWithAllDimensionsGreaterThanOne)2683 TEST_F(ConstantFoldingTest, SqueezeWithAllDimensionsGreaterThanOne) {
2684   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2685 
2686   auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2687   auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2688 
2689   ops::Squeeze s1(scope.WithOpName("s1"), in1);
2690   ops::Squeeze s2(scope.WithOpName("s2"), in2);
2691 
2692   ops::Add out(scope.WithOpName("out"), s1, s2);
2693 
2694   GrapplerItem item;
2695   item.fetch = {"out"};
2696   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2697 
2698   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2699   GraphDef got;
2700   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2701   TF_EXPECT_OK(status);
2702 
2703   GraphDef want;
2704   AddNode("in1", "VariableV2", {}, {}, &want);
2705   AddNode("in2", "VariableV2", {}, {}, &want);
2706   AddNode("s1", "Identity", {"in1"}, {}, &want);
2707   AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2708   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2709 
2710   CompareGraphs(want, got);
2711 
2712   auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2713   auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2714   auto tensors_expected =
2715       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2716   EXPECT_EQ(1, tensors_expected.size());
2717   auto tensors =
2718       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2719   EXPECT_EQ(1, tensors.size());
2720   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2721 }
2722 
TEST_F(ConstantFoldingTest,NoOpReduction)2723 TEST_F(ConstantFoldingTest, NoOpReduction) {
2724   // Build a simple graph with reductions that can be reduced to the
2725   // identity.
2726   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2727 
2728   Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2729   Output c =
2730       ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2731   Output i = ops::Identity(scope.WithOpName("i"), c);
2732   Output p = ops::Prod(scope.WithOpName("p"), v, i);
2733   Output s = ops::Square(scope.WithOpName("s"), p);
2734 
2735   Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2736   Output c2 =
2737       ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2738   ops::Prod::Attrs attr;
2739   attr = attr.KeepDims(true);
2740   Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2741 
2742   // Test with unknown input shape.
2743   Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
2744   Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
2745 
2746   GrapplerItem item;
2747   item.fetch = {"s", "p2", "p3"};
2748   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2749 
2750   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2751   GraphDef output;
2752   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2753   TF_EXPECT_OK(status);
2754 
2755   int found = 0;
2756   for (const auto& node : output.node()) {
2757     if (node.name() == "p") {
2758       found++;
2759       EXPECT_EQ("Identity", node.op());
2760       EXPECT_EQ(2, node.input_size());
2761       EXPECT_EQ("v", node.input(0));
2762       EXPECT_EQ("^i", node.input(1));
2763     } else if (node.name() == "p2") {
2764       found++;
2765       EXPECT_EQ("Identity", node.op());
2766       EXPECT_EQ(2, node.input_size());
2767       EXPECT_EQ("v2", node.input(0));
2768       EXPECT_EQ("^c2", node.input(1));
2769     } else if (node.name() == "p3") {
2770       found++;
2771       EXPECT_EQ("Identity", node.op());
2772       EXPECT_EQ(2, node.input_size());
2773       EXPECT_EQ("a", node.input(0));
2774       EXPECT_EQ("^i", node.input(1));
2775     }
2776   }
2777   EXPECT_EQ(3, found);
2778 
2779   auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2780   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2781   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2782   auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
2783                                         {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2784   EXPECT_EQ(3, tensors_expected.size());
2785   auto tensors =
2786       EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
2787   EXPECT_EQ(3, tensors.size());
2788   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2789   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2790   test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
2791 }
2792 
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2793 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2794   // Build a simple graph with reductions that involve single-element input and
2795   // no axes to reduce along.
2796   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2797 
2798   Output input_var_three_dim = ops::Variable(
2799       scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2800   Output input_var_one_dim =
2801       ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2802   Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2803   Output multiple_axes =
2804       ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2805   Output variable_axis =
2806       ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2807   ops::Mean::Attrs attr;
2808   attr = attr.KeepDims(false);
2809   // Should be optimized to Reshape.
2810   Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2811                             one_axis, attr.KeepDims(false));
2812   Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2813                             multiple_axes, attr.KeepDims(false));
2814   // Should remain as-is, since OutputProperties will not be known this node.
2815   Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2816                             one_axis, attr.KeepDims(false));
2817   // Should remain as-is.
2818   Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2819                             variable_axis, attr.KeepDims(false));
2820   // Should be optimized to Identity, since KeepDims=true.
2821   Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2822                             multiple_axes, attr.KeepDims(true));
2823 
2824   GrapplerItem item;
2825   item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2826   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2827 
2828   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2829   GraphDef output;
2830   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2831   TF_EXPECT_OK(status);
2832 
2833   // Ensure Mean node is optimized to Reshape.
2834   int found = 0;
2835   for (const auto& node : output.node()) {
2836     if (node.name() == "mean_1" || node.name() == "mean_2") {
2837       found++;
2838       EXPECT_EQ("Reshape", node.op());
2839       EXPECT_EQ(2, node.input_size());
2840       EXPECT_EQ("input_var_three_dim", node.input(0));
2841     } else if (node.name() == "mean_3") {
2842       found++;
2843       EXPECT_EQ("Mean", node.op());
2844       EXPECT_EQ(2, node.input_size());
2845       EXPECT_EQ("input_var_one_dim", node.input(0));
2846     } else if (node.name() == "mean_4") {
2847       found++;
2848       EXPECT_EQ("Mean", node.op());
2849       EXPECT_EQ(2, node.input_size());
2850       EXPECT_EQ("input_var_three_dim", node.input(0));
2851     } else if (node.name() == "mean_5") {
2852       found++;
2853       EXPECT_EQ("Identity", node.op());
2854       EXPECT_EQ(2, node.input_size());
2855       EXPECT_EQ("^multiple_axes", node.input(1));
2856     }
2857   }
2858   EXPECT_EQ(5, found);
2859 
2860   // Ensure resultant values from Mean and Reshape are the same.
2861   auto input_var_three_dim_t =
2862       GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2863   auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2864   Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2865   input_var_axis_t.flat<int32>()(0) = 0;
2866   auto tensors_expected =
2867       EvaluateNodes(item.graph, item.fetch,
2868                     {{"input_var_three_dim", input_var_three_dim_t},
2869                      {"input_var_one_dim", input_var_one_dim_t},
2870                      {"input_var_axis", input_var_axis_t}});
2871   EXPECT_EQ(5, tensors_expected.size());
2872   auto tensors = EvaluateNodes(output, item.fetch,
2873                                {{"input_var_three_dim", input_var_three_dim_t},
2874                                 {"input_var_one_dim", input_var_one_dim_t},
2875                                 {"input_var_axis", input_var_axis_t}});
2876   EXPECT_EQ(5, tensors.size());
2877   for (int i = 0; i < 5; ++i) {
2878     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2879   }
2880 }
2881 
TEST_F(ConstantFoldingTest,NoOpReshape)2882 TEST_F(ConstantFoldingTest, NoOpReshape) {
2883   // Build a simple graph with a reshape that can be reduced to the identity.
2884   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2885 
2886   // A reshape than can be optimized
2887   Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2888   Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2889   Output c1 =
2890       ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2891   Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2892   Output r1 =
2893       ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2894   Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2895 
2896   // A multi dimensional reshape than can be optimized
2897   Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2898   Output c3 =
2899       ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2900   Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2901   Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2902   Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2903 
2904   // A multi dimensional partially defined reshape than can be optimized
2905   Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2906   Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2907                          {5, -1, 5}, {3});
2908   Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2909   Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2910   Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2911 
2912   // A reshape that can't be optimized
2913   Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2914   Output c2 =
2915       ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2916   Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2917   Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2918 
2919   GrapplerItem item;
2920   item.fetch = {"s1", "s2", "s3", "s4"};
2921   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2922 
2923   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2924   GraphDef output;
2925   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2926   TF_EXPECT_OK(status);
2927 
2928   int found = 0;
2929   for (const auto& node : output.node()) {
2930     if (node.name() == "r1") {
2931       ++found;
2932       EXPECT_EQ("Identity", node.op());
2933       ASSERT_EQ(3, node.input_size());
2934       EXPECT_EQ("v1", node.input(0));
2935       EXPECT_EQ("^i1", node.input(1));
2936       EXPECT_EQ("^d1", node.input(2));
2937     } else if (node.name() == "r3") {
2938       ++found;
2939       EXPECT_EQ("Identity", node.op());
2940       ASSERT_EQ(2, node.input_size());
2941       EXPECT_EQ("v3", node.input(0));
2942       EXPECT_EQ("^i3", node.input(1));
2943     } else if (node.name() == "r4") {
2944       ++found;
2945       EXPECT_EQ("Identity", node.op());
2946       ASSERT_EQ(2, node.input_size());
2947       EXPECT_EQ("v4", node.input(0));
2948       EXPECT_EQ("^i4", node.input(1));
2949     } else if (node.name() == "r2") {
2950       ++found;
2951       EXPECT_EQ("Reshape", node.op());
2952       ASSERT_EQ(2, node.input_size());
2953       EXPECT_EQ("v2", node.input(0));
2954       EXPECT_EQ("c2", node.input(1));
2955     }
2956   }
2957   EXPECT_EQ(4, found);
2958 
2959   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2960   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2961   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2962   auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2963   auto tensors_expected =
2964       EvaluateNodes(item.graph, item.fetch,
2965                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2966   EXPECT_EQ(4, tensors_expected.size());
2967   auto tensors =
2968       EvaluateNodes(output, item.fetch,
2969                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2970   EXPECT_EQ(4, tensors.size());
2971   for (int i = 0; i < tensors.size(); i++)
2972     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2973 }
2974 
TEST_F(ConstantFoldingTest,Packing)2975 TEST_F(ConstantFoldingTest, Packing) {
2976   // Build a simple graph with a large constant that can be folded.
2977   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2978   Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2979   Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2980   Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2981 
2982   GrapplerItem item;
2983   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2984 
2985   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2986   GraphDef output;
2987   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2988   TF_EXPECT_OK(status);
2989 
2990   const std::vector<string> fetch_nodes = {"i1", "i2"};
2991   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2992   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2993   auto tensors = EvaluateNodes(output, fetch_nodes);
2994   EXPECT_EQ(fetch_nodes.size(), tensors.size());
2995   for (int i = 0; i < fetch_nodes.size(); i++)
2996     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2997 
2998   // Make sure that the representation of the folded constant is space
2999   // efficient: in particular, the whole message should be smaller than 8k
3000   // (the size needed to naively encode 1000 floats folded twice).
3001   EXPECT_GT(8000, output.ByteSizeLong());
3002 }
3003 
TEST_F(ConstantFoldingTest,LargeConstantNoSizeIncrease)3004 TEST_F(ConstantFoldingTest, LargeConstantNoSizeIncrease) {
3005   // Build a simple graph with a large constant with size greater than
3006   // kMaxConstantSize that can be folded because the resulting size does not
3007   // increase.
3008   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3009   const int64_t large_constant_size = kMaxConstantSize + 1;
3010   Output a = ops::Variable(scope.WithOpName("a"), {1, 1}, DT_FLOAT);
3011   Output b_const =
3012       ops::Const(scope.WithOpName("b_const"), 3.14f, {1, large_constant_size});
3013   Output b = ops::Identity(scope.WithOpName("b"), b_const);
3014   Output matmul = ops::MatMul(scope.WithOpName("matmul"), a, b);
3015 
3016   GrapplerItem item;
3017   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3018 
3019   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3020   GraphDef output;
3021   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3022   TF_EXPECT_OK(status);
3023 
3024   item.graph.Swap(&output);
3025   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3026   TF_EXPECT_OK(status);
3027 
3028   for (const auto& node : output.node()) {
3029     if (node.name() == "b") {
3030       EXPECT_EQ("Const", node.op());
3031     }
3032   }
3033   EXPECT_EQ(4, output.node_size());
3034   EXPECT_LT(output.ByteSizeLong(), sizeof(float) * large_constant_size + 500);
3035 }
3036 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)3037 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
3038   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3039   Output a =
3040       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3041                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3042   Output b = ops::Square(s.WithOpName("b"), a);
3043   Output c = ops::Mul(s.WithOpName("c"), a, b);
3044   Output d = ops::Shape(s.WithOpName("d"), a);
3045   Output e = ops::Shape(s.WithOpName("e"), b);
3046 
3047   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3048   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3049   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3050 
3051   Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
3052                               ops::Placeholder::Shape(PartialTensorShape({1})));
3053   Output h = ops::Shape(s.WithOpName("h"), g);
3054   auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
3055   Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
3056   Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
3057 
3058   GrapplerItem item;
3059   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3060 
3061   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3062   GraphDef output;
3063   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3064   TF_EXPECT_OK(status);
3065 
3066   std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
3067   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
3068   auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
3069   auto tensors_expected =
3070       EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3071   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3072 
3073   // Run a second time to make sure the optimization is idempotent.
3074   item.graph.Swap(&output);
3075   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3076   TF_EXPECT_OK(status);
3077 
3078   int found = 0;
3079   for (const auto& node : output.node()) {
3080     if (node.name() == "o1") {
3081       ++found;
3082       EXPECT_EQ(1, node.input_size());
3083       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3084     } else if (node.name() == "o2") {
3085       ++found;
3086       EXPECT_EQ(1, node.input_size());
3087       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3088     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3089       ++found;
3090       EXPECT_EQ("Const", node.op());
3091       EXPECT_EQ(1, node.input_size());
3092       EXPECT_EQ("^f", node.input(0));
3093       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3094                        .num_elements());
3095     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3096       ++found;
3097       EXPECT_EQ("Const", node.op());
3098       EXPECT_EQ(1, node.input_size());
3099       EXPECT_EQ("^f", node.input(0));
3100       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3101                        .num_elements());
3102     } else if (node.name() == "p1") {
3103       ++found;
3104       EXPECT_EQ(1, node.input_size());
3105       EXPECT_EQ("i", node.input(0));
3106     } else if (node.name() == "p2") {
3107       ++found;
3108       EXPECT_EQ(1, node.input_size());
3109       EXPECT_EQ("i:1", node.input(0));
3110     }
3111   }
3112   EXPECT_EQ(6, found);
3113 
3114   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
3115   EXPECT_EQ(fetch_nodes.size(), tensors.size());
3116   for (int i = 0; i < fetch_nodes.size(); i++)
3117     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3118 }
3119 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)3120 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
3121   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3122   Output a =
3123       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
3124                        ops::Placeholder::Shape(PartialTensorShape({2, 2})));
3125   Output b = ops::Square(s.WithOpName("b"), a);
3126   Output c = ops::Mul(s.WithOpName("c"), a, b);
3127   Output d = ops::Shape(s.WithOpName("d"), a);
3128   Output e = ops::Shape(s.WithOpName("e"), b);
3129 
3130   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
3131   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
3132   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
3133 
3134   GrapplerItem item;
3135   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3136 
3137   std::vector<string> fetch_nodes = {"o1", "o2"};
3138   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
3139   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
3140   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
3141 
3142   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3143   GraphDef output;
3144   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3145   TF_EXPECT_OK(status);
3146 
3147   // Run a second time to make sure the optimization is idempotent.
3148   item.graph.Swap(&output);
3149   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3150   TF_EXPECT_OK(status);
3151 
3152   EXPECT_EQ(11, output.node_size());
3153   int found = 0;
3154   for (const auto& node : output.node()) {
3155     if (node.name() == "ConstantFolding/f-folded-1") {
3156       ++found;
3157       EXPECT_EQ("Const", node.op());
3158       EXPECT_EQ(2, node.input_size());
3159       EXPECT_EQ("^a", node.input(0));
3160       EXPECT_EQ("^b", node.input(1));
3161     } else if (node.name() == "d") {
3162       ++found;
3163       EXPECT_EQ("Const", node.op());
3164       EXPECT_EQ(1, node.input_size());
3165       EXPECT_EQ("^a", node.input(0));
3166     } else if (node.name() == "e") {
3167       ++found;
3168       EXPECT_EQ("Const", node.op());
3169       EXPECT_EQ(1, node.input_size());
3170       EXPECT_EQ("^b", node.input(0));
3171     } else if (node.name() == "o1") {
3172       ++found;
3173       EXPECT_EQ(1, node.input_size());
3174       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
3175     } else if (node.name() == "o2") {
3176       ++found;
3177       EXPECT_EQ(1, node.input_size());
3178       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
3179     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
3180       ++found;
3181       EXPECT_EQ("Const", node.op());
3182       EXPECT_EQ(1, node.input_size());
3183       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3184       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3185                        .num_elements());
3186     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
3187       ++found;
3188       EXPECT_EQ("Const", node.op());
3189       EXPECT_EQ(1, node.input_size());
3190       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
3191       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
3192                        .num_elements());
3193     }
3194   }
3195   EXPECT_EQ(7, found);
3196   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
3197   EXPECT_EQ(fetch_nodes.size(), tensors.size());
3198   for (int i = 0; i < fetch_nodes.size(); i++)
3199     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3200 }
3201 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)3202 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
3203   for (bool use_reshape : {true, false}) {
3204     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3205     Output input =
3206         ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3207                          ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
3208     // If use_reshape is false, we need to now the number of indices to apply
3209     // the rewrite.
3210     Output indices = ops::Placeholder(
3211         s.WithOpName("indices"), DT_INT32,
3212         ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
3213     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3214     if (use_reshape) {
3215       Output size = ops::Const(s.WithOpName("size"), 1, {1});
3216       Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
3217     }
3218 
3219     GrapplerItem item;
3220     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3221     item.fetch.push_back(use_reshape ? "reshape" : "sum");
3222 
3223     auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
3224     Tensor indices_t(DT_INT32, TensorShape({2}));
3225     indices_t.flat<int>()(0) = 0;
3226     indices_t.flat<int>()(1) = 1;
3227     auto tensors_expected = EvaluateNodes(
3228         item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
3229     EXPECT_EQ(1, tensors_expected.size());
3230 
3231     // Use aggressive mode to force the shape inference to propagate placeholder
3232     // shapes.
3233     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3234                               /*cpu_device=*/nullptr);
3235     GraphDef output;
3236     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3237     TF_EXPECT_OK(status);
3238 
3239     // Run a second time to make sure the optimization is idempotent.
3240     item.graph.Swap(&output);
3241     status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3242     TF_EXPECT_OK(status);
3243 
3244     int found = 0;
3245     for (const auto& node : output.node()) {
3246       if (node.name() == "ConstantFolding/sum-reduction_indices") {
3247         ++found;
3248         EXPECT_EQ("Const", node.op());
3249         EXPECT_EQ("^indices", node.input(0));
3250         EXPECT_EQ(2,
3251                   TensorShape(node.attr().at("value").tensor().tensor_shape())
3252                       .num_elements());
3253       } else if (node.name() == "sum") {
3254         ++found;
3255         EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
3256       } else if (node.name() == "indices") {
3257         ++found;
3258       }
3259     }
3260     EXPECT_EQ(3, found);
3261 
3262     auto tensors = EvaluateNodes(output, item.fetch,
3263                                  {{"input", input_t}, {"indices", indices_t}});
3264     EXPECT_EQ(1, tensors.size());
3265     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
3266   }
3267 }
3268 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)3269 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
3270   for (bool input_rank_known : {true, false}) {
3271     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3272     Output input =
3273         (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
3274                                              ops::Placeholder::Shape(
3275                                                  PartialTensorShape({-1, -1})))
3276                           : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
3277     Output indices =
3278         ops::Placeholder(s.WithOpName("indices"), DT_INT32,
3279                          ops::Placeholder::Shape(
3280                              PartialTensorShape({input_rank_known ? 1 : 2})));
3281     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
3282 
3283     GrapplerItem item;
3284     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3285     item.fetch.push_back("sum");
3286 
3287     // Use aggressive mode to force the shape inference to propagate placeholder
3288     // shapes.
3289     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
3290                               /*cpu_device=*/nullptr);
3291     GraphDef output;
3292     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3293     TF_EXPECT_OK(status);
3294 
3295     CompareGraphs(item.graph, output);
3296   }
3297 }
3298 
TEST_F(ConstantFoldingTest,LargeConstant)3299 TEST_F(ConstantFoldingTest, LargeConstant) {
3300   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3301   // Generate a 4k by 4k constant, non-compressible matrix.
3302   Output mat_diag =
3303       ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
3304   Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
3305   Output out = ops::Identity(scope.WithOpName("out"), mat);
3306 
3307   GrapplerItem item;
3308   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3309   item.fetch.push_back("out");
3310 
3311   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3312   GraphDef output;
3313   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3314   TF_EXPECT_OK(status);
3315 
3316   // Make sure the diag node hasn't been folded, since it would use too much
3317   // memory to encode the corresponding constant.
3318   int found = 0;
3319   for (const NodeDef& node : output.node()) {
3320     if (node.name() == "out") {
3321       EXPECT_EQ(node.op(), "Identity");
3322       ASSERT_EQ(node.input_size(), 1);
3323       EXPECT_EQ(node.input(0), "mat");
3324       ++found;
3325     } else if (node.name() == "mat") {
3326       EXPECT_EQ(node.op(), "Diag");
3327       ASSERT_EQ(node.input_size(), 1);
3328       EXPECT_EQ(node.input(0), "mat_diag");
3329       ++found;
3330     }
3331   }
3332   EXPECT_EQ(found, 2);
3333   // output should be no longer than the size of the constant "mat_diag"
3334   // plus a small constant amount for the remaining nodes.
3335   EXPECT_LT(output.ByteSizeLong(), sizeof(int) * 4 * 1024 + 500);
3336 
3337   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3338   ASSERT_EQ(tensors_expected.size(), 1);
3339   auto tensors = EvaluateNodes(output, item.fetch);
3340   ASSERT_EQ(tensors.size(), 1);
3341   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3342 }
3343 
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)3344 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
3345   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3346   Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
3347                               ops::Placeholder::Shape(TensorShape({})));
3348   ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
3349   Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
3350   Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
3351 
3352   GrapplerItem item;
3353   item.fetch.push_back("id_false");
3354   item.fetch.push_back("id_true");
3355   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3356 
3357   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3358   GraphDef output;
3359   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3360   TF_EXPECT_OK(status);
3361 
3362   EXPECT_EQ(6, output.node_size());
3363   int found = 0;
3364   for (const auto& node : output.node()) {
3365     if (node.name() == "switch" || node.name() == "x") {
3366       ++found;
3367     }
3368     if (node.name() == "id_false") {
3369       EXPECT_EQ("Const", node.op());
3370       EXPECT_EQ(1, node.input_size());
3371       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3372       ++found;
3373     }
3374     if (node.name() == "id_true") {
3375       EXPECT_EQ("Const", node.op());
3376       EXPECT_EQ(1, node.input_size());
3377       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3378       ++found;
3379     }
3380     if (node.name() == "ConstantFoldingCtrl/switch_0") {
3381       EXPECT_EQ("Identity", node.op());
3382       EXPECT_EQ(1, node.input_size());
3383       EXPECT_EQ("switch", node.input(0));
3384       ++found;
3385     }
3386     if (node.name() == "ConstantFoldingCtrl/switch_1") {
3387       EXPECT_EQ("Identity", node.op());
3388       EXPECT_EQ(1, node.input_size());
3389       EXPECT_EQ("switch:1", node.input(0));
3390       ++found;
3391     }
3392   }
3393   EXPECT_EQ(6, found);
3394 
3395   // Evaluate id_true when input tensor x is true.
3396   Tensor x_t(DT_BOOL, TensorShape({}));
3397   x_t.flat<bool>()(0) = true;
3398   auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3399   EXPECT_EQ(1, tensors_expected.size());
3400   auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3401   EXPECT_EQ(1, tensors.size());
3402   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3403 
3404   // Evaluate id_false when input tensor is false.
3405   x_t.flat<bool>()(0) = false;
3406   tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3407   EXPECT_EQ(1, tensors_expected.size());
3408   tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3409   EXPECT_EQ(1, tensors.size());
3410   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3411 }
3412 
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3413 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3414   std::function<Output(const Scope&, InputList)> addn_fun =
3415       [](const Scope& scope, InputList inputs) {
3416         return ops::AddN(scope, inputs);
3417       };
3418   std::function<Output(const Scope&, InputList)> accumulate_fun =
3419       [](const Scope& scope, InputList inputs) {
3420         return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3421       };
3422   for (bool use_add_n : {true, false}) {
3423     auto fun = use_add_n ? addn_fun : accumulate_fun;
3424     const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3425     Scope s = Scope::NewRootScope();
3426     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3427                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3428     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3429                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3430     Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3431                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3432     Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3433     Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3434     Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3435     Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3436     Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3437     Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3438     Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3439     Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3440     Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3441     Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3442     Output stack = ops::Stack(s.WithOpName("stack"),
3443                               {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3444 
3445     GrapplerItem item;
3446     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3447     item.fetch = {"stack"};
3448 
3449     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3450     GraphDef output;
3451     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3452     TF_EXPECT_OK(status);
3453 
3454     EXPECT_EQ(16, output.node_size());
3455     for (const NodeDef& node : output.node()) {
3456       if (node.name() == "acc0") {
3457         EXPECT_EQ("Const", node.op());
3458       }
3459       if (node.name() == "acc1") {
3460         EXPECT_EQ(op_name, node.op());
3461         EXPECT_EQ(3, node.input_size());
3462         EXPECT_EQ("x", node.input(0));
3463         EXPECT_EQ("y", node.input(1));
3464         EXPECT_EQ("z", node.input(2));
3465       }
3466       if (node.name() == "acc2") {
3467         EXPECT_EQ(op_name, node.op());
3468         EXPECT_EQ(3, node.input_size());
3469         EXPECT_EQ("c1", node.input(0));
3470         EXPECT_EQ("x", node.input(1));
3471         EXPECT_EQ("y", node.input(2));
3472       }
3473       if (node.name() == "acc3") {
3474         EXPECT_EQ(op_name, node.op());
3475         EXPECT_EQ(2, node.input_size());
3476         EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3477         EXPECT_EQ("z", node.input(1));
3478       }
3479       if (node.name() == "acc4") {
3480         EXPECT_EQ(op_name, node.op());
3481         EXPECT_EQ(2, node.input_size());
3482         EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3483         EXPECT_EQ("y", node.input(1));
3484       }
3485       if (node.name() == "acc5") {
3486         EXPECT_EQ(op_name, node.op());
3487         EXPECT_EQ(2, node.input_size());
3488         EXPECT_EQ("x", node.input(0));
3489         EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3490       }
3491       if (node.name() == "acc6") {
3492         EXPECT_EQ(op_name, node.op());
3493         EXPECT_EQ(3, node.input_size());
3494         EXPECT_EQ("x", node.input(0));
3495         EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3496         EXPECT_EQ("y", node.input(2));
3497       }
3498       if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3499         EXPECT_EQ("Const", node.op());
3500       }
3501     }
3502 
3503     std::vector<string> fetch = {"acc0"};
3504     auto tensors_expected = EvaluateNodes(item.graph, fetch);
3505     auto tensors = EvaluateNodes(output, fetch);
3506     EXPECT_EQ(1, tensors_expected.size());
3507     EXPECT_EQ(1, tensors.size());
3508     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3509   }
3510 }
3511 
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3512 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3513   Scope s = Scope::NewRootScope();
3514   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3515                               ops::Placeholder::Shape(TensorShape({2, 2})));
3516   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3517                               ops::Placeholder::Shape(TensorShape({2, 2})));
3518   Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3519                               ops::Placeholder::Shape(TensorShape({2, 2})));
3520   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3521   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3522   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3523   Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3524   Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3525   Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3526   Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3527   Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3528   Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3529   Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3530   Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3531   Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3532   Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3533 
3534   GrapplerItem item;
3535   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3536   item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3537                 "concat5", "concat6", "concat7", "concat8", "concat9"};
3538 
3539   auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3540   EXPECT_EQ(1, tensors_expected.size());
3541   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3542   GraphDef output;
3543   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3544   TF_EXPECT_OK(status);
3545   // Run the optimizer twice to make sure the rewrite is idempotent.
3546   item.graph.Swap(&output);
3547   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3548   TF_EXPECT_OK(status);
3549 
3550   EXPECT_EQ(21, output.node_size());
3551   for (int i = 0; i < output.node_size(); ++i) {
3552     const NodeDef& node = output.node(i);
3553     if (node.name() == "concat0") {
3554       EXPECT_EQ("Const", node.op());
3555     } else if (node.name() == "concat3") {
3556       EXPECT_EQ(3, node.input_size());
3557       EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3558       EXPECT_EQ("z", node.input(1));
3559       EXPECT_EQ("axis", node.input(2));
3560     } else if (node.name() == "concat5") {
3561       EXPECT_EQ(3, node.input_size());
3562       EXPECT_EQ("x", node.input(0));
3563       EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3564       EXPECT_EQ("axis", node.input(2));
3565     } else if (node.name() == "concat7") {
3566       EXPECT_EQ(4, node.input_size());
3567       EXPECT_EQ("x", node.input(0));
3568       EXPECT_EQ("y", node.input(1));
3569       EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3570       EXPECT_EQ("axis", node.input(3));
3571     } else if (node.name() == "concat8") {
3572       EXPECT_EQ(4, node.input_size());
3573       EXPECT_EQ("x", node.input(0));
3574       EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3575       EXPECT_EQ("y", node.input(2));
3576       EXPECT_EQ("axis", node.input(3));
3577     } else if (node.name() == "concat9") {
3578       EXPECT_EQ(4, node.input_size());
3579       EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3580       EXPECT_EQ("x", node.input(1));
3581       EXPECT_EQ("y", node.input(2));
3582       EXPECT_EQ("axis", node.input(3));
3583     } else if (absl::StartsWith(node.name(), "ConstantFolding/")) {
3584       EXPECT_EQ("Const", node.op());
3585     } else {
3586       EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3587     }
3588   }
3589 
3590   auto tensors = EvaluateNodes(output, {"concat0"});
3591   EXPECT_EQ(1, tensors.size());
3592   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3593 }
3594 
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3595 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3596   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3597   Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3598                               ops::Placeholder::Shape(TensorShape({})));
3599   Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3600   Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3601   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3602   auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3603   auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3604   auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3605   auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3606 
3607   GrapplerItem item;
3608   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3609   item.fetch.push_back("id0");
3610   item.fetch.push_back("id1");
3611   item.fetch.push_back("add0");
3612   item.fetch.push_back("add1");
3613 
3614   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3615   GraphDef output;
3616   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3617   TF_EXPECT_OK(status);
3618   EXPECT_EQ(8, output.node_size());
3619   for (const auto& node : output.node()) {
3620     // id_n should remain unchanged.
3621     if (node.name() == "id_n") {
3622       EXPECT_EQ(3, node.input_size());
3623       EXPECT_EQ("c1", node.input(0));
3624       EXPECT_EQ("x", node.input(1));
3625       EXPECT_EQ("c2", node.input(2));
3626     }
3627     // id0 should be constant folded, and a control dependency from id_n.
3628     if (node.name() == "id0") {
3629       EXPECT_EQ("Const", node.op());
3630       EXPECT_EQ(1, node.input_size());
3631       EXPECT_EQ("^id_n", node.input(0));
3632     }
3633     // id1 is unchanged.
3634     if ("id1" == node.name()) {
3635       EXPECT_EQ(1, node.input_size());
3636       EXPECT_EQ("id_n:1", node.input(0));
3637     }
3638 
3639     if ("add0" == node.name()) {
3640       EXPECT_EQ(2, node.input_size());
3641       EXPECT_EQ("c1", node.input(0));
3642       EXPECT_EQ("id_n:1", node.input(1));
3643     }
3644     // add1 should bo constant folded and have a control dependency from id_n.
3645     if ("add1" == node.name()) {
3646       EXPECT_EQ("Const", node.op());
3647       EXPECT_EQ(1, node.input_size());
3648       EXPECT_EQ("^id_n", node.input(0));
3649     }
3650   }
3651 
3652   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3653   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3654   EXPECT_EQ(4, tensors_expected.size());
3655   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3656   EXPECT_EQ(4, tensors.size());
3657   for (int i = 0; i < tensors.size(); i++) {
3658     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3659   }
3660 }
3661 
TEST_F(ConstantFoldingTest,TrivialPack)3662 TEST_F(ConstantFoldingTest, TrivialPack) {
3663   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3664   Output x =
3665       ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3666   Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3667   auto stack =
3668       ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3669                  ops::Stack::Axis(1));
3670   auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3671 
3672   GrapplerItem item;
3673   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3674   item.fetch = {"stack", "stack_no_axis"};
3675 
3676   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3677   GraphDef output;
3678   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3679   TF_EXPECT_OK(status);
3680   EXPECT_EQ(7, output.node_size());
3681   int found = 0;
3682   for (const auto& node : output.node()) {
3683     if (node.name() == "stack") {
3684       EXPECT_EQ("ExpandDims", node.op());
3685       EXPECT_EQ(3, node.input_size());
3686       EXPECT_EQ("x", node.input(0));
3687       EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3688       EXPECT_EQ("^y", node.input(2));
3689       ++found;
3690     } else if (node.name() == "stack_no_axis") {
3691       EXPECT_EQ("ExpandDims", node.op());
3692       EXPECT_EQ(2, node.input_size());
3693       EXPECT_EQ("x", node.input(0));
3694       EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3695       ++found;
3696     } else if (node.name() == "ConstantFolding/stack_const_axis") {
3697       EXPECT_EQ("Const", node.op());
3698       EXPECT_EQ(1, node.input_size());
3699       EXPECT_EQ("^x", node.input(0));
3700       ++found;
3701     }
3702   }
3703   EXPECT_EQ(found, 3);
3704 
3705   std::vector<string> fetch = {"stack", "stack_no_axis"};
3706   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3707   auto tensors = EvaluateNodes(output, fetch);
3708   EXPECT_EQ(2, tensors_expected.size());
3709   EXPECT_EQ(2, tensors.size());
3710   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3711   EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3712 }
3713 
3714 // The test does not evalaute the optimized and original graphs to check if
3715 // their outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3716 TEST_F(ConstantFoldingTest, Enter) {
3717   GrapplerItem item;
3718   AttrValue frame_name;
3719   frame_name.set_s("foo");
3720   AttrValue is_constant_true;
3721   is_constant_true.set_b(true);
3722   AttrValue is_constant_false;
3723   is_constant_false.set_b(false);
3724   AttrValue type;
3725   type.set_type(DT_FLOAT);
3726   AttrValue value;
3727   Tensor value_tensor(DT_FLOAT, TensorShape({}));
3728   value_tensor.flat<float>()(0) = 1;
3729   value_tensor.AsProtoTensorContent(value.mutable_tensor());
3730 
3731   GraphDef& graph = item.graph;
3732   AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3733   AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3734   AddNode("enter1", "Enter", {"x"},
3735           {{"T", type},
3736            {"frame_name", frame_name},
3737            {"is_constant", is_constant_true}},
3738           &graph);
3739   AddNode("enter2", "Enter", {"c1"},
3740           {{"T", type},
3741            {"frame_name", frame_name},
3742            {"is_constant", is_constant_true}},
3743           &graph);
3744   AddNode("enter3", "Enter", {"c1"},
3745           {{"T", type},
3746            {"frame_name", frame_name},
3747            {"is_constant", is_constant_false}},
3748           &graph);
3749   AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3750   AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3751   AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3752   AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3753   item.fetch.push_back("id1");
3754   item.fetch.push_back("id2");
3755   item.fetch.push_back("id3");
3756   item.fetch.push_back("id4");
3757 
3758   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3759   GraphDef output;
3760   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3761   TF_EXPECT_OK(status);
3762   // Run the optimizer twice to make sure the rewrite is idempotent.
3763   item.graph.Swap(&output);
3764   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3765   TF_EXPECT_OK(status);
3766 
3767   EXPECT_EQ(9, output.node_size());
3768   for (const NodeDef& node : output.node()) {
3769     if (node.name() == "id1") {
3770       EXPECT_EQ("Identity", node.op());
3771       EXPECT_EQ(1, node.input_size());
3772       EXPECT_EQ("enter1", node.input(0));
3773     }
3774     if (node.name() == "id2" || node.name() == "id3") {
3775       EXPECT_EQ("Const", node.op());
3776       EXPECT_EQ(1, node.input_size());
3777       EXPECT_EQ("^enter2", node.input(0));
3778     }
3779     if (node.name() == "id4") {
3780       EXPECT_EQ("Identity", node.op());
3781       EXPECT_EQ(1, node.input_size());
3782       EXPECT_EQ("enter3", node.input(0));
3783     }
3784   }
3785 }
3786 
TEST_F(ConstantFoldingTest,TensorArraySize)3787 TEST_F(ConstantFoldingTest, TensorArraySize) {
3788   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3789   Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3790   Output placeholder =
3791       ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3792                        ops::Placeholder::Shape(TensorShape({2})));
3793   Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3794   auto dynamic_array =
3795       ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3796                        ops::TensorArray::DynamicSize(true));
3797   auto static_array =
3798       ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3799                        ops::TensorArray::DynamicSize(false));
3800   auto dynamic_sz = ops::TensorArraySize(
3801       scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3802   auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3803                                         static_array.handle, static_array.flow);
3804   auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3805                                              placeholder, foo);
3806 
3807   GrapplerItem item;
3808   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3809 
3810   auto tensors_expected =
3811       EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3812 
3813   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3814   GraphDef output;
3815   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3816   TF_EXPECT_OK(status);
3817   // Run the optimizer twice to make sure the rewrite is idempotent.
3818   item.graph.Swap(&output);
3819   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3820   TF_EXPECT_OK(status);
3821 
3822   EXPECT_EQ(8, output.node_size());
3823   EXPECT_EQ("dynamic_sz", output.node(5).name());
3824   EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3825   EXPECT_EQ("static_sz", output.node(6).name());
3826   EXPECT_EQ("Const", output.node(6).op());
3827   EXPECT_EQ("placeholder_sz", output.node(7).name());
3828   EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3829 
3830   auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3831   EXPECT_EQ(2, tensors_expected.size());
3832   EXPECT_EQ(2, tensors_actual.size());
3833   test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3834   test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3835 }
3836 
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3837 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3838   // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3839   // Make sure constant folding behaves the same way as TensorFlow.
3840   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3841 
3842   Output a =
3843       ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3844   Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3845   Output c = ops::Mul(s.WithOpName("c"), a, b);
3846 
3847   GrapplerItem item;
3848   item.fetch.push_back("c");
3849   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3850 
3851   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3852   GraphDef output;
3853   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3854   TF_EXPECT_OK(status);
3855 
3856   EXPECT_EQ(1, output.node_size());
3857 
3858   const NodeDef& node_d = output.node(0);
3859   EXPECT_EQ("c", node_d.name());
3860   EXPECT_EQ("Const", node_d.op());
3861 
3862   std::vector<string> fetch = {"c"};
3863   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3864   auto tensors = EvaluateNodes(output, fetch);
3865   EXPECT_EQ(1, tensors_expected.size());
3866   EXPECT_EQ(1, tensors.size());
3867   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3868 }
3869 
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3870 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3871   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3872 
3873   int size = 10 * 1024 * 1024 / 4 / 2;
3874   Output nonconst =
3875       ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3876   Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3877   Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3878   Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3879   Output concat1 =
3880       ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3881   Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3882 
3883   GrapplerItem item;
3884   item.fetch.push_back("result");
3885   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3886 
3887   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3888   GraphDef output;
3889   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3890   TF_EXPECT_OK(status);
3891 
3892   std::vector<string> fetch = {"result"};
3893   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3894   auto tensors = EvaluateNodes(output, fetch);
3895   EXPECT_EQ(1, tensors_expected.size());
3896   EXPECT_EQ(1, tensors.size());
3897   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3898 }
3899 
3900 class ConstantFoldingCastConstTest : public GrapplerTest {
3901  protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3902   void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3903                                 bool fetch_const_child, bool fetch_cast_child) {
3904     if (!fetch_const && !fetch_cast && !fetch_const_child &&
3905         !fetch_cast_child) {
3906       return;
3907     }
3908 
3909     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3910     CreateCastConstGraph(s);
3911     GrapplerItem item;
3912     int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3913                                         fetch_const_child, fetch_cast_child);
3914     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3915 
3916     GraphDef output = ConstantFoldingOptimize(item);
3917     EXPECT_EQ(expected_output_size, output.node_size());
3918 
3919     EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3920   }
3921 
3922  private:
CreateCastConstGraph(const tensorflow::Scope & s)3923   void CreateCastConstGraph(const tensorflow::Scope& s) {
3924     Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3925     Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3926     Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3927     Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3928   }
3929 
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3930   int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3931                bool fetch_const_child, bool fetch_cast_child) {
3932     int expected_output_size = 0;
3933     if (fetch_const) {
3934       item->fetch.push_back("const1");
3935       expected_output_size++;
3936     }
3937     if (fetch_cast) {
3938       item->fetch.push_back("cast");
3939       expected_output_size++;
3940     }
3941     if (fetch_const_child) {
3942       item->fetch.push_back("const1_child");
3943       expected_output_size++;
3944     }
3945     if (fetch_cast_child) {
3946       item->fetch.push_back("cast_child");
3947       expected_output_size++;
3948     }
3949     return expected_output_size;
3950   }
3951 
ConstantFoldingOptimize(const GrapplerItem & item)3952   GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3953     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3954     GraphDef output;
3955     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3956     TF_EXPECT_OK(status);
3957     return output;
3958   }
3959 
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3960   void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3961                                      const GraphDef& optimized_graph,
3962                                      const std::vector<string>& fetch_nodes) {
3963     auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3964     auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3965     ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3966     ASSERT_EQ(fetch_nodes.size(), tensors.size());
3967     for (int i = 0; i < fetch_nodes.size(); i++) {
3968       if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3969         test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3970       } else {
3971         test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3972       }
3973     }
3974   }
3975 };
3976 
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3977 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3978   for (bool fetch_const : {false, true}) {
3979     for (bool fetch_cast : {false, true}) {
3980       for (bool fetch_const_child : {false, true}) {
3981         for (bool fetch_cast_child : {false, true}) {
3982           ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3983                                    fetch_cast_child);
3984         }
3985       }
3986     }
3987   }
3988 }
3989 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3990 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3991   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3992 
3993   Output x =
3994       ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3995                        ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3996   Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3997   Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3998   Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3999 
4000   GrapplerItem item;
4001   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4002   item.fetch = {"ones_like", "zeros_like", "fill"};
4003   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4004   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4005 
4006   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4007   GraphDef output;
4008   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4009   TF_EXPECT_OK(status);
4010 
4011   EXPECT_EQ(output.node_size(), 6);
4012   for (const auto& node : output.node()) {
4013     if (node.name() != "x") {
4014       EXPECT_EQ(node.op(), "Const");
4015     }
4016     if (node.name() == "ones_like" || node.name() == "zeros_like") {
4017       ASSERT_EQ(node.input_size(), 1);
4018       EXPECT_EQ(node.input(0), "^x");
4019     }
4020     if (node.name() == "fill") {
4021       ASSERT_EQ(node.input_size(), 2);
4022       EXPECT_EQ(node.input(0)[0], '^');
4023       EXPECT_EQ(node.input(1)[0], '^');
4024     }
4025   }
4026   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4027   ASSERT_EQ(item.fetch.size(), tensors.size());
4028   ASSERT_EQ(tensors_expected.size(), tensors.size());
4029   for (int i = 0; i < tensors.size(); i++) {
4030     if (item.fetch[i] == "fill") {
4031       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4032     } else {
4033       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4034     }
4035   }
4036 }
4037 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeDisableCompression)4038 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeDisableCompression) {
4039   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4040 
4041   Output x =
4042       ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
4043                        ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
4044   Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
4045   Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
4046   Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
4047 
4048   GrapplerItem item;
4049   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4050   item.fetch = {"ones_like", "zeros_like", "fill"};
4051   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
4052   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
4053 
4054   ConstantFolding optimizer(/*cpu_device=*/nullptr, true);
4055   GraphDef output;
4056   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4057   TF_EXPECT_OK(status);
4058 
4059   EXPECT_EQ(output.node_size(), 6);
4060   for (const auto& node : output.node()) {
4061     if (node.name() == "ones_like") {
4062       EXPECT_EQ(node.op(), "OnesLike");
4063       ASSERT_EQ(node.input_size(), 1);
4064       EXPECT_EQ(node.input(0), "x");
4065     }
4066     if (node.name() == "zeros_like") {
4067       EXPECT_EQ(node.op(), "ZerosLike");
4068       ASSERT_EQ(node.input_size(), 1);
4069       EXPECT_EQ(node.input(0), "x");
4070     }
4071     if (node.name() == "fill") {
4072       EXPECT_EQ(node.op(), "Fill");
4073       ASSERT_EQ(node.input_size(), 2);
4074       EXPECT_EQ(node.input(0), "Const/Const");
4075       EXPECT_EQ(node.input(1), "Const_1/Const");
4076     }
4077   }
4078   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
4079   ASSERT_EQ(item.fetch.size(), tensors.size());
4080   ASSERT_EQ(tensors_expected.size(), tensors.size());
4081   for (int i = 0; i < tensors.size(); i++) {
4082     if (item.fetch[i] == "fill") {
4083       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
4084     } else {
4085       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
4086     }
4087   }
4088 }
4089 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)4090 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
4091   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4092   Output value = ops::Const(scope.WithOpName("value"), 42, {});
4093   Output shape_const = ops::Const(scope.WithOpName("shape"),
4094                                   {1024, 1024, 1024, 1024, 1024}, {5});
4095   Output fill_huge =
4096       ops::Fill(scope.WithOpName("fill_huge"), shape_const, value);
4097 
4098   GrapplerItem item;
4099   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4100   // Manually convert the input value format to tensor_content to test this
4101   // case.
4102   NodeDef* node = item.graph.mutable_node(0);
4103   ASSERT_EQ(node->name(), "value");
4104   TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
4105   t->clear_int_val();
4106   int val = 42;
4107   port::CopyFromArray(t->mutable_tensor_content(),
4108                       reinterpret_cast<const char*>(&val), sizeof(int));
4109   item.fetch = {"fill_huge"};
4110   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4111   GraphDef output;
4112   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4113   TF_EXPECT_OK(status);
4114 
4115   EXPECT_EQ(output.node_size(), 3);
4116   for (const auto& node : output.node()) {
4117     EXPECT_EQ(node.op(), "Const");
4118     if (node.name() == "fill_huge") {
4119       ASSERT_EQ(node.input_size(), 2);
4120       EXPECT_EQ(node.input(0), "^shape");
4121       EXPECT_EQ(node.input(1), "^value");
4122     }
4123   }
4124 }
4125 
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)4126 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
4127   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4128 
4129   Tensor x_t(DT_INT64, TensorShape({2, 2}));
4130   x_t.flat<int64_t>()(0) = 9223372036854775807L;
4131   x_t.flat<int64_t>()(1) = 1L;
4132   x_t.flat<int64_t>()(2) = 9223372036854775807L;
4133   x_t.flat<int64_t>()(3) = 1L;
4134   Output x = ops::Const(scope.WithOpName("x"), x_t);
4135   Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
4136   Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
4137 
4138   GrapplerItem item;
4139   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4140   item.fetch = {"z"};
4141   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
4142 
4143   ConstantFolding optimizer(/*cpu_device=*/nullptr);
4144   GraphDef output;
4145   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4146   TF_EXPECT_OK(status);
4147 
4148   ASSERT_EQ(output.node_size(), 1);
4149   const NodeDef& node = output.node(0);
4150   EXPECT_EQ(node.name(), "z");
4151   EXPECT_EQ(node.op(), "Const");
4152 
4153   auto tensors = EvaluateNodes(output, item.fetch, {});
4154   ASSERT_EQ(tensors.size(), 1);
4155   ASSERT_EQ(tensors_expected.size(), 1);
4156   test::ExpectTensorEqual<int64_t>(tensors[0], tensors_expected[0]);
4157 }
4158 
TEST_F(ConstantFoldingTest,SimplifyCase)4159 TEST_F(ConstantFoldingTest, SimplifyCase) {
4160   using test::function::NDef;
4161 
4162   for (int index = 0; index < 2; ++index) {
4163     // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
4164     GrapplerItem item;
4165     constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
4166     AttrValue branches;
4167     auto* f = branches.mutable_list()->add_func();
4168     f->set_name("XTimesTwo");
4169     (*f->mutable_attr())["T"].set_type(DT_FLOAT);
4170     auto* g = branches.mutable_list()->add_func();
4171     *g = *f;
4172     g->set_name("NonZero");
4173 
4174     // Add a pair of somewhat arbitrary output shapes to
4175     // test that they are correctly propagates to the _output_shapes
4176     // attribute.
4177     AttrValue output_shapes;
4178     // The first shape is a scalar.
4179     output_shapes.mutable_list()->add_shape();
4180     // The second shape is unknown.
4181     TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
4182     g_shape->set_unknown_rank(true);
4183 
4184     const Tensor kZero = test::AsScalar<int32>(0);
4185     const Tensor kOne = test::AsScalar<int32>(1);
4186     item.graph = test::function::GDef(
4187         {NDef("one", "Const", {},
4188               {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
4189               kDevice),
4190          NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
4191          NDef("case", "Case", {"one", "x"},
4192               {{"Tin", DataTypeSlice{DT_FLOAT}},
4193                {"Tout", DataTypeSlice{DT_FLOAT}},
4194                {"branches", branches},
4195                {"output_shapes", output_shapes}},
4196               kDevice),
4197          NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
4198         // FunctionLib
4199         {
4200             test::function::XTimesTwo(),
4201             test::function::NonZero(),
4202         });
4203     VLOG(1) << "Before: " << item.graph.DebugString();
4204 
4205     item.fetch = {"y"};
4206     const Tensor kTwo = test::AsScalar<float>(2.0f);
4207     auto tensors_expected =
4208         EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
4209 
4210     ConstantFolding optimizer(/*cpu_device=*/nullptr);
4211     GraphDef optimized_graph;
4212     TF_ASSERT_OK(
4213         optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4214     VLOG(1) << "After: " << optimized_graph.DebugString();
4215 
4216     int pco_count = 0;
4217     for (const auto& node : optimized_graph.node()) {
4218       EXPECT_NE(node.op(), "Case");
4219       if (node.op() == "PartitionedCall") {
4220         ++pco_count;
4221         const auto& shape_list = node.attr().at("_output_shapes").list();
4222         ASSERT_EQ(shape_list.shape_size(), 1);
4223         EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
4224         if (index == 0) {
4225           EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
4226           EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
4227         } else {
4228           EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
4229           EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
4230         }
4231       }
4232     }
4233     EXPECT_EQ(pco_count, 1);
4234 
4235     auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
4236     ASSERT_EQ(tensors.size(), tensors_expected.size());
4237     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4238   }
4239 }
4240 
TEST_F(ConstantFoldingTest,SimplifySelect)4241 TEST_F(ConstantFoldingTest, SimplifySelect) {
4242   for (bool scalar_pred : {true, false}) {
4243     for (bool pred_val : {true, false}) {
4244       tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4245       std::unique_ptr<Tensor> if_t;
4246       if (scalar_pred) {
4247         if_t.reset(new Tensor(DT_BOOL, TensorShape()));
4248       } else {
4249         if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
4250       }
4251       for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
4252         if_t->flat<bool>()(i) = pred_val;
4253       }
4254       Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4255       Output then_ =
4256           ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4257                            ops::Placeholder::Shape(TensorShape({2, 2})));
4258       Output else_ =
4259           ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4260                            ops::Placeholder::Shape(TensorShape({2, 2})));
4261       Output select =
4262           ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4263       Output id = ops::Identity(scope.WithOpName("id"), select);
4264 
4265       GrapplerItem item;
4266       TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4267       item.fetch = {"id"};
4268 
4269       const Tensor kOne =
4270           test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
4271       const Tensor kTwo =
4272           test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
4273       auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4274                                             {{"then", kOne}, {"else", kTwo}});
4275 
4276       // Use aggressive mode to force the shape inference to propagate
4277       // placeholder shapes.
4278       ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4279                                 /*cpu_device=*/nullptr);
4280       GraphDef optimized_graph;
4281       TF_EXPECT_OK(
4282           optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4283 
4284       ASSERT_EQ(optimized_graph.node_size(), 5);
4285       bool found = false;
4286       for (const auto& node : optimized_graph.node()) {
4287         if (node.name() == "select") {
4288           found = true;
4289           EXPECT_EQ(node.op(), "Identity");
4290           ASSERT_EQ(node.input_size(), 3);
4291           EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4292           EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
4293           EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4294         }
4295       }
4296       EXPECT_TRUE(found);
4297 
4298       auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4299                                    {{"then", kOne}, {"else", kTwo}});
4300       ASSERT_EQ(tensors.size(), 1);
4301       ASSERT_EQ(tensors_expected.size(), 1);
4302       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4303     }
4304   }
4305 }
4306 
TEST_F(ConstantFoldingTest,SimplifySelect_BroadcastTo)4307 TEST_F(ConstantFoldingTest, SimplifySelect_BroadcastTo) {
4308   for (TensorShape pred_shape : {TensorShape{2, 1}, TensorShape{2, 2, 1}}) {
4309     for (bool pred_val : {true, false}) {
4310       tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4311       std::unique_ptr<Tensor> if_t;
4312       if_t.reset(new Tensor(DT_BOOL, pred_shape));
4313       for (int i = 0; i < pred_shape.num_elements(); ++i) {
4314         if_t->flat<bool>()(i) = pred_val;
4315       }
4316       Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
4317       Output then_ =
4318           ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
4319                            ops::Placeholder::Shape(TensorShape({2, 1})));
4320       Output else_ =
4321           ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
4322                            ops::Placeholder::Shape(TensorShape({2, 4})));
4323       Output select =
4324           ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
4325       Output id = ops::Identity(scope.WithOpName("id"), select);
4326 
4327       GrapplerItem item;
4328       TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4329       item.fetch = {"id"};
4330 
4331       const Tensor kOne =
4332           test::AsTensor<float>({1.0f, 1.0f}, TensorShape({2, 1}));
4333       const Tensor kTwo = test::AsTensor<float>(
4334           {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f},
4335           TensorShape({2, 4}));
4336       auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
4337                                             {{"then", kOne}, {"else", kTwo}});
4338 
4339       // Use aggressive mode to force the shape inference to propagate
4340       // placeholder shapes.
4341       ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
4342                                 /*cpu_device=*/nullptr);
4343       GraphDef optimized_graph;
4344       TF_EXPECT_OK(
4345           optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
4346 
4347       ASSERT_EQ(optimized_graph.node_size(), 6);
4348       bool found = false;
4349       for (const auto& node : optimized_graph.node()) {
4350         if (node.name() == "select") {
4351           found = true;
4352           EXPECT_EQ(node.op(), "BroadcastTo");
4353           ASSERT_EQ(node.input_size(), 4);
4354           EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
4355           EXPECT_EQ(node.input(1),
4356                     strings::StrCat("ConstantFolding/select-broadcastto_shape-",
4357                                     pred_val ? 1 : 2));
4358           EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
4359           EXPECT_EQ(node.input(3), pred_val ? "^if" : "^then");
4360         }
4361       }
4362       EXPECT_TRUE(found);
4363 
4364       auto tensors = EvaluateNodes(optimized_graph, item.fetch,
4365                                    {{"then", kOne}, {"else", kTwo}});
4366       ASSERT_EQ(tensors.size(), 1);
4367       ASSERT_EQ(tensors_expected.size(), 1);
4368       ASSERT_EQ(tensors[0].shape(), pred_shape.num_elements() == 2
4369                                         ? TensorShape({2, 4})
4370                                         : TensorShape({2, 2, 4}));
4371       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4372     }
4373   }
4374 }
4375 
TEST_F(ConstantFoldingTest,QuantizationEmulation)4376 TEST_F(ConstantFoldingTest, QuantizationEmulation) {
4377   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
4378   Output x = ops::Const(scope.WithOpName("x"), {0.0f, 1.0f, 2.0f, 3.0f}, {4});
4379   Output min_range = ops::Const(scope.WithOpName("min_range"), 0.0f, {});
4380   Output max_range = ops::Const(scope.WithOpName("max_range"), 3.0f, {});
4381   Output y = ops::QuantizeAndDequantizeV2(scope.WithOpName("y"), x, min_range,
4382                                           max_range);
4383   Output id = ops::Identity(scope.WithOpName("id"), y);
4384 
4385   GrapplerItem item;
4386   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
4387   item.fetch = {"id"};
4388 
4389   std::vector<Tensor> expected_tensors = EvaluateNodes(item.graph, item.fetch);
4390 
4391   for (const bool fold_quantization_emulation : {false, true}) {
4392     ConstantFolding optimizer(/*cpu_device=*/nullptr,
4393                               /*disable_compressed_tensor_optimization=*/false,
4394                               fold_quantization_emulation);
4395     GraphDef output;
4396     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
4397     int num_quantization_emulation_ops = 0;
4398     for (const NodeDef& node : output.node()) {
4399       if (node.op() == "QuantizeAndDequantizeV2") {
4400         num_quantization_emulation_ops++;
4401       }
4402     }
4403     EXPECT_EQ(fold_quantization_emulation ? 0 : 1,
4404               num_quantization_emulation_ops);
4405 
4406     std::vector<Tensor> actual_tensors = EvaluateNodes(output, item.fetch);
4407     for (int i = 0; i < item.fetch.size(); ++i) {
4408       test::ExpectTensorEqual<float>(expected_tensors[i], actual_tensors[i]);
4409     }
4410   }
4411 }
4412 
4413 }  // namespace
4414 }  // namespace grappler
4415 }  // namespace tensorflow
4416