xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/function_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17 
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 namespace function_utils {
28 namespace {
29 
TEST(FunctionDefTensorDesc,Parsing)30 TEST(FunctionDefTensorDesc, Parsing) {
31   FunctionDefTensorDesc f("Cast:y:0");
32   EXPECT_EQ(f.full_str, "Cast:y:0");
33   EXPECT_EQ(f.node_name, "Cast");
34   EXPECT_EQ(f.node_output, "y");
35   EXPECT_EQ(f.position, 0);
36 
37   FunctionDefTensorDesc f2("Arg0");
38   EXPECT_EQ(f2.full_str, "Arg0");
39   EXPECT_EQ(f2.node_name, "Arg0");
40   EXPECT_EQ(f2.node_output, "");
41   EXPECT_EQ(f2.position, -1);
42 }
43 
TEST(ReplaceReferencesTest,ReplaceReferencesTest)44 TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
45   FunctionDef outer = FunctionDefHelper::Create(
46       "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
47       {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
48   NodeDef* derive_node =
49       AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
50   // Check that both the input to "X" and retval of "outer" are replaced.
51   ReplaceReferences("MapDefun:output:0", "arg0", &outer);
52   EXPECT_EQ(outer.ret().at("out"), "arg0");
53   EXPECT_EQ(derive_node->input(0), "arg0");
54 }
55 
TEST(FunctionUtilsTest,AddFunctionOutputWithUniqueName)56 TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
57   FunctionDef function = test::function::XTimesTwo();
58   AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
59   EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
60   EXPECT_EQ(function.ret().at("y/_1"), "two");
61 }
62 
TEST(FunctionUtilsTest,AddFunctionInput)63 TEST(FunctionUtilsTest, AddFunctionInput) {
64   FunctionDef fdef;
65   auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32);
66   auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL);
67   EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0);
68   EXPECT_EQ(arg0->name(), "arg0");
69   EXPECT_EQ(arg0->type(), DT_INT32);
70   EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1);
71   EXPECT_EQ(arg1->name(), "arg1");
72   EXPECT_EQ(arg1->type(), DT_BOOL);
73 }
74 
TEST(FunctionUtilsTest,ContainsFunctionNodeWithName)75 TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
76   FunctionDef function = test::function::XTimesTwo();
77   EXPECT_FALSE(ContainsFunctionNodeWithName(
78       "weird_name_that_should_not_be_there", function));
79   EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
80 }
81 
TEST(FunctionUtilsTest,ContainsFunctionNodeWithOp)82 TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
83   FunctionDef function = test::function::XTimesTwo();
84   EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
85                                           function));
86   EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
87 }
88 
TEST(FunctionUtilsTest,ContainsFunctionOutputWithName)89 TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
90   FunctionDef function = test::function::XTimesTwo();
91   EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
92   EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
93 }
94 
TEST(FunctionUtilsTest,FindFunctionNodeWithName)95 TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
96   FunctionDef function = test::function::XTimesTwo();
97   EXPECT_EQ(
98       FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
99       -1);
100   EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
101 }
102 
TEST(FunctionUtilsTest,FindFunctionNodeWithOp)103 TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
104   FunctionDef function = test::function::XTimesTwo();
105   EXPECT_EQ(
106       FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
107       -1);
108   EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
109 }
110 
TEST(FunctionUtilsTest,FindFunctionInputWithName)111 TEST(FunctionUtilsTest, FindFunctionInputWithName) {
112   FunctionDef function = test::function::XTimesTwo();
113   EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
114   EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
115 }
116 
TEST(FunctionUtilsTest,FindFunctionOutputWithName)117 TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
118   FunctionDef function = test::function::XTimesTwo();
119   EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
120   EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
121 }
122 
TEST(FunctionUtilsTest,SetUniqueFunctionNodeName)123 TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
124   FunctionDef function = test::function::XTimesTwo();
125   NodeDef node;
126   SetUniqueFunctionNodeName("abc", &function, &node);
127   for (const NodeDef& function_node : function.node_def()) {
128     EXPECT_NE(node.name(), function_node.name());
129   }
130   auto* new_node = function.add_node_def();
131   *new_node = node;
132 
133   NodeDef other;
134   SetUniqueFunctionNodeName("abc", &function, &other);
135   EXPECT_NE(other.name(), new_node->name());
136 }
137 
TEST(FunctionUtilsTest,AddNodeToFunctionDef)138 TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
139   FunctionDef func;
140   const char* op_name = "xxx";
141   AddNode(op_name, op_name, {}, {}, &func);
142 
143   const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
144   EXPECT_EQ(node1.op(), op_name);
145   EXPECT_EQ(node1.input_size(), 0);
146   EXPECT_EQ(node1.attr_size(), 0);
147 
148   const std::vector<string> inputs({"input1", "input2"});
149   AddNode("", op_name, inputs, {}, &func);
150   const NodeDef& node2 =
151       func.node_def(FindFunctionNodeWithName("xxx/_2", func));
152   EXPECT_EQ(node2.op(), op_name);
153   EXPECT_EQ(node2.attr_size(), 0);
154   EXPECT_EQ(node2.input_size(), inputs.size());
155   for (size_t i = 0; i < inputs.size(); ++i) {
156     EXPECT_EQ(node2.input(i), inputs[i]);
157   }
158 
159   AttrValue a1, a2;
160   a1.set_type(DT_INT32);
161   a2.set_type(DT_INT64);
162   const std::vector<std::pair<string, AttrValue>> attrs(
163       {{"attr1", a1}, {"attr2", a2}});
164   AddNode("", op_name, {}, attrs, &func);
165   const NodeDef& node3 =
166       func.node_def(FindFunctionNodeWithName("xxx/_3", func));
167   EXPECT_EQ(node3.op(), op_name);
168   EXPECT_EQ(node3.input_size(), 0);
169   EXPECT_EQ(node3.attr_size(), attrs.size());
170   for (size_t i = 0; i < attrs.size(); ++i) {
171     EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
172   }
173 }
174 
175 // Graph containing function with "If" and "Assert" Op.
176 /*
177   @eager_function.defun
178   def test_function():
179     pred = constant_op.constant(True)
180 
181     def fn1():
182       return control_flow_ops.no_op()
183 
184     def fn2():
185       return control_flow_ops.Assert(False, ["Wrong branch!!!"])
186 
187     return control_flow_ops.cond(pred, fn1, fn2)
188 
189   r = test_function()
190 */
191 // Following proto is generated in python using the above code block, to
192 // regenerate get the graph_def from the default graph/specified graph for the
193 // code block (e.g ops.get_default_graph.as_graph_def()).
194 constexpr char kCondGraphProto[] = R"proto(
195   node {
196     name: "StatefulPartitionedCall"
197     op: "StatefulPartitionedCall"
198     attr {
199       key: "Tin"
200       value { list {} }
201     }
202     attr {
203       key: "Tout"
204       value { list { type: DT_BOOL } }
205     }
206     attr {
207       key: "_gradient_op_type"
208       value { s: "PartitionedCall-20" }
209     }
210     attr {
211       key: "config"
212       value { s: "" }
213     }
214     attr {
215       key: "config_proto"
216       value { s: "" }
217     }
218     attr {
219       key: "executor_type"
220       value { s: "" }
221     }
222     attr {
223       key: "f"
224       value { func { name: "__inference_test_function_19" } }
225     }
226   }
227   library {
228     function {
229       signature {
230         name: "cond_true_3"
231         input_arg { name: "identity_const" type: DT_BOOL }
232         output_arg { name: "identity_1" type: DT_BOOL }
233       }
234       node_def { name: "NoOp" op: "NoOp" }
235       node_def {
236         name: "Identity"
237         op: "Identity"
238         input: "identity_const"
239         input: "^NoOp"
240         attr {
241           key: "T"
242           value { type: DT_BOOL }
243         }
244       }
245       node_def {
246         name: "Identity_1"
247         op: "Identity"
248         input: "Identity:output:0"
249         attr {
250           key: "T"
251           value { type: DT_BOOL }
252         }
253       }
254       ret { key: "identity_1" value: "Identity_1:output:0" }
255     }
256     function {
257       signature {
258         name: "cond_false_4"
259         input_arg { name: "identity_const" type: DT_BOOL }
260         output_arg { name: "identity_1" type: DT_BOOL }
261         is_stateful: true
262       }
263       node_def {
264         name: "Assert/Const"
265         op: "Const"
266         attr {
267           key: "dtype"
268           value { type: DT_STRING }
269         }
270         attr {
271           key: "value"
272           value {
273             tensor {
274               dtype: DT_STRING
275               tensor_shape {}
276               string_val: "Wrong branch!!!"
277             }
278           }
279         }
280       }
281       node_def {
282         name: "Assert/Assert/condition"
283         op: "Const"
284         attr {
285           key: "dtype"
286           value { type: DT_BOOL }
287         }
288         attr {
289           key: "value"
290           value {
291             tensor {
292               dtype: DT_BOOL
293               tensor_shape {}
294               bool_val: false
295             }
296           }
297         }
298       }
299       node_def {
300         name: "Assert/Assert/data_0"
301         op: "Const"
302         attr {
303           key: "dtype"
304           value { type: DT_STRING }
305         }
306         attr {
307           key: "value"
308           value {
309             tensor {
310               dtype: DT_STRING
311               tensor_shape {}
312               string_val: "Wrong branch!!!"
313             }
314           }
315         }
316       }
317       node_def {
318         name: "Assert/Assert"
319         op: "Assert"
320         input: "Assert/Assert/condition:output:0"
321         input: "Assert/Assert/data_0:output:0"
322         attr {
323           key: "T"
324           value { list { type: DT_STRING } }
325         }
326         attr {
327           key: "summarize"
328           value { i: 3 }
329         }
330       }
331       node_def {
332         name: "Identity"
333         op: "Identity"
334         input: "identity_const"
335         input: "^Assert/Assert"
336         attr {
337           key: "T"
338           value { type: DT_BOOL }
339         }
340       }
341       node_def {
342         name: "Identity_1"
343         op: "Identity"
344         input: "Identity:output:0"
345         input: "^Assert/Assert"
346         attr {
347           key: "T"
348           value { type: DT_BOOL }
349         }
350       }
351       ret { key: "identity_1" value: "Identity_1:output:0" }
352     }
353     function {
354       signature {
355         name: "__inference_test_function_19"
356         output_arg { name: "identity" type: DT_BOOL }
357         is_stateful: true
358       }
359       node_def {
360         name: "Const"
361         op: "Const"
362         attr {
363           key: "dtype"
364           value { type: DT_BOOL }
365         }
366         attr {
367           key: "value"
368           value {
369             tensor {
370               dtype: DT_BOOL
371               tensor_shape {}
372               bool_val: true
373             }
374           }
375         }
376       }
377       node_def {
378         name: "cond"
379         op: "If"
380         input: "Const:output:0"
381         input: "Const:output:0"
382         attr {
383           key: "Tcond"
384           value { type: DT_BOOL }
385         }
386         attr {
387           key: "Tin"
388           value { list { type: DT_BOOL } }
389         }
390         attr {
391           key: "Tout"
392           value { list { type: DT_BOOL } }
393         }
394         attr {
395           key: "_lower_using_switch_merge"
396           value { b: true }
397         }
398         attr {
399           key: "else_branch"
400           value { func { name: "cond_false_4" } }
401         }
402         attr {
403           key: "output_shapes"
404           value { list { shape {} } }
405         }
406         attr {
407           key: "then_branch"
408           value { func { name: "cond_true_3" } }
409         }
410       }
411       node_def {
412         name: "cond/Identity"
413         op: "Identity"
414         input: "cond:output:0"
415         attr {
416           key: "T"
417           value { type: DT_BOOL }
418         }
419       }
420       node_def {
421         name: "Identity"
422         op: "Identity"
423         input: "cond/Identity:output:0"
424         input: "^cond"
425         attr {
426           key: "T"
427           value { type: DT_BOOL }
428         }
429       }
430       ret { key: "identity" value: "Identity:output:0" }
431     }
432   }
433   versions { producer: 27 min_consumer: 12 })proto";
434 
435 // Graph containing function with "While" Op in python.
436 /*
437   @eager_function.defun
438   def test_function():
439     return control_flow_ops.while_loop(
440         lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
441 
442   r = test_function()
443 */
444 // Following proto is generated in python using the above code block, to
445 // regenerate get the graph_def from the default graph/specified graph for the
446 // code block (e.g ops.get_default_graph.as_graph_def()).
447 constexpr char kWhileGraphProto[] = R"proto(
448   node {
449     name: "StatefulPartitionedCall"
450     op: "StatefulPartitionedCall"
451     attr {
452       key: "Tin"
453       value { list {} }
454     }
455     attr {
456       key: "Tout"
457       value { list { type: DT_INT32 } }
458     }
459     attr {
460       key: "_gradient_op_type"
461       value { s: "PartitionedCall-35" }
462     }
463     attr {
464       key: "config"
465       value { s: "" }
466     }
467     attr {
468       key: "config_proto"
469       value { s: "" }
470     }
471     attr {
472       key: "executor_type"
473       value { s: "" }
474     }
475     attr {
476       key: "f"
477       value { func { name: "__inference_test_function_34" } }
478     }
479   }
480   library {
481     function {
482       signature {
483         name: "while_body_5"
484         input_arg { name: "while_loop_counter" type: DT_INT32 }
485         input_arg { name: "const" type: DT_INT32 }
486         input_arg { name: "maximum_iterations" type: DT_INT32 }
487         output_arg { name: "identity" type: DT_INT32 }
488         output_arg { name: "identity_1" type: DT_INT32 }
489         output_arg { name: "identity_2" type: DT_INT32 }
490       }
491       node_def {
492         name: "add/y"
493         op: "Const"
494         attr {
495           key: "dtype"
496           value { type: DT_INT32 }
497         }
498         attr {
499           key: "value"
500           value {
501             tensor {
502               dtype: DT_INT32
503               tensor_shape {}
504               int_val: 1
505             }
506           }
507         }
508       }
509       node_def {
510         name: "add"
511         op: "Add"
512         input: "const"
513         input: "add/y:output:0"
514         attr {
515           key: "T"
516           value { type: DT_INT32 }
517         }
518       }
519       node_def {
520         name: "add_1/y"
521         op: "Const"
522         attr {
523           key: "dtype"
524           value { type: DT_INT32 }
525         }
526         attr {
527           key: "value"
528           value {
529             tensor {
530               dtype: DT_INT32
531               tensor_shape {}
532               int_val: 1
533             }
534           }
535         }
536       }
537       node_def {
538         name: "add_1"
539         op: "Add"
540         input: "while_loop_counter"
541         input: "add_1/y:output:0"
542         attr {
543           key: "T"
544           value { type: DT_INT32 }
545         }
546       }
547       node_def {
548         name: "Identity"
549         op: "Identity"
550         input: "add_1:z:0"
551         attr {
552           key: "T"
553           value { type: DT_INT32 }
554         }
555       }
556       node_def {
557         name: "Identity_1"
558         op: "Identity"
559         input: "add:z:0"
560         attr {
561           key: "T"
562           value { type: DT_INT32 }
563         }
564       }
565       node_def {
566         name: "Identity_2"
567         op: "Identity"
568         input: "maximum_iterations"
569         attr {
570           key: "T"
571           value { type: DT_INT32 }
572         }
573       }
574       ret { key: "identity" value: "Identity:output:0" }
575       ret { key: "identity_1" value: "Identity_1:output:0" }
576       ret { key: "identity_2" value: "Identity_2:output:0" }
577     }
578     function {
579       signature {
580         name: "__inference_test_function_34"
581         output_arg { name: "identity" type: DT_INT32 }
582         is_stateful: true
583       }
584       node_def {
585         name: "maximum_iterations"
586         op: "Const"
587         attr {
588           key: "dtype"
589           value { type: DT_INT32 }
590         }
591         attr {
592           key: "value"
593           value {
594             tensor {
595               dtype: DT_INT32
596               tensor_shape {}
597               int_val: 1
598             }
599           }
600         }
601       }
602       node_def {
603         name: "Const"
604         op: "Const"
605         attr {
606           key: "dtype"
607           value { type: DT_INT32 }
608         }
609         attr {
610           key: "value"
611           value {
612             tensor {
613               dtype: DT_INT32
614               tensor_shape {}
615               int_val: 0
616             }
617           }
618         }
619       }
620       node_def {
621         name: "while/loop_counter"
622         op: "Const"
623         attr {
624           key: "dtype"
625           value { type: DT_INT32 }
626         }
627         attr {
628           key: "value"
629           value {
630             tensor {
631               dtype: DT_INT32
632               tensor_shape {}
633               int_val: 0
634             }
635           }
636         }
637       }
638       node_def {
639         name: "while"
640         op: "While"
641         input: "while/loop_counter:output:0"
642         input: "Const:output:0"
643         input: "maximum_iterations:output:0"
644         attr {
645           key: "T"
646           value { list { type: DT_INT32 type: DT_INT32 type: DT_INT32 } }
647         }
648         attr {
649           key: "_lower_using_switch_merge"
650           value { b: true }
651         }
652         attr {
653           key: "body"
654           value { func { name: "while_body_5" } }
655         }
656         attr {
657           key: "cond"
658           value { func { name: "while_cond_4" } }
659         }
660         attr {
661           key: "output_shapes"
662           value {
663             list {
664               shape {}
665               shape {}
666               shape {}
667             }
668           }
669         }
670       }
671       node_def {
672         name: "while/Identity"
673         op: "Identity"
674         input: "while:output:0"
675         attr {
676           key: "T"
677           value { type: DT_INT32 }
678         }
679       }
680       node_def {
681         name: "while/Identity_1"
682         op: "Identity"
683         input: "while:output:1"
684         attr {
685           key: "T"
686           value { type: DT_INT32 }
687         }
688       }
689       node_def {
690         name: "while/Identity_2"
691         op: "Identity"
692         input: "while:output:2"
693         attr {
694           key: "T"
695           value { type: DT_INT32 }
696         }
697       }
698       node_def {
699         name: "Identity"
700         op: "Identity"
701         input: "while/Identity_1:output:0"
702         input: "^while"
703         attr {
704           key: "T"
705           value { type: DT_INT32 }
706         }
707       }
708       ret { key: "identity" value: "Identity:output:0" }
709     }
710     function {
711       signature {
712         name: "while_cond_4"
713         input_arg { name: "while_loop_counter" type: DT_INT32 }
714         input_arg { name: "const" type: DT_INT32 }
715         input_arg { name: "less_maximum_iterations" type: DT_INT32 }
716         output_arg { name: "identity" type: DT_BOOL }
717       }
718       node_def {
719         name: "Less"
720         op: "Less"
721         input: "while_loop_counter"
722         input: "less_maximum_iterations"
723         attr {
724           key: "T"
725           value { type: DT_INT32 }
726         }
727       }
728       node_def {
729         name: "Less_1/y"
730         op: "Const"
731         attr {
732           key: "dtype"
733           value { type: DT_INT32 }
734         }
735         attr {
736           key: "value"
737           value {
738             tensor {
739               dtype: DT_INT32
740               tensor_shape {}
741               int_val: 3
742             }
743           }
744         }
745       }
746       node_def {
747         name: "Less_1"
748         op: "Less"
749         input: "const"
750         input: "Less_1/y:output:0"
751         attr {
752           key: "T"
753           value { type: DT_INT32 }
754         }
755       }
756       node_def {
757         name: "LogicalAnd"
758         op: "LogicalAnd"
759         input: "Less:z:0"
760         input: "Less_1:z:0"
761       }
762       node_def {
763         name: "Identity"
764         op: "Identity"
765         input: "LogicalAnd:z:0"
766         attr {
767           key: "T"
768           value { type: DT_BOOL }
769         }
770       }
771       ret { key: "identity" value: "Identity:output:0" }
772     }
773   }
774   versions { producer: 27 min_consumer: 12 })proto";
775 
776 // TODO(shivaniagrawal): split the test into multiple tests for better
777 // readability and add full coverage i.e. add/separate out the tests for all
778 // branches of IsNodeStateful and IsFunctionStateful:
779 // - test for IsNodeStateful for Cond that has a stateful branch
780 // - test for IsNodeStateful for Cond that does not have a stateful branches
781 // - test for IsNodeStateful for While that has a stateful branch
782 // - test for IsNodeStateful for While that does not have a stateful branches
783 // - test for IsNodeStateful for Assert
784 // - test for IsNodeStateful for a stateful op
785 // - test for IsNodeStateful for a stateless op
786 //
787 // - test for IsFunctionStateful for a function that contains a Cond
788 // - test for IsFunctionStateful for a function that contains a While
789 // - test for IsFunctionStateful for a function that contains an Assert (and no
790 //   other stateful op)
791 // - test for IsFunctionStateful for a function that contains a stateful op
792 //   other than Assert
793 // - test for IsFunctionStateful for a function that does not contain a stateful
794 //   op
795 
TEST(FunctionUtilsTest,IsFunctionStateful)796 TEST(FunctionUtilsTest, IsFunctionStateful) {
797   GraphDef graph_def;
798   MutableGraphView graph(&graph_def);
799 
800   NodeDef* nodeA = graph_utils::AddNode("", "A", {}, {}, &graph);
801   FunctionDef* function = graph_def.mutable_library()->add_function();
802   *function = test::function::XTimesTwo();
803 
804   FunctionLibraryDefinition lib_def(OpRegistry::Global(),
805                                     *graph_def.mutable_library());
806 
807   EXPECT_FALSE(IsFunctionStateful(lib_def, *function));
808 
809   // Op "A" is not a registered Op.
810   EXPECT_TRUE(IsNodeStateful(lib_def, *nodeA));
811 
812   // Get graph_def for the graph `kCondGraphProto`, graph with function
813   // containing "If" and "Assert" Op.
814 
815   GraphDef graph_def_cond;
816   protobuf::TextFormat::ParseFromString(kCondGraphProto, &graph_def_cond);
817   FunctionLibraryDefinition cond_lib(OpRegistry::Global(),
818                                      graph_def_cond.library());
819 
820   const FunctionDef* no_op_fnc = cond_lib.Find("cond_true_3");
821 
822   EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc));
823   EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc, true));
824 
825   const FunctionDef* assert_func = cond_lib.Find("cond_false_4");
826 
827   EXPECT_TRUE(IsFunctionStateful(cond_lib, *assert_func));
828   EXPECT_FALSE(IsFunctionStateful(cond_lib, *assert_func, true));
829 
830   EXPECT_TRUE(ContainsFunctionNodeWithOp("Const", *assert_func));
831   EXPECT_TRUE(ContainsFunctionNodeWithOp("Assert", *assert_func));
832 
833   for (auto node : assert_func->node_def()) {
834     if (node.op() == "Const") {
835       EXPECT_FALSE(IsNodeStateful(lib_def, node));
836     }
837     if (node.op() == "Assert") {
838       EXPECT_TRUE(IsNodeStateful(lib_def, node));
839       EXPECT_FALSE(IsNodeStateful(lib_def, node, true));
840     }
841   }
842 
843   const FunctionDef* cond_func = cond_lib.Find("__inference_test_function_19");
844 
845   EXPECT_TRUE(IsFunctionStateful(cond_lib, *cond_func));
846   EXPECT_FALSE(IsFunctionStateful(cond_lib, *cond_func, true));
847 
848   // Get graph def for the graph `kWhileGraphProto`, graph with function
849   // containing "While" Op.
850 
851   GraphDef graph_def_while;
852   protobuf::TextFormat::ParseFromString(kWhileGraphProto, &graph_def_while);
853 
854   FunctionLibraryDefinition while_lib(OpRegistry::Global(),
855                                       graph_def_while.library());
856   const FunctionDef* while_function =
857       while_lib.Find("__inference_test_function_34");
858   EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function));
859   EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function, true));
860 }
861 }  // namespace
862 }  // namespace function_utils
863 }  // namespace grappler
864 }  // namespace tensorflow
865