xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_op_gen_main.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_op_gen.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/op_gen_lib.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/inputbuffer.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/strings/scanner.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/init_main.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace tensorflow {
36 namespace {
37 
ReadOpListFromFile(const string & filename,std::vector<string> * op_list)38 Status ReadOpListFromFile(const string& filename,
39                           std::vector<string>* op_list) {
40   std::unique_ptr<RandomAccessFile> file;
41   TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
42   std::unique_ptr<io::InputBuffer> input_buffer(
43       new io::InputBuffer(file.get(), 256 << 10));
44   string line_contents;
45   Status s = input_buffer->ReadLine(&line_contents);
46   while (s.ok()) {
47     // The parser assumes that the op name is the first string on each
48     // line with no preceding whitespace, and ignores lines that do
49     // not start with an op name as a comment.
50     strings::Scanner scanner{StringPiece(line_contents)};
51     StringPiece op_name;
52     if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
53             .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
54             .GetResult(nullptr, &op_name)) {
55       op_list->emplace_back(op_name);
56     }
57     s = input_buffer->ReadLine(&line_contents);
58   }
59   if (!errors::IsOutOfRange(s)) return s;
60   return OkStatus();
61 }
62 
63 // The argument parsing is deliberately simplistic to support our only
64 // known use cases:
65 //
66 // 1. Read all op names from a file.
67 // 2. Read all op names from the arg as a comma-delimited list.
68 //
69 // Expected command-line argument syntax:
70 // ARG ::= '@' FILENAME
71 //       |  OP_NAME [',' OP_NAME]*
72 //       |  ''
ParseOpListCommandLine(const char * arg,std::vector<string> * op_list)73 Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
74   std::vector<string> op_names = str_util::Split(arg, ',');
75   if (op_names.size() == 1 && op_names[0].empty()) {
76     return OkStatus();
77   } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") {
78     const string filename = op_names[0].substr(1);
79     return tensorflow::ReadOpListFromFile(filename, op_list);
80   } else {
81     *op_list = std::move(op_names);
82   }
83   return OkStatus();
84 }
85 
86 // Use the name of the current executable to infer the C++ source file
87 // where the REGISTER_OP() call for the operator can be found.
88 // Returns the name of the file.
89 // Returns an empty string if the current executable's name does not
90 // follow a known pattern.
InferSourceFileName(const char * argv_zero)91 string InferSourceFileName(const char* argv_zero) {
92   StringPiece command_str = io::Basename(argv_zero);
93 
94   // For built-in ops, the Bazel build creates a separate executable
95   // with the name gen_<op type>_ops_py_wrappers_cc containing the
96   // operators defined in <op type>_ops.cc
97   const char* kExecPrefix = "gen_";
98   const char* kExecSuffix = "_py_wrappers_cc";
99   if (absl::ConsumePrefix(&command_str, kExecPrefix) &&
100       str_util::EndsWith(command_str, kExecSuffix)) {
101     command_str.remove_suffix(strlen(kExecSuffix));
102     return strings::StrCat(command_str, ".cc");
103   } else {
104     return string("");
105   }
106 }
107 
PrintAllPythonOps(const std::vector<string> & op_list,const std::vector<string> & api_def_dirs,const string & source_file_name,bool op_list_is_allowlist,const std::unordered_set<string> type_annotate_ops)108 void PrintAllPythonOps(const std::vector<string>& op_list,
109                        const std::vector<string>& api_def_dirs,
110                        const string& source_file_name,
111                        bool op_list_is_allowlist,
112                        const std::unordered_set<string> type_annotate_ops) {
113   OpList ops;
114   OpRegistry::Global()->Export(false, &ops);
115 
116   ApiDefMap api_def_map(ops);
117   if (!api_def_dirs.empty()) {
118     Env* env = Env::Default();
119 
120     for (const auto& api_def_dir : api_def_dirs) {
121       std::vector<string> api_files;
122       TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
123                                         &api_files));
124       TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
125     }
126     api_def_map.UpdateDocs();
127   }
128 
129   if (op_list_is_allowlist) {
130     std::unordered_set<string> allowlist(op_list.begin(), op_list.end());
131     OpList pruned_ops;
132     for (const auto& op_def : ops.op()) {
133       if (allowlist.find(op_def.name()) != allowlist.end()) {
134         *pruned_ops.mutable_op()->Add() = op_def;
135       }
136     }
137     PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name,
138                    type_annotate_ops);
139   } else {
140     PrintPythonOps(ops, api_def_map, op_list, source_file_name,
141                    type_annotate_ops);
142   }
143 }
144 
145 }  // namespace
146 }  // namespace tensorflow
147 
main(int argc,char * argv[])148 int main(int argc, char* argv[]) {
149   tensorflow::port::InitMain(argv[0], &argc, &argv);
150 
151   tensorflow::string source_file_name =
152       tensorflow::InferSourceFileName(argv[0]);
153 
154   // Usage:
155   //   gen_main api_def_dir1,api_def_dir2,...
156   //       [ @FILENAME | OpName[,OpName]* ] [0 | 1]
157   if (argc < 2) {
158     return -1;
159   }
160   std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
161       argv[1], ",", tensorflow::str_util::SkipEmpty());
162 
163   // Add op name here to generate type annotations for it
164   const std::unordered_set<tensorflow::string> type_annotate_ops{};
165 
166   if (argc == 2) {
167     tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
168                                   false /* op_list_is_allowlist */,
169                                   type_annotate_ops);
170   } else if (argc == 3) {
171     std::vector<tensorflow::string> hidden_ops;
172     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
173     tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
174                                   false /* op_list_is_allowlist */,
175                                   type_annotate_ops);
176   } else if (argc == 4) {
177     std::vector<tensorflow::string> op_list;
178     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
179     tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
180                                   tensorflow::string(argv[3]) == "1",
181                                   type_annotate_ops);
182   } else {
183     return -1;
184   }
185   return 0;
186 }
187