xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/framework/cc_op_gen_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/framework/cc_op_gen.h"
17 
18 #include "tensorflow/core/framework/op_def.pb.h"
19 #include "tensorflow/core/framework/op_gen_lib.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace tensorflow {
26 namespace {
27 
28 constexpr char kBaseOpDef[] = R"(
29 op {
30   name: "Foo"
31   input_arg {
32     name: "images"
33     description: "Images to process."
34   }
35   input_arg {
36     name: "dim"
37     description: "Description for dim."
38     type: DT_FLOAT
39   }
40   output_arg {
41     name: "output"
42     description: "Description for output."
43     type: DT_FLOAT
44   }
45   attr {
46     name: "T"
47     type: "type"
48     description: "Type for images"
49     allowed_values {
50       list {
51         type: DT_UINT8
52         type: DT_INT8
53       }
54     }
55     default_value {
56       i: 1
57     }
58   }
59   summary: "Summary for op Foo."
60   description: "Description for op Foo."
61 }
62 )";
63 
ExpectHasSubstr(StringPiece s,StringPiece expected)64 void ExpectHasSubstr(StringPiece s, StringPiece expected) {
65   EXPECT_TRUE(absl::StrContains(s, expected))
66       << "'" << s << "' does not contain '" << expected << "'";
67 }
68 
ExpectDoesNotHaveSubstr(StringPiece s,StringPiece expected)69 void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
70   EXPECT_FALSE(absl::StrContains(s, expected))
71       << "'" << s << "' contains '" << expected << "'";
72 }
73 
ExpectSubstrOrder(const string & s,const string & before,const string & after)74 void ExpectSubstrOrder(const string& s, const string& before,
75                        const string& after) {
76   int before_pos = s.find(before);
77   int after_pos = s.find(after);
78   ASSERT_NE(std::string::npos, before_pos);
79   ASSERT_NE(std::string::npos, after_pos);
80   EXPECT_LT(before_pos, after_pos)
81       << before << " is not before " << after << " in " << s;
82 }
83 
84 // Runs WriteCCOps and stores output in (internal_)cc_file_path and
85 // (internal_)h_file_path.
GenerateCcOpFiles(Env * env,const OpList & ops,const ApiDefMap & api_def_map,string * h_file_text,string * internal_h_file_text)86 void GenerateCcOpFiles(Env* env, const OpList& ops,
87                        const ApiDefMap& api_def_map, string* h_file_text,
88                        string* internal_h_file_text) {
89   const string& tmpdir = testing::TmpDir();
90 
91   const auto h_file_path = io::JoinPath(tmpdir, "test.h");
92   const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
93   const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
94   const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
95 
96   WriteCCOps(ops, api_def_map, h_file_path, cc_file_path);
97 
98   TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
99   TF_ASSERT_OK(
100       ReadFileToString(env, internal_h_file_path, internal_h_file_text));
101 }
102 
TEST(CcOpGenTest,TestVisibilityChangedToHidden)103 TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
104   const string api_def = R"(
105 op {
106   graph_op_name: "Foo"
107   visibility: HIDDEN
108 }
109 )";
110   Env* env = Env::Default();
111   OpList op_defs;
112   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
113   ApiDefMap api_def_map(op_defs);
114 
115   string h_file_text, internal_h_file_text;
116   // Without ApiDef
117   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
118                     &internal_h_file_text);
119   ExpectHasSubstr(h_file_text, "class Foo");
120   ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
121 
122   // With ApiDef
123   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
124   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
125                     &internal_h_file_text);
126   ExpectHasSubstr(internal_h_file_text, "class Foo");
127   ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
128 }
129 
TEST(CcOpGenTest,TestArgNameChanges)130 TEST(CcOpGenTest, TestArgNameChanges) {
131   const string api_def = R"(
132 op {
133   graph_op_name: "Foo"
134   arg_order: "dim"
135   arg_order: "images"
136 }
137 )";
138   Env* env = Env::Default();
139   OpList op_defs;
140   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
141 
142   ApiDefMap api_def_map(op_defs);
143   string cc_file_text, h_file_text;
144   string internal_cc_file_text, internal_h_file_text;
145   // Without ApiDef
146   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
147                     &internal_h_file_text);
148   ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
149 
150   // With ApiDef
151   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
152   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
153                     &internal_h_file_text);
154   ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
155 }
156 
TEST(CcOpGenTest,TestEndpoints)157 TEST(CcOpGenTest, TestEndpoints) {
158   const string api_def = R"(
159 op {
160   graph_op_name: "Foo"
161   endpoint {
162     name: "Foo1"
163   }
164   endpoint {
165     name: "Foo2"
166   }
167 }
168 )";
169   Env* env = Env::Default();
170   OpList op_defs;
171   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
172 
173   ApiDefMap api_def_map(op_defs);
174   string cc_file_text, h_file_text;
175   string internal_cc_file_text, internal_h_file_text;
176   // Without ApiDef
177   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
178                     &internal_h_file_text);
179   ExpectHasSubstr(h_file_text, "class Foo {");
180   ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
181   ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
182 
183   // With ApiDef
184   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
185   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
186                     &internal_h_file_text);
187   ExpectHasSubstr(h_file_text, "class Foo1");
188   ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
189   ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
190 }
191 }  // namespace
192 }  // namespace tensorflow
193