1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
17 #include "tensorflow/cc/ops/const_op.h"
18 #include "tensorflow/cc/ops/image_ops.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26
27 namespace tensorflow {
28 namespace graph_transforms {
29
30 class TransformUtilsTest : public ::testing::Test {
31 protected:
TestMapNamesToNodes()32 void TestMapNamesToNodes() {
33 auto root = tensorflow::Scope::NewRootScope();
34 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
35
36 const int width = 100;
37
38 Tensor a_data(DT_FLOAT, TensorShape({width}));
39 test::FillIota<float>(&a_data, 1.0f);
40 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
41
42 Tensor b_data(DT_FLOAT, TensorShape({width}));
43 test::FillIota<float>(&b_data, 1.0f);
44 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
45
46 Output add = Add(root.WithOpName("add"), a_const, b_const);
47
48 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
49
50 Output mul = Mul(root.WithOpName("output"), add, placeholder);
51
52 GraphDef graph_def;
53 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
54 std::map<string, const NodeDef*> node_map;
55 MapNamesToNodes(graph_def, &node_map);
56
57 EXPECT_EQ(1, node_map.count("a"));
58 EXPECT_EQ(1, node_map.count("b"));
59 EXPECT_EQ(1, node_map.count("add"));
60 EXPECT_EQ(1, node_map.count("placeholder"));
61 EXPECT_EQ(1, node_map.count("output"));
62 EXPECT_EQ(0, node_map.count("no_such_node"));
63 }
64
TestMapNodesToOutputs()65 void TestMapNodesToOutputs() {
66 auto root = tensorflow::Scope::NewRootScope();
67 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
68
69 const int width = 100;
70
71 Tensor a_data(DT_FLOAT, TensorShape({width}));
72 test::FillIota<float>(&a_data, 1.0f);
73 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
74
75 Tensor b_data(DT_FLOAT, TensorShape({width}));
76 test::FillIota<float>(&b_data, 1.0f);
77 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
78
79 Output add = Add(root.WithOpName("add"), a_const, b_const);
80
81 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
82
83 Output mul = Mul(root.WithOpName("output"), add, placeholder);
84
85 GraphDef graph_def;
86 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
87
88 std::map<string, std::vector<const NodeDef*>> outputs_map;
89 MapNodesToOutputs(graph_def, &outputs_map);
90
91 EXPECT_EQ(1, outputs_map.count("a"));
92 EXPECT_EQ(1, outputs_map["a"].size());
93 EXPECT_EQ("add", outputs_map["a"][0]->name());
94
95 EXPECT_EQ(1, outputs_map.count("b"));
96 EXPECT_EQ(1, outputs_map["b"].size());
97 EXPECT_EQ("add", outputs_map["b"][0]->name());
98
99 EXPECT_EQ(1, outputs_map.count("add"));
100 EXPECT_EQ(1, outputs_map["add"].size());
101 EXPECT_EQ("output", outputs_map["add"][0]->name());
102
103 EXPECT_EQ(1, outputs_map.count("placeholder"));
104 EXPECT_EQ(1, outputs_map["placeholder"].size());
105 EXPECT_EQ("output", outputs_map["placeholder"][0]->name());
106
107 EXPECT_EQ(0, outputs_map.count("output"));
108 EXPECT_EQ(0, outputs_map.count("no_such_node"));
109 }
110
TestNodeNamePartsFromInput()111 void TestNodeNamePartsFromInput() {
112 string prefix;
113 string node_name;
114 string suffix;
115
116 NodeNamePartsFromInput("some_node_name", &prefix, &node_name, &suffix);
117 EXPECT_EQ("", prefix);
118 EXPECT_EQ("some_node_name", node_name);
119 EXPECT_EQ("", suffix);
120
121 NodeNamePartsFromInput("some_node_name/with/slashes", &prefix, &node_name,
122 &suffix);
123 EXPECT_EQ("", prefix);
124 EXPECT_EQ("some_node_name/with/slashes", node_name);
125 EXPECT_EQ("", suffix);
126
127 NodeNamePartsFromInput("some_node_name:0", &prefix, &node_name, &suffix);
128 EXPECT_EQ("", prefix);
129 EXPECT_EQ("some_node_name", node_name);
130 EXPECT_EQ(":0", suffix);
131
132 NodeNamePartsFromInput("^some_node_name", &prefix, &node_name, &suffix);
133 EXPECT_EQ("^", prefix);
134 EXPECT_EQ("some_node_name", node_name);
135 EXPECT_EQ("", suffix);
136
137 NodeNamePartsFromInput("^some_node_name:99", &prefix, &node_name, &suffix);
138 EXPECT_EQ("^", prefix);
139 EXPECT_EQ("some_node_name", node_name);
140 EXPECT_EQ(":99", suffix);
141 }
142
TestNodeNameFromInput()143 void TestNodeNameFromInput() {
144 EXPECT_EQ("node_name", NodeNameFromInput("node_name"));
145 EXPECT_EQ("node_name", NodeNameFromInput("node_name:0"));
146 EXPECT_EQ("node_name", NodeNameFromInput("^node_name"));
147 EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42"));
148 }
149
TestCanonicalInputName()150 void TestCanonicalInputName() {
151 EXPECT_EQ("node_name:0", CanonicalInputName("node_name"));
152 EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0"));
153 EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name"));
154 EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42"));
155 }
156
TestAddNodeInput()157 void TestAddNodeInput() {
158 NodeDef node;
159 AddNodeInput("foo", &node);
160 EXPECT_EQ("foo", node.input(0));
161 }
162
TestCopyNodeAttr()163 void TestCopyNodeAttr() {
164 NodeDef node;
165 auto mutable_attr = node.mutable_attr();
166 (*mutable_attr)["foo"].set_i(3);
167
168 NodeDef copied_node;
169 CopyNodeAttr(node, "foo", "bar", &copied_node);
170 EXPECT_EQ(3, copied_node.attr().at("bar").i());
171 }
172
TestSetNodeAttr()173 void TestSetNodeAttr() {
174 NodeDef node;
175 int32_t value_i = 32;
176 SetNodeAttr("foo", value_i, &node);
177 EXPECT_EQ(32, node.attr().at("foo").i());
178 string value_s = "some_value";
179 SetNodeAttr("bar", value_s, &node);
180 EXPECT_EQ("some_value", node.attr().at("bar").s());
181 }
182
TestSetNodeTensorAttr()183 void TestSetNodeTensorAttr() {
184 NodeDef node;
185 SetNodeTensorAttr<int32>("foo", {3, 1}, {1, 2, 3}, &node);
186 TensorProto tensor_proto = node.attr().at("foo").tensor();
187 Tensor tensor;
188 CHECK(tensor.FromProto(tensor_proto));
189 EXPECT_EQ(DT_INT32, tensor.dtype());
190 EXPECT_EQ(3, tensor.shape().dim_size(0));
191 EXPECT_EQ(1, tensor.shape().dim_size(1));
192 EXPECT_EQ(1, tensor.flat<int32>()(0));
193 EXPECT_EQ(2, tensor.flat<int32>()(1));
194 EXPECT_EQ(3, tensor.flat<int32>()(2));
195 }
196
TestSetNodeTensorAttrWithTensor()197 void TestSetNodeTensorAttrWithTensor() {
198 NodeDef node;
199 Tensor input_tensor(DT_INT32, {4, 5});
200 test::FillIota<int32>(&input_tensor, 1);
201 SetNodeTensorAttr<int32>("foo", input_tensor, &node);
202 TensorProto tensor_proto = node.attr().at("foo").tensor();
203 Tensor tensor;
204 CHECK(tensor.FromProto(tensor_proto));
205 test::ExpectTensorEqual<int32>(input_tensor, tensor);
206 }
207
TestGetNodeTensorAttr()208 void TestGetNodeTensorAttr() {
209 NodeDef node;
210 Tensor input_tensor(DT_INT32, {4, 5});
211 test::FillIota<int32>(&input_tensor, 1);
212 TensorProto tensor_proto;
213 input_tensor.AsProtoTensorContent(&tensor_proto);
214 SetNodeAttr("foo", tensor_proto, &node);
215 Tensor result = GetNodeTensorAttr(node, "foo");
216 test::ExpectTensorEqual<int32>(input_tensor, result);
217 }
218
TestFilterGraphDef()219 void TestFilterGraphDef() {
220 auto root = tensorflow::Scope::NewRootScope();
221 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
222
223 const int width = 100;
224
225 Tensor a_data(DT_FLOAT, TensorShape({width}));
226 test::FillIota<float>(&a_data, 1.0f);
227 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
228
229 Tensor b_data(DT_FLOAT, TensorShape({width}));
230 test::FillIota<float>(&b_data, 1.0f);
231 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
232
233 Output add = Add(root.WithOpName("add"), a_const, b_const);
234
235 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
236
237 Output mul = Mul(root.WithOpName("output"), add, placeholder);
238
239 Output remove_me = Add(root.WithOpName("remove_me"), mul, add);
240
241 GraphDef graph_def;
242 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
243
244 GraphDef result_graph_def;
245 FilterGraphDef(
246 graph_def,
247 [](const NodeDef& node) { return (node.name() != "remove_me"); },
248 &result_graph_def);
249
250 std::map<string, const NodeDef*> node_map;
251 MapNamesToNodes(result_graph_def, &node_map);
252 EXPECT_EQ(1, node_map.count("a"));
253 EXPECT_EQ(1, node_map.count("b"));
254 EXPECT_EQ(1, node_map.count("add"));
255 EXPECT_EQ(1, node_map.count("placeholder"));
256 EXPECT_EQ(1, node_map.count("output"));
257 EXPECT_EQ(0, node_map.count("remove_me"));
258 }
259
TestRemoveAttributes()260 void TestRemoveAttributes() {
261 auto root = tensorflow::Scope::NewRootScope();
262 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
263
264 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
265
266 GraphDef graph_def;
267 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
268
269 GraphDef result_graph_def;
270 RemoveAttributes(graph_def, {"dtype"}, &result_graph_def);
271
272 std::map<string, const NodeDef*> node_map;
273 MapNamesToNodes(result_graph_def, &node_map);
274 const NodeDef* removed_placeholder = node_map["placeholder"];
275 EXPECT_EQ(nullptr,
276 tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
277 }
278
TestGetOpTypeMatches()279 void TestGetOpTypeMatches() {
280 auto root = tensorflow::Scope::NewRootScope();
281 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
282
283 const int width = 100;
284
285 Tensor a_data(DT_FLOAT, TensorShape({width}));
286 test::FillIota<float>(&a_data, 1.0f);
287 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
288
289 Tensor b_data(DT_FLOAT, TensorShape({width}));
290 test::FillIota<float>(&b_data, 1.0f);
291 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
292
293 Output add = Add(root.WithOpName("add"), a_const, b_const);
294
295 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
296
297 Output mul = Mul(root.WithOpName("output"), add, placeholder);
298
299 GraphDef graph_def;
300 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
301
302 GraphMatcher matcher(graph_def);
303
304 std::vector<NodeMatch> const_matches;
305 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches));
306 EXPECT_EQ(2, const_matches.size());
307 for (const NodeMatch& match : const_matches) {
308 EXPECT_EQ("Const", match.node.op());
309 EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
310 << "match.node.name()=" << match.node.name();
311 }
312
313 std::vector<NodeMatch> add_matches;
314 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches));
315 EXPECT_EQ(1, add_matches.size());
316 EXPECT_EQ("Add", add_matches[0].node.op());
317 EXPECT_EQ("add", add_matches[0].node.name());
318
319 std::vector<NodeMatch> add_child_matches;
320 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
321 &add_child_matches));
322 EXPECT_EQ(1, add_child_matches.size());
323 EXPECT_EQ("Add", add_child_matches[0].node.op());
324 EXPECT_EQ("add", add_child_matches[0].node.name());
325 EXPECT_EQ(2, add_child_matches[0].inputs.size());
326 for (const NodeMatch& match : add_child_matches[0].inputs) {
327 EXPECT_EQ("Const", match.node.op());
328 EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
329 << "match.node.name()=" << match.node.name();
330 }
331
332 std::vector<NodeMatch> no_such_matches;
333 TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches));
334 EXPECT_EQ(0, no_such_matches.size());
335
336 std::vector<NodeMatch> all_matches;
337 TF_ASSERT_OK(matcher.GetOpTypeMatches(
338 {"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}},
339 &all_matches));
340 EXPECT_EQ(1, all_matches.size());
341 EXPECT_EQ("Mul", all_matches[0].node.op());
342 EXPECT_EQ("output", all_matches[0].node.name());
343 EXPECT_EQ(2, all_matches[0].inputs.size());
344 EXPECT_EQ("Add", all_matches[0].inputs[0].node.op());
345 EXPECT_EQ("add", all_matches[0].inputs[0].node.name());
346 EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size());
347 EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op());
348 EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name());
349 EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size());
350 EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op());
351 EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name());
352 EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size());
353 EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op());
354 EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name());
355 EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size());
356
357 std::vector<NodeMatch> wildcard_matches;
358 TF_ASSERT_OK(
359 matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches));
360 EXPECT_EQ(1, wildcard_matches.size());
361 EXPECT_EQ("Add", wildcard_matches[0].node.op());
362 EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op());
363 EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name());
364 EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op());
365 EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name());
366
367 std::vector<NodeMatch> or_matches;
368 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches));
369 EXPECT_EQ(2, or_matches.size());
370 EXPECT_EQ("Add", or_matches[0].node.op());
371 EXPECT_EQ("add", or_matches[0].node.name());
372 EXPECT_EQ("Mul", or_matches[1].node.op());
373 EXPECT_EQ("output", or_matches[1].node.name());
374 }
375
TestGetOpTypeMatchesDAG()376 void TestGetOpTypeMatchesDAG() {
377 auto root = tensorflow::Scope::NewRootScope();
378 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
379
380 const int width = 100;
381
382 Tensor a_data(DT_FLOAT, TensorShape({width}));
383 test::FillIota<float>(&a_data, 1.0f);
384 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
385
386 Output add = Add(root.WithOpName("add"), a_const, a_const);
387
388 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
389
390 Output mul = Mul(root.WithOpName("output"), add, placeholder);
391
392 GraphDef graph_def;
393 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
394
395 GraphMatcher matcher(graph_def);
396
397 std::vector<NodeMatch> add_matches;
398 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
399 &add_matches));
400 EXPECT_EQ(1, add_matches.size());
401 EXPECT_EQ("Add", add_matches[0].node.op());
402 EXPECT_EQ("add", add_matches[0].node.name());
403 EXPECT_EQ("Const", add_matches[0].inputs[0].node.op());
404 EXPECT_EQ("a", add_matches[0].inputs[0].node.name());
405 EXPECT_EQ("Const", add_matches[0].inputs[1].node.op());
406 EXPECT_EQ("a", add_matches[0].inputs[1].node.name());
407 }
408
TestReplaceMatchingOpTypes()409 void TestReplaceMatchingOpTypes() {
410 auto root = tensorflow::Scope::NewRootScope();
411 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
412
413 const int width = 10;
414
415 Tensor a_data(DT_FLOAT, TensorShape({width}));
416 test::FillIota<float>(&a_data, 1.0f);
417 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
418
419 Tensor b_data(DT_FLOAT, TensorShape({width}));
420 test::FillIota<float>(&b_data, 1.0f);
421 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
422
423 Output add = Add(root.WithOpName("add"), a_const, b_const);
424
425 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
426
427 Output mul = Mul(root.WithOpName("output"), add, placeholder);
428
429 GraphDef graph_def;
430 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
431
432 GraphDef replaced_graph_def;
433 TF_ASSERT_OK(ReplaceMatchingOpTypes(
434 graph_def, {"*"},
435 [](const NodeMatch& match, const std::set<string>& input_nodes,
436 const std::set<string>& output_nodes,
437 std::vector<NodeDef>* new_nodes) {
438 NodeDef original_copy;
439 original_copy = match.node;
440 const string original_name = match.node.name();
441 original_copy.set_name(original_name + "_before_identity");
442 new_nodes->push_back(original_copy);
443
444 NodeDef identity_node;
445 identity_node.set_op("Identity");
446 identity_node.set_name(original_name);
447 *(identity_node.mutable_input()->Add()) = original_copy.name();
448 new_nodes->push_back(identity_node);
449
450 return OkStatus();
451 },
452 {}, &replaced_graph_def));
453
454 EXPECT_EQ(10, replaced_graph_def.node_size());
455 for (const NodeDef& node : replaced_graph_def.node()) {
456 if (node.name() == "output") {
457 EXPECT_EQ("Identity", node.op());
458 EXPECT_EQ("output_before_identity", node.input(0));
459 } else if (node.name() == "output_before_identity") {
460 EXPECT_EQ("Mul", node.op());
461 EXPECT_EQ("add", node.input(0));
462 EXPECT_EQ("placeholder", node.input(1));
463 } else if (node.name() == "placeholder") {
464 EXPECT_EQ("Identity", node.op());
465 EXPECT_EQ("placeholder_before_identity", node.input(0));
466 } else if (node.name() == "placeholder_before_identity") {
467 EXPECT_EQ("Placeholder", node.op());
468 } else if (node.name() == "add") {
469 EXPECT_EQ("Identity", node.op());
470 EXPECT_EQ("add_before_identity", node.input(0));
471 } else if (node.name() == "add_before_identity") {
472 EXPECT_EQ("Add", node.op());
473 EXPECT_EQ("a", node.input(0));
474 EXPECT_EQ("b", node.input(1));
475 } else if (node.name() == "a") {
476 EXPECT_EQ("Identity", node.op());
477 EXPECT_EQ("a_before_identity", node.input(0));
478 } else if (node.name() == "a_before_identity") {
479 EXPECT_EQ("Const", node.op());
480 } else if (node.name() == "b") {
481 EXPECT_EQ("Identity", node.op());
482 EXPECT_EQ("b_before_identity", node.input(0));
483 } else if (node.name() == "b_before_identity") {
484 EXPECT_EQ("Const", node.op());
485 } else {
486 EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name();
487 }
488 }
489 }
490
TestMatchedNodesAsArray()491 void TestMatchedNodesAsArray() {
492 NodeMatch fourth;
493 fourth.node.set_name("fourth");
494
495 NodeMatch second;
496 second.node.set_name("second");
497 second.inputs.push_back(fourth);
498
499 NodeMatch third;
500 third.node.set_name("third");
501 third.inputs.push_back(fourth);
502
503 NodeMatch first;
504 first.node.set_name("first");
505 first.inputs.push_back(second);
506 first.inputs.push_back(third);
507
508 std::vector<NodeDef> result;
509 MatchedNodesAsArray(first, &result);
510
511 EXPECT_EQ(4, result.size());
512 EXPECT_EQ("first", result[0].name());
513 EXPECT_EQ("second", result[1].name());
514 EXPECT_EQ("third", result[2].name());
515 EXPECT_EQ("fourth", result[3].name());
516 }
517
TestRenameNodeInputs()518 void TestRenameNodeInputs() {
519 auto root = tensorflow::Scope::NewRootScope();
520 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
521
522 const int width = 10;
523
524 Tensor a_data(DT_FLOAT, TensorShape({width}));
525 test::FillIota<float>(&a_data, 1.0f);
526 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
527
528 Tensor b_data(DT_FLOAT, TensorShape({width}));
529 test::FillIota<float>(&b_data, 1.0f);
530 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
531
532 Output add = Add(root.WithOpName("add"), a_const, a_const);
533
534 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
535
536 Output mul = Mul(root.WithOpName("output"), add, placeholder);
537
538 GraphDef graph_def;
539 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
540
541 GraphDef renamed_graph_def;
542 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}},
543 std::unordered_set<string>(),
544 &renamed_graph_def));
545
546 std::map<string, const NodeDef*> node_map;
547 MapNamesToNodes(renamed_graph_def, &node_map);
548 EXPECT_EQ("b", node_map.at("add")->input(0));
549 EXPECT_EQ("b", node_map.at("add")->input(1));
550 }
551
TestRenameNodeInputsWithRedirects()552 void TestRenameNodeInputsWithRedirects() {
553 auto root = tensorflow::Scope::NewRootScope();
554 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
555
556 const int width = 10;
557
558 Tensor a_data(DT_FLOAT, TensorShape({width}));
559 test::FillIota<float>(&a_data, 1.0f);
560 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
561
562 Tensor b_data(DT_FLOAT, TensorShape({width}));
563 test::FillIota<float>(&b_data, 1.0f);
564 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
565
566 Tensor c_data(DT_FLOAT, TensorShape({width}));
567 test::FillIota<float>(&c_data, 1.0f);
568 Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
569
570 Output add = Add(root.WithOpName("add"), a_const, b_const);
571
572 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
573
574 Output mul = Mul(root.WithOpName("output"), add, placeholder);
575
576 GraphDef graph_def;
577 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
578
579 GraphDef renamed_graph_def;
580 TF_ASSERT_OK(RenameNodeInputs(
581 graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
582 std::unordered_set<string>(), &renamed_graph_def));
583
584 std::map<string, const NodeDef*> node_map;
585 MapNamesToNodes(renamed_graph_def, &node_map);
586 EXPECT_EQ("c", node_map.at("add")->input(0));
587 EXPECT_EQ("b", node_map.at("add")->input(1));
588 }
589
TestRenameNodeInputsWithCycle()590 void TestRenameNodeInputsWithCycle() {
591 auto root = tensorflow::Scope::NewRootScope();
592 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
593
594 const int width = 10;
595
596 Tensor a_data(DT_FLOAT, TensorShape({width}));
597 test::FillIota<float>(&a_data, 1.0f);
598 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
599
600 Tensor b_data(DT_FLOAT, TensorShape({width}));
601 test::FillIota<float>(&b_data, 1.0f);
602 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
603
604 Tensor c_data(DT_FLOAT, TensorShape({width}));
605 test::FillIota<float>(&c_data, 1.0f);
606 Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
607
608 Output add = Add(root.WithOpName("add"), a_const, b_const);
609
610 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
611
612 Output mul = Mul(root.WithOpName("output"), add, placeholder);
613
614 GraphDef graph_def;
615 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
616
617 GraphDef renamed_graph_def;
618 Status rename_status =
619 RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
620 std::unordered_set<string>(), &renamed_graph_def);
621 EXPECT_FALSE(rename_status.ok());
622 }
623
TestRenameNodeInputsWithWildcard()624 void TestRenameNodeInputsWithWildcard() {
625 auto root = tensorflow::Scope::DisabledShapeInferenceScope();
626 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
627
628 const int width = 10;
629
630 Tensor a_data(DT_FLOAT, TensorShape({width}));
631 test::FillIota<float>(&a_data, 1.0f);
632 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
633
634 QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const,
635 a_const, DT_QUINT8,
636 QuantizeV2::Attrs().Mode("MIN_FIRST"));
637
638 Tensor b_data(DT_FLOAT, TensorShape({width}));
639 test::FillIota<float>(&b_data, 1.0f);
640 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
641
642 QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const,
643 b_const, DT_QUINT8,
644 QuantizeV2::Attrs().Mode("MIN_FIRST"));
645
646 Output add = Add(root.WithOpName("add"), quantize_a.output_min,
647 quantize_a.output_max);
648
649 GraphDef graph_def;
650 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
651
652 GraphDef renamed_graph_def;
653 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
654 std::unordered_set<string>(),
655 &renamed_graph_def));
656
657 std::map<string, const NodeDef*> node_map;
658 MapNamesToNodes(renamed_graph_def, &node_map);
659 EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0));
660 EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
661 }
662
TestRenameNodeInputsWithIgnores()663 void TestRenameNodeInputsWithIgnores() {
664 auto root = tensorflow::Scope::NewRootScope();
665 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
666
667 const int width = 10;
668
669 Tensor a_data(DT_FLOAT, TensorShape({width}));
670 test::FillIota<float>(&a_data, 1.0f);
671 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
672
673 Tensor b_data(DT_FLOAT, TensorShape({width}));
674 test::FillIota<float>(&b_data, 1.0f);
675 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
676
677 Output add = Add(root.WithOpName("add"), a_const, a_const);
678
679 Output add2 = Add(root.WithOpName("add2"), a_const, a_const);
680
681 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
682
683 Output mul = Mul(root.WithOpName("mul"), add, placeholder);
684
685 Output mul2 = Mul(root.WithOpName("output"), mul, add2);
686
687 GraphDef graph_def;
688 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
689
690 GraphDef renamed_graph_def;
691 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"},
692 &renamed_graph_def));
693
694 std::map<string, const NodeDef*> node_map;
695 MapNamesToNodes(renamed_graph_def, &node_map);
696 EXPECT_EQ("b", node_map.at("add")->input(0));
697 EXPECT_EQ("b", node_map.at("add")->input(1));
698 EXPECT_EQ("a", node_map.at("add2")->input(0));
699 EXPECT_EQ("a", node_map.at("add2")->input(1));
700 }
701
TestFindInvalidInputs()702 void TestFindInvalidInputs() {
703 GraphDef graph_def;
704
705 NodeDef* mul_node = graph_def.mutable_node()->Add();
706 mul_node->set_op("Mul");
707 mul_node->set_name("mul_node");
708 *(mul_node->mutable_input()->Add()) = "add_node1";
709 *(mul_node->mutable_input()->Add()) = "add_node2:0";
710 *(mul_node->mutable_input()->Add()) = "^const_node1:0";
711
712 NodeDef* add_node1 = graph_def.mutable_node()->Add();
713 add_node1->set_op("Add");
714 add_node1->set_name("add_node1");
715 *(add_node1->mutable_input()->Add()) = "missing_input1";
716 *(add_node1->mutable_input()->Add()) = "const_node1:0";
717 *(add_node1->mutable_input()->Add()) = "missing_input2";
718
719 NodeDef* add_node2 = graph_def.mutable_node()->Add();
720 add_node2->set_op("Add");
721 add_node2->set_name("add_node2");
722 *(add_node2->mutable_input()->Add()) = "missing_input3";
723 *(add_node2->mutable_input()->Add()) = "const_node1:0";
724 *(add_node2->mutable_input()->Add()) = "^const_node2";
725
726 NodeDef* const_node1 = graph_def.mutable_node()->Add();
727 const_node1->set_op("Const");
728 const_node1->set_name("const_node1");
729
730 NodeDef* const_node2 = graph_def.mutable_node()->Add();
731 const_node2->set_op("Const");
732 const_node2->set_name("const_node2");
733
734 std::vector<std::pair<string, string>> invalid_inputs;
735 FindInvalidInputs(graph_def, &invalid_inputs);
736 EXPECT_EQ(3, invalid_inputs.size());
737 for (const std::pair<string, string>& invalid_input : invalid_inputs) {
738 EXPECT_TRUE((invalid_input.first == "add_node1") ||
739 (invalid_input.first == "add_node2"));
740 if (invalid_input.first == "add_node1") {
741 EXPECT_TRUE((invalid_input.second == "missing_input1") ||
742 (invalid_input.second == "missing_input2"))
743 << invalid_input.second;
744 } else if (invalid_input.first == "add_node2") {
745 EXPECT_EQ("missing_input3", invalid_input.second);
746 }
747 }
748 }
749
TestIsGraphValid()750 void TestIsGraphValid() {
751 GraphDef invalid_graph_def;
752
753 NodeDef* mul_node = invalid_graph_def.mutable_node()->Add();
754 mul_node->set_op("Mul");
755 mul_node->set_name("mul_node");
756 *(mul_node->mutable_input()->Add()) = "add_node1";
757 *(mul_node->mutable_input()->Add()) = "add_node2:0";
758 *(mul_node->mutable_input()->Add()) = "^const_node1:0";
759
760 NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add();
761 add_node1->set_op("Add");
762 add_node1->set_name("add_node1");
763 *(add_node1->mutable_input()->Add()) = "missing_input1";
764 *(add_node1->mutable_input()->Add()) = "const_node1:0";
765 *(add_node1->mutable_input()->Add()) = "missing_input2";
766
767 NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add();
768 add_node2->set_op("Add");
769 add_node2->set_name("add_node2");
770 *(add_node2->mutable_input()->Add()) = "missing_input3";
771 *(add_node2->mutable_input()->Add()) = "const_node1:0";
772 *(add_node2->mutable_input()->Add()) = "^const_node2";
773
774 NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add();
775 const_node1->set_op("Const");
776 const_node1->set_name("const_node1");
777
778 NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add();
779 const_node2->set_op("Const");
780 const_node2->set_name("const_node2");
781
782 EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok());
783
784 GraphDef valid_graph_def;
785
786 NodeDef* const_node3 = valid_graph_def.mutable_node()->Add();
787 const_node3->set_op("Const");
788 const_node3->set_name("const_node2");
789
790 EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
791 }
792
TestGetInOutTypes()793 void TestGetInOutTypes() {
794 auto root = tensorflow::Scope::NewRootScope();
795 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
796
797 const int width = 20;
798
799 Tensor float_data(DT_FLOAT, TensorShape({width}));
800 test::FillIota<float>(&float_data, 1.0f);
801 Output float_const =
802 Const(root.WithOpName("float_const"), Input::Initializer(float_data));
803
804 Tensor int_data(DT_INT32, TensorShape({width}));
805 test::FillIota<int32>(&int_data, 1);
806 Output int_const =
807 Const(root.WithOpName("int_const"), Input::Initializer(int_data));
808
809 Output float_relu = Relu(root.WithOpName("float_relu"), float_const);
810
811 Output int_relu = Relu(root.WithOpName("int_relu"), int_const);
812
813 GraphDef graph_def;
814 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
815
816 std::map<string, const NodeDef*> node_map;
817 MapNamesToNodes(graph_def, &node_map);
818
819 const NodeDef* float_const_def = node_map.at("float_const");
820 DataTypeVector float_const_inputs;
821 DataTypeVector float_const_outputs;
822 TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs,
823 &float_const_outputs));
824 ASSERT_EQ(0, float_const_inputs.size());
825 ASSERT_EQ(1, float_const_outputs.size());
826 EXPECT_EQ(DT_FLOAT, float_const_outputs[0]);
827
828 const NodeDef* int_const_def = node_map.at("int_const");
829 DataTypeVector int_const_inputs;
830 DataTypeVector int_const_outputs;
831 TF_EXPECT_OK(
832 GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs));
833 ASSERT_EQ(0, int_const_inputs.size());
834 ASSERT_EQ(1, int_const_outputs.size());
835 EXPECT_EQ(DT_INT32, int_const_outputs[0]);
836
837 const NodeDef* float_relu_def = node_map.at("float_relu");
838 DataTypeVector float_relu_inputs;
839 DataTypeVector float_relu_outputs;
840 TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs,
841 &float_relu_outputs));
842 ASSERT_EQ(1, float_relu_inputs.size());
843 EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]);
844 ASSERT_EQ(1, float_relu_outputs.size());
845 EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]);
846
847 const NodeDef* int_relu_def = node_map.at("int_relu");
848 DataTypeVector int_relu_inputs;
849 DataTypeVector int_relu_outputs;
850 TF_EXPECT_OK(
851 GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs));
852 ASSERT_EQ(1, int_relu_inputs.size());
853 EXPECT_EQ(DT_INT32, int_relu_inputs[0]);
854 ASSERT_EQ(1, int_relu_outputs.size());
855 EXPECT_EQ(DT_INT32, int_relu_outputs[0]);
856 }
857
TestCopyOriginalMatch()858 void TestCopyOriginalMatch() {
859 NodeDef a;
860 a.set_op("Relu");
861 a.set_name("a");
862 AddNodeInput("b", &a);
863
864 NodeDef b;
865 b.set_op("Const");
866 b.set_name("b");
867
868 NodeMatch b_match;
869 b_match.node = b;
870
871 NodeMatch a_match;
872 a_match.node = a;
873 a_match.inputs.push_back(b_match);
874
875 std::vector<NodeDef> new_nodes;
876 CopyOriginalMatch(a_match, &new_nodes);
877 EXPECT_EQ(2, new_nodes.size());
878 EXPECT_EQ("a", new_nodes[0].name());
879 EXPECT_EQ("Relu", new_nodes[0].op());
880 EXPECT_EQ("b", new_nodes[1].name());
881 EXPECT_EQ("Const", new_nodes[1].op());
882 }
883
TestHashNodeDef()884 void TestHashNodeDef() {
885 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
886
887 const int width = 10;
888
889 auto a_root = tensorflow::Scope::NewRootScope();
890 Tensor a_data(DT_FLOAT, TensorShape({width}));
891 test::FillIota<float>(&a_data, 1.0f);
892 Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data));
893 GraphDef a_graph_def;
894 TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def));
895 const NodeDef& a_node_def = a_graph_def.node(0);
896
897 auto b_root = tensorflow::Scope::NewRootScope();
898 Tensor b_data(DT_FLOAT, TensorShape({width}));
899 test::FillIota<float>(&b_data, 1.0f);
900 Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data));
901 GraphDef b_graph_def;
902 TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def));
903 const NodeDef& b_node_def = b_graph_def.node(0);
904
905 auto c_root = tensorflow::Scope::NewRootScope();
906 Tensor c_data(DT_FLOAT, TensorShape({width}));
907 test::FillIota<float>(&c_data, 2.0f);
908 Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data));
909 GraphDef c_graph_def;
910 TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def));
911 const NodeDef& c_node_def = c_graph_def.node(0);
912
913 auto d_root = tensorflow::Scope::NewRootScope();
914 Tensor d_data(DT_FLOAT, TensorShape({width}));
915 test::FillIota<float>(&d_data, 1.0f);
916 Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data));
917 GraphDef d_graph_def;
918 TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def));
919 const NodeDef& d_node_def = d_graph_def.node(0);
920
921 auto e_root = tensorflow::Scope::NewRootScope();
922 Tensor e_data(DT_INT32, TensorShape({width}));
923 test::FillIota<int32>(&e_data, 1);
924 Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data));
925 GraphDef e_graph_def;
926 TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def));
927 const NodeDef& e_node_def = e_graph_def.node(0);
928
929 auto f_root = tensorflow::Scope::NewRootScope();
930 Tensor f_data(DT_FLOAT, TensorShape({width - 1}));
931 test::FillIota<float>(&f_data, 1.0f);
932 Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data));
933 GraphDef f_graph_def;
934 TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def));
935 const NodeDef& f_node_def = f_graph_def.node(0);
936
937 auto g_root = tensorflow::Scope::NewRootScope();
938 Tensor g_data(DT_FLOAT, TensorShape({width}));
939 test::FillIota<float>(&g_data, 1);
940 Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"),
941 Input::Initializer(g_data));
942 GraphDef g_graph_def;
943 TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def));
944 const NodeDef& g_node_def = g_graph_def.node(0);
945
946 NodeDef relu1_node_def;
947 relu1_node_def.set_op("Relu");
948 relu1_node_def.set_name("a");
949 relu1_node_def.add_input("foo");
950
951 NodeDef relu2_node_def;
952 relu2_node_def.set_op("Relu");
953 relu2_node_def.set_name("a");
954 relu2_node_def.add_input("bar");
955
956 EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def));
957 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def));
958 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def));
959 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def));
960 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def));
961 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def));
962 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def));
963 EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def));
964 }
965
TestCountParameters()966 void TestCountParameters() {
967 TransformFuncContext context;
968 context.params.insert({"foo", {"a", "b"}});
969 context.params.insert({"bar", {"c"}});
970 EXPECT_EQ(2, context.CountParameters("foo"));
971 EXPECT_EQ(1, context.CountParameters("bar"));
972 EXPECT_EQ(0, context.CountParameters("not_present"));
973 }
974
TestGetOneStringParameter()975 void TestGetOneStringParameter() {
976 TransformFuncContext context;
977 context.params.insert({"foo", {"a", "b"}});
978 context.params.insert({"bar", {"c"}});
979 string value;
980 TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value));
981 EXPECT_EQ("c", value);
982 EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok());
983 TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value));
984 EXPECT_EQ("d", value);
985 }
986
TestGetOneInt32Parameter()987 void TestGetOneInt32Parameter() {
988 TransformFuncContext context;
989 context.params.insert({"foo", {"10", "20"}});
990 context.params.insert({"bar", {"-23"}});
991 context.params.insert({"not_a_number", {"not_numerical"}});
992 context.params.insert({"float", {"-23.232323"}});
993 int32_t value;
994 TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value));
995 EXPECT_EQ(-23, value);
996 EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok());
997 TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value));
998 EXPECT_EQ(10, value);
999 EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok());
1000 EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok());
1001 }
1002
TestGetOneInt64Parameter()1003 void TestGetOneInt64Parameter() {
1004 TransformFuncContext context;
1005 context.params.insert({"foo", {"10", "20"}});
1006 context.params.insert({"bar", {"-23"}});
1007 context.params.insert({"not_a_number", {"not_numerical"}});
1008 context.params.insert({"float", {"-23.232323"}});
1009 int64_t value;
1010 TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value));
1011 EXPECT_EQ(-23, value);
1012 EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok());
1013 TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value));
1014 EXPECT_EQ(10, value);
1015 EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok());
1016 EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok());
1017 }
1018
TestGetOneFloatParameter()1019 void TestGetOneFloatParameter() {
1020 TransformFuncContext context;
1021 context.params.insert({"foo", {"10.0", "20.0"}});
1022 context.params.insert({"bar", {"-23.2323"}});
1023 context.params.insert({"not_a_number", {"not_numerical"}});
1024 float value;
1025 TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value));
1026 EXPECT_NEAR(-23.2323f, value, 1e-5f);
1027 EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok());
1028 TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value));
1029 EXPECT_NEAR(10.5f, value, 1e-5f);
1030 EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok());
1031 }
1032
TestGetOneBoolParameter()1033 void TestGetOneBoolParameter() {
1034 TransformFuncContext context;
1035 context.params.insert({"foo", {"true", "false"}});
1036 context.params.insert({"true", {"true"}});
1037 context.params.insert({"false", {"false"}});
1038 context.params.insert({"one", {"1"}});
1039 context.params.insert({"zero", {"0"}});
1040 context.params.insert({"not_a_bool", {"not_boolean"}});
1041
1042 bool value;
1043 EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok());
1044
1045 value = false;
1046 TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value));
1047 EXPECT_TRUE(value);
1048
1049 value = true;
1050 TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value));
1051 EXPECT_FALSE(value);
1052
1053 value = false;
1054 TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value));
1055 EXPECT_TRUE(value);
1056
1057 value = true;
1058 TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value));
1059 EXPECT_FALSE(value);
1060
1061 EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok());
1062
1063 value = false;
1064 TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
1065 EXPECT_TRUE(value);
1066 }
1067 };
1068
TEST_F(TransformUtilsTest,TestMapNamesToNodes)1069 TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
1070
TEST_F(TransformUtilsTest,TestMapNodesToOutputs)1071 TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); }
1072
TEST_F(TransformUtilsTest,TestNodeNamePartsFromInput)1073 TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) {
1074 TestNodeNamePartsFromInput();
1075 }
1076
TEST_F(TransformUtilsTest,TestCanonicalInputName)1077 TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); }
1078
TEST_F(TransformUtilsTest,TestAddNodeInput)1079 TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); }
1080
TEST_F(TransformUtilsTest,TestCopyNodeAttr)1081 TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); }
1082
TEST_F(TransformUtilsTest,TestSetNodeAttr)1083 TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); }
1084
TEST_F(TransformUtilsTest,TestSetNodeTensorAttr)1085 TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); }
1086
TEST_F(TransformUtilsTest,TestSetNodeTensorAttrWithTensor)1087 TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) {
1088 TestSetNodeTensorAttrWithTensor();
1089 }
1090
TEST_F(TransformUtilsTest,TestGetNodeTensorAttr)1091 TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); }
1092
TEST_F(TransformUtilsTest,TestNodeNameFromInput)1093 TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
1094
TEST_F(TransformUtilsTest,TestFilterGraphDef)1095 TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
1096
TEST_F(TransformUtilsTest,TestRemoveAttributes)1097 TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
1098
TEST_F(TransformUtilsTest,TestGetOpTypeMatches)1099 TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); }
1100
TEST_F(TransformUtilsTest,TestGetOpTypeMatchesDAG)1101 TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) {
1102 TestGetOpTypeMatchesDAG();
1103 }
1104
TEST_F(TransformUtilsTest,TestReplaceMatchingOpTypes)1105 TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) {
1106 TestReplaceMatchingOpTypes();
1107 }
1108
TEST_F(TransformUtilsTest,TestMatchedNodesAsArray)1109 TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) {
1110 TestMatchedNodesAsArray();
1111 }
1112
TEST_F(TransformUtilsTest,TestRenameNodeInputs)1113 TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); }
1114
TEST_F(TransformUtilsTest,TestRenameNodeInputsWithRedirects)1115 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) {
1116 TestRenameNodeInputsWithRedirects();
1117 }
1118
TEST_F(TransformUtilsTest,TestRenameNodeInputsWithCycle)1119 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) {
1120 TestRenameNodeInputsWithCycle();
1121 }
1122
TEST_F(TransformUtilsTest,TestRenameNodeInputsWithWildcard)1123 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
1124 TestRenameNodeInputsWithWildcard();
1125 }
1126
TEST_F(TransformUtilsTest,TestRenameNodeInputsWithIgnores)1127 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) {
1128 TestRenameNodeInputsWithIgnores();
1129 }
1130
TEST_F(TransformUtilsTest,TestFindInvalidInputs)1131 TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
1132
TEST_F(TransformUtilsTest,TestIsGraphValid)1133 TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
1134
TEST_F(TransformUtilsTest,TestGetInOutTypes)1135 TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); }
1136
TEST_F(TransformUtilsTest,TestCopyOriginalMatch)1137 TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
1138
TEST_F(TransformUtilsTest,TestHashNodeDef)1139 TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
1140
TEST_F(TransformUtilsTest,TestCountParameters)1141 TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
1142
TEST_F(TransformUtilsTest,TestGetOneStringParameter)1143 TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
1144 TestGetOneStringParameter();
1145 }
1146
TEST_F(TransformUtilsTest,TestGetOneInt32Parameter)1147 TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) {
1148 TestGetOneInt32Parameter();
1149 }
1150
TEST_F(TransformUtilsTest,TestGetOneInt64Parameter)1151 TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) {
1152 TestGetOneInt64Parameter();
1153 }
1154
TEST_F(TransformUtilsTest,TestGetOneFloatParameter)1155 TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
1156 TestGetOneFloatParameter();
1157 }
1158
TEST_F(TransformUtilsTest,TestGetOneBoolParameter)1159 TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
1160 TestGetOneBoolParameter();
1161 }
1162
1163 } // namespace graph_transforms
1164 } // namespace tensorflow
1165