1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/framework/cc_op_gen.h"
17
18 #include "tensorflow/core/framework/op_def.pb.h"
19 #include "tensorflow/core/framework/op_gen_lib.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace tensorflow {
26 namespace {
27
28 constexpr char kBaseOpDef[] = R"(
29 op {
30 name: "Foo"
31 input_arg {
32 name: "images"
33 description: "Images to process."
34 }
35 input_arg {
36 name: "dim"
37 description: "Description for dim."
38 type: DT_FLOAT
39 }
40 output_arg {
41 name: "output"
42 description: "Description for output."
43 type: DT_FLOAT
44 }
45 attr {
46 name: "T"
47 type: "type"
48 description: "Type for images"
49 allowed_values {
50 list {
51 type: DT_UINT8
52 type: DT_INT8
53 }
54 }
55 default_value {
56 i: 1
57 }
58 }
59 summary: "Summary for op Foo."
60 description: "Description for op Foo."
61 }
62 )";
63
ExpectHasSubstr(StringPiece s,StringPiece expected)64 void ExpectHasSubstr(StringPiece s, StringPiece expected) {
65 EXPECT_TRUE(absl::StrContains(s, expected))
66 << "'" << s << "' does not contain '" << expected << "'";
67 }
68
ExpectDoesNotHaveSubstr(StringPiece s,StringPiece expected)69 void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
70 EXPECT_FALSE(absl::StrContains(s, expected))
71 << "'" << s << "' contains '" << expected << "'";
72 }
73
ExpectSubstrOrder(const string & s,const string & before,const string & after)74 void ExpectSubstrOrder(const string& s, const string& before,
75 const string& after) {
76 int before_pos = s.find(before);
77 int after_pos = s.find(after);
78 ASSERT_NE(std::string::npos, before_pos);
79 ASSERT_NE(std::string::npos, after_pos);
80 EXPECT_LT(before_pos, after_pos)
81 << before << " is not before " << after << " in " << s;
82 }
83
84 // Runs WriteCCOps and stores output in (internal_)cc_file_path and
85 // (internal_)h_file_path.
GenerateCcOpFiles(Env * env,const OpList & ops,const ApiDefMap & api_def_map,string * h_file_text,string * internal_h_file_text)86 void GenerateCcOpFiles(Env* env, const OpList& ops,
87 const ApiDefMap& api_def_map, string* h_file_text,
88 string* internal_h_file_text) {
89 const string& tmpdir = testing::TmpDir();
90
91 const auto h_file_path = io::JoinPath(tmpdir, "test.h");
92 const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
93 const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
94 const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
95
96 WriteCCOps(ops, api_def_map, h_file_path, cc_file_path);
97
98 TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
99 TF_ASSERT_OK(
100 ReadFileToString(env, internal_h_file_path, internal_h_file_text));
101 }
102
TEST(CcOpGenTest,TestVisibilityChangedToHidden)103 TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
104 const string api_def = R"(
105 op {
106 graph_op_name: "Foo"
107 visibility: HIDDEN
108 }
109 )";
110 Env* env = Env::Default();
111 OpList op_defs;
112 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
113 ApiDefMap api_def_map(op_defs);
114
115 string h_file_text, internal_h_file_text;
116 // Without ApiDef
117 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
118 &internal_h_file_text);
119 ExpectHasSubstr(h_file_text, "class Foo");
120 ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
121
122 // With ApiDef
123 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
124 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
125 &internal_h_file_text);
126 ExpectHasSubstr(internal_h_file_text, "class Foo");
127 ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
128 }
129
TEST(CcOpGenTest,TestArgNameChanges)130 TEST(CcOpGenTest, TestArgNameChanges) {
131 const string api_def = R"(
132 op {
133 graph_op_name: "Foo"
134 arg_order: "dim"
135 arg_order: "images"
136 }
137 )";
138 Env* env = Env::Default();
139 OpList op_defs;
140 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
141
142 ApiDefMap api_def_map(op_defs);
143 string cc_file_text, h_file_text;
144 string internal_cc_file_text, internal_h_file_text;
145 // Without ApiDef
146 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
147 &internal_h_file_text);
148 ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
149
150 // With ApiDef
151 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
152 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
153 &internal_h_file_text);
154 ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
155 }
156
TEST(CcOpGenTest,TestEndpoints)157 TEST(CcOpGenTest, TestEndpoints) {
158 const string api_def = R"(
159 op {
160 graph_op_name: "Foo"
161 endpoint {
162 name: "Foo1"
163 }
164 endpoint {
165 name: "Foo2"
166 }
167 }
168 )";
169 Env* env = Env::Default();
170 OpList op_defs;
171 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
172
173 ApiDefMap api_def_map(op_defs);
174 string cc_file_text, h_file_text;
175 string internal_cc_file_text, internal_h_file_text;
176 // Without ApiDef
177 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
178 &internal_h_file_text);
179 ExpectHasSubstr(h_file_text, "class Foo {");
180 ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
181 ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
182
183 // With ApiDef
184 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
185 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
186 &internal_h_file_text);
187 ExpectHasSubstr(h_file_text, "class Foo1");
188 ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
189 ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
190 }
191 } // namespace
192 } // namespace tensorflow
193