xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/transform_graph.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/transform_graph.h"
17 
18 #include "tensorflow/core/framework/function.pb.h"
19 #include "tensorflow/core/lib/strings/scanner.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/init_main.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/util/command_line_flags.h"
25 #include "tensorflow/tools/graph_transforms/file_utils.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 #if !defined(PLATFORM_WINDOWS)
28 #include <pwd.h>
29 #include <unistd.h>
30 #endif
31 
32 namespace tensorflow {
33 namespace graph_transforms {
34 
35 using tensorflow::strings::Scanner;
36 
ParseTransformParameters(const string & transforms_string,TransformParameters * params_list)37 Status ParseTransformParameters(const string& transforms_string,
38                                 TransformParameters* params_list) {
39   params_list->clear();
40   enum {
41     TRANSFORM_NAME,
42     TRANSFORM_PARAM_NAME,
43     TRANSFORM_PARAM_VALUE,
44   } state = TRANSFORM_NAME;
45   StringPiece remaining(transforms_string);
46   StringPiece match;
47   StringPiece transform_name;
48   StringPiece parameter_name;
49   StringPiece parameter_value;
50   TransformFuncParameters func_parameters;
51   while (!remaining.empty()) {
52     if (state == TRANSFORM_NAME) {
53       // Reset the list of parameters.
54       func_parameters.clear();
55       // Eat up any leading spaces.
56       Scanner(remaining).AnySpace().GetResult(&remaining, &match);
57       if (remaining.empty()) {
58         // Nothing remains after consuming trailing spaces.
59         // Consumed all transform parameter string without errors.
60         return OkStatus();
61       }
62       // See if we have a valid transform name.
63       const bool found_transform_name =
64           Scanner(remaining)
65               .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
66               .GetResult(&remaining, &transform_name);
67       if (!found_transform_name) {
68         return errors::InvalidArgument("Looking for transform name, but found ",
69                                        string(remaining).c_str());
70       }
71       if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
72         state = TRANSFORM_PARAM_NAME;
73       } else {
74         // Add a transform with no parameters.
75         params_list->push_back({string(transform_name), func_parameters});
76         transform_name = "";
77         state = TRANSFORM_NAME;
78       }
79     } else if (state == TRANSFORM_PARAM_NAME) {
80       if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
81         params_list->push_back({string(transform_name), func_parameters});
82         transform_name = "";
83         state = TRANSFORM_NAME;
84       } else {
85         // Eat up any leading spaces or commas.
86         Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match);
87         Scanner(remaining).AnySpace().GetResult(&remaining, &match);
88         // See if we have a valid parameter name.
89         const bool found_parameter_name =
90             Scanner(remaining)
91                 .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
92                 .GetResult(&remaining, &parameter_name);
93         if (!found_parameter_name) {
94           return errors::InvalidArgument(
95               "Looking for parameter name, but found ",
96               string(remaining).c_str());
97         }
98         if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
99           state = TRANSFORM_PARAM_VALUE;
100         } else {
101           return errors::InvalidArgument("Looking for =, but found ",
102                                          string(remaining).c_str());
103         }
104       }
105     } else if (state == TRANSFORM_PARAM_VALUE) {
106       bool found_parameter_value;
107       // Deal with quoted values.
108       if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) {
109         found_parameter_value =
110             Scanner(remaining).ScanEscapedUntil('"').GetResult(
111                 &remaining, &parameter_value);
112         if (found_parameter_value) {
113           Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match);
114         }
115       } else {
116         // See if we have a valid parameter name.
117         found_parameter_value =
118             Scanner(remaining)
119                 .Many(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
120                 .GetResult(&remaining, &parameter_value);
121       }
122       if (!found_parameter_value) {
123         return errors::InvalidArgument("Looking for parameter name, but found ",
124                                        string(remaining).c_str());
125       }
126       func_parameters[string(parameter_name)].emplace_back(parameter_value);
127       // Eat up any trailing quotes.
128       Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
129       Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
130       state = TRANSFORM_PARAM_NAME;
131     }
132   }
133   return OkStatus();
134 }
135 
ExpandPath(const std::string & path_string)136 std::string ExpandPath(const std::string& path_string) {
137 #if defined(PLATFORM_WINDOWS)
138   return path_string;
139 #else
140   if (path_string.empty() || path_string[0] != '~') {
141     return path_string;
142   }
143 
144   const char* home = nullptr;
145   std::string::size_type prefix = path_string.find_first_of('/');
146   if (path_string.length() == 1 || prefix == 1) {
147     // The value of $HOME, e.g., ~/foo
148     home = getenv("HOME");
149     if (!home) {
150       // If HOME is not available, get uid
151       struct passwd* pw = getpwuid(getuid());
152       if (pw) {
153         home = pw->pw_dir;
154       }
155     }
156   } else {
157     // The value of ~user, e.g., ~user/foo
158     std::string user(path_string, 1, (prefix == std::string::npos)
159                                          ? std::string::npos
160                                          : prefix - 1);
161     struct passwd* pw = getpwnam(user.c_str());
162     if (pw) {
163       home = pw->pw_dir;
164     }
165   }
166 
167   if (!home) {
168     return path_string;
169   }
170 
171   string path(home);
172   if (prefix == std::string::npos) {
173     return path;
174   }
175 
176   if (path.length() == 0 || path[path.length() - 1] != '/') {
177     path += '/';
178   }
179   path += path_string.substr(prefix + 1);
180   return path;
181 #endif
182 }
183 
ParseFlagsAndTransformGraph(int argc,char * argv[],bool init_main)184 int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
185   string in_graph_string = "";
186   string out_graph_string = "";
187   string inputs_string = "";
188   string outputs_string = "";
189   string transforms_string = "";
190   bool output_as_text = false;
191   std::vector<Flag> flag_list = {
192       Flag("in_graph", &in_graph_string, "input graph file name"),
193       Flag("out_graph", &out_graph_string, "output graph file name"),
194       Flag("inputs", &inputs_string, "inputs"),
195       Flag("outputs", &outputs_string, "outputs"),
196       Flag("transforms", &transforms_string, "list of transforms"),
197       Flag("output_as_text", &output_as_text,
198            "whether to write the graph in text protobuf format"),
199   };
200   string usage = Flags::Usage(argv[0], flag_list);
201   usage += "\nTransforms are:\n";
202   TransformRegistry* transform_registry = GetTransformRegistry();
203   for (const auto& pair : *transform_registry) {
204     usage += pair.first + "\n";
205   }
206 
207   const bool parse_result = Flags::Parse(&argc, argv, flag_list);
208   // We need to call this to set up global state for TensorFlow.
209   if (init_main) {
210     port::InitMain(argv[0], &argc, &argv);
211   }
212   if (!parse_result) {
213     LOG(ERROR) << usage;
214     return -1;
215   }
216   if (argc > 1) {
217     LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
218     return -1;
219   }
220   if (in_graph_string.empty()) {
221     LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
222     return -1;
223   }
224   if (out_graph_string.empty()) {
225     LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
226     return -1;
227   }
228   if (transforms_string.empty()) {
229     LOG(ERROR) << "You must specify at least one transform.\n" << usage;
230     return -1;
231   }
232 
233   string in_graph = ExpandPath(in_graph_string);
234   string out_graph = ExpandPath(out_graph_string);
235 
236   std::vector<string> inputs = str_util::Split(inputs_string, ',');
237   std::vector<string> outputs = str_util::Split(outputs_string, ',');
238   TransformParameters transform_params;
239   Status parse_status =
240       ParseTransformParameters(transforms_string, &transform_params);
241   if (!parse_status.ok()) {
242     LOG(ERROR) << "Failed to parse --transform argument, error was "
243                << parse_status.error_message();
244     return -1;
245   }
246   if (transform_params.empty()) {
247     LOG(ERROR) << "You must specify at least one transform.\n" << usage;
248     return -1;
249   }
250 
251   GraphDef graph_def;
252   Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
253   if (!load_status.ok()) {
254     LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with "
255                << load_status.error_message();
256     LOG(ERROR) << usage;
257     return -1;
258   }
259 
260   Status transform_result =
261       TransformGraph(inputs, outputs, transform_params, &graph_def);
262 
263   if (!transform_result.ok()) {
264     LOG(ERROR) << transform_result.error_message();
265     LOG(ERROR) << usage;
266     return -1;
267   }
268 
269   Status save_status;
270   if (output_as_text) {
271     save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
272   } else {
273     save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
274   }
275   if (!save_status.ok()) {
276     LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with "
277                << save_status.error_message();
278     return -1;
279   }
280 
281   return 0;
282 }
283 
ShouldIgnoreErrors(const TransformFuncParameters & transform_params,bool * ignore_errors)284 Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
285                           bool* ignore_errors) {
286   *ignore_errors = false;
287   if (transform_params.count("ignore_errors") &&
288       (!transform_params.at("ignore_errors").empty())) {
289     const string& ignore_errors_string =
290         absl::AsciiStrToLower(transform_params.at("ignore_errors").at(0));
291     if (ignore_errors_string == "true") {
292       *ignore_errors = true;
293     } else if (ignore_errors_string == "false") {
294       *ignore_errors = false;
295     } else {
296       return errors::InvalidArgument(
297           "ignore_errors should be true or false, found ",
298           ignore_errors_string);
299     }
300   }
301   return OkStatus();
302 }
303 
TransformGraph(const std::vector<string> & inputs,const std::vector<string> & outputs,const TransformParameters & transform_params,GraphDef * graph_def)304 Status TransformGraph(const std::vector<string>& inputs,
305                       const std::vector<string>& outputs,
306                       const TransformParameters& transform_params,
307                       GraphDef* graph_def) {
308   TransformRegistry* transform_registry = GetTransformRegistry();
309   for (const auto& transform_info : transform_params) {
310     const string& transform_name = transform_info.first;
311     if (transform_name.empty()) {
312       continue;
313     }
314     if (!transform_registry->count(transform_name)) {
315       return errors::InvalidArgument("Transform '", transform_name,
316                                      "' not recognized.");
317     }
318     LOG(INFO) << "Applying " << transform_name;
319     const TransformFunc& transform_func =
320         transform_registry->at(transform_name);
321     TransformFuncContext context;
322     context.input_names = inputs;
323     context.output_names = outputs;
324     context.params = transform_info.second;
325     bool ignore_errors;
326     TF_RETURN_IF_ERROR(
327         ShouldIgnoreErrors(transform_info.second, &ignore_errors));
328     GraphDef transformed_graph_def;
329     Status transform_result =
330         transform_func(*graph_def, context, &transformed_graph_def);
331     if (!transform_result.ok()) {
332       if (ignore_errors) {
333         LOG(ERROR) << transform_name << ": Ignoring error "
334                    << transform_result.error_message();
335         transformed_graph_def = *graph_def;
336       } else {
337         return transform_result;
338       }
339     }
340     // Copy over the library from the original input graph.
341     *transformed_graph_def.mutable_library() = graph_def->library();
342     TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def));
343 
344     *graph_def = transformed_graph_def;
345   }
346   return OkStatus();
347 }
348 }  // namespace graph_transforms
349 }  // namespace tensorflow
350