xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/graph_properties_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40 
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44 
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48 
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50 
51 REGISTER_OP("TestOpWithNoInferenceFn")
52     .Input("x: float")
53     .Output("y: float")
54     .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59 
60 class GraphPropertiesTest : public ::testing::Test {
61  public:
SetUp()62   void SetUp() override {
63     // Provision a single machine with 3 cpu cores
64     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65     TF_ASSERT_OK(cluster_->Provision());
66 
67     // This function is simply
68     // out = Fill(shape, value), but
69     // Fill requires values in the shape input, not just shape of it, to infer
70     // output shape.
71     auto f = FunctionDefHelper::Create(
72         // Name
73         "MyFillFunc",
74         // Inputs
75         {"shape: int32", "value: float"},
76         // Outputs
77         {"out: float"},
78         // Attrs
79         {},
80         // Nodes
81         {
82             {{"a"},
83              "Fill",
84              {"shape", "value"},
85              {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86         },
87         // Returns
88         {{"out", "a:output:0"}});
89     function_lib_.add_function()->Swap(&f);
90   }
91 
TearDown()92   void TearDown() override {
93     TF_ASSERT_OK(cluster_->Shutdown());
94     cluster_.reset();
95   }
96 
97  protected:
98   // Returns a string form of <p>, suitable for comparing type and shape.
99   // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100   string PropToString(const OpInfo::TensorProperties& p) {
101     string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102     if (p.shape().unknown_rank()) {
103       strings::StrAppend(&s, "?");
104     } else {
105       strings::StrAppend(&s, "[");
106       for (int i = 0; i < p.shape().dim_size(); ++i) {
107         strings::StrAppend(&s, i == 0 ? "" : ",",
108                            std::max<int64_t>(p.shape().dim(i).size(), -1));
109       }
110       strings::StrAppend(&s, "]");
111     }
112     return s;
113   }
114 
115   // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116   // ones.
ExpectTensorValues(const std::vector<int64_t> & expected,const TensorProto & tensor_proto_to_compare)117   void ExpectTensorValues(const std::vector<int64_t>& expected,
118                           const TensorProto& tensor_proto_to_compare) {
119     Tensor tensor;
120     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121     EXPECT_EQ(expected.size(), tensor.NumElements());
122     // We're interested in only integer tensors as only shapes are exported as
123     // graph properties values.
124     ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125     if (tensor.dtype() == DT_INT32) {
126       for (int i = 0; i < tensor.NumElements(); i++) {
127         EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128       }
129     } else {
130       for (int i = 0; i < tensor.NumElements(); i++) {
131         EXPECT_EQ(expected[i], tensor.flat<int64_t>()(i));
132       }
133     }
134   }
135 
136   // Compare values of float (DT_FLOAT) tensor against expected
137   // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138   void ExpectFloatTensorValues(const std::vector<float>& expected,
139                                const TensorProto& tensor_proto_to_compare) {
140     Tensor tensor;
141     ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142     EXPECT_EQ(expected.size(), tensor.NumElements());
143     ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144     for (int i = 0; i < tensor.NumElements(); i++) {
145       EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146     }
147   }
148 
149   std::unique_ptr<SingleMachine> cluster_;
150   FunctionDefLibrary function_lib_;
151 };
152 
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155                                           cluster_->GetDeviceNames());
156   GrapplerItem item;
157   CHECK(fake_input.NextItem(&item));
158 
159   GraphProperties properties(item);
160   Status s = properties.InferStatically(true);
161   TF_ASSERT_OK(s);
162 
163   for (const auto& node : item.graph.node()) {
164     if (node.op() == "RandomStandardNormal") {
165       // The node has one input (the shape of the tensor to generate).
166       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167       // The const node has one output.
168       const auto props = properties.GetOutputProperties(node.name());
169       EXPECT_EQ(1, props.size());
170       const OpInfo::TensorProperties& prop = props[0];
171       EXPECT_EQ(DT_FLOAT, prop.dtype());
172       EXPECT_FALSE(prop.shape().unknown_rank());
173       EXPECT_EQ(2, prop.shape().dim_size());
174       EXPECT_EQ(10, prop.shape().dim(0).size());
175       EXPECT_EQ(1, prop.shape().dim(1).size());
176     } else if (node.op() == "AddN") {
177       const auto in_props = properties.GetInputProperties(node.name());
178       EXPECT_EQ(1, in_props.size());
179       const OpInfo::TensorProperties& in_prop = in_props[0];
180       EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181       EXPECT_FALSE(in_prop.shape().unknown_rank());
182       EXPECT_EQ(2, in_prop.shape().dim_size());
183       EXPECT_EQ(10, in_prop.shape().dim(0).size());
184       EXPECT_EQ(1, in_prop.shape().dim(1).size());
185       const auto out_props = properties.GetOutputProperties(node.name());
186       EXPECT_EQ(1, out_props.size());
187       EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188       EXPECT_EQ(in_prop.shape().DebugString(),
189                 out_props[0].shape().DebugString());
190     }
191   }
192 }
193 
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196                                           cluster_->GetDeviceNames());
197   GrapplerItem item;
198   CHECK(fake_input.NextItem(&item));
199 
200   GraphProperties properties(item);
201   Status s = properties.InferStatically(true);
202   TF_ASSERT_OK(s);
203 
204   for (const auto& node : item.graph.node()) {
205     if (node.op() == "RandomStandardNormal") {
206       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207       const auto props = properties.GetOutputProperties(node.name());
208       properties.ClearOutputProperties(node.name());
209       const auto cleared_props = properties.GetOutputProperties(node.name());
210       EXPECT_TRUE(cleared_props.empty());
211     } else if (node.op() == "AddN") {
212       const auto in_props = properties.GetInputProperties(node.name());
213       EXPECT_EQ(1, in_props.size());
214       properties.ClearInputProperties(node.name());
215       const auto cleared_props = properties.GetInputProperties(node.name());
216       EXPECT_TRUE(cleared_props.empty());
217     }
218   }
219 }
220 
TEST_F(GraphPropertiesTest,Clear)221 TEST_F(GraphPropertiesTest, Clear) {
222   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223                                           cluster_->GetDeviceNames());
224   GrapplerItem item;
225   CHECK(fake_input.NextItem(&item));
226 
227   GraphProperties properties(item);
228   Status s = properties.InferStatically(true);
229   TF_ASSERT_OK(s);
230 
231   EXPECT_TRUE(properties.has_properties());
232   properties.Clear();
233   EXPECT_FALSE(properties.has_properties());
234 }
235 
TEST_F(GraphPropertiesTest,DynamicProperties)236 TEST_F(GraphPropertiesTest, DynamicProperties) {
237   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
238                                           cluster_->GetDeviceNames());
239   GrapplerItem item;
240   CHECK(fake_input.NextItem(&item));
241 
242   GraphProperties properties(item);
243   TF_ASSERT_OK(cluster_->Initialize(item));
244   Status s = properties.InferDynamically(cluster_.get());
245   TF_ASSERT_OK(s);
246 
247   for (const auto& node : item.graph.node()) {
248     if (node.op() == "RandomStandardNormal") {
249       // The random node is missing from the cost graph (why ?)
250       EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
251     } else if (node.op() == "AddN") {
252       // Since the random node is missing, we can't infer the input properties
253       // of the first AddN node. The other AddN nodes have the expected
254       // properties.
255       if (node.name() == "AddN") {
256         const auto props = properties.GetInputProperties(node.name());
257         EXPECT_EQ(1, props.size());
258         const OpInfo::TensorProperties& prop = props[0];
259         EXPECT_EQ(DT_INVALID, prop.dtype());
260         EXPECT_TRUE(prop.shape().unknown_rank());
261       } else {
262         const auto props = properties.GetInputProperties(node.name());
263         EXPECT_EQ(1, props.size());
264         const OpInfo::TensorProperties& prop = props[0];
265         EXPECT_EQ(DT_FLOAT, prop.dtype());
266         EXPECT_FALSE(prop.shape().unknown_rank());
267         EXPECT_EQ(2, prop.shape().dim_size());
268         EXPECT_EQ(10, prop.shape().dim(0).size());
269         EXPECT_EQ(1, prop.shape().dim(1).size());
270         const auto out_props = properties.GetOutputProperties(node.name());
271 #ifdef INTEL_MKL
272         if (!NativeFormatEnabled()) {
273           // Intel MKL AddN OP would have two output.
274           // One is the real output, another one for MKL metadata
275           EXPECT_EQ(2, out_props.size());
276         } else {
277           EXPECT_EQ(1, out_props.size());
278         }
279 #else
280         EXPECT_EQ(1, out_props.size());
281 #endif  // INTEL_MKL
282         string prop_str;
283         ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
284         string out_prop_str;
285         ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
286                                                           &out_prop_str);
287         EXPECT_EQ(prop_str, out_prop_str);
288       }
289     }
290   }
291 }
292 
293 // A test op that outputs different shape based on input_tensor in the shape
294 // inference context.
295 REGISTER_OP("DetectInputValueInShapeInferenceOp")
296     .Input("a: T")
297     .Output("o: T")
298     .Attr("T: {numbertype, bool}")
__anon3628d9fb0202(shape_inference::InferenceContext* c) 299     .SetShapeFn([](shape_inference::InferenceContext* c) {
300       if (c->input_tensor(0)) {
301         // 10x10 if input_tensor is given to the inference context.
302         c->set_output(0, c->Matrix(10, 10));
303         return OkStatus();
304       }
305       // unknown rank if input_tensor is not provided.
306       return shape_inference::UnknownShape(c);
307     });
308 
309 // Helper class for testing Const tensor skip.
310 class ConstTensorSkipTestCase {
311  public:
ConstTensorSkipTestCase(const DataType data_type,const std::vector<int64_t> shape,const double value,const bool expected)312   ConstTensorSkipTestCase(const DataType data_type,
313                           const std::vector<int64_t> shape, const double value,
314                           const bool expected)
315       : data_type_(data_type),
316         shape_(shape),
317         value_(value),
318         expected_(expected) {}
319 
RunTestAndValidate() const320   void RunTestAndValidate() const {
321     LOG(INFO) << "Run Const tensor skip test: "
322               << "data_type: " << data_type_ << ", shape: {"
323               << absl::StrJoin(shape_, ",") << "}, value: " << value_
324               << ", expected: " << expected_;
325     // Build a graph with Const --> Identity --> Detect.
326     GrapplerItem item;
327     const gtl::ArraySlice<int64_t> shape_array_slice(shape_);
328     Tensor const_tensor_value(data_type_, TensorShape(shape_array_slice));
329     // Fill the const tensor value based on data type.
330     switch (data_type_) {
331       case DT_INT32:
332         test::FillIota<int32>(&const_tensor_value, static_cast<int32>(value_));
333         break;
334       case DT_INT64:
335         test::FillIota<int64_t>(&const_tensor_value,
336                                 static_cast<int64_t>(value_));
337         break;
338       case DT_FLOAT:
339         test::FillIota<float>(&const_tensor_value, static_cast<float>(value_));
340         break;
341       case DT_DOUBLE:
342         test::FillIota<double>(&const_tensor_value,
343                                static_cast<double>(value_));
344         break;
345       case DT_BFLOAT16:
346         test::FillIota<Eigen::bfloat16>(&const_tensor_value,
347                                         static_cast<Eigen::bfloat16>(value_));
348         break;
349       default:
350         CHECK(false) << "Unsupported data type (" << data_type_
351                      << ") in this test.";
352         break;
353     }
354     TF_ASSERT_OK(NodeDefBuilder("const", "Const")
355                      .Attr("dtype", data_type_)
356                      .Attr("value", const_tensor_value)
357                      .Finalize(item.graph.add_node()));
358     TF_ASSERT_OK(NodeDefBuilder("const_identity", "Identity")
359                      .Attr("dtype", data_type_)
360                      .Input("const", 0, data_type_)
361                      .Finalize(item.graph.add_node()));
362     TF_ASSERT_OK(NodeDefBuilder("detect", "DetectInputValueInShapeInferenceOp")
363                      .Attr("T", data_type_)
364                      .Input("const_identity", 0, data_type_)
365                      .Finalize(item.graph.add_node()));
366     item.fetch.push_back("const");
367     item.fetch.push_back("const_identity");
368     item.fetch.push_back("detect");
369 
370     // Run static shape inference.
371     GraphProperties graph_properties(item);
372     TF_ASSERT_OK(graph_properties.InferStatically(false));
373 
374     // Extract input / output properties of interest.
375     const auto& const_output = graph_properties.GetOutputProperties("const");
376     EXPECT_EQ(1, const_output.size());
377     const OpInfo::TensorProperties& const_output0 = const_output[0];
378     const auto& const_identity_input =
379         graph_properties.GetInputProperties("const_identity");
380     EXPECT_EQ(1, const_identity_input.size());
381     const OpInfo::TensorProperties& const_identity_input0 =
382         const_identity_input[0];
383     const auto& const_identity_output =
384         graph_properties.GetOutputProperties("const_identity");
385     EXPECT_EQ(1, const_identity_output.size());
386     const OpInfo::TensorProperties& const_identity_output0 =
387         const_identity_output[0];
388     EXPECT_TRUE(const_output0.has_value());
389     EXPECT_TRUE(const_identity_input0.has_value());
390     EXPECT_TRUE(const_identity_output0.has_value());
391     const auto& detect_input = graph_properties.GetInputProperties("detect");
392     EXPECT_EQ(1, detect_input.size());
393     const OpInfo::TensorProperties& detect_input0 = detect_input[0];
394     const auto& detect_output = graph_properties.GetOutputProperties("detect");
395     EXPECT_EQ(1, detect_output.size());
396     const OpInfo::TensorProperties& detect_output0 = detect_output[0];
397 
398     // Tensor protos are propagated, regardless of types and sizes.
399     EXPECT_TRUE(const_output0.has_value());
400     EXPECT_TRUE(const_identity_input0.has_value());
401     EXPECT_TRUE(const_identity_output0.has_value());
402     EXPECT_TRUE(detect_input0.has_value());
403 
404     // Detect op outputs 10x10 matrix if it has input_tensor in the shape
405     // inference context. Otherwise, unknown rank.
406     if (expected_) {
407       EXPECT_EQ(detect_output0.shape().dim_size(), 2);
408       EXPECT_EQ(detect_output0.shape().dim(0).size(), 10);
409       EXPECT_EQ(detect_output0.shape().dim(1).size(), 10);
410     } else {
411       EXPECT_TRUE(detect_output0.shape().unknown_rank());
412     }
413   }
414 
415  private:
416   DataType data_type_;
417   std::vector<int64_t> shape_;
418   double value_;
419   bool expected_;
420 };
421 
TEST_F(GraphPropertiesTest,SkipInstantiatingConstTensor)422 TEST_F(GraphPropertiesTest, SkipInstantiatingConstTensor) {
423   // We skip const tensor value propagation in shape inference, if a const
424   // tensor is too large.
425   std::vector<ConstTensorSkipTestCase> test_cases = {
426       // data_type, shape, value, bool: propagate const?
427       {DT_INT32, {16, 8}, 1, true},      // 128 elements; smaller than threshold
428       {DT_INT32, {1, 129}, 2, false},    // 129 elements; larger than threshold
429       {DT_INT64, {8, 8}, 3, true},       // 64 elements; smaller than threshold
430       {DT_INT64, {128, 2}, 0, false},    // 256 elements; larger than threshold
431       {DT_FLOAT, {16, 8}, 1.0, true},    // integer value for float tensor
432       {DT_FLOAT, {16, 8}, 1.3, true},    // fractional value (1.3)
433       {DT_FLOAT, {1, 129}, 0.7, false},  // fractional value (0.7)
434       {DT_DOUBLE, {16, 8}, 1.0, true},   // integer value for float tensor
435       {DT_DOUBLE, {16, 8}, 1.3, true},   // fractional value (1.3)
436       {DT_DOUBLE, {1, 129}, 0.7, false},    // fractional value (0.7)
437       {DT_BFLOAT16, {16, 8}, 1.0, true},    // integer value for float tensor
438       {DT_BFLOAT16, {16, 8}, 1.3, true},    // fractional value (1.3)
439       {DT_BFLOAT16, {1, 129}, 0.7, false},  // fractional value (0.7)
440   };
441   for (const auto& test_case : test_cases) {
442     test_case.RunTestAndValidate();
443   }
444 }
445 
TEST_F(GraphPropertiesTest,Variables)446 TEST_F(GraphPropertiesTest, Variables) {
447   GrapplerItem item;
448   TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
449                    .Attr("dtype", DT_FLOAT)
450                    .Attr("shape", TensorShape({3, 7}))
451                    .Finalize(item.graph.add_node()));
452   item.fetch.push_back("Var");
453 
454   Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
455   test::FillIota<float>(&initial_val, 0);
456   TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
457                    .Attr("dtype", DT_FLOAT)
458                    .Attr("value", initial_val)
459                    .Finalize(item.graph.add_node()));
460   TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
461                    .Input("Var", 0, DT_FLOAT_REF)
462                    .Input("InitialVal", 0, DT_FLOAT)
463                    .Finalize(item.graph.add_node()));
464   item.init_ops.push_back("InitVar");
465 
466   {
467     GraphProperties static_properties(item);
468     TF_ASSERT_OK(static_properties.InferStatically(false));
469 
470     const auto props = static_properties.GetOutputProperties("Var");
471     EXPECT_EQ(1, props.size());
472     const OpInfo::TensorProperties& prop = props[0];
473     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
474     EXPECT_FALSE(prop.shape().unknown_rank());
475     EXPECT_EQ(2, prop.shape().dim_size());
476     EXPECT_EQ(3, prop.shape().dim(0).size());
477     EXPECT_EQ(7, prop.shape().dim(1).size());
478   }
479   {
480     TF_ASSERT_OK(cluster_->Initialize(item));
481     GraphProperties dynamic_properties(item);
482     TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
483 
484     const auto props = dynamic_properties.GetOutputProperties("Var");
485     EXPECT_EQ(1, props.size());
486     const OpInfo::TensorProperties& prop = props[0];
487     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
488     EXPECT_FALSE(prop.shape().unknown_rank());
489     EXPECT_EQ(2, prop.shape().dim_size());
490     EXPECT_EQ(3, prop.shape().dim(0).size());
491     EXPECT_EQ(7, prop.shape().dim(1).size());
492   }
493 }
494 
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)495 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
496   GrapplerItem item;
497   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
498                    .Attr("dtype", DT_FLOAT)
499                    .Attr("shape", TensorShape({3, 7}))
500                    .Finalize(item.graph.add_node()));
501   TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
502                    .Attr("T", DT_RESOURCE)
503                    .Attr("frame_name", "while_context")
504                    .Attr("is_constant", true)
505                    .Attr("parallel_iterations", 10)
506                    .Input("Var", 0, DT_RESOURCE)
507                    .Finalize(item.graph.add_node()));
508   TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
509                    .Attr("dtype", DT_FLOAT)
510                    .Input("Enter", 0, DT_RESOURCE)
511                    .Finalize(item.graph.add_node()));
512 
513   GraphProperties properties(item);
514   TF_ASSERT_OK(properties.InferStatically(false));
515   const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
516   EXPECT_EQ(1, props.size());
517   const OpInfo::TensorProperties& prop = props[0];
518   EXPECT_EQ(DT_FLOAT, prop.dtype());
519   EXPECT_FALSE(prop.shape().unknown_rank());
520   EXPECT_EQ(2, prop.shape().dim_size());
521   EXPECT_EQ(3, prop.shape().dim(0).size());
522   EXPECT_EQ(7, prop.shape().dim(1).size());
523 }
524 
TEST_F(GraphPropertiesTest,VarHandles)525 TEST_F(GraphPropertiesTest, VarHandles) {
526   GrapplerItem item;
527   TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
528                    .Attr("dtype", DT_FLOAT)
529                    .Attr("shape", TensorShape({3, 7}))
530                    .Finalize(item.graph.add_node()));
531 
532   TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
533                    .Attr("dtype", DT_FLOAT)
534                    .Input("Var", 0, DT_RESOURCE)
535                    .Finalize(item.graph.add_node()));
536 
537   GraphProperties properties(item);
538   TF_ASSERT_OK(properties.InferStatically(false));
539 
540   const auto props = properties.GetOutputProperties("VarRead");
541   EXPECT_EQ(1, props.size());
542   const OpInfo::TensorProperties& prop = props[0];
543   EXPECT_EQ(DT_FLOAT, prop.dtype());
544   EXPECT_FALSE(prop.shape().unknown_rank());
545   EXPECT_EQ(2, prop.shape().dim_size());
546   EXPECT_EQ(3, prop.shape().dim(0).size());
547   EXPECT_EQ(7, prop.shape().dim(1).size());
548 }
549 
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)550 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
551   // Test graph is first generated in python using:
552   /*
553     i0 = tf.constant(0)
554     v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
555     def cond(i, x):
556       return i < 3
557     def body(i, x):
558       return i + 1, x + x
559     v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
560   */
561   // and then modified by hand such that the ReadVariableOp is inside the loop
562   // body instead of outside the while loop (which is the case when constructed
563   // using the python API), such that we have the following pattern: VarHandleOp
564   // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
565   // DT_RESOURCE is passed all the way until ReadVariableOp.
566   GrapplerItem item;
567   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
568                                  "while_loop_var_handle_op.pbtxt");
569   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
570   GraphProperties properties(item);
571   TF_ASSERT_OK(properties.InferStatically(false));
572 
573   std::vector<string> resource_nodes{
574       "loop_var",       "while/Enter",         "while/Merge", "while/Switch",
575       "while/Identity", "while/NextIteration", "while/Exit"};
576   for (const string& node : resource_nodes) {
577     const auto props = properties.GetOutputProperties(node);
578     EXPECT_GE(props.size(), 1);  // Merge has 2 outputs.
579     EXPECT_EQ("resource: []", PropToString(props[0]));
580   }
581 
582   // After ReadVariableOp, the shape should be recovered.
583   const auto props = properties.GetOutputProperties("while/ReadVariableOp");
584   EXPECT_EQ(1, props.size());
585   EXPECT_EQ("int32: []", PropToString(props[0]));
586 }
587 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)588 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
589   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
590   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
591   auto dequeue1 =
592       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
593 
594   GrapplerItem item;
595   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
596 
597   GraphProperties properties(item);
598   TF_ASSERT_OK(properties.InferStatically(false));
599 
600   const auto props1 = properties.GetOutputProperties("Dequeue1");
601   ASSERT_EQ(1, props1.size());
602   EXPECT_EQ("float: ?", PropToString(props1[0]));
603 }
604 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)605 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
606   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
607   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
608                            ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
609   auto dequeue1 =
610       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
611 
612   GrapplerItem item;
613   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
614 
615   GraphProperties properties(item);
616   TF_ASSERT_OK(properties.InferStatically(false));
617 
618   const auto props1 = properties.GetOutputProperties("Dequeue1");
619   ASSERT_EQ(1, props1.size());
620   EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
621 }
622 
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)623 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
624   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
625   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
626                            ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
627   auto dequeue1 =
628       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
629 
630   GrapplerItem item;
631   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
632 
633   GraphProperties properties(item);
634   TF_ASSERT_OK(properties.InferStatically(false));
635 
636   const auto props1 = properties.GetOutputProperties("Dequeue1");
637   ASSERT_EQ(1, props1.size());
638   EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
639 }
640 
TEST_F(GraphPropertiesTest,Queues)641 TEST_F(GraphPropertiesTest, Queues) {
642   // Create a graph with known input shapes, and propagate the shapes through a
643   // couple of queues.
644   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
645 
646   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
647   Output rnd =
648       ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
649   Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
650   auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
651   auto dequeue1 =
652       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
653 
654   auto q2 =
655       ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
656   Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
657   auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
658   auto dequeue2 =
659       ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
660 
661   auto q4 =
662       ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
663   auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
664   auto enqueue4_2 =
665       ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
666   auto dequeue4 =
667       ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
668 
669   // Create a queue that takes in three tensors.
670   auto q5 = ops::RandomShuffleQueue(
671       root.WithOpName("Queue5"),
672       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
673   Output rnd2 =
674       ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
675   Output rnd3 =
676       ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
677   auto enqueue5 =
678       ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
679   auto dequeue5 = ops::QueueDequeue(
680       root.WithOpName("Dequeue5"), q5,
681       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
682 
683   GrapplerItem item;
684   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
685 
686   GraphProperties properties(item);
687   TF_ASSERT_OK(properties.InferStatically(false));
688 
689   const auto props1 = properties.GetOutputProperties("Dequeue1");
690   ASSERT_EQ(1, props1.size());
691   EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
692 
693   const auto props2 = properties.GetOutputProperties("Dequeue2");
694   ASSERT_EQ(1, props2.size());
695   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
696 
697   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
698   // that we merge the 2 properly to determine the shape of the data coming out
699   // of the queue.
700   const auto props4 = properties.GetOutputProperties("Dequeue4");
701   ASSERT_EQ(1, props4.size());
702   EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
703 
704   // The dequeue5 op shape is known.
705   const auto props5 = properties.GetOutputProperties("Dequeue5");
706   ASSERT_EQ(3, props5.size());
707   EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
708   EXPECT_EQ("double: [10]", PropToString(props5[1]));
709   EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
710 }
711 
TEST_F(GraphPropertiesTest,MergeWithoutLoops)712 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
713   // Test graph produced in python using:
714   /*
715     with tf.Graph().as_default():
716       x = tf.constant(2)
717       y = tf.constant(5)
718       z = tf.ones([1,1,1])
719       def f1(): return tf.concat([z, z], axis=0)
720       def f2(): return tf.concat([z, z], axis=1)
721       r = tf.cond(tf.less(x, y), f1, f2)
722       tf.concat([r, r], axis=2)
723       with open('/tmp/graph.pbtxt', 'w') as f:
724         f.write(str(tf.get_default_graph().as_graph_def()))
725    */
726 
727   GrapplerItem item;
728   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
729                                  "merge_without_loops.pbtxt");
730   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
731   GraphProperties properties(item);
732   TF_ASSERT_OK(properties.InferStatically(false));
733 
734   std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
735   std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
736                                        "float: [1,2,1]"};
737   for (int i = 0; i < nodes.size(); i++) {
738     const auto props = properties.GetOutputProperties(nodes[i]);
739     const OpInfo::TensorProperties& prop = props[0];
740     EXPECT_EQ(DT_FLOAT, prop.dtype());
741     EXPECT_EQ(expected_outputs[i], PropToString(prop));
742   }
743 
744   // The "Less" node should be fed by 2 int32 scalar constant values.
745   const auto props = properties.GetInputProperties("Less");
746   EXPECT_EQ(2, props.size());
747   for (int i = 0; i < props.size(); ++i) {
748     EXPECT_EQ(DT_INT32, props[i].dtype());
749     EXPECT_TRUE(props[i].has_value());
750     EXPECT_EQ("int32: []", PropToString(props[i]));
751   }
752 }
753 
TEST_F(GraphPropertiesTest,WhileLoop)754 TEST_F(GraphPropertiesTest, WhileLoop) {
755   // Test graph produced in python using:
756   /*
757      with tf.Graph().as_default():
758        i0 = tf.constant(0)
759        m0 = tf.placeholder([-1, 2])
760        c = lambda i, m: i < 10
761        b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
762        r = tf.while_loop(
763               c, b, loop_vars=[i0, m0],
764               shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
765        with open('/tmp/graph.pbtxt', 'w') as f:
766          f.write(str(tf.get_default_graph().as_graph_def()))
767   */
768 
769   GrapplerItem item;
770   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
771                                  "while_loop.pbtxt");
772   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
773   GraphProperties properties(item);
774   TF_ASSERT_OK(properties.InferStatically(false));
775 
776   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
777                             "while/Exit_1"};
778   for (const string& node : nodes) {
779     const auto props = properties.GetOutputProperties(node);
780     const OpInfo::TensorProperties& prop = props[0];
781     EXPECT_EQ(DT_FLOAT, prop.dtype());
782     EXPECT_EQ("float: [-1,2]", PropToString(prop));
783   }
784 
785   // The loop outputs batch dim should be different from the input batch dim
786   // since we concatenated along the batch dim.
787   auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
788   auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
789   EXPECT_GE(-2, shape_in.dim(0).size());
790   EXPECT_GE(-2, shape_out.dim(0).size());
791   EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
792 }
793 
TEST_F(GraphPropertiesTest,NestedLoop)794 TEST_F(GraphPropertiesTest, NestedLoop) {
795   // Test graph produced in python using:
796   /*
797     with tf.Graph().as_default():
798       i0 = tf.constant(0)
799 
800       def inner(j, y):
801         def inner_cond(j, y):
802           return j < 3
803 
804         def inner_body(j, y):
805           return j+1, tf.concat([y, y], axis=2)
806 
807         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
808                              shape_invariants=[i0.get_shape(),
809                                               tf.TensorShape([None, 1, None])])
810 
811       def outer_cond(i, x):
812         return i < 3
813 
814       def outer_body(i, x):
815         j, y = inner(0, x)
816         return i+1, tf.concat([x, x], axis=0)
817 
818       r = tf.while_loop(outer_cond, outer_body,
819                         loop_vars=[i0, tf.ones([1, 1, 1])],
820                         shape_invariants=[i0.get_shape(),
821                                           tf.TensorShape([None, 1, None])])
822 
823       with open('/tmp/graph.pbtxt', 'w') as f:
824         f.write(str(tf.get_default_graph().as_graph_def()))
825   */
826 
827   GrapplerItem item;
828   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
829                                  "nested_loop.pbtxt");
830   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
831   GraphProperties properties(item);
832   TF_ASSERT_OK(properties.InferStatically(false));
833 
834   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
835                                   "while/Exit_1"};
836   std::vector<string> inner_nodes{"while/while/Merge_1",
837                                   "while/while/NextIteration_1",
838                                   "while/while/Exit_1"};
839   for (const string& node : outer_nodes) {
840     const auto props = properties.GetOutputProperties(node);
841     const OpInfo::TensorProperties& prop = props[0];
842     EXPECT_EQ(DT_FLOAT, prop.dtype());
843     EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
844   }
845   for (const string& node : inner_nodes) {
846     const auto props = properties.GetOutputProperties(node);
847     const OpInfo::TensorProperties& prop = props[0];
848     EXPECT_EQ(DT_FLOAT, prop.dtype());
849     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
850   }
851 }
852 
TEST_F(GraphPropertiesTest,LoopsAndQueues)853 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
854   // Test graph produced in python using:
855   /*
856     with tf.Graph().as_default():
857       i0 = tf.constant(0)
858       q = tf.FIFOQueue(1, "float")
859 
860       def inner(j, y):
861         def inner_cond(j, y):
862           return j < 3
863 
864         def inner_body(j, y):
865           return j+1, tf.concat([y, y], axis=0)
866 
867         return tf.while_loop(inner_cond, inner_body,
868                              loop_vars=[j, y],
869                              shape_invariants=[i0.get_shape(),
870                                                tf.TensorShape(None)])
871 
872       def outer_cond(i, x):
873         return i < 3
874 
875       def outer_body(i, x):
876         q.enqueue(x)
877         y = tf.concat([x, x], axis=2)
878         inner(0, q.dequeue())
879         return i+1, y
880 
881       i, z = tf.while_loop(outer_cond, outer_body,
882                            loop_vars=[i0, tf.ones([1, 1, 1])],
883                            shape_invariants=[i0.get_shape(),
884                                              tf.TensorShape([None, 1, None])])
885 
886       with open('/tmp/graph.pbtxt', 'w') as f:
887         f.write(str(tf.get_default_graph().as_graph_def()))
888    */
889 
890   GrapplerItem item;
891   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
892                                  "loops_and_queues.pbtxt");
893   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
894   GraphProperties properties(item);
895   TF_ASSERT_OK(properties.InferStatically(false));
896 
897   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
898                                   "while/Exit_1"};
899   std::vector<string> inner_nodes{"while/while/Merge_1",
900                                   "while/while/NextIteration_1",
901                                   "while/while/Exit_1"};
902   for (const string& node : outer_nodes) {
903     const auto props = properties.GetOutputProperties(node);
904     const OpInfo::TensorProperties& prop = props[0];
905     EXPECT_EQ(DT_FLOAT, prop.dtype());
906     EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
907   }
908   for (const string& node : inner_nodes) {
909     const auto props = properties.GetOutputProperties(node);
910     const OpInfo::TensorProperties& prop = props[0];
911     EXPECT_EQ(DT_FLOAT, prop.dtype());
912     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
913   }
914 }
915 
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)916 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
917   // Test graph produced in python using:
918   /*
919     with tf.Graph().as_default():
920       i0 = tf.constant(0)
921       with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
922         v = tf.get_variable(initializer=i0, name='loop_var')
923 
924       def inner(j, y):
925         def inner_cond(j, y):
926           return j < 3
927 
928         def inner_body(j, y):
929           return j + 1, y + y
930 
931         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
932 
933       def outer_cond(i, x):
934         return i < 3
935 
936       def outer_body(i, x):
937         y = x + x
938         inner(0, v)
939         return i + 1, y
940 
941       v, z = tf.while_loop(outer_cond, outer_body,
942                            loop_vars=[v, tf.constant(1)])
943 
944       with open('/tmp/graph.pbtxt', 'w') as f:
945         f.write(str(tf.get_default_graph().as_graph_def()))
946   */
947 
948   GrapplerItem item;
949   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
950                                  "loops_and_resource_vars.pbtxt");
951   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
952   GraphProperties properties(item);
953   TF_ASSERT_OK(properties.InferStatically(false));
954 
955   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
956                                   "while/Exit_1"};
957   std::vector<string> inner_nodes{"while/while/Merge_1",
958                                   "while/while/NextIteration_1",
959                                   "while/while/Exit_1"};
960   for (const string& node : outer_nodes) {
961     const auto props = properties.GetOutputProperties(node);
962     const OpInfo::TensorProperties& prop = props[0];
963     EXPECT_EQ(DT_INT32, prop.dtype());
964     EXPECT_EQ("int32: []", PropToString(prop));
965   }
966   for (const string& node : inner_nodes) {
967     const auto props = properties.GetOutputProperties(node);
968     const OpInfo::TensorProperties& prop = props[0];
969     EXPECT_EQ(DT_INT32, prop.dtype());
970     EXPECT_EQ("int32: []", PropToString(prop));
971   }
972 }
973 
TEST_F(GraphPropertiesTest,QueuesAndLoops)974 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
975   // Test graph produced in python using:
976   /*
977     with tf.Graph().as_default():
978       i0 = tf.constant(0)
979       q0 = tf.FIFOQueue(1, "float")
980       q0.enqueue(tf.ones([2, 2]))
981       q1 = tf.FIFOQueue(1, "float")
982 
983       def c(i, m):
984         return i < 10
985 
986       def b(i, m):
987         return i+1, tf.concat([m, m], axis=0)
988 
989       i, m = tf.while_loop(
990           c, b, loop_vars=[i0,  q0.dequeue()],
991           shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
992 
993       q1.enqueue(m)
994       v = q1.dequeue();
995       tf.concat([v, v], axis=1)
996       with open('/tmp/graph.pbtxt', 'w') as f:
997         f.write(str(tf.get_default_graph().as_graph_def()))
998   */
999 
1000   GrapplerItem item;
1001   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1002                                  "queues_and_loops.pbtxt");
1003   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1004   GraphProperties properties(item);
1005   TF_ASSERT_OK(properties.InferStatically(false));
1006 
1007   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
1008                             "while/Exit_1"};
1009 
1010   for (const string& node : nodes) {
1011     const auto props = properties.GetOutputProperties(node);
1012     const OpInfo::TensorProperties& prop = props[0];
1013     EXPECT_EQ(DT_FLOAT, prop.dtype());
1014     EXPECT_EQ("float: [-1,2]", PropToString(prop));
1015   }
1016 
1017   const auto props = properties.GetOutputProperties("concat");
1018   const OpInfo::TensorProperties& prop = props[0];
1019   EXPECT_EQ(DT_FLOAT, prop.dtype());
1020   EXPECT_EQ("float: [-1,4]", PropToString(prop));
1021 }
1022 
TEST_F(GraphPropertiesTest,InferRestoreOpShape)1023 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
1024   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025   Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
1026                              DataType::DT_FLOAT);
1027   Output filename =
1028       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1029   Output tensor_name =
1030       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1031   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1032                                 DataType::DT_FLOAT);
1033   Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
1034 
1035   Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
1036                                       string("256 256 0,128:-"), TensorShape());
1037   Output restore_slice =
1038       ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
1039                         shape_and_slice, DataType::DT_FLOAT);
1040   Output init_restore_slice =
1041       ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
1042 
1043   Output restore_v2 =
1044       ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
1045                         shape_and_slice, DataType::DT_FLOAT);
1046   Output init_restore_v2 =
1047       ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
1048 
1049   GrapplerItem item;
1050   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1051   item.fetch.push_back("init_restore");
1052 
1053   GraphProperties properties(item);
1054   TF_ASSERT_OK(properties.InferStatically(false));
1055 
1056   const auto restore_props = properties.GetOutputProperties("restore");
1057   const OpInfo::TensorProperties& restore_prop = restore_props[0];
1058   EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
1059   EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
1060 
1061   const auto restore_slice_props =
1062       properties.GetOutputProperties("restore_slice");
1063   const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
1064   EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
1065   EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
1066 
1067   const auto restorev2_props = properties.GetOutputProperties("restore_v2");
1068   const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
1069   EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
1070   EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
1071 
1072   // Check input shapes of assign op are propagated correctly.
1073   const auto input_props = properties.GetInputProperties("init_restore");
1074   ASSERT_EQ(2, input_props.size());
1075   const OpInfo::TensorProperties& input_prop = input_props[1];
1076   EXPECT_EQ(DT_FLOAT, input_prop.dtype());
1077   EXPECT_EQ("float: [128,256]", PropToString(input_prop));
1078 }
1079 
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)1080 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
1081   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1082   Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
1083                              DataType::DT_FLOAT);
1084   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
1085                               DataType::DT_FLOAT);
1086   Output filename =
1087       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1088   Output tensor_name =
1089       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1090   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1091                                 DataType::DT_FLOAT);
1092   Output init = ops::Assign(s.WithOpName("init"), var, restore);
1093   Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
1094 
1095   GrapplerItem item;
1096   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1097   item.fetch.push_back("init");
1098   item.fetch.push_back("init2");
1099 
1100   GraphProperties properties(item);
1101   TF_ASSERT_OK(properties.InferStatically(false));
1102 
1103   const auto props = properties.GetOutputProperties("restore");
1104   const OpInfo::TensorProperties& prop = props[0];
1105   EXPECT_EQ(DT_FLOAT, prop.dtype());
1106   EXPECT_EQ("float: [128,256]", PropToString(prop));
1107 }
1108 
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)1109 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
1110   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1111   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1112   Output a1 = ops::Identity(s.WithOpName("a1"), a);
1113   Output b = ops::Const(s.WithOpName("b"), 99, {});
1114   Output b1 = ops::Identity(s.WithOpName("b1"), b);
1115   Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
1116   Output c1 = ops::Identity(s.WithOpName("c1"), c);
1117 
1118   GrapplerItem item;
1119   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1120   GraphProperties properties(item);
1121   TF_ASSERT_OK(properties.InferStatically(false));
1122 
1123   // Check output shapes.
1124   EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
1125   EXPECT_EQ("int32: [2]",
1126             PropToString(properties.GetOutputProperties("a1")[0]));
1127   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
1128   EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
1129   EXPECT_EQ("int32: [4,4,4]",
1130             PropToString(properties.GetOutputProperties("c")[0]));
1131   EXPECT_EQ("int32: [4,4,4]",
1132             PropToString(properties.GetOutputProperties("c1")[0]));
1133 
1134   // Check has_value.
1135   EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
1136   EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
1137   EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
1138   EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
1139   EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
1140   EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
1141   EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
1142   EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
1143   // Note that we propagate tensor value of only 1D vector and scalar.
1144   EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
1145 
1146   // Check values.
1147   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
1148   ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
1149   ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
1150   ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
1151   ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
1152   ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
1153   std::vector<int64_t> c_values;
1154   for (int i = 0; i < 4 * 4 * 4; i++) {
1155     c_values.push_back(1);
1156   }
1157   ExpectTensorValues({c_values},
1158                      properties.GetOutputProperties("c")[0].value());
1159   ExpectTensorValues({c_values},
1160                      properties.GetInputProperties("c1")[0].value());
1161   ExpectTensorValues({c_values},
1162                      properties.GetOutputProperties("c1")[0].value());
1163 }
1164 
TEST_F(GraphPropertiesTest,IdentityPassingShape)1165 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
1166   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1167   Output a = ops::Const(s.WithOpName("a"), 5, {2});
1168   Output b = ops::Identity(s.WithOpName("b"), a);
1169   Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1170   // Fill needs not only e's shape but also the value of e to figure out output
1171   // shape; hence, Identity op (b) should pass a's value as
1172   // output_tensors_as_shape.
1173   Output d = ops::Fill(s.WithOpName("fill"), b, c);
1174 
1175   GrapplerItem item;
1176   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1177   GraphProperties properties(item);
1178   TF_ASSERT_OK(properties.InferStatically(false));
1179   const auto out_props = properties.GetOutputProperties("fill");
1180   const OpInfo::TensorProperties out_prop0 = out_props[0];
1181   EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1182 }
1183 
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1184 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1185   // When using aggressive_shape_inference, we run EvaluateNode() for
1186   // allowlisted ops and small input / output tensors. For instance, Fill op is
1187   // evaluated and produces output tensor value if output tensor size is small
1188   // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1189   // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1190   // initializing a large table using Fill.
1191   // Note that we do not propagate float const tensors with fractional values
1192   // (even if they're small); so this test should use integer values.
1193   {
1194     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1195     Output a = ops::Const(s.WithOpName("a"), 4, {2});  // 4x4
1196     Output b = ops::Const(s.WithOpName("const"), 7, {});
1197     // Shape described by a is small; expect output values of Fill op.
1198     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1199 
1200     GrapplerItem item;
1201     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1202     GraphProperties properties(item);
1203     TF_ASSERT_OK(properties.InferStatically(
1204         /*assume_valid_feeds=*/false,
1205         /*aggressive_shape_inference=*/true,
1206         /*include_tensor_values=*/true));
1207     const auto out_props = properties.GetOutputProperties("fill");
1208     const OpInfo::TensorProperties out_prop0 = out_props[0];
1209     EXPECT_EQ("int32: [4,4]", PropToString(out_prop0));
1210     EXPECT_TRUE(out_prop0.has_value());
1211   }
1212   {
1213     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1214     Output a = ops::Const(s.WithOpName("a"), 1000, {4});  // 1000x1000x1000x1000
1215     Output b = ops::Const(s.WithOpName("const"), 7, {});
1216     // Shape described by a is huge; in that case we skip value inference.
1217     // Otherwise, it'd be too much overhead.
1218     Output c = ops::Fill(s.WithOpName("fill"), a, b);
1219 
1220     GrapplerItem item;
1221     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1222     GraphProperties properties(item);
1223     TF_ASSERT_OK(properties.InferStatically(
1224         /*assume_valid_feeds=*/false,
1225         /*aggressive_shape_inference=*/true,
1226         /*include_tensor_values=*/true));
1227     const auto out_props = properties.GetOutputProperties("fill");
1228     const OpInfo::TensorProperties out_prop0 = out_props[0];
1229     EXPECT_EQ("int32: [1000,1000,1000,1000]", PropToString(out_prop0));
1230     EXPECT_FALSE(out_prop0.has_value());
1231   }
1232 }
1233 
TEST_F(GraphPropertiesTest,PackWithConstInput)1234 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1235   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1236   Output a = ops::Const(s.WithOpName("a"), 1, {});
1237   Output b = ops::Const(s.WithOpName("b"), 2, {});
1238   Output c = ops::Const(s.WithOpName("c"), 3, {});
1239   Output d = ops::Const(s.WithOpName("d"), 4, {});
1240   // Note ops::Stack instantiates Pack op.
1241   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1242   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1243   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1244   // Fill needs not only e's shape but also its value to figure out output
1245   // shape.
1246   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1247 
1248   GrapplerItem item;
1249   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1250   GraphProperties properties(item);
1251   TF_ASSERT_OK(properties.InferStatically(false));
1252   const auto out_props = properties.GetOutputProperties("fill");
1253   const OpInfo::TensorProperties out_prop0 = out_props[0];
1254   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1255 }
1256 
TEST_F(GraphPropertiesTest,RankOp)1257 TEST_F(GraphPropertiesTest, RankOp) {
1258   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1259   Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1260   Output r = ops::Rank(s.WithOpName("Rank"), c);
1261   Output i = ops::Identity(s.WithOpName("Identity"), r);
1262 
1263   GrapplerItem item;
1264   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1265   GraphProperties properties(item);
1266   TF_ASSERT_OK(properties.InferStatically(false));
1267   const auto rank_props = properties.GetOutputProperties("Rank");
1268   const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1269   EXPECT_EQ("int32: []", PropToString(rank_prop0));
1270   EXPECT_TRUE(rank_prop0.has_value());
1271   ExpectTensorValues({3}, rank_prop0.value());
1272   const auto identity_props = properties.GetOutputProperties("Identity");
1273   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1274   EXPECT_EQ("int32: []", PropToString(identity_props0));
1275   EXPECT_TRUE(identity_props0.has_value());
1276   ExpectTensorValues({3}, identity_props0.value());
1277 }
1278 
TEST_F(GraphPropertiesTest,SizeOp)1279 TEST_F(GraphPropertiesTest, SizeOp) {
1280   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1281   Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1282   Output r = ops::Size(s.WithOpName("Size"), c);
1283   Output i = ops::Identity(s.WithOpName("Identity"), r);
1284 
1285   GrapplerItem item;
1286   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1287   GraphProperties properties(item);
1288   TF_ASSERT_OK(properties.InferStatically(false));
1289   const auto size_props = properties.GetOutputProperties("Size");
1290   const OpInfo::TensorProperties size_props0 = size_props[0];
1291   EXPECT_EQ("int32: []", PropToString(size_props0));
1292   EXPECT_TRUE(size_props0.has_value());
1293   ExpectTensorValues({24}, size_props0.value());
1294   const auto identity_props = properties.GetOutputProperties("Identity");
1295   const OpInfo::TensorProperties identity_props0 = identity_props[0];
1296   EXPECT_EQ("int32: []", PropToString(identity_props0));
1297   EXPECT_TRUE(identity_props0.has_value());
1298   ExpectTensorValues({24}, identity_props0.value());
1299 }
1300 
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1301 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1302   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1303   Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1304   Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1305   Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1306   // pack is [2], with values {4, -1}.
1307 
1308   Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1309   Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1310 
1311   Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1312   Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1313   // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1314   // their output shapes would be [4, -1]. However, though we use the same
1315   // shape input to the Reshape ops, their output shapes can be different;
1316   // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1317   // the same.
1318 
1319   // if input to the Select ops. Note that s0 has a fully defined shape, while
1320   // s1 has unknown shape.
1321   Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1322   Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1323 
1324   Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1325   Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1326 
1327   // We instantiate SelectV2, but will replace it with Select. The shape
1328   // inference function for Select links all inputs and outputs as they should
1329   // have the same shapes.
1330   Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1331   Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1332 
1333   // For z0, as we know the shape of s0, symbolic shape manager in shape
1334   // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1335   // which is [4, 16].
1336   // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1337   // at best.
1338   // Note that x0 and x1 share the same shape input to the Reshape op, but
1339   // -1 in the shape input should not be treated as the same symoblic unknown
1340   // dim; it is merely a constant value -1 for identitying unknown dim for
1341   // Reshape operation.
1342 
1343   GrapplerItem item;
1344   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1345 
1346   // Replace SelectV2 op with Select op.
1347   for (int i = 0; i < item.graph.node_size(); ++i) {
1348     auto* node = item.graph.mutable_node(i);
1349     if (node->op() == "SelectV2") {
1350       node->set_op("Select");
1351     }
1352   }
1353 
1354   GraphProperties properties(item);
1355   TF_ASSERT_OK(properties.InferStatically(false));
1356   for (const auto& node_name : {"x0", "y0", "z0"}) {
1357     const auto out_props = properties.GetOutputProperties(node_name);
1358     const OpInfo::TensorProperties out_prop0 = out_props[0];
1359     EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1360   }
1361   {
1362     const auto out_props = properties.GetOutputProperties("s0");
1363     const OpInfo::TensorProperties out_prop0 = out_props[0];
1364     EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1365   }
1366 
1367   for (const auto& node_name : {"x1", "y1", "z1"}) {
1368     const auto out_props = properties.GetOutputProperties(node_name);
1369     const OpInfo::TensorProperties out_prop0 = out_props[0];
1370     EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1371   }
1372   // if input of Select can be either vector or the same shape to the
1373   // input/output; in this case, even if we know input and output are
1374   // [4, ?], we can't say it's [4, ?] or a vector; hence, it should be
1375   // unknown.
1376   {
1377     const auto out_props = properties.GetOutputProperties("s1");
1378     const OpInfo::TensorProperties out_prop0 = out_props[0];
1379     EXPECT_EQ("bool: ?", PropToString(out_prop0));
1380   }
1381 }
1382 
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1383 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1384   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1385   // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1386   // from Const.
1387   // If output_tensors_as_shape is not set for those Shape ops or Pack op
1388   // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1389   // hence, its output shape becomes unknown.
1390   Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1391   Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1392   Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1393   Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1394   Output a = ops::Identity(s.WithOpName("a"), a0);
1395   Output b = ops::Identity(s.WithOpName("b"), b0);
1396   Output c = ops::Identity(s.WithOpName("c"), c0);
1397   Output d = ops::Identity(s.WithOpName("d"), d0);
1398   // Note ops::Stack instantiates Pack op.
1399   Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1400   // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1401   Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1402   // Fill needs not only e's shape but also its value to figure out output
1403   // shape.
1404   Output g = ops::Fill(s.WithOpName("fill"), e, f);
1405 
1406   GrapplerItem item;
1407   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1408   GraphProperties properties(item);
1409   TF_ASSERT_OK(properties.InferStatically(false));
1410   const auto out_props = properties.GetOutputProperties("fill");
1411   const OpInfo::TensorProperties out_prop0 = out_props[0];
1412   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1413 }
1414 
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1415 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1416   // Function ops may have DT_RESOURCE input; if not properly set shapes and
1417   // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1418   // function ops.
1419   GrapplerItem item;
1420   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1421                                  "function_with_dt_resource_input.pbtxt");
1422   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1423 
1424   // This graph evaluates FunctionWithDtResourceInput with two inputs:
1425   // x [DT_FLOAT Const],
1426   // _Arg [DT_RESOURCE _Arg]
1427   // and has two outputs:
1428   // z1 = x + _Arg
1429   // z2 = x
1430   {
1431     GraphProperties properties(item);
1432     TF_ASSERT_OK(properties.InferStatically(false));
1433     const auto out_props =
1434         properties.GetOutputProperties("FunctionWithDtResourceInput");
1435     EXPECT_EQ(out_props.size(), 2);
1436     const OpInfo::TensorProperties out_prop0 = out_props[0];
1437     EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1438     const OpInfo::TensorProperties out_prop1 = out_props[1];
1439     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1440   }
1441 
1442   {
1443     // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1444     for (int i = 0; i < item.graph.node_size(); i++) {
1445       auto* node = item.graph.mutable_node(i);
1446       if (node->name() == "y") {  // _Arg node with DT_RESOURCE
1447         node->mutable_attr()->erase("_handle_dtypes");
1448         node->mutable_attr()->erase("_handle_shapes");
1449         break;
1450       }
1451     }
1452     // We cannot infer the function output shape correctly without those attr,
1453     // but still it shouldn't fail; also, there can be some shapes we can
1454     // infer in such a case. In this test graph,
1455     // z2 of the function node just returns x input; hence, even if _Arg's shape
1456     // cannot be inferred, we can infer z2 output shape.
1457     GraphProperties properties(item);
1458     TF_ASSERT_OK(properties.InferStatically(false));
1459     const auto out_props =
1460         properties.GetOutputProperties("FunctionWithDtResourceInput");
1461     EXPECT_EQ(out_props.size(), 2);
1462     const OpInfo::TensorProperties out_prop0 = out_props[0];
1463     // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1464     // for x + _Arg.
1465     EXPECT_EQ("float: ?", PropToString(out_prop0));
1466     // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1467     // infer this output shape.
1468     const OpInfo::TensorProperties out_prop1 = out_props[1];
1469     EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1470   }
1471 }
1472 
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1473 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1474   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1475   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1476   Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1477   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1478   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1479                                          s.graph()->op_registry());
1480   tensorflow::Node* func_op;
1481   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1482   auto _value = tensorflow::ops::AsNodeOut(s, value);
1483   TF_ASSERT_OK(
1484       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1485   GrapplerItem item;
1486   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1487 
1488   GraphProperties properties(item);
1489   TF_ASSERT_OK(properties.InferStatically(false));
1490   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1491   const OpInfo::TensorProperties out_prop0 = out_props[0];
1492   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1493 }
1494 
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1495 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1496   // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1497   // so tensor shapes, not tensor value, should be used as Const input to
1498   // function.
1499   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1500   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1501   Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1502   Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1503   Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1504   auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1505                                          s.graph()->op_registry());
1506   tensorflow::Node* func_op;
1507   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1508   auto _value = tensorflow::ops::AsNodeOut(s, value);
1509   TF_ASSERT_OK(
1510       builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1511   GrapplerItem item;
1512   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1513 
1514   GraphProperties properties(item);
1515   TF_ASSERT_OK(properties.InferStatically(false));
1516   const auto out_props = properties.GetOutputProperties("MyFillFunc");
1517   const OpInfo::TensorProperties out_prop0 = out_props[0];
1518   EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1519 }
1520 
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1521 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1522   FunctionDefLibrary library;
1523   *library.add_function() = FunctionDefHelper::Create(
1524       "MyFunc",                                                   // Name
1525       {"x: int32"},                                               // Inputs
1526       {"out: int32"},                                             // Outputs
1527       {},                                                         // Attrs
1528       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1529       {{"out", "a:output:0"}});                                   // Returns
1530   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1531   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1532 
1533   // MyFunc takes Const (shape) and passes it with Identity. Expect function
1534   // output has the same shape as well as value (output_tensors_as_shape) as
1535   // input Const tensor.
1536   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1537   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1538   auto builder =
1539       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1540   tensorflow::Node* func_op;
1541   TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1542 
1543   GrapplerItem item;
1544   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1545 
1546   GraphProperties properties(item);
1547   TF_ASSERT_OK(properties.InferStatically(true));
1548   const auto out_props = properties.GetOutputProperties("MyFunc");
1549   const OpInfo::TensorProperties out_prop0 = out_props[0];
1550   EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1551   EXPECT_TRUE(out_prop0.has_value());
1552   ExpectTensorValues({5, 7}, out_prop0.value());
1553   ExpectTensorValues({5, 7},
1554                      properties.GetInputProperties("MyFunc")[0].value());
1555 }
1556 
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1557 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1558   FunctionDefLibrary library;
1559   // Function that adds two input values.
1560   *library.add_function() = FunctionDefHelper::Create(
1561       "MyFunc",                                                   // Name
1562       {"x: int32", "y: int32"},                                   // Inputs
1563       {"out: int32"},                                             // Outputs
1564       {},                                                         // Attrs
1565       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}},  // Nodes
1566       {{"out", "a:z:0"}});                                        // Returns
1567   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1568   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1569 
1570   Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1571   auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1572   auto builder =
1573       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1574   tensorflow::Node* func_op;
1575   TF_ASSERT_OK(
1576       builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1577 
1578   GrapplerItem item;
1579   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1580   {
1581     GraphProperties properties(item);
1582     // Without aggressive_shape_inference, the internal function does not
1583     // evaluate output value.
1584     TF_ASSERT_OK(properties.InferStatically(
1585         /*assume_valid_feeds=*/true,
1586         /*aggressive_shape_inference=*/false,
1587         /*include_tensor_values=*/true));
1588     const auto out_props = properties.GetOutputProperties("MyFunc");
1589     const OpInfo::TensorProperties out_prop0 = out_props[0];
1590     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1591     EXPECT_FALSE(out_prop0.has_value());
1592   }
1593 
1594   {
1595     GraphProperties properties(item);
1596     // With aggressive_shape_inference, output value is evaluated.
1597     TF_ASSERT_OK(properties.InferStatically(
1598         /*assume_valid_feeds=*/true,
1599         /*aggressive_shape_inference=*/true,
1600         /*include_tensor_values=*/true));
1601     const auto out_props = properties.GetOutputProperties("MyFunc");
1602     const OpInfo::TensorProperties out_prop0 = out_props[0];
1603     EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1604     EXPECT_TRUE(out_prop0.has_value());
1605 
1606     ExpectTensorValues({10, 14}, out_prop0.value());
1607     ExpectTensorValues({5, 7},
1608                        properties.GetInputProperties("MyFunc")[0].value());
1609     ExpectTensorValues({5, 7},
1610                        properties.GetInputProperties("MyFunc")[1].value());
1611   }
1612 }
1613 
1614 // Same as the above, but float values; also, one of the function input is
1615 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1616 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1617   FunctionDefLibrary library;
1618   // Function that adds two input values.
1619   *library.add_function() = FunctionDefHelper::Create(
1620       "MyFunc",                                                   // Name
1621       {"x: float", "y: float"},                                   // Inputs
1622       {"out: float"},                                             // Outputs
1623       {},                                                         // Attrs
1624       {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1625       {{"out", "a:z:0"}});                                        // Returns
1626   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1627   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1628 
1629   Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1630   Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1631   auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1632   auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1633   auto builder =
1634       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1635   tensorflow::Node* func_op;
1636   TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1637 
1638   GrapplerItem item;
1639   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1640   {
1641     GraphProperties properties(item);
1642     // Without aggressive_shape_inference, the internal function does not
1643     // evaluate output value.
1644     TF_ASSERT_OK(properties.InferStatically(
1645         /*assume_valid_feeds=*/true,
1646         /*aggressive_shape_inference=*/false,
1647         /*include_tensor_values=*/true));
1648     const auto out_props = properties.GetOutputProperties("MyFunc");
1649     const OpInfo::TensorProperties out_prop0 = out_props[0];
1650     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1651     EXPECT_FALSE(out_prop0.has_value());
1652   }
1653 
1654   {
1655     GraphProperties properties(item);
1656     // With aggressive_shape_inference, output value is evaluated.
1657     TF_ASSERT_OK(properties.InferStatically(
1658         /*assume_valid_feeds=*/true,
1659         /*aggressive_shape_inference=*/true,
1660         /*include_tensor_values=*/true));
1661     const auto out_props = properties.GetOutputProperties("MyFunc");
1662     const OpInfo::TensorProperties out_prop0 = out_props[0];
1663     EXPECT_EQ("float: [2]", PropToString(out_prop0));
1664     EXPECT_TRUE(out_prop0.has_value());
1665 
1666     ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1667     ExpectFloatTensorValues({5.0, 7.0},
1668                             properties.GetInputProperties("MyFunc")[0].value());
1669     ExpectFloatTensorValues({5.0, 7.0},
1670                             properties.GetInputProperties("MyFunc")[1].value());
1671   }
1672 }
1673 
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1674 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1675   // Create graph with a function that takes a scalar value so that we use
1676   // Placeholder with scalar as for input to the function shape inference.
1677   // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1678   // the input; all tensors are scalars.
1679   FunctionDefLibrary library;
1680   *library.add_function() = FunctionDefHelper::Create(
1681       "MyFunc",                                                   // Name
1682       {"x: float"},                                               // Inputs
1683       {"out: float"},                                             // Outputs
1684       {},                                                         // Attrs
1685       {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}},  // Nodes
1686       {{"out", "a:output:0"}});                                   // Returns
1687   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1688   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1689   Output placeholder =
1690       ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1691                        ops::Placeholder::Shape(TensorShape({})));
1692   Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1693   auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1694   auto builder =
1695       tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1696   tensorflow::Node* func_op;
1697   TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1698   GrapplerItem item;
1699   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1700 
1701   // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1702   // as unknown, instead of scalar.
1703   EXPECT_GT(item.graph.versions().producer(), 21);
1704 
1705   // MyFunc output shouldn't be unknown rank.
1706   GraphProperties properties(item);
1707   TF_ASSERT_OK(properties.InferStatically(true));
1708   const auto out_props = properties.GetOutputProperties("MyFunc");
1709   const OpInfo::TensorProperties out_prop0 = out_props[0];
1710   EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1711   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1712 }
1713 
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1714 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1715   // Test graph produced in python using:
1716   /*
1717     @function.Defun(*[tf.float32] * 2, noinline=True)
1718     def MyAdd(x, y):
1719       return tf.add(x,y)
1720 
1721     with tf.Graph().as_default():
1722       x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1723       y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1724       z = MyAdd(x, y)
1725       z = MyAdd(x, z)
1726   */
1727   GrapplerItem item;
1728   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1729                                  "simple_function.pbtxt");
1730   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1731   GraphProperties properties(item);
1732   TF_ASSERT_OK(properties.InferStatically(false));
1733   const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1734   const OpInfo::TensorProperties& out_prop = out_props[0];
1735   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1736   EXPECT_FALSE(out_prop.shape().unknown_rank());
1737   EXPECT_EQ(2, out_prop.shape().dim_size());
1738   EXPECT_EQ(1, out_prop.shape().dim(0).size());
1739   EXPECT_EQ(2, out_prop.shape().dim(1).size());
1740 
1741   const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1742   EXPECT_EQ(2, in_props.size());
1743 
1744   const OpInfo::TensorProperties& in_prop = in_props[0];
1745   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1746 
1747   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1748   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1749 }
1750 
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1751 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1752   GrapplerItem item;
1753   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1754                                  "large_function_graph.pbtxt");
1755   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1756   GraphProperties properties(item);
1757   TF_ASSERT_OK(properties.InferStatically(false));
1758 
1759   const auto out_props = properties.GetOutputProperties("y0");
1760   EXPECT_EQ(2, out_props.size());
1761 
1762   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1763   EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1764 
1765   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1766   EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1767 
1768   const auto in_props = properties.GetInputProperties("y0");
1769   EXPECT_EQ(4, in_props.size());
1770 
1771   const OpInfo::TensorProperties& in_prop0 = in_props[0];
1772   EXPECT_EQ("float: [64]", PropToString(in_prop0));
1773 
1774   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1775   EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1776 
1777   const OpInfo::TensorProperties& in_prop2 = in_props[2];
1778   EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1779 
1780   const OpInfo::TensorProperties& in_prop3 = in_props[3];
1781   EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1782 }
1783 
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1784 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1785   // Test graph produced in python using:
1786   /*
1787     @function.Defun(noinline=True)
1788     def MyFunc():
1789       @function.Defun(*[tf.float32] * 2)
1790       def Cond(n, unused_x):
1791         return n > 0
1792 
1793       @function.Defun(*[tf.float32] * 2)
1794       def Body(n, x):
1795         return n - 1, x + n
1796 
1797       i = tf.constant(10)
1798       return functional_ops.While([i, 0.], Cond, Body)
1799 
1800     with tf.Graph().as_default():
1801       z = MyFunc()
1802   */
1803   GrapplerItem item;
1804   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1805                                  "function_functional_while.pbtxt");
1806   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1807   GraphProperties properties(item);
1808   TF_ASSERT_OK(properties.InferStatically(false));
1809 
1810   const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1811   EXPECT_EQ(2, out_props.size());
1812 
1813   const OpInfo::TensorProperties& out_prop0 = out_props[0];
1814   EXPECT_EQ(DT_INT32, out_prop0.dtype());
1815   EXPECT_FALSE(out_prop0.shape().unknown_rank());
1816 
1817   const OpInfo::TensorProperties& out_prop1 = out_props[1];
1818   EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1819   EXPECT_FALSE(out_prop1.shape().unknown_rank());
1820 }
1821 
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1822 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1823   GrapplerItem item;
1824   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1825                                  "function_error.pbtxt");
1826   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1827   GraphProperties properties(item);
1828   TF_ASSERT_OK(properties.InferStatically(false));
1829 
1830   const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1831   EXPECT_EQ(1, out_props.size());
1832 
1833   const OpInfo::TensorProperties& out_prop = out_props[0];
1834   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1835   EXPECT_TRUE(out_prop.shape().unknown_rank());
1836 
1837   const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1838   EXPECT_EQ(2, in_props.size());
1839 
1840   const OpInfo::TensorProperties& in_prop = in_props[0];
1841   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1842 
1843   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1844   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1845 }
1846 
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1847 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1848   // Test graph produced in python using:
1849   /*
1850     @function.Defun(*[tf.float32] * 2, noinline=True)
1851     def MyAdd(x, y):
1852       return tf.add(x, y)
1853 
1854     with tf.Graph().as_default():
1855       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1856       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1857       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1858       z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1859   */
1860   GrapplerItem item;
1861   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1862                                  "function_switch.pbtxt");
1863   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1864   GraphProperties properties(item);
1865   TF_ASSERT_OK(properties.InferStatically(false));
1866   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1867   const OpInfo::TensorProperties& out_prop = out_props[0];
1868   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1869   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1870 
1871   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1872   EXPECT_EQ(2, in_props.size());
1873 
1874   const OpInfo::TensorProperties& in_prop = in_props[0];
1875   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1876 
1877   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1878   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1879 }
1880 
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1881 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1882   // Test graph produced in python using:
1883   /*
1884     @function.Defun(*[tf.float32] * 2, noinline=True)
1885     def MyAdd(x, y):
1886       return tf.add(x, y)
1887 
1888     with tf.Graph().as_default():
1889       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1890       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1891       z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1892       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1893   */
1894   GrapplerItem item;
1895   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1896                                  "function_switch_2.pbtxt");
1897   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1898   GraphProperties properties(item);
1899   TF_ASSERT_OK(properties.InferStatically(false));
1900   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1901   const OpInfo::TensorProperties& out_prop = out_props[0];
1902   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1903 
1904   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1905   EXPECT_EQ(2, in_props.size());
1906 
1907   const OpInfo::TensorProperties& in_prop = in_props[0];
1908   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1909 
1910   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1911   EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1912 }
1913 
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1914 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1915   // Test graph produced in python using:
1916   /*
1917     @function.Defun(*[tf.float32] * 2, noinline=True)
1918     def MyAdd(x, y):
1919       a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1920       b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1921       c = tf.add(x, a)
1922       d = tf.add(y, b)
1923       return c
1924 
1925     with tf.Graph().as_default():
1926       x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1927       y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1928       z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1929       z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1930   */
1931   GrapplerItem item;
1932   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1933                                  "function_switch_shapes.pbtxt");
1934   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1935   GraphProperties properties(item);
1936   TF_ASSERT_OK(properties.InferStatically(false));
1937   const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1938   const OpInfo::TensorProperties& out_prop = out_props[0];
1939   EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1940 
1941   const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1942   EXPECT_EQ(2, in_props.size());
1943 
1944   const OpInfo::TensorProperties& in_prop = in_props[0];
1945   EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1946 
1947   const OpInfo::TensorProperties& in_prop1 = in_props[1];
1948   EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1949 }
1950 
TEST_F(GraphPropertiesTest,SymbolicShapes)1951 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1952   // Build a simple graph with placeholders of unknown dimensions. These
1953   // dimensions will be encoded symbolically.
1954   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1955 
1956   Output a =
1957       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1958                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1959   Output b =
1960       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1961                        ops::Placeholder::Shape(PartialTensorShape({-1})));
1962   Output c = ops::Identity(s.WithOpName("c"), a);
1963   Output d = ops::Identity(s.WithOpName("d"), b);
1964   Output e = ops::Add(s.WithOpName("e"), c, d);
1965   Output f = ops::Add(s.WithOpName("f"), a, c);
1966 
1967   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1968   Output g = ops::Shape(s.WithOpName("g"), c);
1969   Output h = ops::Fill(s.WithOpName("h"), g, zero);
1970   Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1971   Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1972 
1973   GrapplerItem item;
1974   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1975 
1976   GraphProperties properties(item);
1977   TF_ASSERT_OK(properties.InferStatically(false));
1978   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1979   const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1980   EXPECT_EQ(2, shape_a.dim_size());
1981   EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1982   EXPECT_GE(-2, shape_a.dim(0).size());
1983   EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1984   EXPECT_GE(-2, shape_a.dim(1).size());
1985   EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1986 
1987   PartialTensorShape shape(shape_a);
1988   EXPECT_FALSE(shape.IsFullyDefined());
1989   EXPECT_FALSE(shape.unknown_rank());
1990 
1991   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1992   const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1993   EXPECT_EQ(1, shape_b.dim_size());
1994   EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1995   EXPECT_GE(-2, shape_b.dim(0).size());
1996   EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1997   EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1998 
1999   const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
2000   ASSERT_EQ(2, shape_e.dim_size());
2001   EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
2002   EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
2003   EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
2004 
2005   const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
2006   ASSERT_EQ(2, shape_f.dim_size());
2007   EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
2008   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
2009 
2010   const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
2011   ASSERT_EQ(2, shape_f.dim_size());
2012   EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
2013   EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
2014 
2015   const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
2016   ASSERT_EQ(1, shape_j.dim_size());
2017   EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
2018 }
2019 
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)2020 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
2021   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2022   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
2023   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
2024   Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
2025   GrapplerItem item;
2026   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2027   // Create a graph with node a removed (say by some graph optimization
2028   // pass), noting that node c is colocated with a. This is fine as it
2029   // is in the late stage of graph execution, the colocation constraints have
2030   // been validated previously and the device placement of nodes has completed.
2031   GraphDef optimized_graph;
2032   for (const auto& node : item.graph.node()) {
2033     if (node.name() != "a") {
2034       *optimized_graph.add_node() = node;
2035     }
2036   }
2037   item.graph.Swap(&optimized_graph);
2038   GraphProperties properties(item);
2039   // This function should return OK, since it doesn't validate the colocation
2040   // constraints internally.
2041   TF_EXPECT_OK(properties.InferStatically(false));
2042 }
2043 
TEST_F(GraphPropertiesTest,ShapeTracking)2044 TEST_F(GraphPropertiesTest, ShapeTracking) {
2045   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2046   Output a =
2047       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2048                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2049   Output b =
2050       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
2051                        ops::Placeholder::Shape(PartialTensorShape({-1})));
2052   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2053   auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
2054   Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
2055   Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
2056 
2057   GrapplerItem item;
2058   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2059 
2060   GraphProperties properties(item);
2061   TF_ASSERT_OK(properties.InferStatically(false));
2062   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2063   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
2064   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2065   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2066   EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
2067   EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
2068 }
2069 
TEST_F(GraphPropertiesTest,FedNodes)2070 TEST_F(GraphPropertiesTest, FedNodes) {
2071   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
2072                                           cluster_->GetDeviceNames());
2073   GrapplerItem item;
2074   CHECK(fake_input.NextItem(&item));
2075 
2076   {
2077     // Conservative shape analysis: the shape of fed ports should be unknown
2078     GraphProperties properties(item);
2079     Status s = properties.InferStatically(false);
2080     TF_ASSERT_OK(s);
2081     for (const auto& node : item.graph.node()) {
2082       if (node.op() == "Const") {
2083         continue;
2084       }
2085       const auto in_props = properties.GetInputProperties(node.name());
2086       EXPECT_EQ(1, in_props.size());
2087       const OpInfo::TensorProperties& in_prop = in_props[0];
2088       const auto out_props = properties.GetOutputProperties(node.name());
2089       EXPECT_EQ(1, out_props.size());
2090       const OpInfo::TensorProperties& out_prop = out_props[0];
2091 
2092       if (node.name() == "x") {
2093         // x is fed: its input should have a known shape, while its output
2094         // doesn't
2095         EXPECT_FALSE(in_prop.shape().unknown_rank());
2096         EXPECT_EQ(1, in_prop.shape().dim_size());
2097         EXPECT_EQ(2, in_prop.shape().dim(0).size());
2098         EXPECT_TRUE(out_prop.shape().unknown_rank());
2099       } else if (node.op() == "Square" || node.op() == "AddN") {
2100         // These nodes are in the fanout of x: their shapes should be unknown.
2101         EXPECT_TRUE(in_prop.shape().unknown_rank());
2102         EXPECT_TRUE(out_prop.shape().unknown_rank());
2103       }
2104     }
2105   }
2106   {
2107     // Optimistic shape analysis: the shape of fed ports should be derived from
2108     // the shape of the fanin.
2109     GraphProperties properties(item);
2110     Status s = properties.InferStatically(true);
2111     TF_ASSERT_OK(s);
2112     for (const auto& node : item.graph.node()) {
2113       if (node.op() == "Square" || node.op() == "AddN") {
2114         const auto in_props = properties.GetInputProperties(node.name());
2115         EXPECT_EQ(1, in_props.size());
2116         const OpInfo::TensorProperties& in_prop = in_props[0];
2117         EXPECT_EQ(DT_FLOAT, in_prop.dtype());
2118         EXPECT_FALSE(in_prop.shape().unknown_rank());
2119         EXPECT_EQ(2, in_prop.shape().dim_size());
2120         const auto out_props = properties.GetOutputProperties(node.name());
2121         EXPECT_EQ(1, out_props.size());
2122         const OpInfo::TensorProperties& out_prop = out_props[0];
2123         EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
2124         EXPECT_EQ(in_prop.shape().DebugString(),
2125                   out_prop.shape().DebugString());
2126       }
2127     }
2128   }
2129 }
2130 
TEST_F(GraphPropertiesTest,Performance)2131 TEST_F(GraphPropertiesTest, Performance) {
2132   // Load a large graph with many nested loops to make sure we can infer shapes
2133   // quickly.
2134   GrapplerItem item;
2135   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
2136                                  "large_graph.pbtxt.html");
2137   TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
2138   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
2139       &item.graph,
2140       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
2141       true));
2142 
2143   GraphProperties properties(item);
2144   TF_ASSERT_OK(properties.InferStatically(false));
2145 }
2146 
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)2147 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
2148   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2149   Output a =
2150       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2151                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2152   auto shp = ops::Shape(s.WithOpName("shape"), {a});
2153 
2154   Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
2155   Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
2156   Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
2157 
2158   Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
2159   Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
2160 
2161   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2162   Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
2163   Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
2164 
2165   GrapplerItem item;
2166   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2167 
2168   GraphProperties properties(item);
2169   TF_ASSERT_OK(properties.InferStatically(false));
2170   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2171   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2172   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2173   EXPECT_EQ(2, shape_a.dim_size());
2174   EXPECT_EQ(1, shape_o1.dim_size());
2175   EXPECT_EQ(1, shape_o2.dim_size());
2176   EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2177   EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2178 }
2179 
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2180 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2181   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2182   Output placeholder =
2183       ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2184                        ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2185   auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2186 
2187   Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2188   Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2189   Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2190 
2191   Output slice =
2192       ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2193                         stride, ops::StridedSlice::ShrinkAxisMask(1));
2194 
2195   GrapplerItem item;
2196   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2197 
2198   // Without aggressive shape inference, it cannot infer output value of
2199   // StridedSlice with ShrinkAxisMask.
2200   {
2201     GraphProperties properties(item);
2202     TF_ASSERT_OK(properties.InferStatically(
2203         /*assume_valid_feeds=*/false,
2204         /*aggressive_shape_inference=*/false,
2205         /*include_tensor_values=*/true));
2206     EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2207   }
2208 
2209   // InferStatically with aggressive shape inference can infer output value of
2210   // StridedSlice with ShrinkAxisMask.
2211   {
2212     GraphProperties properties(item);
2213     TF_ASSERT_OK(properties.InferStatically(
2214         /*assume_valid_feeds=*/false,
2215         /*aggressive_shape_inference=*/true,
2216         /*include_tensor_values=*/true));
2217     EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2218     const auto slice_value =
2219         properties.GetOutputProperties("slice").at(0).value();
2220     ExpectTensorValues({5}, slice_value);
2221   }
2222 }
2223 
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2224 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2225   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2226   Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2227   Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2228   Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2229 
2230   Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2231   Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2232   Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2233   Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2234   Output c_plus_b_plus_2a =
2235       ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2236 
2237   GrapplerItem item;
2238   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2239   GraphProperties properties(item);
2240   TF_ASSERT_OK(properties.InferStatically(
2241       /*assume_valid_feeds=*/false,
2242       /*aggressive_shape_inference=*/true,
2243       /*include_tensor_values=*/true));
2244 
2245   // Check output shapes and values.
2246   const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2247   EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2248   EXPECT_TRUE(a_plus_one_prop.has_value());
2249   ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2250 
2251   const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2252   EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2253   EXPECT_TRUE(a_plus_a_prop.has_value());
2254   ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2255 
2256   const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2257   EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2258   EXPECT_TRUE(b_plus_2a_prop.has_value());
2259   ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2260 
2261   const auto& c_plus_b_plus_2a_prop =
2262       properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2263   EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2264   EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2265   ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2266 }
2267 
TEST_F(GraphPropertiesTest,ShapeAnnotation)2268 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2269   GrapplerItem item;
2270   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2271                    .Attr("dtype", DT_FLOAT)
2272                    .Attr("shape", PartialTensorShape({-1, -1}))
2273                    .Finalize(item.graph.add_node()));
2274   // Annotate shapes.
2275   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2276                    .Attr("dtype", DT_FLOAT)
2277                    .Attr("_same_output_for_iterations", true)
2278                    .Attr("_output_shape_vector", {TensorShape({5, 7})})
2279                    .Input("Input", 0, DT_FLOAT)
2280                    .Finalize(item.graph.add_node()));
2281   {
2282     GraphProperties properties(item);
2283     // Without aggressive_shape_inference, ignore annotated information.
2284     TF_ASSERT_OK(properties.InferStatically(
2285         /*assume_valid_feeds=*/false,
2286         /*aggressive_shape_inference=*/false,
2287         /*include_tensor_values=*/true));
2288     const auto props = properties.GetOutputProperties("Identity");
2289     EXPECT_EQ(1, props.size());
2290     const OpInfo::TensorProperties& prop = props[0];
2291     EXPECT_EQ(DT_FLOAT, prop.dtype());
2292     EXPECT_EQ(2, prop.shape().dim_size());
2293     // Get unknown shapes without using annotated information.
2294     EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2295   }
2296   {
2297     GraphProperties properties(item);
2298     // Use annotated information.
2299     TF_ASSERT_OK(properties.InferStatically(
2300         /*assume_valid_feeds=*/false,
2301         /*aggressive_shape_inference=*/true,
2302         /*include_tensor_values=*/true));
2303     const auto props = properties.GetOutputProperties("Identity");
2304     EXPECT_EQ(1, props.size());
2305     const OpInfo::TensorProperties& prop = props[0];
2306     EXPECT_EQ(DT_FLOAT, prop.dtype());
2307     EXPECT_EQ(2, prop.shape().dim_size());
2308     // Update output shape using annotated shapes.
2309     EXPECT_EQ("float: [5,7]", PropToString(prop));
2310   }
2311 }
2312 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2313 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2314   GrapplerItem item;
2315   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2316                    .Attr("dtype", DT_FLOAT)
2317                    .Attr("shape", PartialTensorShape({-1, 100}))
2318                    .Finalize(item.graph.add_node()));
2319   // Annotate shapes.
2320   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2321                    .Attr("dtype", DT_FLOAT)
2322                    .Attr("_same_output_for_iterations", true)
2323                    .Attr("_output_shape_vector", {TensorShape({10, 100})})
2324                    .Input("Input", 0, DT_FLOAT)
2325                    .Finalize(item.graph.add_node()));
2326   GraphProperties properties(item);
2327   // Use annotated information.
2328   TF_ASSERT_OK(properties.InferStatically(
2329       /*assume_valid_feeds=*/false,
2330       /*aggressive_shape_inference=*/true,
2331       /*include_tensor_values=*/true));
2332   const auto props = properties.GetOutputProperties("Identity");
2333   EXPECT_EQ(1, props.size());
2334   const OpInfo::TensorProperties& prop = props[0];
2335   EXPECT_EQ(DT_FLOAT, prop.dtype());
2336   EXPECT_EQ(2, prop.shape().dim_size());
2337   // Compatible shapes. Update output shape using annotated shapes.
2338   EXPECT_EQ("float: [10,100]", PropToString(prop));
2339 }
2340 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2341 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2342   GrapplerItem item;
2343   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2344                    .Attr("dtype", DT_FLOAT)
2345                    .Attr("shape", PartialTensorShape({-1, 100}))
2346                    .Finalize(item.graph.add_node()));
2347   // Annotate shapes.
2348   TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2349                    .Attr("dtype", DT_FLOAT)
2350                    .Attr("_same_output_for_iterations", true)
2351                    .Attr("_output_shape_vector", {TensorShape({10, 10})})
2352                    .Input("Input", 0, DT_FLOAT)
2353                    .Finalize(item.graph.add_node()));
2354   GraphProperties properties(item);
2355   // Use annotated information.
2356   TF_ASSERT_OK(properties.InferStatically(
2357       /*assume_valid_feeds=*/false,
2358       /*aggressive_shape_inference=*/true,
2359       /*include_tensor_values=*/true));
2360   const auto props = properties.GetOutputProperties("Identity");
2361   EXPECT_EQ(1, props.size());
2362   const OpInfo::TensorProperties& prop = props[0];
2363   EXPECT_EQ(DT_FLOAT, prop.dtype());
2364   EXPECT_EQ(2, prop.shape().dim_size());
2365   // Incompatible shapes. Do not use annotated shapes.
2366   EXPECT_EQ("float: [-1,100]", PropToString(prop));
2367 }
2368 
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2369 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2370   GrapplerItem item;
2371   TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2372                    .Attr("dtype", DT_FLOAT)
2373                    .Attr("shape", PartialTensorShape({-1, -1}))
2374                    .Finalize(item.graph.add_node()));
2375   // Annotate shapes.
2376   TF_ASSERT_OK(
2377       NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2378           .Attr("_same_output_for_iterations", true)
2379           .Attr("_output_shape_vector", {TensorShape({10, 100})})
2380           .Input("Input", 0, DT_FLOAT)
2381           .Finalize(item.graph.add_node()));
2382   GraphProperties properties(item);
2383   // Use annotated information.
2384   TF_ASSERT_OK(properties.InferStatically(
2385       /*assume_valid_feeds=*/false,
2386       /*aggressive_shape_inference=*/true,
2387       /*include_tensor_values=*/true));
2388   const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2389   EXPECT_EQ(1, props.size());
2390   const OpInfo::TensorProperties& prop = props[0];
2391   EXPECT_EQ(DT_FLOAT, prop.dtype());
2392   EXPECT_EQ(2, prop.shape().dim_size());
2393   EXPECT_EQ("float: [10,100]", PropToString(prop));
2394 }
2395 
TEST_F(GraphPropertiesTest,PartitionedCallOp)2396 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2397   Scope root = Scope::NewRootScope().ExitOnError();
2398   FunctionDefLibrary library;
2399   FunctionDef called_func = FunctionDefHelper::Create(
2400       "identity_function",
2401       /*in_def=*/{"arg0: int32"},
2402       /*out_def=*/{"ret0: int32"},
2403       /*attr_def=*/{},
2404       {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2405       /*ret_def=*/{{"ret0", "Identity:output:0"}});
2406   *library.add_function() = called_func;
2407   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2408 
2409   Output in = ops::Const(root, {3, 1, 2, 0});
2410   NameAttrList b_name_attr;
2411   b_name_attr.set_name("identity_function");
2412   ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2413                             b_name_attr);
2414 
2415   GrapplerItem item;
2416   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2417 
2418   GraphProperties properties(item);
2419   TF_ASSERT_OK(properties.InferStatically(
2420       /*assume_valid_feeds=*/true,
2421       /*aggressive_shape_inference=*/false,
2422       /*include_tensor_values=*/true));
2423 
2424   EXPECT_EQ("int32: [4]",
2425             PropToString(properties.GetOutputProperties("identity_call")[0]));
2426 }
2427 
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2428 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2429   auto f = FunctionDefHelper::Create(
2430       // Name
2431       "FunctionWhichAdds",
2432       // Inputs
2433       {"arg0: int32", "arg1: int32"},
2434       // Outputs
2435       {"ret0: int32"},
2436       /*attr_def=*/{},
2437       // Nodes
2438       {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2439       /*ret_def=*/{{"ret0", "a:z:0"}});
2440 
2441   FunctionDefLibrary function_lib;
2442   function_lib.add_function()->Swap(&f);
2443   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2444   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2445 
2446   PartialTensorShape input_shape({2, 2, -1});
2447   Output in1 =
2448       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2449   Output in2 =
2450       ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2451   NameAttrList b_name_attr;
2452   b_name_attr.set_name("FunctionWhichAdds");
2453   ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2454                             b_name_attr);
2455 
2456   GrapplerItem item;
2457   TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2458 
2459   GraphProperties properties(item);
2460   TF_ASSERT_OK(properties.InferStatically(
2461       /*assume_valid_feeds=*/true,
2462       /*aggressive_shape_inference=*/false,
2463       /*include_tensor_values=*/true));
2464 
2465   EXPECT_EQ("int32: [2,2,-1]",
2466             PropToString(properties.GetOutputProperties("add_call")[0]));
2467 }
2468 
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2469 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2470   // A function, which we cannot infer output shape statically.
2471   auto f = FunctionDefHelper::Create(
2472       // Name
2473       "FuncShapeCannotBeInferred",
2474       // Inputs
2475       {},
2476       // Outputs
2477       {"output: float"},
2478       // Attrs
2479       {},
2480       // Nodes
2481       {
2482           // Placeholder without shape attr; unknown rank.
2483           {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2484       },
2485       // Returns
2486       {{"output", "p:output:0"}});
2487   FunctionDefLibrary function_lib;
2488   function_lib.add_function()->Swap(&f);
2489   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2490   TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2491   tensorflow::Node* func_op;
2492   TensorShapeProto output_shape;
2493   output_shape.set_unknown_rank(false);
2494   output_shape.add_dim()->set_size(1);
2495   output_shape.add_dim()->set_size(2);
2496   output_shape.add_dim()->set_size(3);
2497   output_shape.add_dim()->set_size(4);
2498   // The function node, f, includes shape annotation.
2499   TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2500                                        s.graph()->op_registry())
2501                    .Attr("_execution_count", 1)
2502                    .Attr("_same_output_for_iterations", true)
2503                    .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2504                    .Attr("_output_shape_vector", {output_shape})
2505                    .Finalize(s.graph(), &func_op));
2506   GrapplerItem item;
2507   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2508 
2509   // InferStatically with aggressive_shape_inference would fail to infer
2510   // the output shape of the node f.
2511   {
2512     GraphProperties properties(item);
2513     TF_ASSERT_OK(properties.InferStatically(
2514         /*assume_valid_feeds=*/false,
2515         /*aggressive_shape_inference=*/false,
2516         /*include_tensor_values=*/false));
2517     const auto out_props = properties.GetOutputProperties("f");
2518     const OpInfo::TensorProperties out_prop0 = out_props[0];
2519     EXPECT_EQ("float: ?", PropToString(out_prop0));
2520   }
2521   // With aggressive_shape_inference, it skips recursively callying
2522   // InferStatically for the function node and outputs annotated shape info.
2523   {
2524     GraphProperties properties(item);
2525     TF_ASSERT_OK(properties.InferStatically(
2526         /*assume_valid_feeds=*/false,
2527         /*aggressive_shape_inference=*/true,
2528         /*include_tensor_values=*/true));
2529     const auto out_props = properties.GetOutputProperties("f");
2530     const OpInfo::TensorProperties out_prop0 = out_props[0];
2531     EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2532   }
2533 }
2534 
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2535 TEST_F(GraphPropertiesTest,
2536        SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2537   GrapplerItem item;
2538   // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2539   // used for two reshape ops. One reshape op is segment_ids input to
2540   // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2541   // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2542   // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2543   // SymbolicShapeRefiner.
2544   // This dim value (10), however, should not affect the other reshape op, even
2545   // though it shares the shape input; -1 in the shape input of Reshape op is
2546   // a special case of computed output dim, not unknown dim.
2547   // data and num_segments are inputs to UnsortedSegmenetSum.
2548 
2549   TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2550                    .Attr("dtype", DT_FLOAT)
2551                    .Attr("shape", TensorShape({10, 10, 10, 10}))
2552                    .Finalize(item.graph.add_node()));
2553   Tensor num_segments(DT_INT32, TensorShape({}));
2554   // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2555   // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2556   // output to form shape vector [-1, 10] to Reshape.
2557   test::FillIota<int>(&num_segments, 3);
2558   TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2559                    .Attr("dtype", DT_INT32)
2560                    .Attr("value", num_segments)
2561                    .Finalize(item.graph.add_node()));
2562   Tensor minus_one(DT_INT32, TensorShape({1}));
2563   test::FillIota<int>(&minus_one, -1);
2564   TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2565                    .Attr("dtype", DT_INT32)
2566                    .Attr("value", minus_one)
2567                    .Finalize(item.graph.add_node()));
2568   Tensor plus_ten(DT_INT32, TensorShape({1}));
2569   test::FillIota<int>(&plus_ten, 10);
2570   TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2571                    .Attr("dtype", DT_INT32)
2572                    .Attr("value", plus_ten)
2573                    .Finalize(item.graph.add_node()));
2574   Tensor axis(DT_INT32, TensorShape({}));
2575   test::FillIota<int>(&axis, -1);
2576   TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2577                    .Attr("dtype", DT_INT32)
2578                    .Attr("value", axis)
2579                    .Finalize(item.graph.add_node()));
2580   std::vector<NodeDefBuilder::NodeOut> inputs(2);
2581   inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2582   inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2583   TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2584                    .Input(inputs)
2585                    .Input("axis", 0, DT_INT32)
2586                    .Attr("N", 2)
2587                    .Attr("T", DT_INT32)
2588                    .Attr("Tidx", DT_INT32)
2589                    .Finalize(item.graph.add_node()));
2590   TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2591                    .Attr("dtype", DT_FLOAT)
2592                    .Finalize(item.graph.add_node()));
2593   TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2594                    .Input("segment_ids_", 0, DT_FLOAT)
2595                    .Attr("T", DT_FLOAT)
2596                    .Attr("out_type", DT_INT32)
2597                    .Finalize(item.graph.add_node()));
2598   TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2599                    .Input("segment_ids_", 0, DT_FLOAT)
2600                    .Input("concat", 0, DT_INT32)
2601                    .Attr("T", DT_FLOAT)
2602                    .Attr("Tshape", DT_INT32)
2603                    .Finalize(item.graph.add_node()));
2604   // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2605   // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2606   // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2607   // assign 10 from data shape to the unknown dim in segment_ids.
2608   TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2609                    .Input("data", 0, DT_FLOAT)
2610                    .Input("segment_ids", 0, DT_INT32)
2611                    .Input("num_segments", 0, DT_INT32)
2612                    .Attr("T", DT_FLOAT)
2613                    .Attr("Tindices", DT_INT32)
2614                    .Attr("Tnumsegments", DT_INT32)
2615                    .Finalize(item.graph.add_node()));
2616   // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2617   // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2618   TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2619                    .Attr("dtype", DT_FLOAT)
2620                    .Finalize(item.graph.add_node()));
2621   TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2622                    .Input("x1", 0, DT_FLOAT)
2623                    .Input("concat", 0, DT_INT32)
2624                    .Attr("T", DT_FLOAT)
2625                    .Attr("Tshape", DT_INT32)
2626                    .Finalize(item.graph.add_node()));
2627 
2628   GraphProperties properties(item);
2629   TF_ASSERT_OK(properties.InferStatically(true));
2630   const auto& y1_output_properties = properties.GetOutputProperties("y1");
2631   // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2632   // The first dimension should not be 10.
2633   EXPECT_EQ(y1_output_properties.size(), 1);
2634   EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2635   EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2636   EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2637 }
2638 
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2639 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2640   // Make a dummy InferenceContext.
2641   NodeDef node_def;
2642   OpRegistrationData op_reg_data;
2643   OpDefBuilder b("dummy");
2644   CHECK(b.Finalize(&op_reg_data).ok());
2645   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2646       input_handle_shapes_and_types;
2647   InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2648                       /*input_shapes=*/{},
2649                       /*input_tensors=*/{},
2650                       /*input_tensors_as_shapes=*/{},
2651                       std::move(input_handle_shapes_and_types));
2652 
2653   // ShapeHandles for testing.
2654   ShapeHandle fully_defined_vector = ic.MakeShape(
2655       {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2656   ShapeHandle vector_with_unknown = ic.MakeShape(
2657       {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2658   // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2659   // graph_properties.cc
2660   ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2661       {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2662   ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2663 
2664   // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2665   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2666       &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2667   EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2668       &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2669 
2670   // Non-integer data type.
2671   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2672       &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2673 
2674   // tensor_as_shape including Unknown or UnknownFromConst.
2675   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2676       &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2677   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2678       &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2679   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2680       &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2681   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2682       &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2683 
2684   // shape rank > 1.
2685   EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2686       &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2687 }
2688 }  // namespace
2689 }  // namespace grappler
2690 }  // namespace tensorflow
2691