1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 namespace function_utils {
28 namespace {
29
TEST(FunctionDefTensorDesc,Parsing)30 TEST(FunctionDefTensorDesc, Parsing) {
31 FunctionDefTensorDesc f("Cast:y:0");
32 EXPECT_EQ(f.full_str, "Cast:y:0");
33 EXPECT_EQ(f.node_name, "Cast");
34 EXPECT_EQ(f.node_output, "y");
35 EXPECT_EQ(f.position, 0);
36
37 FunctionDefTensorDesc f2("Arg0");
38 EXPECT_EQ(f2.full_str, "Arg0");
39 EXPECT_EQ(f2.node_name, "Arg0");
40 EXPECT_EQ(f2.node_output, "");
41 EXPECT_EQ(f2.position, -1);
42 }
43
TEST(ReplaceReferencesTest,ReplaceReferencesTest)44 TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
45 FunctionDef outer = FunctionDefHelper::Create(
46 "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
47 {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
48 NodeDef* derive_node =
49 AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
50 // Check that both the input to "X" and retval of "outer" are replaced.
51 ReplaceReferences("MapDefun:output:0", "arg0", &outer);
52 EXPECT_EQ(outer.ret().at("out"), "arg0");
53 EXPECT_EQ(derive_node->input(0), "arg0");
54 }
55
TEST(FunctionUtilsTest,AddFunctionOutputWithUniqueName)56 TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
57 FunctionDef function = test::function::XTimesTwo();
58 AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
59 EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
60 EXPECT_EQ(function.ret().at("y/_1"), "two");
61 }
62
TEST(FunctionUtilsTest,AddFunctionInput)63 TEST(FunctionUtilsTest, AddFunctionInput) {
64 FunctionDef fdef;
65 auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32);
66 auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL);
67 EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0);
68 EXPECT_EQ(arg0->name(), "arg0");
69 EXPECT_EQ(arg0->type(), DT_INT32);
70 EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1);
71 EXPECT_EQ(arg1->name(), "arg1");
72 EXPECT_EQ(arg1->type(), DT_BOOL);
73 }
74
TEST(FunctionUtilsTest,ContainsFunctionNodeWithName)75 TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
76 FunctionDef function = test::function::XTimesTwo();
77 EXPECT_FALSE(ContainsFunctionNodeWithName(
78 "weird_name_that_should_not_be_there", function));
79 EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
80 }
81
TEST(FunctionUtilsTest,ContainsFunctionNodeWithOp)82 TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
83 FunctionDef function = test::function::XTimesTwo();
84 EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
85 function));
86 EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
87 }
88
TEST(FunctionUtilsTest,ContainsFunctionOutputWithName)89 TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
90 FunctionDef function = test::function::XTimesTwo();
91 EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
92 EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
93 }
94
TEST(FunctionUtilsTest,FindFunctionNodeWithName)95 TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
96 FunctionDef function = test::function::XTimesTwo();
97 EXPECT_EQ(
98 FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
99 -1);
100 EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
101 }
102
TEST(FunctionUtilsTest,FindFunctionNodeWithOp)103 TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
104 FunctionDef function = test::function::XTimesTwo();
105 EXPECT_EQ(
106 FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
107 -1);
108 EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
109 }
110
TEST(FunctionUtilsTest,FindFunctionInputWithName)111 TEST(FunctionUtilsTest, FindFunctionInputWithName) {
112 FunctionDef function = test::function::XTimesTwo();
113 EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
114 EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
115 }
116
TEST(FunctionUtilsTest,FindFunctionOutputWithName)117 TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
118 FunctionDef function = test::function::XTimesTwo();
119 EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
120 EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
121 }
122
TEST(FunctionUtilsTest,SetUniqueFunctionNodeName)123 TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
124 FunctionDef function = test::function::XTimesTwo();
125 NodeDef node;
126 SetUniqueFunctionNodeName("abc", &function, &node);
127 for (const NodeDef& function_node : function.node_def()) {
128 EXPECT_NE(node.name(), function_node.name());
129 }
130 auto* new_node = function.add_node_def();
131 *new_node = node;
132
133 NodeDef other;
134 SetUniqueFunctionNodeName("abc", &function, &other);
135 EXPECT_NE(other.name(), new_node->name());
136 }
137
TEST(FunctionUtilsTest,AddNodeToFunctionDef)138 TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
139 FunctionDef func;
140 const char* op_name = "xxx";
141 AddNode(op_name, op_name, {}, {}, &func);
142
143 const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
144 EXPECT_EQ(node1.op(), op_name);
145 EXPECT_EQ(node1.input_size(), 0);
146 EXPECT_EQ(node1.attr_size(), 0);
147
148 const std::vector<string> inputs({"input1", "input2"});
149 AddNode("", op_name, inputs, {}, &func);
150 const NodeDef& node2 =
151 func.node_def(FindFunctionNodeWithName("xxx/_2", func));
152 EXPECT_EQ(node2.op(), op_name);
153 EXPECT_EQ(node2.attr_size(), 0);
154 EXPECT_EQ(node2.input_size(), inputs.size());
155 for (size_t i = 0; i < inputs.size(); ++i) {
156 EXPECT_EQ(node2.input(i), inputs[i]);
157 }
158
159 AttrValue a1, a2;
160 a1.set_type(DT_INT32);
161 a2.set_type(DT_INT64);
162 const std::vector<std::pair<string, AttrValue>> attrs(
163 {{"attr1", a1}, {"attr2", a2}});
164 AddNode("", op_name, {}, attrs, &func);
165 const NodeDef& node3 =
166 func.node_def(FindFunctionNodeWithName("xxx/_3", func));
167 EXPECT_EQ(node3.op(), op_name);
168 EXPECT_EQ(node3.input_size(), 0);
169 EXPECT_EQ(node3.attr_size(), attrs.size());
170 for (size_t i = 0; i < attrs.size(); ++i) {
171 EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
172 }
173 }
174
175 // Graph containing function with "If" and "Assert" Op.
176 /*
177 @eager_function.defun
178 def test_function():
179 pred = constant_op.constant(True)
180
181 def fn1():
182 return control_flow_ops.no_op()
183
184 def fn2():
185 return control_flow_ops.Assert(False, ["Wrong branch!!!"])
186
187 return control_flow_ops.cond(pred, fn1, fn2)
188
189 r = test_function()
190 */
191 // Following proto is generated in python using the above code block, to
192 // regenerate get the graph_def from the default graph/specified graph for the
193 // code block (e.g ops.get_default_graph.as_graph_def()).
194 constexpr char kCondGraphProto[] = R"proto(
195 node {
196 name: "StatefulPartitionedCall"
197 op: "StatefulPartitionedCall"
198 attr {
199 key: "Tin"
200 value { list {} }
201 }
202 attr {
203 key: "Tout"
204 value { list { type: DT_BOOL } }
205 }
206 attr {
207 key: "_gradient_op_type"
208 value { s: "PartitionedCall-20" }
209 }
210 attr {
211 key: "config"
212 value { s: "" }
213 }
214 attr {
215 key: "config_proto"
216 value { s: "" }
217 }
218 attr {
219 key: "executor_type"
220 value { s: "" }
221 }
222 attr {
223 key: "f"
224 value { func { name: "__inference_test_function_19" } }
225 }
226 }
227 library {
228 function {
229 signature {
230 name: "cond_true_3"
231 input_arg { name: "identity_const" type: DT_BOOL }
232 output_arg { name: "identity_1" type: DT_BOOL }
233 }
234 node_def { name: "NoOp" op: "NoOp" }
235 node_def {
236 name: "Identity"
237 op: "Identity"
238 input: "identity_const"
239 input: "^NoOp"
240 attr {
241 key: "T"
242 value { type: DT_BOOL }
243 }
244 }
245 node_def {
246 name: "Identity_1"
247 op: "Identity"
248 input: "Identity:output:0"
249 attr {
250 key: "T"
251 value { type: DT_BOOL }
252 }
253 }
254 ret { key: "identity_1" value: "Identity_1:output:0" }
255 }
256 function {
257 signature {
258 name: "cond_false_4"
259 input_arg { name: "identity_const" type: DT_BOOL }
260 output_arg { name: "identity_1" type: DT_BOOL }
261 is_stateful: true
262 }
263 node_def {
264 name: "Assert/Const"
265 op: "Const"
266 attr {
267 key: "dtype"
268 value { type: DT_STRING }
269 }
270 attr {
271 key: "value"
272 value {
273 tensor {
274 dtype: DT_STRING
275 tensor_shape {}
276 string_val: "Wrong branch!!!"
277 }
278 }
279 }
280 }
281 node_def {
282 name: "Assert/Assert/condition"
283 op: "Const"
284 attr {
285 key: "dtype"
286 value { type: DT_BOOL }
287 }
288 attr {
289 key: "value"
290 value {
291 tensor {
292 dtype: DT_BOOL
293 tensor_shape {}
294 bool_val: false
295 }
296 }
297 }
298 }
299 node_def {
300 name: "Assert/Assert/data_0"
301 op: "Const"
302 attr {
303 key: "dtype"
304 value { type: DT_STRING }
305 }
306 attr {
307 key: "value"
308 value {
309 tensor {
310 dtype: DT_STRING
311 tensor_shape {}
312 string_val: "Wrong branch!!!"
313 }
314 }
315 }
316 }
317 node_def {
318 name: "Assert/Assert"
319 op: "Assert"
320 input: "Assert/Assert/condition:output:0"
321 input: "Assert/Assert/data_0:output:0"
322 attr {
323 key: "T"
324 value { list { type: DT_STRING } }
325 }
326 attr {
327 key: "summarize"
328 value { i: 3 }
329 }
330 }
331 node_def {
332 name: "Identity"
333 op: "Identity"
334 input: "identity_const"
335 input: "^Assert/Assert"
336 attr {
337 key: "T"
338 value { type: DT_BOOL }
339 }
340 }
341 node_def {
342 name: "Identity_1"
343 op: "Identity"
344 input: "Identity:output:0"
345 input: "^Assert/Assert"
346 attr {
347 key: "T"
348 value { type: DT_BOOL }
349 }
350 }
351 ret { key: "identity_1" value: "Identity_1:output:0" }
352 }
353 function {
354 signature {
355 name: "__inference_test_function_19"
356 output_arg { name: "identity" type: DT_BOOL }
357 is_stateful: true
358 }
359 node_def {
360 name: "Const"
361 op: "Const"
362 attr {
363 key: "dtype"
364 value { type: DT_BOOL }
365 }
366 attr {
367 key: "value"
368 value {
369 tensor {
370 dtype: DT_BOOL
371 tensor_shape {}
372 bool_val: true
373 }
374 }
375 }
376 }
377 node_def {
378 name: "cond"
379 op: "If"
380 input: "Const:output:0"
381 input: "Const:output:0"
382 attr {
383 key: "Tcond"
384 value { type: DT_BOOL }
385 }
386 attr {
387 key: "Tin"
388 value { list { type: DT_BOOL } }
389 }
390 attr {
391 key: "Tout"
392 value { list { type: DT_BOOL } }
393 }
394 attr {
395 key: "_lower_using_switch_merge"
396 value { b: true }
397 }
398 attr {
399 key: "else_branch"
400 value { func { name: "cond_false_4" } }
401 }
402 attr {
403 key: "output_shapes"
404 value { list { shape {} } }
405 }
406 attr {
407 key: "then_branch"
408 value { func { name: "cond_true_3" } }
409 }
410 }
411 node_def {
412 name: "cond/Identity"
413 op: "Identity"
414 input: "cond:output:0"
415 attr {
416 key: "T"
417 value { type: DT_BOOL }
418 }
419 }
420 node_def {
421 name: "Identity"
422 op: "Identity"
423 input: "cond/Identity:output:0"
424 input: "^cond"
425 attr {
426 key: "T"
427 value { type: DT_BOOL }
428 }
429 }
430 ret { key: "identity" value: "Identity:output:0" }
431 }
432 }
433 versions { producer: 27 min_consumer: 12 })proto";
434
435 // Graph containing function with "While" Op in python.
436 /*
437 @eager_function.defun
438 def test_function():
439 return control_flow_ops.while_loop(
440 lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
441
442 r = test_function()
443 */
444 // Following proto is generated in python using the above code block, to
445 // regenerate get the graph_def from the default graph/specified graph for the
446 // code block (e.g ops.get_default_graph.as_graph_def()).
447 constexpr char kWhileGraphProto[] = R"proto(
448 node {
449 name: "StatefulPartitionedCall"
450 op: "StatefulPartitionedCall"
451 attr {
452 key: "Tin"
453 value { list {} }
454 }
455 attr {
456 key: "Tout"
457 value { list { type: DT_INT32 } }
458 }
459 attr {
460 key: "_gradient_op_type"
461 value { s: "PartitionedCall-35" }
462 }
463 attr {
464 key: "config"
465 value { s: "" }
466 }
467 attr {
468 key: "config_proto"
469 value { s: "" }
470 }
471 attr {
472 key: "executor_type"
473 value { s: "" }
474 }
475 attr {
476 key: "f"
477 value { func { name: "__inference_test_function_34" } }
478 }
479 }
480 library {
481 function {
482 signature {
483 name: "while_body_5"
484 input_arg { name: "while_loop_counter" type: DT_INT32 }
485 input_arg { name: "const" type: DT_INT32 }
486 input_arg { name: "maximum_iterations" type: DT_INT32 }
487 output_arg { name: "identity" type: DT_INT32 }
488 output_arg { name: "identity_1" type: DT_INT32 }
489 output_arg { name: "identity_2" type: DT_INT32 }
490 }
491 node_def {
492 name: "add/y"
493 op: "Const"
494 attr {
495 key: "dtype"
496 value { type: DT_INT32 }
497 }
498 attr {
499 key: "value"
500 value {
501 tensor {
502 dtype: DT_INT32
503 tensor_shape {}
504 int_val: 1
505 }
506 }
507 }
508 }
509 node_def {
510 name: "add"
511 op: "Add"
512 input: "const"
513 input: "add/y:output:0"
514 attr {
515 key: "T"
516 value { type: DT_INT32 }
517 }
518 }
519 node_def {
520 name: "add_1/y"
521 op: "Const"
522 attr {
523 key: "dtype"
524 value { type: DT_INT32 }
525 }
526 attr {
527 key: "value"
528 value {
529 tensor {
530 dtype: DT_INT32
531 tensor_shape {}
532 int_val: 1
533 }
534 }
535 }
536 }
537 node_def {
538 name: "add_1"
539 op: "Add"
540 input: "while_loop_counter"
541 input: "add_1/y:output:0"
542 attr {
543 key: "T"
544 value { type: DT_INT32 }
545 }
546 }
547 node_def {
548 name: "Identity"
549 op: "Identity"
550 input: "add_1:z:0"
551 attr {
552 key: "T"
553 value { type: DT_INT32 }
554 }
555 }
556 node_def {
557 name: "Identity_1"
558 op: "Identity"
559 input: "add:z:0"
560 attr {
561 key: "T"
562 value { type: DT_INT32 }
563 }
564 }
565 node_def {
566 name: "Identity_2"
567 op: "Identity"
568 input: "maximum_iterations"
569 attr {
570 key: "T"
571 value { type: DT_INT32 }
572 }
573 }
574 ret { key: "identity" value: "Identity:output:0" }
575 ret { key: "identity_1" value: "Identity_1:output:0" }
576 ret { key: "identity_2" value: "Identity_2:output:0" }
577 }
578 function {
579 signature {
580 name: "__inference_test_function_34"
581 output_arg { name: "identity" type: DT_INT32 }
582 is_stateful: true
583 }
584 node_def {
585 name: "maximum_iterations"
586 op: "Const"
587 attr {
588 key: "dtype"
589 value { type: DT_INT32 }
590 }
591 attr {
592 key: "value"
593 value {
594 tensor {
595 dtype: DT_INT32
596 tensor_shape {}
597 int_val: 1
598 }
599 }
600 }
601 }
602 node_def {
603 name: "Const"
604 op: "Const"
605 attr {
606 key: "dtype"
607 value { type: DT_INT32 }
608 }
609 attr {
610 key: "value"
611 value {
612 tensor {
613 dtype: DT_INT32
614 tensor_shape {}
615 int_val: 0
616 }
617 }
618 }
619 }
620 node_def {
621 name: "while/loop_counter"
622 op: "Const"
623 attr {
624 key: "dtype"
625 value { type: DT_INT32 }
626 }
627 attr {
628 key: "value"
629 value {
630 tensor {
631 dtype: DT_INT32
632 tensor_shape {}
633 int_val: 0
634 }
635 }
636 }
637 }
638 node_def {
639 name: "while"
640 op: "While"
641 input: "while/loop_counter:output:0"
642 input: "Const:output:0"
643 input: "maximum_iterations:output:0"
644 attr {
645 key: "T"
646 value { list { type: DT_INT32 type: DT_INT32 type: DT_INT32 } }
647 }
648 attr {
649 key: "_lower_using_switch_merge"
650 value { b: true }
651 }
652 attr {
653 key: "body"
654 value { func { name: "while_body_5" } }
655 }
656 attr {
657 key: "cond"
658 value { func { name: "while_cond_4" } }
659 }
660 attr {
661 key: "output_shapes"
662 value {
663 list {
664 shape {}
665 shape {}
666 shape {}
667 }
668 }
669 }
670 }
671 node_def {
672 name: "while/Identity"
673 op: "Identity"
674 input: "while:output:0"
675 attr {
676 key: "T"
677 value { type: DT_INT32 }
678 }
679 }
680 node_def {
681 name: "while/Identity_1"
682 op: "Identity"
683 input: "while:output:1"
684 attr {
685 key: "T"
686 value { type: DT_INT32 }
687 }
688 }
689 node_def {
690 name: "while/Identity_2"
691 op: "Identity"
692 input: "while:output:2"
693 attr {
694 key: "T"
695 value { type: DT_INT32 }
696 }
697 }
698 node_def {
699 name: "Identity"
700 op: "Identity"
701 input: "while/Identity_1:output:0"
702 input: "^while"
703 attr {
704 key: "T"
705 value { type: DT_INT32 }
706 }
707 }
708 ret { key: "identity" value: "Identity:output:0" }
709 }
710 function {
711 signature {
712 name: "while_cond_4"
713 input_arg { name: "while_loop_counter" type: DT_INT32 }
714 input_arg { name: "const" type: DT_INT32 }
715 input_arg { name: "less_maximum_iterations" type: DT_INT32 }
716 output_arg { name: "identity" type: DT_BOOL }
717 }
718 node_def {
719 name: "Less"
720 op: "Less"
721 input: "while_loop_counter"
722 input: "less_maximum_iterations"
723 attr {
724 key: "T"
725 value { type: DT_INT32 }
726 }
727 }
728 node_def {
729 name: "Less_1/y"
730 op: "Const"
731 attr {
732 key: "dtype"
733 value { type: DT_INT32 }
734 }
735 attr {
736 key: "value"
737 value {
738 tensor {
739 dtype: DT_INT32
740 tensor_shape {}
741 int_val: 3
742 }
743 }
744 }
745 }
746 node_def {
747 name: "Less_1"
748 op: "Less"
749 input: "const"
750 input: "Less_1/y:output:0"
751 attr {
752 key: "T"
753 value { type: DT_INT32 }
754 }
755 }
756 node_def {
757 name: "LogicalAnd"
758 op: "LogicalAnd"
759 input: "Less:z:0"
760 input: "Less_1:z:0"
761 }
762 node_def {
763 name: "Identity"
764 op: "Identity"
765 input: "LogicalAnd:z:0"
766 attr {
767 key: "T"
768 value { type: DT_BOOL }
769 }
770 }
771 ret { key: "identity" value: "Identity:output:0" }
772 }
773 }
774 versions { producer: 27 min_consumer: 12 })proto";
775
776 // TODO(shivaniagrawal): split the test into multiple tests for better
777 // readability and add full coverage i.e. add/separate out the tests for all
778 // branches of IsNodeStateful and IsFunctionStateful:
779 // - test for IsNodeStateful for Cond that has a stateful branch
780 // - test for IsNodeStateful for Cond that does not have a stateful branches
781 // - test for IsNodeStateful for While that has a stateful branch
782 // - test for IsNodeStateful for While that does not have a stateful branches
783 // - test for IsNodeStateful for Assert
784 // - test for IsNodeStateful for a stateful op
785 // - test for IsNodeStateful for a stateless op
786 //
787 // - test for IsFunctionStateful for a function that contains a Cond
788 // - test for IsFunctionStateful for a function that contains a While
789 // - test for IsFunctionStateful for a function that contains an Assert (and no
790 // other stateful op)
791 // - test for IsFunctionStateful for a function that contains a stateful op
792 // other than Assert
793 // - test for IsFunctionStateful for a function that does not contain a stateful
794 // op
795
TEST(FunctionUtilsTest,IsFunctionStateful)796 TEST(FunctionUtilsTest, IsFunctionStateful) {
797 GraphDef graph_def;
798 MutableGraphView graph(&graph_def);
799
800 NodeDef* nodeA = graph_utils::AddNode("", "A", {}, {}, &graph);
801 FunctionDef* function = graph_def.mutable_library()->add_function();
802 *function = test::function::XTimesTwo();
803
804 FunctionLibraryDefinition lib_def(OpRegistry::Global(),
805 *graph_def.mutable_library());
806
807 EXPECT_FALSE(IsFunctionStateful(lib_def, *function));
808
809 // Op "A" is not a registered Op.
810 EXPECT_TRUE(IsNodeStateful(lib_def, *nodeA));
811
812 // Get graph_def for the graph `kCondGraphProto`, graph with function
813 // containing "If" and "Assert" Op.
814
815 GraphDef graph_def_cond;
816 protobuf::TextFormat::ParseFromString(kCondGraphProto, &graph_def_cond);
817 FunctionLibraryDefinition cond_lib(OpRegistry::Global(),
818 graph_def_cond.library());
819
820 const FunctionDef* no_op_fnc = cond_lib.Find("cond_true_3");
821
822 EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc));
823 EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc, true));
824
825 const FunctionDef* assert_func = cond_lib.Find("cond_false_4");
826
827 EXPECT_TRUE(IsFunctionStateful(cond_lib, *assert_func));
828 EXPECT_FALSE(IsFunctionStateful(cond_lib, *assert_func, true));
829
830 EXPECT_TRUE(ContainsFunctionNodeWithOp("Const", *assert_func));
831 EXPECT_TRUE(ContainsFunctionNodeWithOp("Assert", *assert_func));
832
833 for (auto node : assert_func->node_def()) {
834 if (node.op() == "Const") {
835 EXPECT_FALSE(IsNodeStateful(lib_def, node));
836 }
837 if (node.op() == "Assert") {
838 EXPECT_TRUE(IsNodeStateful(lib_def, node));
839 EXPECT_FALSE(IsNodeStateful(lib_def, node, true));
840 }
841 }
842
843 const FunctionDef* cond_func = cond_lib.Find("__inference_test_function_19");
844
845 EXPECT_TRUE(IsFunctionStateful(cond_lib, *cond_func));
846 EXPECT_FALSE(IsFunctionStateful(cond_lib, *cond_func, true));
847
848 // Get graph def for the graph `kWhileGraphProto`, graph with function
849 // containing "While" Op.
850
851 GraphDef graph_def_while;
852 protobuf::TextFormat::ParseFromString(kWhileGraphProto, &graph_def_while);
853
854 FunctionLibraryDefinition while_lib(OpRegistry::Global(),
855 graph_def_while.library());
856 const FunctionDef* while_function =
857 while_lib.Find("__inference_test_function_34");
858 EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function));
859 EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function, true));
860 }
861 } // namespace
862 } // namespace function_utils
863 } // namespace grappler
864 } // namespace tensorflow
865