1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28
29 namespace tensorflow {
30 namespace graph_transforms {
31
32 // Declare here, so we don't need a public header.
33 Status BackportConcatV2Transform(const GraphDef& input_graph_def,
34 const TransformFuncContext& context,
35 GraphDef* output_graph_def);
36 Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
37 const TransformFuncContext& context,
38 GraphDef* output_graph_def);
39
40 class BackportConcatV2Test : public ::testing::Test {
41 protected:
TestBackportConcatV2()42 void TestBackportConcatV2() {
43 GraphDef graph_def;
44
45 NodeDef* mul_node1 = graph_def.add_node();
46 mul_node1->set_name("mul_node1");
47 mul_node1->set_op("Mul");
48 mul_node1->add_input("add_node2");
49 mul_node1->add_input("add_node3");
50
51 NodeDef* add_node2 = graph_def.add_node();
52 add_node2->set_name("add_node2");
53 add_node2->set_op("Add");
54 add_node2->add_input("const_node1");
55 add_node2->add_input("const_node2");
56
57 NodeDef* add_node3 = graph_def.add_node();
58 add_node3->set_name("add_node3");
59 add_node3->set_op("Add");
60 add_node3->add_input("const_node1");
61 add_node3->add_input("const_node3");
62
63 NodeDef* const_node1 = graph_def.add_node();
64 const_node1->set_name("const_node1");
65 const_node1->set_op("Const");
66
67 NodeDef* const_node2 = graph_def.add_node();
68 const_node2->set_name("const_node2");
69 const_node2->set_op("Const");
70
71 NodeDef* const_node3 = graph_def.add_node();
72 const_node3->set_name("const_node3");
73 const_node3->set_op("Const");
74
75 NodeDef* concat_node = graph_def.add_node();
76 concat_node->set_name("concat_node");
77 concat_node->set_op("ConcatV2");
78 concat_node->add_input("const_node1");
79 concat_node->add_input("const_node2");
80 concat_node->add_input("const_node3");
81 SetNodeAttr("Tidx", DT_INT32, concat_node);
82
83 GraphDef result;
84 TransformFuncContext context;
85 context.input_names = {};
86 context.output_names = {"concat_node"};
87 TF_ASSERT_OK(BackportConcatV2Transform(graph_def, context, &result));
88
89 std::map<string, const NodeDef*> node_lookup;
90 MapNamesToNodes(result, &node_lookup);
91 EXPECT_EQ(1, node_lookup.count("concat_node"));
92 EXPECT_EQ("Concat", node_lookup.at("concat_node")->op());
93 EXPECT_EQ(0, node_lookup.at("concat_node")->attr().count("Tidx"));
94 EXPECT_EQ("const_node3", node_lookup.at("concat_node")->input(0));
95 EXPECT_EQ("const_node1", node_lookup.at("concat_node")->input(1));
96 EXPECT_EQ("const_node2", node_lookup.at("concat_node")->input(2));
97 EXPECT_EQ(1, node_lookup.count("const_node1"));
98 EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
99 EXPECT_EQ(1, node_lookup.count("const_node2"));
100 EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
101 EXPECT_EQ(1, node_lookup.count("const_node3"));
102 EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
103 }
104 };
105
TEST_F(BackportConcatV2Test,TestBackportConcatV2)106 TEST_F(BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2(); }
107
TEST(BackportTensorArrayV3Test,TestBackportTensorArrayV3)108 TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3) {
109 GraphDef graph_def;
110
111 NodeDef* size_node = graph_def.add_node();
112 size_node->set_name("size_node");
113 size_node->set_op("Const");
114 Tensor size_tensor(DT_INT32, {});
115 size_tensor.flat<int32>()(0) = 1;
116 SetNodeTensorAttr<float>("value", size_tensor, size_node);
117
118 NodeDef* tensor_array_node = graph_def.add_node();
119 tensor_array_node->set_name("tensor_array_node");
120 tensor_array_node->set_op("TensorArrayV3");
121 tensor_array_node->add_input("size_node");
122 SetNodeAttr("dtype", DT_FLOAT, tensor_array_node);
123 SetNodeAttr("element_shape", TensorShape({1, 2}), tensor_array_node);
124 SetNodeAttr("dynamic_size", false, tensor_array_node);
125 SetNodeAttr("clear_after_read", true, tensor_array_node);
126 SetNodeAttr("tensor_array_name", "some_name", tensor_array_node);
127
128 NodeDef* handle_output_node = graph_def.add_node();
129 handle_output_node->set_name("handle_output_node");
130 handle_output_node->set_op("Identity");
131 handle_output_node->add_input("tensor_array_node:0");
132
133 NodeDef* flow_output_node = graph_def.add_node();
134 flow_output_node->set_name("flow_output_node");
135 flow_output_node->set_op("Identity");
136 flow_output_node->add_input("tensor_array_node:1");
137
138 NodeDef* tensor_array_grad_node = graph_def.add_node();
139 tensor_array_grad_node->set_name("tensor_array_grad_node");
140 tensor_array_grad_node->set_op("TensorArrayGradV3");
141 tensor_array_grad_node->add_input("tensor_array_node:0");
142 tensor_array_grad_node->add_input("tensor_array_node:1");
143 SetNodeAttr("source", "foo", tensor_array_grad_node);
144
145 NodeDef* grad_handle_output_node = graph_def.add_node();
146 grad_handle_output_node->set_name("grad_handle_output_node");
147 grad_handle_output_node->set_op("Identity");
148 grad_handle_output_node->add_input("tensor_array_grad_node:0");
149
150 NodeDef* grad_flow_output_node = graph_def.add_node();
151 grad_flow_output_node->set_name("grad_flow_output_node");
152 grad_flow_output_node->set_op("Identity");
153 grad_flow_output_node->add_input("tensor_array_grad_node:1");
154
155 GraphDef result;
156 TransformFuncContext context;
157 context.input_names = {};
158 context.output_names = {"handle_output_node", "grad_handle_output_node"};
159 TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
160
161 std::map<string, const NodeDef*> node_lookup;
162 MapNamesToNodes(result, &node_lookup);
163 ASSERT_EQ(1, node_lookup.count("tensor_array_node"));
164 EXPECT_EQ("TensorArrayV2", node_lookup.at("tensor_array_node")->op());
165 EXPECT_EQ("TensorArrayGradV2",
166 node_lookup.at("tensor_array_grad_node")->op());
167
168 for (const NodeDef& node : result.node()) {
169 for (const string& input : node.input()) {
170 EXPECT_NE("tensor_array_node:1", input);
171 }
172 }
173 }
174
TEST(BackportTensorArrayV3Test,TestBackportTensorArrayV3Subtypes)175 TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) {
176 const std::vector<string> v3_ops = {
177 "TensorArrayWriteV3", "TensorArrayReadV3", "TensorArrayGatherV3",
178 "TensorArrayScatterV3", "TensorArrayConcatV3", "TensorArraySplitV3",
179 "TensorArraySizeV3", "TensorArrayCloseV3"};
180 for (const string& v3_op : v3_ops) {
181 GraphDef graph_def;
182 NodeDef* v3_node = graph_def.add_node();
183 v3_node->set_name("v3_node");
184 v3_node->set_op(v3_op);
185
186 GraphDef result;
187 TransformFuncContext context;
188 context.input_names = {};
189 context.output_names = {""};
190 TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
191
192 std::map<string, const NodeDef*> node_lookup;
193 MapNamesToNodes(result, &node_lookup);
194 ASSERT_EQ(1, node_lookup.count("v3_node"));
195 EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
196 }
197 }
198
199 } // namespace graph_transforms
200 } // namespace tensorflow
201