xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/backports_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28 
29 namespace tensorflow {
30 namespace graph_transforms {
31 
32 // Declare here, so we don't need a public header.
33 Status BackportConcatV2Transform(const GraphDef& input_graph_def,
34                                  const TransformFuncContext& context,
35                                  GraphDef* output_graph_def);
36 Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
37                                       const TransformFuncContext& context,
38                                       GraphDef* output_graph_def);
39 
40 class BackportConcatV2Test : public ::testing::Test {
41  protected:
TestBackportConcatV2()42   void TestBackportConcatV2() {
43     GraphDef graph_def;
44 
45     NodeDef* mul_node1 = graph_def.add_node();
46     mul_node1->set_name("mul_node1");
47     mul_node1->set_op("Mul");
48     mul_node1->add_input("add_node2");
49     mul_node1->add_input("add_node3");
50 
51     NodeDef* add_node2 = graph_def.add_node();
52     add_node2->set_name("add_node2");
53     add_node2->set_op("Add");
54     add_node2->add_input("const_node1");
55     add_node2->add_input("const_node2");
56 
57     NodeDef* add_node3 = graph_def.add_node();
58     add_node3->set_name("add_node3");
59     add_node3->set_op("Add");
60     add_node3->add_input("const_node1");
61     add_node3->add_input("const_node3");
62 
63     NodeDef* const_node1 = graph_def.add_node();
64     const_node1->set_name("const_node1");
65     const_node1->set_op("Const");
66 
67     NodeDef* const_node2 = graph_def.add_node();
68     const_node2->set_name("const_node2");
69     const_node2->set_op("Const");
70 
71     NodeDef* const_node3 = graph_def.add_node();
72     const_node3->set_name("const_node3");
73     const_node3->set_op("Const");
74 
75     NodeDef* concat_node = graph_def.add_node();
76     concat_node->set_name("concat_node");
77     concat_node->set_op("ConcatV2");
78     concat_node->add_input("const_node1");
79     concat_node->add_input("const_node2");
80     concat_node->add_input("const_node3");
81     SetNodeAttr("Tidx", DT_INT32, concat_node);
82 
83     GraphDef result;
84     TransformFuncContext context;
85     context.input_names = {};
86     context.output_names = {"concat_node"};
87     TF_ASSERT_OK(BackportConcatV2Transform(graph_def, context, &result));
88 
89     std::map<string, const NodeDef*> node_lookup;
90     MapNamesToNodes(result, &node_lookup);
91     EXPECT_EQ(1, node_lookup.count("concat_node"));
92     EXPECT_EQ("Concat", node_lookup.at("concat_node")->op());
93     EXPECT_EQ(0, node_lookup.at("concat_node")->attr().count("Tidx"));
94     EXPECT_EQ("const_node3", node_lookup.at("concat_node")->input(0));
95     EXPECT_EQ("const_node1", node_lookup.at("concat_node")->input(1));
96     EXPECT_EQ("const_node2", node_lookup.at("concat_node")->input(2));
97     EXPECT_EQ(1, node_lookup.count("const_node1"));
98     EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
99     EXPECT_EQ(1, node_lookup.count("const_node2"));
100     EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
101     EXPECT_EQ(1, node_lookup.count("const_node3"));
102     EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
103   }
104 };
105 
TEST_F(BackportConcatV2Test,TestBackportConcatV2)106 TEST_F(BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2(); }
107 
TEST(BackportTensorArrayV3Test,TestBackportTensorArrayV3)108 TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3) {
109   GraphDef graph_def;
110 
111   NodeDef* size_node = graph_def.add_node();
112   size_node->set_name("size_node");
113   size_node->set_op("Const");
114   Tensor size_tensor(DT_INT32, {});
115   size_tensor.flat<int32>()(0) = 1;
116   SetNodeTensorAttr<float>("value", size_tensor, size_node);
117 
118   NodeDef* tensor_array_node = graph_def.add_node();
119   tensor_array_node->set_name("tensor_array_node");
120   tensor_array_node->set_op("TensorArrayV3");
121   tensor_array_node->add_input("size_node");
122   SetNodeAttr("dtype", DT_FLOAT, tensor_array_node);
123   SetNodeAttr("element_shape", TensorShape({1, 2}), tensor_array_node);
124   SetNodeAttr("dynamic_size", false, tensor_array_node);
125   SetNodeAttr("clear_after_read", true, tensor_array_node);
126   SetNodeAttr("tensor_array_name", "some_name", tensor_array_node);
127 
128   NodeDef* handle_output_node = graph_def.add_node();
129   handle_output_node->set_name("handle_output_node");
130   handle_output_node->set_op("Identity");
131   handle_output_node->add_input("tensor_array_node:0");
132 
133   NodeDef* flow_output_node = graph_def.add_node();
134   flow_output_node->set_name("flow_output_node");
135   flow_output_node->set_op("Identity");
136   flow_output_node->add_input("tensor_array_node:1");
137 
138   NodeDef* tensor_array_grad_node = graph_def.add_node();
139   tensor_array_grad_node->set_name("tensor_array_grad_node");
140   tensor_array_grad_node->set_op("TensorArrayGradV3");
141   tensor_array_grad_node->add_input("tensor_array_node:0");
142   tensor_array_grad_node->add_input("tensor_array_node:1");
143   SetNodeAttr("source", "foo", tensor_array_grad_node);
144 
145   NodeDef* grad_handle_output_node = graph_def.add_node();
146   grad_handle_output_node->set_name("grad_handle_output_node");
147   grad_handle_output_node->set_op("Identity");
148   grad_handle_output_node->add_input("tensor_array_grad_node:0");
149 
150   NodeDef* grad_flow_output_node = graph_def.add_node();
151   grad_flow_output_node->set_name("grad_flow_output_node");
152   grad_flow_output_node->set_op("Identity");
153   grad_flow_output_node->add_input("tensor_array_grad_node:1");
154 
155   GraphDef result;
156   TransformFuncContext context;
157   context.input_names = {};
158   context.output_names = {"handle_output_node", "grad_handle_output_node"};
159   TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
160 
161   std::map<string, const NodeDef*> node_lookup;
162   MapNamesToNodes(result, &node_lookup);
163   ASSERT_EQ(1, node_lookup.count("tensor_array_node"));
164   EXPECT_EQ("TensorArrayV2", node_lookup.at("tensor_array_node")->op());
165   EXPECT_EQ("TensorArrayGradV2",
166             node_lookup.at("tensor_array_grad_node")->op());
167 
168   for (const NodeDef& node : result.node()) {
169     for (const string& input : node.input()) {
170       EXPECT_NE("tensor_array_node:1", input);
171     }
172   }
173 }
174 
TEST(BackportTensorArrayV3Test,TestBackportTensorArrayV3Subtypes)175 TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) {
176   const std::vector<string> v3_ops = {
177       "TensorArrayWriteV3",   "TensorArrayReadV3",   "TensorArrayGatherV3",
178       "TensorArrayScatterV3", "TensorArrayConcatV3", "TensorArraySplitV3",
179       "TensorArraySizeV3",    "TensorArrayCloseV3"};
180   for (const string& v3_op : v3_ops) {
181     GraphDef graph_def;
182     NodeDef* v3_node = graph_def.add_node();
183     v3_node->set_name("v3_node");
184     v3_node->set_op(v3_op);
185 
186     GraphDef result;
187     TransformFuncContext context;
188     context.input_names = {};
189     context.output_names = {""};
190     TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
191 
192     std::map<string, const NodeDef*> node_lookup;
193     MapNamesToNodes(result, &node_lookup);
194     ASSERT_EQ(1, node_lookup.count("v3_node"));
195     EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
196   }
197 }
198 
199 }  // namespace graph_transforms
200 }  // namespace tensorflow
201