1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/grappler_item.h"
30 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
31 #include "tensorflow/core/grappler/inputs/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #ifdef INTEL_MKL
38 #include "tensorflow/core/graph/mkl_graph_util.h"
39 #endif
40
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48
49 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
50
51 REGISTER_OP("TestOpWithNoInferenceFn")
52 .Input("x: float")
53 .Output("y: float")
54 .Doc(R"doc(
55 Test op with no Inference Function registered.
56 x: input
57 y: output
58 )doc");
59
60 class GraphPropertiesTest : public ::testing::Test {
61 public:
SetUp()62 void SetUp() override {
63 // Provision a single machine with 3 cpu cores
64 cluster_.reset(new SingleMachine(5 * 60, 3, 0));
65 TF_ASSERT_OK(cluster_->Provision());
66
67 // This function is simply
68 // out = Fill(shape, value), but
69 // Fill requires values in the shape input, not just shape of it, to infer
70 // output shape.
71 auto f = FunctionDefHelper::Create(
72 // Name
73 "MyFillFunc",
74 // Inputs
75 {"shape: int32", "value: float"},
76 // Outputs
77 {"out: float"},
78 // Attrs
79 {},
80 // Nodes
81 {
82 {{"a"},
83 "Fill",
84 {"shape", "value"},
85 {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
86 },
87 // Returns
88 {{"out", "a:output:0"}});
89 function_lib_.add_function()->Swap(&f);
90 }
91
TearDown()92 void TearDown() override {
93 TF_ASSERT_OK(cluster_->Shutdown());
94 cluster_.reset();
95 }
96
97 protected:
98 // Returns a string form of <p>, suitable for comparing type and shape.
99 // Example output for 4-d float tensor: "float: [10,2,30,4]"
PropToString(const OpInfo::TensorProperties & p)100 string PropToString(const OpInfo::TensorProperties& p) {
101 string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
102 if (p.shape().unknown_rank()) {
103 strings::StrAppend(&s, "?");
104 } else {
105 strings::StrAppend(&s, "[");
106 for (int i = 0; i < p.shape().dim_size(); ++i) {
107 strings::StrAppend(&s, i == 0 ? "" : ",",
108 std::max<int64_t>(p.shape().dim(i).size(), -1));
109 }
110 strings::StrAppend(&s, "]");
111 }
112 return s;
113 }
114
115 // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
116 // ones.
ExpectTensorValues(const std::vector<int64_t> & expected,const TensorProto & tensor_proto_to_compare)117 void ExpectTensorValues(const std::vector<int64_t>& expected,
118 const TensorProto& tensor_proto_to_compare) {
119 Tensor tensor;
120 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
121 EXPECT_EQ(expected.size(), tensor.NumElements());
122 // We're interested in only integer tensors as only shapes are exported as
123 // graph properties values.
124 ASSERT_TRUE(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
125 if (tensor.dtype() == DT_INT32) {
126 for (int i = 0; i < tensor.NumElements(); i++) {
127 EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
128 }
129 } else {
130 for (int i = 0; i < tensor.NumElements(); i++) {
131 EXPECT_EQ(expected[i], tensor.flat<int64_t>()(i));
132 }
133 }
134 }
135
136 // Compare values of float (DT_FLOAT) tensor against expected
137 // ones.
ExpectFloatTensorValues(const std::vector<float> & expected,const TensorProto & tensor_proto_to_compare)138 void ExpectFloatTensorValues(const std::vector<float>& expected,
139 const TensorProto& tensor_proto_to_compare) {
140 Tensor tensor;
141 ASSERT_TRUE(tensor.FromProto(tensor_proto_to_compare));
142 EXPECT_EQ(expected.size(), tensor.NumElements());
143 ASSERT_EQ(tensor.dtype(), DT_FLOAT);
144 for (int i = 0; i < tensor.NumElements(); i++) {
145 EXPECT_EQ(expected[i], tensor.flat<float>()(i));
146 }
147 }
148
149 std::unique_ptr<SingleMachine> cluster_;
150 FunctionDefLibrary function_lib_;
151 };
152
TEST_F(GraphPropertiesTest,StaticProperties)153 TEST_F(GraphPropertiesTest, StaticProperties) {
154 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
155 cluster_->GetDeviceNames());
156 GrapplerItem item;
157 CHECK(fake_input.NextItem(&item));
158
159 GraphProperties properties(item);
160 Status s = properties.InferStatically(true);
161 TF_ASSERT_OK(s);
162
163 for (const auto& node : item.graph.node()) {
164 if (node.op() == "RandomStandardNormal") {
165 // The node has one input (the shape of the tensor to generate).
166 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
167 // The const node has one output.
168 const auto props = properties.GetOutputProperties(node.name());
169 EXPECT_EQ(1, props.size());
170 const OpInfo::TensorProperties& prop = props[0];
171 EXPECT_EQ(DT_FLOAT, prop.dtype());
172 EXPECT_FALSE(prop.shape().unknown_rank());
173 EXPECT_EQ(2, prop.shape().dim_size());
174 EXPECT_EQ(10, prop.shape().dim(0).size());
175 EXPECT_EQ(1, prop.shape().dim(1).size());
176 } else if (node.op() == "AddN") {
177 const auto in_props = properties.GetInputProperties(node.name());
178 EXPECT_EQ(1, in_props.size());
179 const OpInfo::TensorProperties& in_prop = in_props[0];
180 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
181 EXPECT_FALSE(in_prop.shape().unknown_rank());
182 EXPECT_EQ(2, in_prop.shape().dim_size());
183 EXPECT_EQ(10, in_prop.shape().dim(0).size());
184 EXPECT_EQ(1, in_prop.shape().dim(1).size());
185 const auto out_props = properties.GetOutputProperties(node.name());
186 EXPECT_EQ(1, out_props.size());
187 EXPECT_EQ(in_prop.dtype(), out_props[0].dtype());
188 EXPECT_EQ(in_prop.shape().DebugString(),
189 out_props[0].shape().DebugString());
190 }
191 }
192 }
193
TEST_F(GraphPropertiesTest,ClearProperties)194 TEST_F(GraphPropertiesTest, ClearProperties) {
195 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
196 cluster_->GetDeviceNames());
197 GrapplerItem item;
198 CHECK(fake_input.NextItem(&item));
199
200 GraphProperties properties(item);
201 Status s = properties.InferStatically(true);
202 TF_ASSERT_OK(s);
203
204 for (const auto& node : item.graph.node()) {
205 if (node.op() == "RandomStandardNormal") {
206 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
207 const auto props = properties.GetOutputProperties(node.name());
208 properties.ClearOutputProperties(node.name());
209 const auto cleared_props = properties.GetOutputProperties(node.name());
210 EXPECT_TRUE(cleared_props.empty());
211 } else if (node.op() == "AddN") {
212 const auto in_props = properties.GetInputProperties(node.name());
213 EXPECT_EQ(1, in_props.size());
214 properties.ClearInputProperties(node.name());
215 const auto cleared_props = properties.GetInputProperties(node.name());
216 EXPECT_TRUE(cleared_props.empty());
217 }
218 }
219 }
220
TEST_F(GraphPropertiesTest,Clear)221 TEST_F(GraphPropertiesTest, Clear) {
222 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
223 cluster_->GetDeviceNames());
224 GrapplerItem item;
225 CHECK(fake_input.NextItem(&item));
226
227 GraphProperties properties(item);
228 Status s = properties.InferStatically(true);
229 TF_ASSERT_OK(s);
230
231 EXPECT_TRUE(properties.has_properties());
232 properties.Clear();
233 EXPECT_FALSE(properties.has_properties());
234 }
235
TEST_F(GraphPropertiesTest,DynamicProperties)236 TEST_F(GraphPropertiesTest, DynamicProperties) {
237 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
238 cluster_->GetDeviceNames());
239 GrapplerItem item;
240 CHECK(fake_input.NextItem(&item));
241
242 GraphProperties properties(item);
243 TF_ASSERT_OK(cluster_->Initialize(item));
244 Status s = properties.InferDynamically(cluster_.get());
245 TF_ASSERT_OK(s);
246
247 for (const auto& node : item.graph.node()) {
248 if (node.op() == "RandomStandardNormal") {
249 // The random node is missing from the cost graph (why ?)
250 EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
251 } else if (node.op() == "AddN") {
252 // Since the random node is missing, we can't infer the input properties
253 // of the first AddN node. The other AddN nodes have the expected
254 // properties.
255 if (node.name() == "AddN") {
256 const auto props = properties.GetInputProperties(node.name());
257 EXPECT_EQ(1, props.size());
258 const OpInfo::TensorProperties& prop = props[0];
259 EXPECT_EQ(DT_INVALID, prop.dtype());
260 EXPECT_TRUE(prop.shape().unknown_rank());
261 } else {
262 const auto props = properties.GetInputProperties(node.name());
263 EXPECT_EQ(1, props.size());
264 const OpInfo::TensorProperties& prop = props[0];
265 EXPECT_EQ(DT_FLOAT, prop.dtype());
266 EXPECT_FALSE(prop.shape().unknown_rank());
267 EXPECT_EQ(2, prop.shape().dim_size());
268 EXPECT_EQ(10, prop.shape().dim(0).size());
269 EXPECT_EQ(1, prop.shape().dim(1).size());
270 const auto out_props = properties.GetOutputProperties(node.name());
271 #ifdef INTEL_MKL
272 if (!NativeFormatEnabled()) {
273 // Intel MKL AddN OP would have two output.
274 // One is the real output, another one for MKL metadata
275 EXPECT_EQ(2, out_props.size());
276 } else {
277 EXPECT_EQ(1, out_props.size());
278 }
279 #else
280 EXPECT_EQ(1, out_props.size());
281 #endif // INTEL_MKL
282 string prop_str;
283 ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
284 string out_prop_str;
285 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
286 &out_prop_str);
287 EXPECT_EQ(prop_str, out_prop_str);
288 }
289 }
290 }
291 }
292
293 // A test op that outputs different shape based on input_tensor in the shape
294 // inference context.
295 REGISTER_OP("DetectInputValueInShapeInferenceOp")
296 .Input("a: T")
297 .Output("o: T")
298 .Attr("T: {numbertype, bool}")
__anon3628d9fb0202(shape_inference::InferenceContext* c) 299 .SetShapeFn([](shape_inference::InferenceContext* c) {
300 if (c->input_tensor(0)) {
301 // 10x10 if input_tensor is given to the inference context.
302 c->set_output(0, c->Matrix(10, 10));
303 return OkStatus();
304 }
305 // unknown rank if input_tensor is not provided.
306 return shape_inference::UnknownShape(c);
307 });
308
309 // Helper class for testing Const tensor skip.
310 class ConstTensorSkipTestCase {
311 public:
ConstTensorSkipTestCase(const DataType data_type,const std::vector<int64_t> shape,const double value,const bool expected)312 ConstTensorSkipTestCase(const DataType data_type,
313 const std::vector<int64_t> shape, const double value,
314 const bool expected)
315 : data_type_(data_type),
316 shape_(shape),
317 value_(value),
318 expected_(expected) {}
319
RunTestAndValidate() const320 void RunTestAndValidate() const {
321 LOG(INFO) << "Run Const tensor skip test: "
322 << "data_type: " << data_type_ << ", shape: {"
323 << absl::StrJoin(shape_, ",") << "}, value: " << value_
324 << ", expected: " << expected_;
325 // Build a graph with Const --> Identity --> Detect.
326 GrapplerItem item;
327 const gtl::ArraySlice<int64_t> shape_array_slice(shape_);
328 Tensor const_tensor_value(data_type_, TensorShape(shape_array_slice));
329 // Fill the const tensor value based on data type.
330 switch (data_type_) {
331 case DT_INT32:
332 test::FillIota<int32>(&const_tensor_value, static_cast<int32>(value_));
333 break;
334 case DT_INT64:
335 test::FillIota<int64_t>(&const_tensor_value,
336 static_cast<int64_t>(value_));
337 break;
338 case DT_FLOAT:
339 test::FillIota<float>(&const_tensor_value, static_cast<float>(value_));
340 break;
341 case DT_DOUBLE:
342 test::FillIota<double>(&const_tensor_value,
343 static_cast<double>(value_));
344 break;
345 case DT_BFLOAT16:
346 test::FillIota<Eigen::bfloat16>(&const_tensor_value,
347 static_cast<Eigen::bfloat16>(value_));
348 break;
349 default:
350 CHECK(false) << "Unsupported data type (" << data_type_
351 << ") in this test.";
352 break;
353 }
354 TF_ASSERT_OK(NodeDefBuilder("const", "Const")
355 .Attr("dtype", data_type_)
356 .Attr("value", const_tensor_value)
357 .Finalize(item.graph.add_node()));
358 TF_ASSERT_OK(NodeDefBuilder("const_identity", "Identity")
359 .Attr("dtype", data_type_)
360 .Input("const", 0, data_type_)
361 .Finalize(item.graph.add_node()));
362 TF_ASSERT_OK(NodeDefBuilder("detect", "DetectInputValueInShapeInferenceOp")
363 .Attr("T", data_type_)
364 .Input("const_identity", 0, data_type_)
365 .Finalize(item.graph.add_node()));
366 item.fetch.push_back("const");
367 item.fetch.push_back("const_identity");
368 item.fetch.push_back("detect");
369
370 // Run static shape inference.
371 GraphProperties graph_properties(item);
372 TF_ASSERT_OK(graph_properties.InferStatically(false));
373
374 // Extract input / output properties of interest.
375 const auto& const_output = graph_properties.GetOutputProperties("const");
376 EXPECT_EQ(1, const_output.size());
377 const OpInfo::TensorProperties& const_output0 = const_output[0];
378 const auto& const_identity_input =
379 graph_properties.GetInputProperties("const_identity");
380 EXPECT_EQ(1, const_identity_input.size());
381 const OpInfo::TensorProperties& const_identity_input0 =
382 const_identity_input[0];
383 const auto& const_identity_output =
384 graph_properties.GetOutputProperties("const_identity");
385 EXPECT_EQ(1, const_identity_output.size());
386 const OpInfo::TensorProperties& const_identity_output0 =
387 const_identity_output[0];
388 EXPECT_TRUE(const_output0.has_value());
389 EXPECT_TRUE(const_identity_input0.has_value());
390 EXPECT_TRUE(const_identity_output0.has_value());
391 const auto& detect_input = graph_properties.GetInputProperties("detect");
392 EXPECT_EQ(1, detect_input.size());
393 const OpInfo::TensorProperties& detect_input0 = detect_input[0];
394 const auto& detect_output = graph_properties.GetOutputProperties("detect");
395 EXPECT_EQ(1, detect_output.size());
396 const OpInfo::TensorProperties& detect_output0 = detect_output[0];
397
398 // Tensor protos are propagated, regardless of types and sizes.
399 EXPECT_TRUE(const_output0.has_value());
400 EXPECT_TRUE(const_identity_input0.has_value());
401 EXPECT_TRUE(const_identity_output0.has_value());
402 EXPECT_TRUE(detect_input0.has_value());
403
404 // Detect op outputs 10x10 matrix if it has input_tensor in the shape
405 // inference context. Otherwise, unknown rank.
406 if (expected_) {
407 EXPECT_EQ(detect_output0.shape().dim_size(), 2);
408 EXPECT_EQ(detect_output0.shape().dim(0).size(), 10);
409 EXPECT_EQ(detect_output0.shape().dim(1).size(), 10);
410 } else {
411 EXPECT_TRUE(detect_output0.shape().unknown_rank());
412 }
413 }
414
415 private:
416 DataType data_type_;
417 std::vector<int64_t> shape_;
418 double value_;
419 bool expected_;
420 };
421
TEST_F(GraphPropertiesTest,SkipInstantiatingConstTensor)422 TEST_F(GraphPropertiesTest, SkipInstantiatingConstTensor) {
423 // We skip const tensor value propagation in shape inference, if a const
424 // tensor is too large.
425 std::vector<ConstTensorSkipTestCase> test_cases = {
426 // data_type, shape, value, bool: propagate const?
427 {DT_INT32, {16, 8}, 1, true}, // 128 elements; smaller than threshold
428 {DT_INT32, {1, 129}, 2, false}, // 129 elements; larger than threshold
429 {DT_INT64, {8, 8}, 3, true}, // 64 elements; smaller than threshold
430 {DT_INT64, {128, 2}, 0, false}, // 256 elements; larger than threshold
431 {DT_FLOAT, {16, 8}, 1.0, true}, // integer value for float tensor
432 {DT_FLOAT, {16, 8}, 1.3, true}, // fractional value (1.3)
433 {DT_FLOAT, {1, 129}, 0.7, false}, // fractional value (0.7)
434 {DT_DOUBLE, {16, 8}, 1.0, true}, // integer value for float tensor
435 {DT_DOUBLE, {16, 8}, 1.3, true}, // fractional value (1.3)
436 {DT_DOUBLE, {1, 129}, 0.7, false}, // fractional value (0.7)
437 {DT_BFLOAT16, {16, 8}, 1.0, true}, // integer value for float tensor
438 {DT_BFLOAT16, {16, 8}, 1.3, true}, // fractional value (1.3)
439 {DT_BFLOAT16, {1, 129}, 0.7, false}, // fractional value (0.7)
440 };
441 for (const auto& test_case : test_cases) {
442 test_case.RunTestAndValidate();
443 }
444 }
445
TEST_F(GraphPropertiesTest,Variables)446 TEST_F(GraphPropertiesTest, Variables) {
447 GrapplerItem item;
448 TF_ASSERT_OK(NodeDefBuilder("Var", "Variable")
449 .Attr("dtype", DT_FLOAT)
450 .Attr("shape", TensorShape({3, 7}))
451 .Finalize(item.graph.add_node()));
452 item.fetch.push_back("Var");
453
454 Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
455 test::FillIota<float>(&initial_val, 0);
456 TF_ASSERT_OK(NodeDefBuilder("InitialVal", "Const")
457 .Attr("dtype", DT_FLOAT)
458 .Attr("value", initial_val)
459 .Finalize(item.graph.add_node()));
460 TF_ASSERT_OK(NodeDefBuilder("InitVar", "Assign")
461 .Input("Var", 0, DT_FLOAT_REF)
462 .Input("InitialVal", 0, DT_FLOAT)
463 .Finalize(item.graph.add_node()));
464 item.init_ops.push_back("InitVar");
465
466 {
467 GraphProperties static_properties(item);
468 TF_ASSERT_OK(static_properties.InferStatically(false));
469
470 const auto props = static_properties.GetOutputProperties("Var");
471 EXPECT_EQ(1, props.size());
472 const OpInfo::TensorProperties& prop = props[0];
473 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
474 EXPECT_FALSE(prop.shape().unknown_rank());
475 EXPECT_EQ(2, prop.shape().dim_size());
476 EXPECT_EQ(3, prop.shape().dim(0).size());
477 EXPECT_EQ(7, prop.shape().dim(1).size());
478 }
479 {
480 TF_ASSERT_OK(cluster_->Initialize(item));
481 GraphProperties dynamic_properties(item);
482 TF_ASSERT_OK(dynamic_properties.InferDynamically(cluster_.get()));
483
484 const auto props = dynamic_properties.GetOutputProperties("Var");
485 EXPECT_EQ(1, props.size());
486 const OpInfo::TensorProperties& prop = props[0];
487 EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
488 EXPECT_FALSE(prop.shape().unknown_rank());
489 EXPECT_EQ(2, prop.shape().dim_size());
490 EXPECT_EQ(3, prop.shape().dim(0).size());
491 EXPECT_EQ(7, prop.shape().dim(1).size());
492 }
493 }
494
TEST_F(GraphPropertiesTest,ReadVariableOpAfterEnter)495 TEST_F(GraphPropertiesTest, ReadVariableOpAfterEnter) {
496 GrapplerItem item;
497 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
498 .Attr("dtype", DT_FLOAT)
499 .Attr("shape", TensorShape({3, 7}))
500 .Finalize(item.graph.add_node()));
501 TF_ASSERT_OK(NodeDefBuilder("Enter", "Enter")
502 .Attr("T", DT_RESOURCE)
503 .Attr("frame_name", "while_context")
504 .Attr("is_constant", true)
505 .Attr("parallel_iterations", 10)
506 .Input("Var", 0, DT_RESOURCE)
507 .Finalize(item.graph.add_node()));
508 TF_ASSERT_OK(NodeDefBuilder("ReadVariableOpAfterEnter", "ReadVariableOp")
509 .Attr("dtype", DT_FLOAT)
510 .Input("Enter", 0, DT_RESOURCE)
511 .Finalize(item.graph.add_node()));
512
513 GraphProperties properties(item);
514 TF_ASSERT_OK(properties.InferStatically(false));
515 const auto props = properties.GetOutputProperties("ReadVariableOpAfterEnter");
516 EXPECT_EQ(1, props.size());
517 const OpInfo::TensorProperties& prop = props[0];
518 EXPECT_EQ(DT_FLOAT, prop.dtype());
519 EXPECT_FALSE(prop.shape().unknown_rank());
520 EXPECT_EQ(2, prop.shape().dim_size());
521 EXPECT_EQ(3, prop.shape().dim(0).size());
522 EXPECT_EQ(7, prop.shape().dim(1).size());
523 }
524
TEST_F(GraphPropertiesTest,VarHandles)525 TEST_F(GraphPropertiesTest, VarHandles) {
526 GrapplerItem item;
527 TF_ASSERT_OK(NodeDefBuilder("Var", "VarHandleOp")
528 .Attr("dtype", DT_FLOAT)
529 .Attr("shape", TensorShape({3, 7}))
530 .Finalize(item.graph.add_node()));
531
532 TF_ASSERT_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
533 .Attr("dtype", DT_FLOAT)
534 .Input("Var", 0, DT_RESOURCE)
535 .Finalize(item.graph.add_node()));
536
537 GraphProperties properties(item);
538 TF_ASSERT_OK(properties.InferStatically(false));
539
540 const auto props = properties.GetOutputProperties("VarRead");
541 EXPECT_EQ(1, props.size());
542 const OpInfo::TensorProperties& prop = props[0];
543 EXPECT_EQ(DT_FLOAT, prop.dtype());
544 EXPECT_FALSE(prop.shape().unknown_rank());
545 EXPECT_EQ(2, prop.shape().dim_size());
546 EXPECT_EQ(3, prop.shape().dim(0).size());
547 EXPECT_EQ(7, prop.shape().dim(1).size());
548 }
549
TEST_F(GraphPropertiesTest,WhileLoopWithVarHandleOpInput)550 TEST_F(GraphPropertiesTest, WhileLoopWithVarHandleOpInput) {
551 // Test graph is first generated in python using:
552 /*
553 i0 = tf.constant(0)
554 v = tf.get_variable(initializer=i0, name='loop_var', use_resource=True)
555 def cond(i, x):
556 return i < 3
557 def body(i, x):
558 return i + 1, x + x
559 v, y = tf.while_loop(cond, body, loop_vars=[v, tf.constant(1)])
560 */
561 // and then modified by hand such that the ReadVariableOp is inside the loop
562 // body instead of outside the while loop (which is the case when constructed
563 // using the python API), such that we have the following pattern: VarHandleOp
564 // -> Enter -> Switch -> ReadVariableOp -> other parts of loop body. Note
565 // DT_RESOURCE is passed all the way until ReadVariableOp.
566 GrapplerItem item;
567 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
568 "while_loop_var_handle_op.pbtxt");
569 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
570 GraphProperties properties(item);
571 TF_ASSERT_OK(properties.InferStatically(false));
572
573 std::vector<string> resource_nodes{
574 "loop_var", "while/Enter", "while/Merge", "while/Switch",
575 "while/Identity", "while/NextIteration", "while/Exit"};
576 for (const string& node : resource_nodes) {
577 const auto props = properties.GetOutputProperties(node);
578 EXPECT_GE(props.size(), 1); // Merge has 2 outputs.
579 EXPECT_EQ("resource: []", PropToString(props[0]));
580 }
581
582 // After ReadVariableOp, the shape should be recovered.
583 const auto props = properties.GetOutputProperties("while/ReadVariableOp");
584 EXPECT_EQ(1, props.size());
585 EXPECT_EQ("int32: []", PropToString(props[0]));
586 }
587
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_NoShapeAttr)588 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
589 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
590 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
591 auto dequeue1 =
592 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
593
594 GrapplerItem item;
595 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
596
597 GraphProperties properties(item);
598 TF_ASSERT_OK(properties.InferStatically(false));
599
600 const auto props1 = properties.GetOutputProperties("Dequeue1");
601 ASSERT_EQ(1, props1.size());
602 EXPECT_EQ("float: ?", PropToString(props1[0]));
603 }
604
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_ShapeAttr)605 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
606 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
607 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
608 ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
609 auto dequeue1 =
610 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
611
612 GrapplerItem item;
613 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
614
615 GraphProperties properties(item);
616 TF_ASSERT_OK(properties.InferStatically(false));
617
618 const auto props1 = properties.GetOutputProperties("Dequeue1");
619 ASSERT_EQ(1, props1.size());
620 EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
621 }
622
TEST_F(GraphPropertiesTest,QueueWithOnlyDequeue_PartialShapeAttr)623 TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
624 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
625 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
626 ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
627 auto dequeue1 =
628 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
629
630 GrapplerItem item;
631 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
632
633 GraphProperties properties(item);
634 TF_ASSERT_OK(properties.InferStatically(false));
635
636 const auto props1 = properties.GetOutputProperties("Dequeue1");
637 ASSERT_EQ(1, props1.size());
638 EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
639 }
640
TEST_F(GraphPropertiesTest,Queues)641 TEST_F(GraphPropertiesTest, Queues) {
642 // Create a graph with known input shapes, and propagate the shapes through a
643 // couple of queues.
644 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
645
646 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
647 Output rnd =
648 ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
649 Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
650 auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
651 auto dequeue1 =
652 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
653
654 auto q2 =
655 ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
656 Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
657 auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
658 auto dequeue2 =
659 ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
660
661 auto q4 =
662 ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
663 auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
664 auto enqueue4_2 =
665 ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
666 auto dequeue4 =
667 ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
668
669 // Create a queue that takes in three tensors.
670 auto q5 = ops::RandomShuffleQueue(
671 root.WithOpName("Queue5"),
672 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
673 Output rnd2 =
674 ops::RandomNormal(root.WithOpName("rnd2"), {10}, DataType::DT_DOUBLE);
675 Output rnd3 =
676 ops::RandomNormal(root.WithOpName("rnd3"), {1, 2, 3}, DataType::DT_FLOAT);
677 auto enqueue5 =
678 ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
679 auto dequeue5 = ops::QueueDequeue(
680 root.WithOpName("Dequeue5"), q5,
681 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
682
683 GrapplerItem item;
684 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
685
686 GraphProperties properties(item);
687 TF_ASSERT_OK(properties.InferStatically(false));
688
689 const auto props1 = properties.GetOutputProperties("Dequeue1");
690 ASSERT_EQ(1, props1.size());
691 EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
692
693 const auto props2 = properties.GetOutputProperties("Dequeue2");
694 ASSERT_EQ(1, props2.size());
695 EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
696
697 // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
698 // that we merge the 2 properly to determine the shape of the data coming out
699 // of the queue.
700 const auto props4 = properties.GetOutputProperties("Dequeue4");
701 ASSERT_EQ(1, props4.size());
702 EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
703
704 // The dequeue5 op shape is known.
705 const auto props5 = properties.GetOutputProperties("Dequeue5");
706 ASSERT_EQ(3, props5.size());
707 EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
708 EXPECT_EQ("double: [10]", PropToString(props5[1]));
709 EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
710 }
711
TEST_F(GraphPropertiesTest,MergeWithoutLoops)712 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
713 // Test graph produced in python using:
714 /*
715 with tf.Graph().as_default():
716 x = tf.constant(2)
717 y = tf.constant(5)
718 z = tf.ones([1,1,1])
719 def f1(): return tf.concat([z, z], axis=0)
720 def f2(): return tf.concat([z, z], axis=1)
721 r = tf.cond(tf.less(x, y), f1, f2)
722 tf.concat([r, r], axis=2)
723 with open('/tmp/graph.pbtxt', 'w') as f:
724 f.write(str(tf.get_default_graph().as_graph_def()))
725 */
726
727 GrapplerItem item;
728 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
729 "merge_without_loops.pbtxt");
730 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
731 GraphProperties properties(item);
732 TF_ASSERT_OK(properties.InferStatically(false));
733
734 std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
735 std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
736 "float: [1,2,1]"};
737 for (int i = 0; i < nodes.size(); i++) {
738 const auto props = properties.GetOutputProperties(nodes[i]);
739 const OpInfo::TensorProperties& prop = props[0];
740 EXPECT_EQ(DT_FLOAT, prop.dtype());
741 EXPECT_EQ(expected_outputs[i], PropToString(prop));
742 }
743
744 // The "Less" node should be fed by 2 int32 scalar constant values.
745 const auto props = properties.GetInputProperties("Less");
746 EXPECT_EQ(2, props.size());
747 for (int i = 0; i < props.size(); ++i) {
748 EXPECT_EQ(DT_INT32, props[i].dtype());
749 EXPECT_TRUE(props[i].has_value());
750 EXPECT_EQ("int32: []", PropToString(props[i]));
751 }
752 }
753
TEST_F(GraphPropertiesTest,WhileLoop)754 TEST_F(GraphPropertiesTest, WhileLoop) {
755 // Test graph produced in python using:
756 /*
757 with tf.Graph().as_default():
758 i0 = tf.constant(0)
759 m0 = tf.placeholder([-1, 2])
760 c = lambda i, m: i < 10
761 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
762 r = tf.while_loop(
763 c, b, loop_vars=[i0, m0],
764 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
765 with open('/tmp/graph.pbtxt', 'w') as f:
766 f.write(str(tf.get_default_graph().as_graph_def()))
767 */
768
769 GrapplerItem item;
770 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
771 "while_loop.pbtxt");
772 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
773 GraphProperties properties(item);
774 TF_ASSERT_OK(properties.InferStatically(false));
775
776 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
777 "while/Exit_1"};
778 for (const string& node : nodes) {
779 const auto props = properties.GetOutputProperties(node);
780 const OpInfo::TensorProperties& prop = props[0];
781 EXPECT_EQ(DT_FLOAT, prop.dtype());
782 EXPECT_EQ("float: [-1,2]", PropToString(prop));
783 }
784
785 // The loop outputs batch dim should be different from the input batch dim
786 // since we concatenated along the batch dim.
787 auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
788 auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
789 EXPECT_GE(-2, shape_in.dim(0).size());
790 EXPECT_GE(-2, shape_out.dim(0).size());
791 EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
792 }
793
TEST_F(GraphPropertiesTest,NestedLoop)794 TEST_F(GraphPropertiesTest, NestedLoop) {
795 // Test graph produced in python using:
796 /*
797 with tf.Graph().as_default():
798 i0 = tf.constant(0)
799
800 def inner(j, y):
801 def inner_cond(j, y):
802 return j < 3
803
804 def inner_body(j, y):
805 return j+1, tf.concat([y, y], axis=2)
806
807 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
808 shape_invariants=[i0.get_shape(),
809 tf.TensorShape([None, 1, None])])
810
811 def outer_cond(i, x):
812 return i < 3
813
814 def outer_body(i, x):
815 j, y = inner(0, x)
816 return i+1, tf.concat([x, x], axis=0)
817
818 r = tf.while_loop(outer_cond, outer_body,
819 loop_vars=[i0, tf.ones([1, 1, 1])],
820 shape_invariants=[i0.get_shape(),
821 tf.TensorShape([None, 1, None])])
822
823 with open('/tmp/graph.pbtxt', 'w') as f:
824 f.write(str(tf.get_default_graph().as_graph_def()))
825 */
826
827 GrapplerItem item;
828 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
829 "nested_loop.pbtxt");
830 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
831 GraphProperties properties(item);
832 TF_ASSERT_OK(properties.InferStatically(false));
833
834 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
835 "while/Exit_1"};
836 std::vector<string> inner_nodes{"while/while/Merge_1",
837 "while/while/NextIteration_1",
838 "while/while/Exit_1"};
839 for (const string& node : outer_nodes) {
840 const auto props = properties.GetOutputProperties(node);
841 const OpInfo::TensorProperties& prop = props[0];
842 EXPECT_EQ(DT_FLOAT, prop.dtype());
843 EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
844 }
845 for (const string& node : inner_nodes) {
846 const auto props = properties.GetOutputProperties(node);
847 const OpInfo::TensorProperties& prop = props[0];
848 EXPECT_EQ(DT_FLOAT, prop.dtype());
849 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
850 }
851 }
852
TEST_F(GraphPropertiesTest,LoopsAndQueues)853 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
854 // Test graph produced in python using:
855 /*
856 with tf.Graph().as_default():
857 i0 = tf.constant(0)
858 q = tf.FIFOQueue(1, "float")
859
860 def inner(j, y):
861 def inner_cond(j, y):
862 return j < 3
863
864 def inner_body(j, y):
865 return j+1, tf.concat([y, y], axis=0)
866
867 return tf.while_loop(inner_cond, inner_body,
868 loop_vars=[j, y],
869 shape_invariants=[i0.get_shape(),
870 tf.TensorShape(None)])
871
872 def outer_cond(i, x):
873 return i < 3
874
875 def outer_body(i, x):
876 q.enqueue(x)
877 y = tf.concat([x, x], axis=2)
878 inner(0, q.dequeue())
879 return i+1, y
880
881 i, z = tf.while_loop(outer_cond, outer_body,
882 loop_vars=[i0, tf.ones([1, 1, 1])],
883 shape_invariants=[i0.get_shape(),
884 tf.TensorShape([None, 1, None])])
885
886 with open('/tmp/graph.pbtxt', 'w') as f:
887 f.write(str(tf.get_default_graph().as_graph_def()))
888 */
889
890 GrapplerItem item;
891 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
892 "loops_and_queues.pbtxt");
893 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
894 GraphProperties properties(item);
895 TF_ASSERT_OK(properties.InferStatically(false));
896
897 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
898 "while/Exit_1"};
899 std::vector<string> inner_nodes{"while/while/Merge_1",
900 "while/while/NextIteration_1",
901 "while/while/Exit_1"};
902 for (const string& node : outer_nodes) {
903 const auto props = properties.GetOutputProperties(node);
904 const OpInfo::TensorProperties& prop = props[0];
905 EXPECT_EQ(DT_FLOAT, prop.dtype());
906 EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
907 }
908 for (const string& node : inner_nodes) {
909 const auto props = properties.GetOutputProperties(node);
910 const OpInfo::TensorProperties& prop = props[0];
911 EXPECT_EQ(DT_FLOAT, prop.dtype());
912 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
913 }
914 }
915
TEST_F(GraphPropertiesTest,LoopsAndResourceVars)916 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
917 // Test graph produced in python using:
918 /*
919 with tf.Graph().as_default():
920 i0 = tf.constant(0)
921 with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
922 v = tf.get_variable(initializer=i0, name='loop_var')
923
924 def inner(j, y):
925 def inner_cond(j, y):
926 return j < 3
927
928 def inner_body(j, y):
929 return j + 1, y + y
930
931 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
932
933 def outer_cond(i, x):
934 return i < 3
935
936 def outer_body(i, x):
937 y = x + x
938 inner(0, v)
939 return i + 1, y
940
941 v, z = tf.while_loop(outer_cond, outer_body,
942 loop_vars=[v, tf.constant(1)])
943
944 with open('/tmp/graph.pbtxt', 'w') as f:
945 f.write(str(tf.get_default_graph().as_graph_def()))
946 */
947
948 GrapplerItem item;
949 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
950 "loops_and_resource_vars.pbtxt");
951 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
952 GraphProperties properties(item);
953 TF_ASSERT_OK(properties.InferStatically(false));
954
955 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
956 "while/Exit_1"};
957 std::vector<string> inner_nodes{"while/while/Merge_1",
958 "while/while/NextIteration_1",
959 "while/while/Exit_1"};
960 for (const string& node : outer_nodes) {
961 const auto props = properties.GetOutputProperties(node);
962 const OpInfo::TensorProperties& prop = props[0];
963 EXPECT_EQ(DT_INT32, prop.dtype());
964 EXPECT_EQ("int32: []", PropToString(prop));
965 }
966 for (const string& node : inner_nodes) {
967 const auto props = properties.GetOutputProperties(node);
968 const OpInfo::TensorProperties& prop = props[0];
969 EXPECT_EQ(DT_INT32, prop.dtype());
970 EXPECT_EQ("int32: []", PropToString(prop));
971 }
972 }
973
TEST_F(GraphPropertiesTest,QueuesAndLoops)974 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
975 // Test graph produced in python using:
976 /*
977 with tf.Graph().as_default():
978 i0 = tf.constant(0)
979 q0 = tf.FIFOQueue(1, "float")
980 q0.enqueue(tf.ones([2, 2]))
981 q1 = tf.FIFOQueue(1, "float")
982
983 def c(i, m):
984 return i < 10
985
986 def b(i, m):
987 return i+1, tf.concat([m, m], axis=0)
988
989 i, m = tf.while_loop(
990 c, b, loop_vars=[i0, q0.dequeue()],
991 shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
992
993 q1.enqueue(m)
994 v = q1.dequeue();
995 tf.concat([v, v], axis=1)
996 with open('/tmp/graph.pbtxt', 'w') as f:
997 f.write(str(tf.get_default_graph().as_graph_def()))
998 */
999
1000 GrapplerItem item;
1001 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1002 "queues_and_loops.pbtxt");
1003 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1004 GraphProperties properties(item);
1005 TF_ASSERT_OK(properties.InferStatically(false));
1006
1007 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
1008 "while/Exit_1"};
1009
1010 for (const string& node : nodes) {
1011 const auto props = properties.GetOutputProperties(node);
1012 const OpInfo::TensorProperties& prop = props[0];
1013 EXPECT_EQ(DT_FLOAT, prop.dtype());
1014 EXPECT_EQ("float: [-1,2]", PropToString(prop));
1015 }
1016
1017 const auto props = properties.GetOutputProperties("concat");
1018 const OpInfo::TensorProperties& prop = props[0];
1019 EXPECT_EQ(DT_FLOAT, prop.dtype());
1020 EXPECT_EQ("float: [-1,4]", PropToString(prop));
1021 }
1022
TEST_F(GraphPropertiesTest,InferRestoreOpShape)1023 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
1024 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025 Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
1026 DataType::DT_FLOAT);
1027 Output filename =
1028 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1029 Output tensor_name =
1030 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1031 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1032 DataType::DT_FLOAT);
1033 Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
1034
1035 Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
1036 string("256 256 0,128:-"), TensorShape());
1037 Output restore_slice =
1038 ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
1039 shape_and_slice, DataType::DT_FLOAT);
1040 Output init_restore_slice =
1041 ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
1042
1043 Output restore_v2 =
1044 ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
1045 shape_and_slice, DataType::DT_FLOAT);
1046 Output init_restore_v2 =
1047 ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
1048
1049 GrapplerItem item;
1050 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1051 item.fetch.push_back("init_restore");
1052
1053 GraphProperties properties(item);
1054 TF_ASSERT_OK(properties.InferStatically(false));
1055
1056 const auto restore_props = properties.GetOutputProperties("restore");
1057 const OpInfo::TensorProperties& restore_prop = restore_props[0];
1058 EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
1059 EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
1060
1061 const auto restore_slice_props =
1062 properties.GetOutputProperties("restore_slice");
1063 const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
1064 EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
1065 EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
1066
1067 const auto restorev2_props = properties.GetOutputProperties("restore_v2");
1068 const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
1069 EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
1070 EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
1071
1072 // Check input shapes of assign op are propagated correctly.
1073 const auto input_props = properties.GetInputProperties("init_restore");
1074 ASSERT_EQ(2, input_props.size());
1075 const OpInfo::TensorProperties& input_prop = input_props[1];
1076 EXPECT_EQ(DT_FLOAT, input_prop.dtype());
1077 EXPECT_EQ("float: [128,256]", PropToString(input_prop));
1078 }
1079
TEST_F(GraphPropertiesTest,InferRestoreOpShape_WithTwoNodesShareSameOutput)1080 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
1081 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1082 Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
1083 DataType::DT_FLOAT);
1084 Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
1085 DataType::DT_FLOAT);
1086 Output filename =
1087 ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
1088 Output tensor_name =
1089 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
1090 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
1091 DataType::DT_FLOAT);
1092 Output init = ops::Assign(s.WithOpName("init"), var, restore);
1093 Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
1094
1095 GrapplerItem item;
1096 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1097 item.fetch.push_back("init");
1098 item.fetch.push_back("init2");
1099
1100 GraphProperties properties(item);
1101 TF_ASSERT_OK(properties.InferStatically(false));
1102
1103 const auto props = properties.GetOutputProperties("restore");
1104 const OpInfo::TensorProperties& prop = props[0];
1105 EXPECT_EQ(DT_FLOAT, prop.dtype());
1106 EXPECT_EQ("float: [128,256]", PropToString(prop));
1107 }
1108
TEST_F(GraphPropertiesTest,TensorAsShapesPropagation)1109 TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
1110 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1111 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
1112 Output a1 = ops::Identity(s.WithOpName("a1"), a);
1113 Output b = ops::Const(s.WithOpName("b"), 99, {});
1114 Output b1 = ops::Identity(s.WithOpName("b1"), b);
1115 Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
1116 Output c1 = ops::Identity(s.WithOpName("c1"), c);
1117
1118 GrapplerItem item;
1119 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1120 GraphProperties properties(item);
1121 TF_ASSERT_OK(properties.InferStatically(false));
1122
1123 // Check output shapes.
1124 EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
1125 EXPECT_EQ("int32: [2]",
1126 PropToString(properties.GetOutputProperties("a1")[0]));
1127 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
1128 EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
1129 EXPECT_EQ("int32: [4,4,4]",
1130 PropToString(properties.GetOutputProperties("c")[0]));
1131 EXPECT_EQ("int32: [4,4,4]",
1132 PropToString(properties.GetOutputProperties("c1")[0]));
1133
1134 // Check has_value.
1135 EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
1136 EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
1137 EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
1138 EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
1139 EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
1140 EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
1141 EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
1142 EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
1143 // Note that we propagate tensor value of only 1D vector and scalar.
1144 EXPECT_TRUE(properties.GetOutputProperties("c1")[0].has_value());
1145
1146 // Check values.
1147 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
1148 ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
1149 ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
1150 ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
1151 ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
1152 ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
1153 std::vector<int64_t> c_values;
1154 for (int i = 0; i < 4 * 4 * 4; i++) {
1155 c_values.push_back(1);
1156 }
1157 ExpectTensorValues({c_values},
1158 properties.GetOutputProperties("c")[0].value());
1159 ExpectTensorValues({c_values},
1160 properties.GetInputProperties("c1")[0].value());
1161 ExpectTensorValues({c_values},
1162 properties.GetOutputProperties("c1")[0].value());
1163 }
1164
TEST_F(GraphPropertiesTest,IdentityPassingShape)1165 TEST_F(GraphPropertiesTest, IdentityPassingShape) {
1166 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1167 Output a = ops::Const(s.WithOpName("a"), 5, {2});
1168 Output b = ops::Identity(s.WithOpName("b"), a);
1169 Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
1170 // Fill needs not only e's shape but also the value of e to figure out output
1171 // shape; hence, Identity op (b) should pass a's value as
1172 // output_tensors_as_shape.
1173 Output d = ops::Fill(s.WithOpName("fill"), b, c);
1174
1175 GrapplerItem item;
1176 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1177 GraphProperties properties(item);
1178 TF_ASSERT_OK(properties.InferStatically(false));
1179 const auto out_props = properties.GetOutputProperties("fill");
1180 const OpInfo::TensorProperties out_prop0 = out_props[0];
1181 EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
1182 }
1183
TEST_F(GraphPropertiesTest,SkippingValueInferenceForLargeTensors)1184 TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
1185 // When using aggressive_shape_inference, we run EvaluateNode() for
1186 // allowlisted ops and small input / output tensors. For instance, Fill op is
1187 // evaluated and produces output tensor value if output tensor size is small
1188 // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
1189 // This is to avoid wasting time and memory for producing huge tensors (e.g.,
1190 // initializing a large table using Fill.
1191 // Note that we do not propagate float const tensors with fractional values
1192 // (even if they're small); so this test should use integer values.
1193 {
1194 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1195 Output a = ops::Const(s.WithOpName("a"), 4, {2}); // 4x4
1196 Output b = ops::Const(s.WithOpName("const"), 7, {});
1197 // Shape described by a is small; expect output values of Fill op.
1198 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1199
1200 GrapplerItem item;
1201 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1202 GraphProperties properties(item);
1203 TF_ASSERT_OK(properties.InferStatically(
1204 /*assume_valid_feeds=*/false,
1205 /*aggressive_shape_inference=*/true,
1206 /*include_tensor_values=*/true));
1207 const auto out_props = properties.GetOutputProperties("fill");
1208 const OpInfo::TensorProperties out_prop0 = out_props[0];
1209 EXPECT_EQ("int32: [4,4]", PropToString(out_prop0));
1210 EXPECT_TRUE(out_prop0.has_value());
1211 }
1212 {
1213 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1214 Output a = ops::Const(s.WithOpName("a"), 1000, {4}); // 1000x1000x1000x1000
1215 Output b = ops::Const(s.WithOpName("const"), 7, {});
1216 // Shape described by a is huge; in that case we skip value inference.
1217 // Otherwise, it'd be too much overhead.
1218 Output c = ops::Fill(s.WithOpName("fill"), a, b);
1219
1220 GrapplerItem item;
1221 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1222 GraphProperties properties(item);
1223 TF_ASSERT_OK(properties.InferStatically(
1224 /*assume_valid_feeds=*/false,
1225 /*aggressive_shape_inference=*/true,
1226 /*include_tensor_values=*/true));
1227 const auto out_props = properties.GetOutputProperties("fill");
1228 const OpInfo::TensorProperties out_prop0 = out_props[0];
1229 EXPECT_EQ("int32: [1000,1000,1000,1000]", PropToString(out_prop0));
1230 EXPECT_FALSE(out_prop0.has_value());
1231 }
1232 }
1233
TEST_F(GraphPropertiesTest,PackWithConstInput)1234 TEST_F(GraphPropertiesTest, PackWithConstInput) {
1235 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1236 Output a = ops::Const(s.WithOpName("a"), 1, {});
1237 Output b = ops::Const(s.WithOpName("b"), 2, {});
1238 Output c = ops::Const(s.WithOpName("c"), 3, {});
1239 Output d = ops::Const(s.WithOpName("d"), 4, {});
1240 // Note ops::Stack instantiates Pack op.
1241 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1242 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1243 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1244 // Fill needs not only e's shape but also its value to figure out output
1245 // shape.
1246 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1247
1248 GrapplerItem item;
1249 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1250 GraphProperties properties(item);
1251 TF_ASSERT_OK(properties.InferStatically(false));
1252 const auto out_props = properties.GetOutputProperties("fill");
1253 const OpInfo::TensorProperties out_prop0 = out_props[0];
1254 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1255 }
1256
TEST_F(GraphPropertiesTest,RankOp)1257 TEST_F(GraphPropertiesTest, RankOp) {
1258 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1259 Output c = ops::Const(s.WithOpName("Const"), 1, {4, 4, 4});
1260 Output r = ops::Rank(s.WithOpName("Rank"), c);
1261 Output i = ops::Identity(s.WithOpName("Identity"), r);
1262
1263 GrapplerItem item;
1264 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1265 GraphProperties properties(item);
1266 TF_ASSERT_OK(properties.InferStatically(false));
1267 const auto rank_props = properties.GetOutputProperties("Rank");
1268 const OpInfo::TensorProperties rank_prop0 = rank_props[0];
1269 EXPECT_EQ("int32: []", PropToString(rank_prop0));
1270 EXPECT_TRUE(rank_prop0.has_value());
1271 ExpectTensorValues({3}, rank_prop0.value());
1272 const auto identity_props = properties.GetOutputProperties("Identity");
1273 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1274 EXPECT_EQ("int32: []", PropToString(identity_props0));
1275 EXPECT_TRUE(identity_props0.has_value());
1276 ExpectTensorValues({3}, identity_props0.value());
1277 }
1278
TEST_F(GraphPropertiesTest,SizeOp)1279 TEST_F(GraphPropertiesTest, SizeOp) {
1280 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1281 Output c = ops::Const(s.WithOpName("Const"), 1, {1, 2, 3, 4});
1282 Output r = ops::Size(s.WithOpName("Size"), c);
1283 Output i = ops::Identity(s.WithOpName("Identity"), r);
1284
1285 GrapplerItem item;
1286 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1287 GraphProperties properties(item);
1288 TF_ASSERT_OK(properties.InferStatically(false));
1289 const auto size_props = properties.GetOutputProperties("Size");
1290 const OpInfo::TensorProperties size_props0 = size_props[0];
1291 EXPECT_EQ("int32: []", PropToString(size_props0));
1292 EXPECT_TRUE(size_props0.has_value());
1293 ExpectTensorValues({24}, size_props0.value());
1294 const auto identity_props = properties.GetOutputProperties("Identity");
1295 const OpInfo::TensorProperties identity_props0 = identity_props[0];
1296 EXPECT_EQ("int32: []", PropToString(identity_props0));
1297 EXPECT_TRUE(identity_props0.has_value());
1298 ExpectTensorValues({24}, identity_props0.value());
1299 }
1300
TEST_F(GraphPropertiesTest,PackWithConstMinus1AndReshapes)1301 TEST_F(GraphPropertiesTest, PackWithConstMinus1AndReshapes) {
1302 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1303 Output shape0 = ops::Const(s.WithOpName("shape0"), 4, {});
1304 Output shape1 = ops::Const(s.WithOpName("shape1"), -1, {});
1305 Output pack = ops::Stack(s.WithOpName("pack"), {shape0, shape1});
1306 // pack is [2], with values {4, -1}.
1307
1308 Output x0_ = ops::Placeholder(s.WithOpName("x0_"), DataType::DT_FLOAT);
1309 Output x1_ = ops::Placeholder(s.WithOpName("x1_"), DataType::DT_FLOAT);
1310
1311 Output x0 = ops::Reshape(s.WithOpName("x0"), x0_, pack);
1312 Output x1 = ops::Reshape(s.WithOpName("x1"), x1_, pack);
1313 // Two unknown rank tensors (x0_ and x1_) are reshaped with pack {4, -1},
1314 // their output shapes would be [4, -1]. However, though we use the same
1315 // shape input to the Reshape ops, their output shapes can be different;
1316 // i.e., unknown dim values (-1) of x0 and x1 shapes are not necessarily
1317 // the same.
1318
1319 // if input to the Select ops. Note that s0 has a fully defined shape, while
1320 // s1 has unknown shape.
1321 Output s0 = ops::Const(s.WithOpName("s0"), true, {4, 16});
1322 Output s1 = ops::Placeholder(s.WithOpName("s1"), DataType::DT_BOOL);
1323
1324 Output y0 = ops::Placeholder(s.WithOpName("y0"), DataType::DT_FLOAT);
1325 Output y1 = ops::Placeholder(s.WithOpName("y1"), DataType::DT_FLOAT);
1326
1327 // We instantiate SelectV2, but will replace it with Select. The shape
1328 // inference function for Select links all inputs and outputs as they should
1329 // have the same shapes.
1330 Output z0 = ops::SelectV2(s.WithOpName("z0"), s0, x0, y0);
1331 Output z1 = ops::SelectV2(s.WithOpName("z1"), s1, x1, y1);
1332
1333 // For z0, as we know the shape of s0, symbolic shape manager in shape
1334 // inference will make the shapes of x0, y0, and z0 equal to the shape of s0,
1335 // which is [4, 16].
1336 // For z1, s0 and y1 are all unknown shapes, so we can infer they're [4, -1]
1337 // at best.
1338 // Note that x0 and x1 share the same shape input to the Reshape op, but
1339 // -1 in the shape input should not be treated as the same symoblic unknown
1340 // dim; it is merely a constant value -1 for identitying unknown dim for
1341 // Reshape operation.
1342
1343 GrapplerItem item;
1344 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1345
1346 // Replace SelectV2 op with Select op.
1347 for (int i = 0; i < item.graph.node_size(); ++i) {
1348 auto* node = item.graph.mutable_node(i);
1349 if (node->op() == "SelectV2") {
1350 node->set_op("Select");
1351 }
1352 }
1353
1354 GraphProperties properties(item);
1355 TF_ASSERT_OK(properties.InferStatically(false));
1356 for (const auto& node_name : {"x0", "y0", "z0"}) {
1357 const auto out_props = properties.GetOutputProperties(node_name);
1358 const OpInfo::TensorProperties out_prop0 = out_props[0];
1359 EXPECT_EQ("float: [4,16]", PropToString(out_prop0));
1360 }
1361 {
1362 const auto out_props = properties.GetOutputProperties("s0");
1363 const OpInfo::TensorProperties out_prop0 = out_props[0];
1364 EXPECT_EQ("bool: [4,16]", PropToString(out_prop0));
1365 }
1366
1367 for (const auto& node_name : {"x1", "y1", "z1"}) {
1368 const auto out_props = properties.GetOutputProperties(node_name);
1369 const OpInfo::TensorProperties out_prop0 = out_props[0];
1370 EXPECT_EQ("float: [4,-1]", PropToString(out_prop0));
1371 }
1372 // if input of Select can be either vector or the same shape to the
1373 // input/output; in this case, even if we know input and output are
1374 // [4, ?], we can't say it's [4, ?] or a vector; hence, it should be
1375 // unknown.
1376 {
1377 const auto out_props = properties.GetOutputProperties("s1");
1378 const OpInfo::TensorProperties out_prop0 = out_props[0];
1379 EXPECT_EQ("bool: ?", PropToString(out_prop0));
1380 }
1381 }
1382
TEST_F(GraphPropertiesTest,PackWithIdentityInput)1383 TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
1384 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1385 // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
1386 // from Const.
1387 // If output_tensors_as_shape is not set for those Shape ops or Pack op
1388 // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
1389 // hence, its output shape becomes unknown.
1390 Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
1391 Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
1392 Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
1393 Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
1394 Output a = ops::Identity(s.WithOpName("a"), a0);
1395 Output b = ops::Identity(s.WithOpName("b"), b0);
1396 Output c = ops::Identity(s.WithOpName("c"), c0);
1397 Output d = ops::Identity(s.WithOpName("d"), d0);
1398 // Note ops::Stack instantiates Pack op.
1399 Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
1400 // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
1401 Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
1402 // Fill needs not only e's shape but also its value to figure out output
1403 // shape.
1404 Output g = ops::Fill(s.WithOpName("fill"), e, f);
1405
1406 GrapplerItem item;
1407 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1408 GraphProperties properties(item);
1409 TF_ASSERT_OK(properties.InferStatically(false));
1410 const auto out_props = properties.GetOutputProperties("fill");
1411 const OpInfo::TensorProperties out_prop0 = out_props[0];
1412 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1413 }
1414
TEST_F(GraphPropertiesTest,FunctionWithDtResourceInput)1415 TEST_F(GraphPropertiesTest, FunctionWithDtResourceInput) {
1416 // Function ops may have DT_RESOURCE input; if not properly set shapes and
1417 // dtypes through the DT_RESOURCE _Arg, we cannot infer output shapes of such
1418 // function ops.
1419 GrapplerItem item;
1420 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1421 "function_with_dt_resource_input.pbtxt");
1422 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1423
1424 // This graph evaluates FunctionWithDtResourceInput with two inputs:
1425 // x [DT_FLOAT Const],
1426 // _Arg [DT_RESOURCE _Arg]
1427 // and has two outputs:
1428 // z1 = x + _Arg
1429 // z2 = x
1430 {
1431 GraphProperties properties(item);
1432 TF_ASSERT_OK(properties.InferStatically(false));
1433 const auto out_props =
1434 properties.GetOutputProperties("FunctionWithDtResourceInput");
1435 EXPECT_EQ(out_props.size(), 2);
1436 const OpInfo::TensorProperties out_prop0 = out_props[0];
1437 EXPECT_EQ("float: [1,3]", PropToString(out_prop0));
1438 const OpInfo::TensorProperties out_prop1 = out_props[1];
1439 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1440 }
1441
1442 {
1443 // Delete _handle_dtypes and _handle_shapes attr for the input _Arg node.
1444 for (int i = 0; i < item.graph.node_size(); i++) {
1445 auto* node = item.graph.mutable_node(i);
1446 if (node->name() == "y") { // _Arg node with DT_RESOURCE
1447 node->mutable_attr()->erase("_handle_dtypes");
1448 node->mutable_attr()->erase("_handle_shapes");
1449 break;
1450 }
1451 }
1452 // We cannot infer the function output shape correctly without those attr,
1453 // but still it shouldn't fail; also, there can be some shapes we can
1454 // infer in such a case. In this test graph,
1455 // z2 of the function node just returns x input; hence, even if _Arg's shape
1456 // cannot be inferred, we can infer z2 output shape.
1457 GraphProperties properties(item);
1458 TF_ASSERT_OK(properties.InferStatically(false));
1459 const auto out_props =
1460 properties.GetOutputProperties("FunctionWithDtResourceInput");
1461 EXPECT_EQ(out_props.size(), 2);
1462 const OpInfo::TensorProperties out_prop0 = out_props[0];
1463 // Without shape and dtype attr, we don't know _Arg's shape; hence, unknown
1464 // for x + _Arg.
1465 EXPECT_EQ("float: ?", PropToString(out_prop0));
1466 // The 2nd output is just x, so even if _Arg's shape is unknown, we can
1467 // infer this output shape.
1468 const OpInfo::TensorProperties out_prop1 = out_props[1];
1469 EXPECT_EQ("float: [1,3]", PropToString(out_prop1));
1470 }
1471 }
1472
TEST_F(GraphPropertiesTest,FunctionWithConstInput)1473 TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
1474 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1475 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1476 Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
1477 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1478 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1479 s.graph()->op_registry());
1480 tensorflow::Node* func_op;
1481 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1482 auto _value = tensorflow::ops::AsNodeOut(s, value);
1483 TF_ASSERT_OK(
1484 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1485 GrapplerItem item;
1486 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1487
1488 GraphProperties properties(item);
1489 TF_ASSERT_OK(properties.InferStatically(false));
1490 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1491 const OpInfo::TensorProperties out_prop0 = out_props[0];
1492 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1493 }
1494
TEST_F(GraphPropertiesTest,FunctionWithIdentityOfConstInput)1495 TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
1496 // Same to FunctionWithConstInput, but function inputs are Identity of Const,
1497 // so tensor shapes, not tensor value, should be used as Const input to
1498 // function.
1499 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1500 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib_));
1501 Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
1502 Output shape = ops::Identity(s.WithOpName("shape"), shape_);
1503 Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
1504 auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
1505 s.graph()->op_registry());
1506 tensorflow::Node* func_op;
1507 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1508 auto _value = tensorflow::ops::AsNodeOut(s, value);
1509 TF_ASSERT_OK(
1510 builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
1511 GrapplerItem item;
1512 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1513
1514 GraphProperties properties(item);
1515 TF_ASSERT_OK(properties.InferStatically(false));
1516 const auto out_props = properties.GetOutputProperties("MyFillFunc");
1517 const OpInfo::TensorProperties out_prop0 = out_props[0];
1518 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
1519 }
1520
TEST_F(GraphPropertiesTest,FunctionReturnTensorValue)1521 TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
1522 FunctionDefLibrary library;
1523 *library.add_function() = FunctionDefHelper::Create(
1524 "MyFunc", // Name
1525 {"x: int32"}, // Inputs
1526 {"out: int32"}, // Outputs
1527 {}, // Attrs
1528 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1529 {{"out", "a:output:0"}}); // Returns
1530 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1531 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1532
1533 // MyFunc takes Const (shape) and passes it with Identity. Expect function
1534 // output has the same shape as well as value (output_tensors_as_shape) as
1535 // input Const tensor.
1536 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1537 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1538 auto builder =
1539 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1540 tensorflow::Node* func_op;
1541 TF_ASSERT_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
1542
1543 GrapplerItem item;
1544 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1545
1546 GraphProperties properties(item);
1547 TF_ASSERT_OK(properties.InferStatically(true));
1548 const auto out_props = properties.GetOutputProperties("MyFunc");
1549 const OpInfo::TensorProperties out_prop0 = out_props[0];
1550 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1551 EXPECT_TRUE(out_prop0.has_value());
1552 ExpectTensorValues({5, 7}, out_prop0.value());
1553 ExpectTensorValues({5, 7},
1554 properties.GetInputProperties("MyFunc")[0].value());
1555 }
1556
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValue)1557 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
1558 FunctionDefLibrary library;
1559 // Function that adds two input values.
1560 *library.add_function() = FunctionDefHelper::Create(
1561 "MyFunc", // Name
1562 {"x: int32", "y: int32"}, // Inputs
1563 {"out: int32"}, // Outputs
1564 {}, // Attrs
1565 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
1566 {{"out", "a:z:0"}}); // Returns
1567 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1568 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1569
1570 Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
1571 auto _shape = tensorflow::ops::AsNodeOut(s, shape);
1572 auto builder =
1573 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1574 tensorflow::Node* func_op;
1575 TF_ASSERT_OK(
1576 builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
1577
1578 GrapplerItem item;
1579 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1580 {
1581 GraphProperties properties(item);
1582 // Without aggressive_shape_inference, the internal function does not
1583 // evaluate output value.
1584 TF_ASSERT_OK(properties.InferStatically(
1585 /*assume_valid_feeds=*/true,
1586 /*aggressive_shape_inference=*/false,
1587 /*include_tensor_values=*/true));
1588 const auto out_props = properties.GetOutputProperties("MyFunc");
1589 const OpInfo::TensorProperties out_prop0 = out_props[0];
1590 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1591 EXPECT_FALSE(out_prop0.has_value());
1592 }
1593
1594 {
1595 GraphProperties properties(item);
1596 // With aggressive_shape_inference, output value is evaluated.
1597 TF_ASSERT_OK(properties.InferStatically(
1598 /*assume_valid_feeds=*/true,
1599 /*aggressive_shape_inference=*/true,
1600 /*include_tensor_values=*/true));
1601 const auto out_props = properties.GetOutputProperties("MyFunc");
1602 const OpInfo::TensorProperties out_prop0 = out_props[0];
1603 EXPECT_EQ("int32: [2]", PropToString(out_prop0));
1604 EXPECT_TRUE(out_prop0.has_value());
1605
1606 ExpectTensorValues({10, 14}, out_prop0.value());
1607 ExpectTensorValues({5, 7},
1608 properties.GetInputProperties("MyFunc")[0].value());
1609 ExpectTensorValues({5, 7},
1610 properties.GetInputProperties("MyFunc")[1].value());
1611 }
1612 }
1613
1614 // Same as the above, but float values; also, one of the function input is
1615 // Identity of Const.
TEST_F(GraphPropertiesTest,ArithmeticFunctionReturnTensorValueFloat)1616 TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValueFloat) {
1617 FunctionDefLibrary library;
1618 // Function that adds two input values.
1619 *library.add_function() = FunctionDefHelper::Create(
1620 "MyFunc", // Name
1621 {"x: float", "y: float"}, // Inputs
1622 {"out: float"}, // Outputs
1623 {}, // Attrs
1624 {{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1625 {{"out", "a:z:0"}}); // Returns
1626 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1627 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1628
1629 Output x1 = ops::Const(s.WithOpName("x1"), {5.0f, 7.0f}, {2});
1630 Output x2 = ops::Identity(s.WithOpName("x1"), x1);
1631 auto _x1 = tensorflow::ops::AsNodeOut(s, x1);
1632 auto _x2 = tensorflow::ops::AsNodeOut(s, x2);
1633 auto builder =
1634 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1635 tensorflow::Node* func_op;
1636 TF_ASSERT_OK(builder.Input(_x1).Input(_x2).Finalize(s.graph(), &func_op));
1637
1638 GrapplerItem item;
1639 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1640 {
1641 GraphProperties properties(item);
1642 // Without aggressive_shape_inference, the internal function does not
1643 // evaluate output value.
1644 TF_ASSERT_OK(properties.InferStatically(
1645 /*assume_valid_feeds=*/true,
1646 /*aggressive_shape_inference=*/false,
1647 /*include_tensor_values=*/true));
1648 const auto out_props = properties.GetOutputProperties("MyFunc");
1649 const OpInfo::TensorProperties out_prop0 = out_props[0];
1650 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1651 EXPECT_FALSE(out_prop0.has_value());
1652 }
1653
1654 {
1655 GraphProperties properties(item);
1656 // With aggressive_shape_inference, output value is evaluated.
1657 TF_ASSERT_OK(properties.InferStatically(
1658 /*assume_valid_feeds=*/true,
1659 /*aggressive_shape_inference=*/true,
1660 /*include_tensor_values=*/true));
1661 const auto out_props = properties.GetOutputProperties("MyFunc");
1662 const OpInfo::TensorProperties out_prop0 = out_props[0];
1663 EXPECT_EQ("float: [2]", PropToString(out_prop0));
1664 EXPECT_TRUE(out_prop0.has_value());
1665
1666 ExpectFloatTensorValues({10.0, 14.0}, out_prop0.value());
1667 ExpectFloatTensorValues({5.0, 7.0},
1668 properties.GetInputProperties("MyFunc")[0].value());
1669 ExpectFloatTensorValues({5.0, 7.0},
1670 properties.GetInputProperties("MyFunc")[1].value());
1671 }
1672 }
1673
TEST_F(GraphPropertiesTest,FunctionWithScalarInput)1674 TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
1675 // Create graph with a function that takes a scalar value so that we use
1676 // Placeholder with scalar as for input to the function shape inference.
1677 // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
1678 // the input; all tensors are scalars.
1679 FunctionDefLibrary library;
1680 *library.add_function() = FunctionDefHelper::Create(
1681 "MyFunc", // Name
1682 {"x: float"}, // Inputs
1683 {"out: float"}, // Outputs
1684 {}, // Attrs
1685 {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
1686 {{"out", "a:output:0"}}); // Returns
1687 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1688 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(library));
1689 Output placeholder =
1690 ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
1691 ops::Placeholder::Shape(TensorShape({})));
1692 Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
1693 auto _identity = tensorflow::ops::AsNodeOut(s, identity);
1694 auto builder =
1695 tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
1696 tensorflow::Node* func_op;
1697 TF_ASSERT_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
1698 GrapplerItem item;
1699 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1700
1701 // Tensorflow version < 21 infers output shape of Placeholder with empty shape
1702 // as unknown, instead of scalar.
1703 EXPECT_GT(item.graph.versions().producer(), 21);
1704
1705 // MyFunc output shouldn't be unknown rank.
1706 GraphProperties properties(item);
1707 TF_ASSERT_OK(properties.InferStatically(true));
1708 const auto out_props = properties.GetOutputProperties("MyFunc");
1709 const OpInfo::TensorProperties out_prop0 = out_props[0];
1710 EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
1711 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1712 }
1713
TEST_F(GraphPropertiesTest,SimpleFunctionStaticShapeInference)1714 TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
1715 // Test graph produced in python using:
1716 /*
1717 @function.Defun(*[tf.float32] * 2, noinline=True)
1718 def MyAdd(x, y):
1719 return tf.add(x,y)
1720
1721 with tf.Graph().as_default():
1722 x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1723 y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1724 z = MyAdd(x, y)
1725 z = MyAdd(x, z)
1726 */
1727 GrapplerItem item;
1728 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1729 "simple_function.pbtxt");
1730 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1731 GraphProperties properties(item);
1732 TF_ASSERT_OK(properties.InferStatically(false));
1733 const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
1734 const OpInfo::TensorProperties& out_prop = out_props[0];
1735 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1736 EXPECT_FALSE(out_prop.shape().unknown_rank());
1737 EXPECT_EQ(2, out_prop.shape().dim_size());
1738 EXPECT_EQ(1, out_prop.shape().dim(0).size());
1739 EXPECT_EQ(2, out_prop.shape().dim(1).size());
1740
1741 const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
1742 EXPECT_EQ(2, in_props.size());
1743
1744 const OpInfo::TensorProperties& in_prop = in_props[0];
1745 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1746
1747 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1748 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1749 }
1750
TEST_F(GraphPropertiesTest,LargeFunctionStaticShapeInference)1751 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
1752 GrapplerItem item;
1753 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1754 "large_function_graph.pbtxt");
1755 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1756 GraphProperties properties(item);
1757 TF_ASSERT_OK(properties.InferStatically(false));
1758
1759 const auto out_props = properties.GetOutputProperties("y0");
1760 EXPECT_EQ(2, out_props.size());
1761
1762 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1763 EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
1764
1765 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1766 EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
1767
1768 const auto in_props = properties.GetInputProperties("y0");
1769 EXPECT_EQ(4, in_props.size());
1770
1771 const OpInfo::TensorProperties& in_prop0 = in_props[0];
1772 EXPECT_EQ("float: [64]", PropToString(in_prop0));
1773
1774 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1775 EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
1776
1777 const OpInfo::TensorProperties& in_prop2 = in_props[2];
1778 EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
1779
1780 const OpInfo::TensorProperties& in_prop3 = in_props[3];
1781 EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
1782 }
1783
TEST_F(GraphPropertiesTest,LargeFunctionWithMultipleOutputs)1784 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
1785 // Test graph produced in python using:
1786 /*
1787 @function.Defun(noinline=True)
1788 def MyFunc():
1789 @function.Defun(*[tf.float32] * 2)
1790 def Cond(n, unused_x):
1791 return n > 0
1792
1793 @function.Defun(*[tf.float32] * 2)
1794 def Body(n, x):
1795 return n - 1, x + n
1796
1797 i = tf.constant(10)
1798 return functional_ops.While([i, 0.], Cond, Body)
1799
1800 with tf.Graph().as_default():
1801 z = MyFunc()
1802 */
1803 GrapplerItem item;
1804 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1805 "function_functional_while.pbtxt");
1806 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1807 GraphProperties properties(item);
1808 TF_ASSERT_OK(properties.InferStatically(false));
1809
1810 const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
1811 EXPECT_EQ(2, out_props.size());
1812
1813 const OpInfo::TensorProperties& out_prop0 = out_props[0];
1814 EXPECT_EQ(DT_INT32, out_prop0.dtype());
1815 EXPECT_FALSE(out_prop0.shape().unknown_rank());
1816
1817 const OpInfo::TensorProperties& out_prop1 = out_props[1];
1818 EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
1819 EXPECT_FALSE(out_prop1.shape().unknown_rank());
1820 }
1821
TEST_F(GraphPropertiesTest,FunctionWithErrorStaticShapeInference)1822 TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
1823 GrapplerItem item;
1824 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1825 "function_error.pbtxt");
1826 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1827 GraphProperties properties(item);
1828 TF_ASSERT_OK(properties.InferStatically(false));
1829
1830 const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
1831 EXPECT_EQ(1, out_props.size());
1832
1833 const OpInfo::TensorProperties& out_prop = out_props[0];
1834 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1835 EXPECT_TRUE(out_prop.shape().unknown_rank());
1836
1837 const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
1838 EXPECT_EQ(2, in_props.size());
1839
1840 const OpInfo::TensorProperties& in_prop = in_props[0];
1841 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1842
1843 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1844 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1845 }
1846
TEST_F(GraphPropertiesTest,FunctionSwitchStaticShapeInference)1847 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
1848 // Test graph produced in python using:
1849 /*
1850 @function.Defun(*[tf.float32] * 2, noinline=True)
1851 def MyAdd(x, y):
1852 return tf.add(x, y)
1853
1854 with tf.Graph().as_default():
1855 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1856 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1857 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1858 z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
1859 */
1860 GrapplerItem item;
1861 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1862 "function_switch.pbtxt");
1863 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1864 GraphProperties properties(item);
1865 TF_ASSERT_OK(properties.InferStatically(false));
1866 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1867 const OpInfo::TensorProperties& out_prop = out_props[0];
1868 EXPECT_EQ(DT_FLOAT, out_prop.dtype());
1869 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1870
1871 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1872 EXPECT_EQ(2, in_props.size());
1873
1874 const OpInfo::TensorProperties& in_prop = in_props[0];
1875 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1876
1877 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1878 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1879 }
1880
TEST_F(GraphPropertiesTest,FunctionSwitch2StaticShapeInference)1881 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
1882 // Test graph produced in python using:
1883 /*
1884 @function.Defun(*[tf.float32] * 2, noinline=True)
1885 def MyAdd(x, y):
1886 return tf.add(x, y)
1887
1888 with tf.Graph().as_default():
1889 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1890 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1891 z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1892 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1893 */
1894 GrapplerItem item;
1895 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1896 "function_switch_2.pbtxt");
1897 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1898 GraphProperties properties(item);
1899 TF_ASSERT_OK(properties.InferStatically(false));
1900 const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
1901 const OpInfo::TensorProperties& out_prop = out_props[0];
1902 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1903
1904 const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
1905 EXPECT_EQ(2, in_props.size());
1906
1907 const OpInfo::TensorProperties& in_prop = in_props[0];
1908 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1909
1910 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1911 EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
1912 }
1913
TEST_F(GraphPropertiesTest,FunctionSwitchShapesStaticShapeInference)1914 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
1915 // Test graph produced in python using:
1916 /*
1917 @function.Defun(*[tf.float32] * 2, noinline=True)
1918 def MyAdd(x, y):
1919 a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1920 b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1921 c = tf.add(x, a)
1922 d = tf.add(y, b)
1923 return c
1924
1925 with tf.Graph().as_default():
1926 x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1927 y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
1928 z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
1929 z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
1930 */
1931 GrapplerItem item;
1932 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
1933 "function_switch_shapes.pbtxt");
1934 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
1935 GraphProperties properties(item);
1936 TF_ASSERT_OK(properties.InferStatically(false));
1937 const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
1938 const OpInfo::TensorProperties& out_prop = out_props[0];
1939 EXPECT_EQ("float: [1,2]", PropToString(out_prop));
1940
1941 const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
1942 EXPECT_EQ(2, in_props.size());
1943
1944 const OpInfo::TensorProperties& in_prop = in_props[0];
1945 EXPECT_EQ("float: [1,2]", PropToString(in_prop));
1946
1947 const OpInfo::TensorProperties& in_prop1 = in_props[1];
1948 EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
1949 }
1950
TEST_F(GraphPropertiesTest,SymbolicShapes)1951 TEST_F(GraphPropertiesTest, SymbolicShapes) {
1952 // Build a simple graph with placeholders of unknown dimensions. These
1953 // dimensions will be encoded symbolically.
1954 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1955
1956 Output a =
1957 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
1958 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
1959 Output b =
1960 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
1961 ops::Placeholder::Shape(PartialTensorShape({-1})));
1962 Output c = ops::Identity(s.WithOpName("c"), a);
1963 Output d = ops::Identity(s.WithOpName("d"), b);
1964 Output e = ops::Add(s.WithOpName("e"), c, d);
1965 Output f = ops::Add(s.WithOpName("f"), a, c);
1966
1967 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
1968 Output g = ops::Shape(s.WithOpName("g"), c);
1969 Output h = ops::Fill(s.WithOpName("h"), g, zero);
1970 Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
1971 Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
1972
1973 GrapplerItem item;
1974 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
1975
1976 GraphProperties properties(item);
1977 TF_ASSERT_OK(properties.InferStatically(false));
1978 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
1979 const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
1980 EXPECT_EQ(2, shape_a.dim_size());
1981 EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
1982 EXPECT_GE(-2, shape_a.dim(0).size());
1983 EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
1984 EXPECT_GE(-2, shape_a.dim(1).size());
1985 EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
1986
1987 PartialTensorShape shape(shape_a);
1988 EXPECT_FALSE(shape.IsFullyDefined());
1989 EXPECT_FALSE(shape.unknown_rank());
1990
1991 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
1992 const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
1993 EXPECT_EQ(1, shape_b.dim_size());
1994 EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
1995 EXPECT_GE(-2, shape_b.dim(0).size());
1996 EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
1997 EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
1998
1999 const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
2000 ASSERT_EQ(2, shape_e.dim_size());
2001 EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
2002 EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
2003 EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
2004
2005 const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
2006 ASSERT_EQ(2, shape_f.dim_size());
2007 EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
2008 EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
2009
2010 const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
2011 ASSERT_EQ(2, shape_f.dim_size());
2012 EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
2013 EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
2014
2015 const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
2016 ASSERT_EQ(1, shape_j.dim_size());
2017 EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
2018 }
2019
TEST_F(GraphPropertiesTest,DoNotValidateColocationConstraints)2020 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
2021 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2022 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
2023 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
2024 Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
2025 GrapplerItem item;
2026 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2027 // Create a graph with node a removed (say by some graph optimization
2028 // pass), noting that node c is colocated with a. This is fine as it
2029 // is in the late stage of graph execution, the colocation constraints have
2030 // been validated previously and the device placement of nodes has completed.
2031 GraphDef optimized_graph;
2032 for (const auto& node : item.graph.node()) {
2033 if (node.name() != "a") {
2034 *optimized_graph.add_node() = node;
2035 }
2036 }
2037 item.graph.Swap(&optimized_graph);
2038 GraphProperties properties(item);
2039 // This function should return OK, since it doesn't validate the colocation
2040 // constraints internally.
2041 TF_EXPECT_OK(properties.InferStatically(false));
2042 }
2043
TEST_F(GraphPropertiesTest,ShapeTracking)2044 TEST_F(GraphPropertiesTest, ShapeTracking) {
2045 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2046 Output a =
2047 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2048 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2049 Output b =
2050 ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
2051 ops::Placeholder::Shape(PartialTensorShape({-1})));
2052 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2053 auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
2054 Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
2055 Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
2056
2057 GrapplerItem item;
2058 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2059
2060 GraphProperties properties(item);
2061 TF_ASSERT_OK(properties.InferStatically(false));
2062 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2063 const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
2064 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2065 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2066 EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
2067 EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
2068 }
2069
TEST_F(GraphPropertiesTest,FedNodes)2070 TEST_F(GraphPropertiesTest, FedNodes) {
2071 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
2072 cluster_->GetDeviceNames());
2073 GrapplerItem item;
2074 CHECK(fake_input.NextItem(&item));
2075
2076 {
2077 // Conservative shape analysis: the shape of fed ports should be unknown
2078 GraphProperties properties(item);
2079 Status s = properties.InferStatically(false);
2080 TF_ASSERT_OK(s);
2081 for (const auto& node : item.graph.node()) {
2082 if (node.op() == "Const") {
2083 continue;
2084 }
2085 const auto in_props = properties.GetInputProperties(node.name());
2086 EXPECT_EQ(1, in_props.size());
2087 const OpInfo::TensorProperties& in_prop = in_props[0];
2088 const auto out_props = properties.GetOutputProperties(node.name());
2089 EXPECT_EQ(1, out_props.size());
2090 const OpInfo::TensorProperties& out_prop = out_props[0];
2091
2092 if (node.name() == "x") {
2093 // x is fed: its input should have a known shape, while its output
2094 // doesn't
2095 EXPECT_FALSE(in_prop.shape().unknown_rank());
2096 EXPECT_EQ(1, in_prop.shape().dim_size());
2097 EXPECT_EQ(2, in_prop.shape().dim(0).size());
2098 EXPECT_TRUE(out_prop.shape().unknown_rank());
2099 } else if (node.op() == "Square" || node.op() == "AddN") {
2100 // These nodes are in the fanout of x: their shapes should be unknown.
2101 EXPECT_TRUE(in_prop.shape().unknown_rank());
2102 EXPECT_TRUE(out_prop.shape().unknown_rank());
2103 }
2104 }
2105 }
2106 {
2107 // Optimistic shape analysis: the shape of fed ports should be derived from
2108 // the shape of the fanin.
2109 GraphProperties properties(item);
2110 Status s = properties.InferStatically(true);
2111 TF_ASSERT_OK(s);
2112 for (const auto& node : item.graph.node()) {
2113 if (node.op() == "Square" || node.op() == "AddN") {
2114 const auto in_props = properties.GetInputProperties(node.name());
2115 EXPECT_EQ(1, in_props.size());
2116 const OpInfo::TensorProperties& in_prop = in_props[0];
2117 EXPECT_EQ(DT_FLOAT, in_prop.dtype());
2118 EXPECT_FALSE(in_prop.shape().unknown_rank());
2119 EXPECT_EQ(2, in_prop.shape().dim_size());
2120 const auto out_props = properties.GetOutputProperties(node.name());
2121 EXPECT_EQ(1, out_props.size());
2122 const OpInfo::TensorProperties& out_prop = out_props[0];
2123 EXPECT_EQ(in_prop.dtype(), out_prop.dtype());
2124 EXPECT_EQ(in_prop.shape().DebugString(),
2125 out_prop.shape().DebugString());
2126 }
2127 }
2128 }
2129 }
2130
TEST_F(GraphPropertiesTest,Performance)2131 TEST_F(GraphPropertiesTest, Performance) {
2132 // Load a large graph with many nested loops to make sure we can infer shapes
2133 // quickly.
2134 GrapplerItem item;
2135 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
2136 "large_graph.pbtxt.html");
2137 TF_ASSERT_OK(ReadGraphDefFromFile(filename, &item.graph));
2138 TF_ASSERT_OK(AddDefaultAttrsToGraphDef(
2139 &item.graph,
2140 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
2141 true));
2142
2143 GraphProperties properties(item);
2144 TF_ASSERT_OK(properties.InferStatically(false));
2145 }
2146
TEST_F(GraphPropertiesTest,StridedSlicesOfShapes)2147 TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
2148 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2149 Output a =
2150 ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2151 ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2152 auto shp = ops::Shape(s.WithOpName("shape"), {a});
2153
2154 Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
2155 Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
2156 Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
2157
2158 Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
2159 Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
2160
2161 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
2162 Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
2163 Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
2164
2165 GrapplerItem item;
2166 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2167
2168 GraphProperties properties(item);
2169 TF_ASSERT_OK(properties.InferStatically(false));
2170 const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
2171 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
2172 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
2173 EXPECT_EQ(2, shape_a.dim_size());
2174 EXPECT_EQ(1, shape_o1.dim_size());
2175 EXPECT_EQ(1, shape_o2.dim_size());
2176 EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
2177 EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
2178 }
2179
TEST_F(GraphPropertiesTest,StridedSliceOfShapeWithShrinkAxisMask)2180 TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
2181 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2182 Output placeholder =
2183 ops::Placeholder(scope.WithOpName("input_placeholder"), DT_FLOAT,
2184 ops::Placeholder::Shape(TensorShape({5, 480, 40, 1})));
2185 auto input_shape = ops::Shape(scope.WithOpName("input_shape"), placeholder);
2186
2187 Output begin = ops::Const(scope.WithOpName("begin"), {0}, {1});
2188 Output end = ops::Const(scope.WithOpName("end"), {3}, {1});
2189 Output stride = ops::Const(scope.WithOpName("stride"), {1}, {1});
2190
2191 Output slice =
2192 ops::StridedSlice(scope.WithOpName("slice"), input_shape, begin, end,
2193 stride, ops::StridedSlice::ShrinkAxisMask(1));
2194
2195 GrapplerItem item;
2196 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2197
2198 // Without aggressive shape inference, it cannot infer output value of
2199 // StridedSlice with ShrinkAxisMask.
2200 {
2201 GraphProperties properties(item);
2202 TF_ASSERT_OK(properties.InferStatically(
2203 /*assume_valid_feeds=*/false,
2204 /*aggressive_shape_inference=*/false,
2205 /*include_tensor_values=*/true));
2206 EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
2207 }
2208
2209 // InferStatically with aggressive shape inference can infer output value of
2210 // StridedSlice with ShrinkAxisMask.
2211 {
2212 GraphProperties properties(item);
2213 TF_ASSERT_OK(properties.InferStatically(
2214 /*assume_valid_feeds=*/false,
2215 /*aggressive_shape_inference=*/true,
2216 /*include_tensor_values=*/true));
2217 EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
2218 const auto slice_value =
2219 properties.GetOutputProperties("slice").at(0).value();
2220 ExpectTensorValues({5}, slice_value);
2221 }
2222 }
2223
TEST_F(GraphPropertiesTest,ValuePropagationThroughArithmeticOps)2224 TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
2225 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2226 Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
2227 Output b = ops::Const(s.WithOpName("b"), {8, 8}, {2});
2228 Output c = ops::Const(s.WithOpName("c"), {2, 2}, {2});
2229
2230 Output a1 = ops::OnesLike(s.WithOpName("a1"), a);
2231 Output a_plus_one = ops::Add(s.WithOpName("a_plus_one"), a, a1);
2232 Output a_plus_a = ops::Add(s.WithOpName("a_plus_a"), a, a);
2233 Output b_plus_2a = ops::Add(s.WithOpName("b_plus_2a"), b, a_plus_a);
2234 Output c_plus_b_plus_2a =
2235 ops::Add(s.WithOpName("c_plus_b_plus_2a"), c, b_plus_2a);
2236
2237 GrapplerItem item;
2238 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2239 GraphProperties properties(item);
2240 TF_ASSERT_OK(properties.InferStatically(
2241 /*assume_valid_feeds=*/false,
2242 /*aggressive_shape_inference=*/true,
2243 /*include_tensor_values=*/true));
2244
2245 // Check output shapes and values.
2246 const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
2247 EXPECT_EQ("int32: [2]", PropToString(a_plus_one_prop));
2248 EXPECT_TRUE(a_plus_one_prop.has_value());
2249 ExpectTensorValues({6, 8}, a_plus_one_prop.value());
2250
2251 const auto& a_plus_a_prop = properties.GetOutputProperties("a_plus_a")[0];
2252 EXPECT_EQ("int32: [2]", PropToString(a_plus_a_prop));
2253 EXPECT_TRUE(a_plus_a_prop.has_value());
2254 ExpectTensorValues({10, 14}, a_plus_a_prop.value());
2255
2256 const auto& b_plus_2a_prop = properties.GetOutputProperties("b_plus_2a")[0];
2257 EXPECT_EQ("int32: [2]", PropToString(b_plus_2a_prop));
2258 EXPECT_TRUE(b_plus_2a_prop.has_value());
2259 ExpectTensorValues({18, 22}, b_plus_2a_prop.value());
2260
2261 const auto& c_plus_b_plus_2a_prop =
2262 properties.GetOutputProperties("c_plus_b_plus_2a")[0];
2263 EXPECT_EQ("int32: [2]", PropToString(c_plus_b_plus_2a_prop));
2264 EXPECT_TRUE(c_plus_b_plus_2a_prop.has_value());
2265 ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
2266 }
2267
TEST_F(GraphPropertiesTest,ShapeAnnotation)2268 TEST_F(GraphPropertiesTest, ShapeAnnotation) {
2269 GrapplerItem item;
2270 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2271 .Attr("dtype", DT_FLOAT)
2272 .Attr("shape", PartialTensorShape({-1, -1}))
2273 .Finalize(item.graph.add_node()));
2274 // Annotate shapes.
2275 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2276 .Attr("dtype", DT_FLOAT)
2277 .Attr("_same_output_for_iterations", true)
2278 .Attr("_output_shape_vector", {TensorShape({5, 7})})
2279 .Input("Input", 0, DT_FLOAT)
2280 .Finalize(item.graph.add_node()));
2281 {
2282 GraphProperties properties(item);
2283 // Without aggressive_shape_inference, ignore annotated information.
2284 TF_ASSERT_OK(properties.InferStatically(
2285 /*assume_valid_feeds=*/false,
2286 /*aggressive_shape_inference=*/false,
2287 /*include_tensor_values=*/true));
2288 const auto props = properties.GetOutputProperties("Identity");
2289 EXPECT_EQ(1, props.size());
2290 const OpInfo::TensorProperties& prop = props[0];
2291 EXPECT_EQ(DT_FLOAT, prop.dtype());
2292 EXPECT_EQ(2, prop.shape().dim_size());
2293 // Get unknown shapes without using annotated information.
2294 EXPECT_EQ("float: [-1,-1]", PropToString(prop));
2295 }
2296 {
2297 GraphProperties properties(item);
2298 // Use annotated information.
2299 TF_ASSERT_OK(properties.InferStatically(
2300 /*assume_valid_feeds=*/false,
2301 /*aggressive_shape_inference=*/true,
2302 /*include_tensor_values=*/true));
2303 const auto props = properties.GetOutputProperties("Identity");
2304 EXPECT_EQ(1, props.size());
2305 const OpInfo::TensorProperties& prop = props[0];
2306 EXPECT_EQ(DT_FLOAT, prop.dtype());
2307 EXPECT_EQ(2, prop.shape().dim_size());
2308 // Update output shape using annotated shapes.
2309 EXPECT_EQ("float: [5,7]", PropToString(prop));
2310 }
2311 }
2312
TEST_F(GraphPropertiesTest,ShapeAnnotationWithCompatibleShapes)2313 TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
2314 GrapplerItem item;
2315 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2316 .Attr("dtype", DT_FLOAT)
2317 .Attr("shape", PartialTensorShape({-1, 100}))
2318 .Finalize(item.graph.add_node()));
2319 // Annotate shapes.
2320 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2321 .Attr("dtype", DT_FLOAT)
2322 .Attr("_same_output_for_iterations", true)
2323 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2324 .Input("Input", 0, DT_FLOAT)
2325 .Finalize(item.graph.add_node()));
2326 GraphProperties properties(item);
2327 // Use annotated information.
2328 TF_ASSERT_OK(properties.InferStatically(
2329 /*assume_valid_feeds=*/false,
2330 /*aggressive_shape_inference=*/true,
2331 /*include_tensor_values=*/true));
2332 const auto props = properties.GetOutputProperties("Identity");
2333 EXPECT_EQ(1, props.size());
2334 const OpInfo::TensorProperties& prop = props[0];
2335 EXPECT_EQ(DT_FLOAT, prop.dtype());
2336 EXPECT_EQ(2, prop.shape().dim_size());
2337 // Compatible shapes. Update output shape using annotated shapes.
2338 EXPECT_EQ("float: [10,100]", PropToString(prop));
2339 }
2340
TEST_F(GraphPropertiesTest,ShapeAnnotationWithIncompatibleShapes)2341 TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
2342 GrapplerItem item;
2343 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2344 .Attr("dtype", DT_FLOAT)
2345 .Attr("shape", PartialTensorShape({-1, 100}))
2346 .Finalize(item.graph.add_node()));
2347 // Annotate shapes.
2348 TF_ASSERT_OK(NodeDefBuilder("Identity", "Identity")
2349 .Attr("dtype", DT_FLOAT)
2350 .Attr("_same_output_for_iterations", true)
2351 .Attr("_output_shape_vector", {TensorShape({10, 10})})
2352 .Input("Input", 0, DT_FLOAT)
2353 .Finalize(item.graph.add_node()));
2354 GraphProperties properties(item);
2355 // Use annotated information.
2356 TF_ASSERT_OK(properties.InferStatically(
2357 /*assume_valid_feeds=*/false,
2358 /*aggressive_shape_inference=*/true,
2359 /*include_tensor_values=*/true));
2360 const auto props = properties.GetOutputProperties("Identity");
2361 EXPECT_EQ(1, props.size());
2362 const OpInfo::TensorProperties& prop = props[0];
2363 EXPECT_EQ(DT_FLOAT, prop.dtype());
2364 EXPECT_EQ(2, prop.shape().dim_size());
2365 // Incompatible shapes. Do not use annotated shapes.
2366 EXPECT_EQ("float: [-1,100]", PropToString(prop));
2367 }
2368
TEST_F(GraphPropertiesTest,ShapeAnnotationWithoutInferenceFn)2369 TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
2370 GrapplerItem item;
2371 TF_ASSERT_OK(NodeDefBuilder("Input", "Placeholder")
2372 .Attr("dtype", DT_FLOAT)
2373 .Attr("shape", PartialTensorShape({-1, -1}))
2374 .Finalize(item.graph.add_node()));
2375 // Annotate shapes.
2376 TF_ASSERT_OK(
2377 NodeDefBuilder("TestOpWithNoInferenceFn", "TestOpWithNoInferenceFn")
2378 .Attr("_same_output_for_iterations", true)
2379 .Attr("_output_shape_vector", {TensorShape({10, 100})})
2380 .Input("Input", 0, DT_FLOAT)
2381 .Finalize(item.graph.add_node()));
2382 GraphProperties properties(item);
2383 // Use annotated information.
2384 TF_ASSERT_OK(properties.InferStatically(
2385 /*assume_valid_feeds=*/false,
2386 /*aggressive_shape_inference=*/true,
2387 /*include_tensor_values=*/true));
2388 const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
2389 EXPECT_EQ(1, props.size());
2390 const OpInfo::TensorProperties& prop = props[0];
2391 EXPECT_EQ(DT_FLOAT, prop.dtype());
2392 EXPECT_EQ(2, prop.shape().dim_size());
2393 EXPECT_EQ("float: [10,100]", PropToString(prop));
2394 }
2395
TEST_F(GraphPropertiesTest,PartitionedCallOp)2396 TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2397 Scope root = Scope::NewRootScope().ExitOnError();
2398 FunctionDefLibrary library;
2399 FunctionDef called_func = FunctionDefHelper::Create(
2400 "identity_function",
2401 /*in_def=*/{"arg0: int32"},
2402 /*out_def=*/{"ret0: int32"},
2403 /*attr_def=*/{},
2404 {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2405 /*ret_def=*/{{"ret0", "Identity:output:0"}});
2406 *library.add_function() = called_func;
2407 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
2408
2409 Output in = ops::Const(root, {3, 1, 2, 0});
2410 NameAttrList b_name_attr;
2411 b_name_attr.set_name("identity_function");
2412 ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2413 b_name_attr);
2414
2415 GrapplerItem item;
2416 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2417
2418 GraphProperties properties(item);
2419 TF_ASSERT_OK(properties.InferStatically(
2420 /*assume_valid_feeds=*/true,
2421 /*aggressive_shape_inference=*/false,
2422 /*include_tensor_values=*/true));
2423
2424 EXPECT_EQ("int32: [4]",
2425 PropToString(properties.GetOutputProperties("identity_call")[0]));
2426 }
2427
TEST_F(GraphPropertiesTest,NonTrivialInputPartitionedCallOp)2428 TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2429 auto f = FunctionDefHelper::Create(
2430 // Name
2431 "FunctionWhichAdds",
2432 // Inputs
2433 {"arg0: int32", "arg1: int32"},
2434 // Outputs
2435 {"ret0: int32"},
2436 /*attr_def=*/{},
2437 // Nodes
2438 {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2439 /*ret_def=*/{{"ret0", "a:z:0"}});
2440
2441 FunctionDefLibrary function_lib;
2442 function_lib.add_function()->Swap(&f);
2443 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2444 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(function_lib));
2445
2446 PartialTensorShape input_shape({2, 2, -1});
2447 Output in1 =
2448 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2449 Output in2 =
2450 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2451 NameAttrList b_name_attr;
2452 b_name_attr.set_name("FunctionWhichAdds");
2453 ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2454 b_name_attr);
2455
2456 GrapplerItem item;
2457 TF_ASSERT_OK(root.ToGraphDef(&item.graph));
2458
2459 GraphProperties properties(item);
2460 TF_ASSERT_OK(properties.InferStatically(
2461 /*assume_valid_feeds=*/true,
2462 /*aggressive_shape_inference=*/false,
2463 /*include_tensor_values=*/true));
2464
2465 EXPECT_EQ("int32: [2,2,-1]",
2466 PropToString(properties.GetOutputProperties("add_call")[0]));
2467 }
2468
TEST_F(GraphPropertiesTest,ShapeAnnotatedFunctionOp)2469 TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
2470 // A function, which we cannot infer output shape statically.
2471 auto f = FunctionDefHelper::Create(
2472 // Name
2473 "FuncShapeCannotBeInferred",
2474 // Inputs
2475 {},
2476 // Outputs
2477 {"output: float"},
2478 // Attrs
2479 {},
2480 // Nodes
2481 {
2482 // Placeholder without shape attr; unknown rank.
2483 {{"p"}, "Placeholder", {}, {{"dtype", DataType::DT_FLOAT}}},
2484 },
2485 // Returns
2486 {{"output", "p:output:0"}});
2487 FunctionDefLibrary function_lib;
2488 function_lib.add_function()->Swap(&f);
2489 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2490 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(function_lib));
2491 tensorflow::Node* func_op;
2492 TensorShapeProto output_shape;
2493 output_shape.set_unknown_rank(false);
2494 output_shape.add_dim()->set_size(1);
2495 output_shape.add_dim()->set_size(2);
2496 output_shape.add_dim()->set_size(3);
2497 output_shape.add_dim()->set_size(4);
2498 // The function node, f, includes shape annotation.
2499 TF_ASSERT_OK(tensorflow::NodeBuilder("f", "FuncShapeCannotBeInferred",
2500 s.graph()->op_registry())
2501 .Attr("_execution_count", 1)
2502 .Attr("_same_output_for_iterations", true)
2503 .Attr("_output_dtype_vector", {DataType::DT_FLOAT})
2504 .Attr("_output_shape_vector", {output_shape})
2505 .Finalize(s.graph(), &func_op));
2506 GrapplerItem item;
2507 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
2508
2509 // InferStatically with aggressive_shape_inference would fail to infer
2510 // the output shape of the node f.
2511 {
2512 GraphProperties properties(item);
2513 TF_ASSERT_OK(properties.InferStatically(
2514 /*assume_valid_feeds=*/false,
2515 /*aggressive_shape_inference=*/false,
2516 /*include_tensor_values=*/false));
2517 const auto out_props = properties.GetOutputProperties("f");
2518 const OpInfo::TensorProperties out_prop0 = out_props[0];
2519 EXPECT_EQ("float: ?", PropToString(out_prop0));
2520 }
2521 // With aggressive_shape_inference, it skips recursively callying
2522 // InferStatically for the function node and outputs annotated shape info.
2523 {
2524 GraphProperties properties(item);
2525 TF_ASSERT_OK(properties.InferStatically(
2526 /*assume_valid_feeds=*/false,
2527 /*aggressive_shape_inference=*/true,
2528 /*include_tensor_values=*/true));
2529 const auto out_props = properties.GetOutputProperties("f");
2530 const OpInfo::TensorProperties out_prop0 = out_props[0];
2531 EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
2532 }
2533 }
2534
TEST_F(GraphPropertiesTest,SymbolicShapeInferenceWithReshapeOpsSharingShapeVector)2535 TEST_F(GraphPropertiesTest,
2536 SymbolicShapeInferenceWithReshapeOpsSharingShapeVector) {
2537 GrapplerItem item;
2538 // This graph creates a shape vector [-1, 10] from Concat(Const, Const)
2539 // used for two reshape ops. One reshape op is segment_ids input to
2540 // UnsortedSegmentSum op, which applies MergePrefix from its shape function.
2541 // segment_ids has a shape [-1, 10] (from reshape), but MergePrefix with
2542 // data input ([10, 10, 10, 10]) makes -1, or unknown dim, 10, with
2543 // SymbolicShapeRefiner.
2544 // This dim value (10), however, should not affect the other reshape op, even
2545 // though it shares the shape input; -1 in the shape input of Reshape op is
2546 // a special case of computed output dim, not unknown dim.
2547 // data and num_segments are inputs to UnsortedSegmenetSum.
2548
2549 TF_ASSERT_OK(NodeDefBuilder("data", "Placeholder")
2550 .Attr("dtype", DT_FLOAT)
2551 .Attr("shape", TensorShape({10, 10, 10, 10}))
2552 .Finalize(item.graph.add_node()));
2553 Tensor num_segments(DT_INT32, TensorShape({}));
2554 // Build semgent_ids input to UnsortedSegmentSum from Const ops, ConcatV2,
2555 // and Reshape ops. tensors_as_shape from Const ops are propagated to ConcatV2
2556 // output to form shape vector [-1, 10] to Reshape.
2557 test::FillIota<int>(&num_segments, 3);
2558 TF_ASSERT_OK(NodeDefBuilder("num_segments", "Const")
2559 .Attr("dtype", DT_INT32)
2560 .Attr("value", num_segments)
2561 .Finalize(item.graph.add_node()));
2562 Tensor minus_one(DT_INT32, TensorShape({1}));
2563 test::FillIota<int>(&minus_one, -1);
2564 TF_ASSERT_OK(NodeDefBuilder("minus_one", "Const")
2565 .Attr("dtype", DT_INT32)
2566 .Attr("value", minus_one)
2567 .Finalize(item.graph.add_node()));
2568 Tensor plus_ten(DT_INT32, TensorShape({1}));
2569 test::FillIota<int>(&plus_ten, 10);
2570 TF_ASSERT_OK(NodeDefBuilder("plus_ten", "Const")
2571 .Attr("dtype", DT_INT32)
2572 .Attr("value", plus_ten)
2573 .Finalize(item.graph.add_node()));
2574 Tensor axis(DT_INT32, TensorShape({}));
2575 test::FillIota<int>(&axis, -1);
2576 TF_ASSERT_OK(NodeDefBuilder("axis", "Const")
2577 .Attr("dtype", DT_INT32)
2578 .Attr("value", axis)
2579 .Finalize(item.graph.add_node()));
2580 std::vector<NodeDefBuilder::NodeOut> inputs(2);
2581 inputs[0] = NodeDefBuilder::NodeOut{"minus_one", 0, DT_INT32};
2582 inputs[1] = NodeDefBuilder::NodeOut{"plus_ten", 0, DT_INT32};
2583 TF_ASSERT_OK(NodeDefBuilder("concat", "ConcatV2")
2584 .Input(inputs)
2585 .Input("axis", 0, DT_INT32)
2586 .Attr("N", 2)
2587 .Attr("T", DT_INT32)
2588 .Attr("Tidx", DT_INT32)
2589 .Finalize(item.graph.add_node()));
2590 TF_ASSERT_OK(NodeDefBuilder("segment_ids_", "Placeholder")
2591 .Attr("dtype", DT_FLOAT)
2592 .Finalize(item.graph.add_node()));
2593 TF_ASSERT_OK(NodeDefBuilder("segment_ids_shape_before_reshape", "Shape")
2594 .Input("segment_ids_", 0, DT_FLOAT)
2595 .Attr("T", DT_FLOAT)
2596 .Attr("out_type", DT_INT32)
2597 .Finalize(item.graph.add_node()));
2598 TF_ASSERT_OK(NodeDefBuilder("segment_ids", "Reshape")
2599 .Input("segment_ids_", 0, DT_FLOAT)
2600 .Input("concat", 0, DT_INT32)
2601 .Attr("T", DT_FLOAT)
2602 .Attr("Tshape", DT_INT32)
2603 .Finalize(item.graph.add_node()));
2604 // Shape function of UnsortedSegmentSum applies MergePrefix to data and
2605 // segment_ids (the latter being prefix). data shape is [10,10,10,10] and
2606 // segment_ids shape is [-1, 10], but MergePrefix and symbolic shape inference
2607 // assign 10 from data shape to the unknown dim in segment_ids.
2608 TF_ASSERT_OK(NodeDefBuilder("y", "UnsortedSegmentSum")
2609 .Input("data", 0, DT_FLOAT)
2610 .Input("segment_ids", 0, DT_INT32)
2611 .Input("num_segments", 0, DT_INT32)
2612 .Attr("T", DT_FLOAT)
2613 .Attr("Tindices", DT_INT32)
2614 .Attr("Tnumsegments", DT_INT32)
2615 .Finalize(item.graph.add_node()));
2616 // Note that y2=Reshape(x1) using the same shape vector as segment_ids, but
2617 // y2 shape shouldn't be affected by symbolic shape inference w/ segment_ids.
2618 TF_ASSERT_OK(NodeDefBuilder("x1", "Placeholder")
2619 .Attr("dtype", DT_FLOAT)
2620 .Finalize(item.graph.add_node()));
2621 TF_ASSERT_OK(NodeDefBuilder("y1", "Reshape")
2622 .Input("x1", 0, DT_FLOAT)
2623 .Input("concat", 0, DT_INT32)
2624 .Attr("T", DT_FLOAT)
2625 .Attr("Tshape", DT_INT32)
2626 .Finalize(item.graph.add_node()));
2627
2628 GraphProperties properties(item);
2629 TF_ASSERT_OK(properties.InferStatically(true));
2630 const auto& y1_output_properties = properties.GetOutputProperties("y1");
2631 // y1=reshape(x1), but x1's shape in unknown, so y1 should be [-1, 10].
2632 // The first dimension should not be 10.
2633 EXPECT_EQ(y1_output_properties.size(), 1);
2634 EXPECT_EQ(y1_output_properties[0].shape().dim_size(), 2);
2635 EXPECT_LT(y1_output_properties[0].shape().dim(0).size(), 0);
2636 EXPECT_EQ(y1_output_properties[0].shape().dim(1).size(), 10);
2637 }
2638
TEST(HelperFunctions,IsShapeFullyDefinedIntegerVectorOrScalar)2639 TEST(HelperFunctions, IsShapeFullyDefinedIntegerVectorOrScalar) {
2640 // Make a dummy InferenceContext.
2641 NodeDef node_def;
2642 OpRegistrationData op_reg_data;
2643 OpDefBuilder b("dummy");
2644 CHECK(b.Finalize(&op_reg_data).ok());
2645 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
2646 input_handle_shapes_and_types;
2647 InferenceContext ic(/*graph_def_version=*/0, node_def, op_reg_data.op_def,
2648 /*input_shapes=*/{},
2649 /*input_tensors=*/{},
2650 /*input_tensors_as_shapes=*/{},
2651 std::move(input_handle_shapes_and_types));
2652
2653 // ShapeHandles for testing.
2654 ShapeHandle fully_defined_vector = ic.MakeShape(
2655 {ic.MakeDim(4), ic.MakeDim(5), ic.MakeDim(6), ic.MakeDim(7)});
2656 ShapeHandle vector_with_unknown = ic.MakeShape(
2657 {ic.MakeDim(4), ic.MakeDim(5), ic.UnknownDim(), ic.MakeDim(7)});
2658 // INT64_MAX is used as unknown from Const. See kUnknownFromConst const in
2659 // graph_properties.cc
2660 ShapeHandle vector_with_unknown_from_const = ic.MakeShape(
2661 {ic.MakeDim(4), ic.MakeDim(INT64_MAX), ic.MakeDim(6), ic.MakeDim(7)});
2662 ShapeHandle rank_1_vector = ic.MakeShape({ic.MakeDim(4)});
2663
2664 // Rank-1 shape and fully defined tensor_as_shape with INT32 or INT64.
2665 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2666 &ic, rank_1_vector, fully_defined_vector, DT_INT32));
2667 EXPECT_TRUE(IsShapeFullyDefinedIntegerVectorOrScalar(
2668 &ic, rank_1_vector, fully_defined_vector, DT_INT64));
2669
2670 // Non-integer data type.
2671 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2672 &ic, rank_1_vector, fully_defined_vector, DT_FLOAT));
2673
2674 // tensor_as_shape including Unknown or UnknownFromConst.
2675 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2676 &ic, rank_1_vector, vector_with_unknown, DT_INT32));
2677 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2678 &ic, rank_1_vector, vector_with_unknown_from_const, DT_INT32));
2679 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2680 &ic, rank_1_vector, ic.UnknownShape(), DT_INT32));
2681 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2682 &ic, ic.UnknownShape(), fully_defined_vector, DT_INT32));
2683
2684 // shape rank > 1.
2685 EXPECT_FALSE(IsShapeFullyDefinedIntegerVectorOrScalar(
2686 &ic, fully_defined_vector, vector_with_unknown_from_const, DT_INT32));
2687 }
2688 } // namespace
2689 } // namespace grappler
2690 } // namespace tensorflow
2691