xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/op_gen_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_gen_lib.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/util/proto/proto_utils.h"
30 
31 namespace tensorflow {
32 
WordWrap(StringPiece prefix,StringPiece str,int width)33 string WordWrap(StringPiece prefix, StringPiece str, int width) {
34   const string indent_next_line = "\n" + Spaces(prefix.size());
35   width -= prefix.size();
36   string result;
37   strings::StrAppend(&result, prefix);
38 
39   while (!str.empty()) {
40     if (static_cast<int>(str.size()) <= width) {
41       // Remaining text fits on one line.
42       strings::StrAppend(&result, str);
43       break;
44     }
45     auto space = str.rfind(' ', width);
46     if (space == StringPiece::npos) {
47       // Rather make a too-long line and break at a space.
48       space = str.find(' ');
49       if (space == StringPiece::npos) {
50         strings::StrAppend(&result, str);
51         break;
52       }
53     }
54     // Breaking at character at position <space>.
55     StringPiece to_append = str.substr(0, space);
56     str.remove_prefix(space + 1);
57     // Remove spaces at break.
58     while (str_util::EndsWith(to_append, " ")) {
59       to_append.remove_suffix(1);
60     }
61     while (absl::ConsumePrefix(&str, " ")) {
62     }
63 
64     // Go on to the next line.
65     strings::StrAppend(&result, to_append);
66     if (!str.empty()) strings::StrAppend(&result, indent_next_line);
67   }
68 
69   return result;
70 }
71 
ConsumeEquals(StringPiece * description)72 bool ConsumeEquals(StringPiece* description) {
73   if (absl::ConsumePrefix(description, "=")) {
74     while (absl::ConsumePrefix(description,
75                                " ")) {  // Also remove spaces after "=".
76     }
77     return true;
78   }
79   return false;
80 }
81 
82 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
83 // Returns whether `split_ch` was found. Afterwards, `*before_split`
84 // contains the maximum prefix of the input `*orig` that doesn't
85 // contain `split_ch`, and `*orig` contains everything after the
86 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)87 static bool SplitAt(char split_ch, StringPiece* orig,
88                     StringPiece* before_split) {
89   auto pos = orig->find(split_ch);
90   if (pos == StringPiece::npos) {
91     *before_split = *orig;
92     *orig = StringPiece();
93     return false;
94   } else {
95     *before_split = orig->substr(0, pos);
96     orig->remove_prefix(pos + 1);
97     return true;
98   }
99 }
100 
101 // Does this line start with "<spaces><field>:" where "<field>" is
102 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)103 static bool StartsWithFieldName(StringPiece line,
104                                 const std::vector<string>& multi_line_fields) {
105   StringPiece up_to_colon;
106   if (!SplitAt(':', &line, &up_to_colon)) return false;
107   while (absl::ConsumePrefix(&up_to_colon, " "))
108     ;  // Remove leading spaces.
109   for (const auto& field : multi_line_fields) {
110     if (up_to_colon == field) {
111       return true;
112     }
113   }
114   return false;
115 }
116 
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)117 static bool ConvertLine(StringPiece line,
118                         const std::vector<string>& multi_line_fields,
119                         string* ml) {
120   // Is this a field we should convert?
121   if (!StartsWithFieldName(line, multi_line_fields)) {
122     return false;
123   }
124   // Has a matching field name, so look for "..." after the colon.
125   StringPiece up_to_colon;
126   StringPiece after_colon = line;
127   SplitAt(':', &after_colon, &up_to_colon);
128   while (absl::ConsumePrefix(&after_colon, " "))
129     ;  // Remove leading spaces.
130   if (!absl::ConsumePrefix(&after_colon, "\"")) {
131     // We only convert string fields, so don't convert this line.
132     return false;
133   }
134   auto last_quote = after_colon.rfind('\"');
135   if (last_quote == StringPiece::npos) {
136     // Error: we don't see the expected matching quote, abort the conversion.
137     return false;
138   }
139   StringPiece escaped = after_colon.substr(0, last_quote);
140   StringPiece suffix = after_colon.substr(last_quote + 1);
141   // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
142 
143   string unescaped;
144   if (!absl::CUnescape(escaped, &unescaped, nullptr)) {
145     // Error unescaping, abort the conversion.
146     return false;
147   }
148   // No more errors possible at this point.
149 
150   // Find a string to mark the end that isn't in unescaped.
151   string end = "END";
152   for (int s = 0; unescaped.find(end) != string::npos; ++s) {
153     end = strings::StrCat("END", s);
154   }
155 
156   // Actually start writing the converted output.
157   strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
158   if (!suffix.empty()) {
159     // Output suffix, in case there was a trailing comment in the source.
160     strings::StrAppend(ml, suffix);
161   }
162   strings::StrAppend(ml, "\n");
163   return true;
164 }
165 
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)166 string PBTxtToMultiline(StringPiece pbtxt,
167                         const std::vector<string>& multi_line_fields) {
168   string ml;
169   // Probably big enough, since the input and output are about the
170   // same size, but just a guess.
171   ml.reserve(pbtxt.size() * (17. / 16));
172   StringPiece line;
173   while (!pbtxt.empty()) {
174     // Split pbtxt into its first line and everything after.
175     SplitAt('\n', &pbtxt, &line);
176     // Convert line or output it unchanged
177     if (!ConvertLine(line, multi_line_fields, &ml)) {
178       strings::StrAppend(&ml, line, "\n");
179     }
180   }
181   return ml;
182 }
183 
184 // Given a single line of text `line` with first : at `colon`, determine if
185 // there is an "<<END" expression after the colon and if so return true and set
186 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)187 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
188   if (colon == StringPiece::npos) return false;
189   line.remove_prefix(colon + 1);
190   while (absl::ConsumePrefix(&line, " ")) {
191   }
192   if (absl::ConsumePrefix(&line, "<<")) {
193     *end = string(line);
194     return true;
195   }
196   return false;
197 }
198 
PBTxtFromMultiline(StringPiece multiline_pbtxt)199 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
200   string pbtxt;
201   // Probably big enough, since the input and output are about the
202   // same size, but just a guess.
203   pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
204   StringPiece line;
205   while (!multiline_pbtxt.empty()) {
206     // Split multiline_pbtxt into its first line and everything after.
207     if (!SplitAt('\n', &multiline_pbtxt, &line)) {
208       strings::StrAppend(&pbtxt, line);
209       break;
210     }
211 
212     string end;
213     auto colon = line.find(':');
214     if (!FindMultiline(line, colon, &end)) {
215       // Normal case: not a multi-line string, just output the line as-is.
216       strings::StrAppend(&pbtxt, line, "\n");
217       continue;
218     }
219 
220     // Multi-line case:
221     //     something: <<END
222     // xx
223     // yy
224     // END
225     // Should be converted to:
226     //     something: "xx\nyy"
227 
228     // Output everything up to the colon ("    something:").
229     strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
230 
231     // Add every line to unescaped until we see the "END" string.
232     string unescaped;
233     bool first = true;
234     while (!multiline_pbtxt.empty()) {
235       SplitAt('\n', &multiline_pbtxt, &line);
236       if (absl::ConsumePrefix(&line, end)) break;
237       if (first) {
238         first = false;
239       } else {
240         unescaped.push_back('\n');
241       }
242       strings::StrAppend(&unescaped, line);
243       line = StringPiece();
244     }
245 
246     // Escape what we extracted and then output it in quotes.
247     strings::StrAppend(&pbtxt, " \"", absl::CEscape(unescaped), "\"", line,
248                        "\n");
249   }
250   return pbtxt;
251 }
252 
StringReplace(const string & from,const string & to,string * s)253 static void StringReplace(const string& from, const string& to, string* s) {
254   // Split *s into pieces delimited by `from`.
255   std::vector<string> split;
256   string::size_type pos = 0;
257   while (pos < s->size()) {
258     auto found = s->find(from, pos);
259     if (found == string::npos) {
260       split.push_back(s->substr(pos));
261       break;
262     } else {
263       split.push_back(s->substr(pos, found - pos));
264       pos = found + from.size();
265       if (pos == s->size()) {  // handle case where `from` is at the very end.
266         split.push_back("");
267       }
268     }
269   }
270   // Join the pieces back together with a new delimiter.
271   *s = absl::StrJoin(split, to);
272 }
273 
RenameInDocs(const string & from,const string & to,ApiDef * api_def)274 static void RenameInDocs(const string& from, const string& to,
275                          ApiDef* api_def) {
276   const string from_quoted = strings::StrCat("`", from, "`");
277   const string to_quoted = strings::StrCat("`", to, "`");
278   for (int i = 0; i < api_def->in_arg_size(); ++i) {
279     if (!api_def->in_arg(i).description().empty()) {
280       StringReplace(from_quoted, to_quoted,
281                     api_def->mutable_in_arg(i)->mutable_description());
282     }
283   }
284   for (int i = 0; i < api_def->out_arg_size(); ++i) {
285     if (!api_def->out_arg(i).description().empty()) {
286       StringReplace(from_quoted, to_quoted,
287                     api_def->mutable_out_arg(i)->mutable_description());
288     }
289   }
290   for (int i = 0; i < api_def->attr_size(); ++i) {
291     if (!api_def->attr(i).description().empty()) {
292       StringReplace(from_quoted, to_quoted,
293                     api_def->mutable_attr(i)->mutable_description());
294     }
295   }
296   if (!api_def->summary().empty()) {
297     StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
298   }
299   if (!api_def->description().empty()) {
300     StringReplace(from_quoted, to_quoted, api_def->mutable_description());
301   }
302 }
303 
304 namespace {
305 
306 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)307 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
308   api_def->set_graph_op_name(op_def.name());
309   api_def->set_visibility(ApiDef::VISIBLE);
310 
311   auto* endpoint = api_def->add_endpoint();
312   endpoint->set_name(op_def.name());
313 
314   for (const auto& op_in_arg : op_def.input_arg()) {
315     auto* api_in_arg = api_def->add_in_arg();
316     api_in_arg->set_name(op_in_arg.name());
317     api_in_arg->set_rename_to(op_in_arg.name());
318     api_in_arg->set_description(op_in_arg.description());
319 
320     *api_def->add_arg_order() = op_in_arg.name();
321   }
322   for (const auto& op_out_arg : op_def.output_arg()) {
323     auto* api_out_arg = api_def->add_out_arg();
324     api_out_arg->set_name(op_out_arg.name());
325     api_out_arg->set_rename_to(op_out_arg.name());
326     api_out_arg->set_description(op_out_arg.description());
327   }
328   for (const auto& op_attr : op_def.attr()) {
329     auto* api_attr = api_def->add_attr();
330     api_attr->set_name(op_attr.name());
331     api_attr->set_rename_to(op_attr.name());
332     if (op_attr.has_default_value()) {
333       *api_attr->mutable_default_value() = op_attr.default_value();
334     }
335     api_attr->set_description(op_attr.description());
336   }
337   api_def->set_summary(op_def.summary());
338   api_def->set_description(op_def.description());
339 }
340 
341 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)342 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
343   if (!new_arg.rename_to().empty()) {
344     base_arg->set_rename_to(new_arg.rename_to());
345   }
346   if (!new_arg.description().empty()) {
347     base_arg->set_description(new_arg.description());
348   }
349 }
350 
351 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)352 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
353   if (!new_attr.rename_to().empty()) {
354     base_attr->set_rename_to(new_attr.rename_to());
355   }
356   if (new_attr.has_default_value()) {
357     *base_attr->mutable_default_value() = new_attr.default_value();
358   }
359   if (!new_attr.description().empty()) {
360     base_attr->set_description(new_attr.description());
361   }
362 }
363 
364 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)365 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
366   // Merge visibility
367   if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
368     base_api_def->set_visibility(new_api_def.visibility());
369   }
370   // Merge endpoints
371   if (new_api_def.endpoint_size() > 0) {
372     base_api_def->clear_endpoint();
373     std::copy(
374         new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
375         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
376   }
377   // Merge args
378   for (const auto& new_arg : new_api_def.in_arg()) {
379     bool found_base_arg = false;
380     for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
381       auto* base_arg = base_api_def->mutable_in_arg(i);
382       if (base_arg->name() == new_arg.name()) {
383         MergeArg(base_arg, new_arg);
384         found_base_arg = true;
385         break;
386       }
387     }
388     if (!found_base_arg) {
389       return errors::FailedPrecondition("Argument ", new_arg.name(),
390                                         " not defined in base api for ",
391                                         base_api_def->graph_op_name());
392     }
393   }
394   for (const auto& new_arg : new_api_def.out_arg()) {
395     bool found_base_arg = false;
396     for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
397       auto* base_arg = base_api_def->mutable_out_arg(i);
398       if (base_arg->name() == new_arg.name()) {
399         MergeArg(base_arg, new_arg);
400         found_base_arg = true;
401         break;
402       }
403     }
404     if (!found_base_arg) {
405       return errors::FailedPrecondition("Argument ", new_arg.name(),
406                                         " not defined in base api for ",
407                                         base_api_def->graph_op_name());
408     }
409   }
410   // Merge arg order
411   if (new_api_def.arg_order_size() > 0) {
412     // Validate that new arg_order is correct.
413     if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
414       return errors::FailedPrecondition(
415           "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
416           base_api_def->graph_op_name(),
417           ". Expected: ", base_api_def->arg_order_size());
418     }
419     if (!std::is_permutation(new_api_def.arg_order().begin(),
420                              new_api_def.arg_order().end(),
421                              base_api_def->arg_order().begin())) {
422       return errors::FailedPrecondition(
423           "Invalid arg_order: ", absl::StrJoin(new_api_def.arg_order(), ", "),
424           " for ", base_api_def->graph_op_name(),
425           ". All elements in arg_order override must match base arg_order: ",
426           absl::StrJoin(base_api_def->arg_order(), ", "));
427     }
428 
429     base_api_def->clear_arg_order();
430     std::copy(
431         new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
432         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
433   }
434   // Merge attributes
435   for (const auto& new_attr : new_api_def.attr()) {
436     bool found_base_attr = false;
437     for (int i = 0; i < base_api_def->attr_size(); ++i) {
438       auto* base_attr = base_api_def->mutable_attr(i);
439       if (base_attr->name() == new_attr.name()) {
440         MergeAttr(base_attr, new_attr);
441         found_base_attr = true;
442         break;
443       }
444     }
445     if (!found_base_attr) {
446       return errors::FailedPrecondition("Attribute ", new_attr.name(),
447                                         " not defined in base api for ",
448                                         base_api_def->graph_op_name());
449     }
450   }
451   // Merge summary
452   if (!new_api_def.summary().empty()) {
453     base_api_def->set_summary(new_api_def.summary());
454   }
455   // Merge description
456   auto description = new_api_def.description().empty()
457                          ? base_api_def->description()
458                          : new_api_def.description();
459 
460   if (!new_api_def.description_prefix().empty()) {
461     description =
462         strings::StrCat(new_api_def.description_prefix(), "\n", description);
463   }
464   if (!new_api_def.description_suffix().empty()) {
465     description =
466         strings::StrCat(description, "\n", new_api_def.description_suffix());
467   }
468   base_api_def->set_description(description);
469   return OkStatus();
470 }
471 }  // namespace
472 
ApiDefMap(const OpList & op_list)473 ApiDefMap::ApiDefMap(const OpList& op_list) {
474   for (const auto& op : op_list.op()) {
475     ApiDef api_def;
476     InitApiDefFromOpDef(op, &api_def);
477     map_[op.name()] = api_def;
478   }
479 }
480 
~ApiDefMap()481 ApiDefMap::~ApiDefMap() {}
482 
LoadFileList(Env * env,const std::vector<string> & filenames)483 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
484   for (const auto& filename : filenames) {
485     TF_RETURN_IF_ERROR(LoadFile(env, filename));
486   }
487   return OkStatus();
488 }
489 
LoadFile(Env * env,const string & filename)490 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
491   if (filename.empty()) return OkStatus();
492   string contents;
493   TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
494   Status status = LoadApiDef(contents);
495   if (!status.ok()) {
496     // Return failed status annotated with filename to aid in debugging.
497     return errors::CreateWithUpdatedMessage(
498         status, strings::StrCat("Error parsing ApiDef file ", filename, ": ",
499                                 status.error_message()));
500   }
501   return OkStatus();
502 }
503 
LoadApiDef(const string & api_def_file_contents)504 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
505   const string contents = PBTxtFromMultiline(api_def_file_contents);
506   ApiDefs api_defs;
507   TF_RETURN_IF_ERROR(
508       proto_utils::ParseTextFormatFromString(contents, &api_defs));
509   for (const auto& api_def : api_defs.op()) {
510     // Check if the op definition is loaded. If op definition is not
511     // loaded, then we just skip this ApiDef.
512     if (map_.find(api_def.graph_op_name()) != map_.end()) {
513       // Overwrite current api def with data in api_def.
514       TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
515     }
516   }
517   return OkStatus();
518 }
519 
UpdateDocs()520 void ApiDefMap::UpdateDocs() {
521   for (auto& name_and_api_def : map_) {
522     auto& api_def = name_and_api_def.second;
523     CHECK_GT(api_def.endpoint_size(), 0);
524     const string canonical_name = api_def.endpoint(0).name();
525     if (api_def.graph_op_name() != canonical_name) {
526       RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
527     }
528     for (const auto& in_arg : api_def.in_arg()) {
529       if (in_arg.name() != in_arg.rename_to()) {
530         RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
531       }
532     }
533     for (const auto& out_arg : api_def.out_arg()) {
534       if (out_arg.name() != out_arg.rename_to()) {
535         RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
536       }
537     }
538     for (const auto& attr : api_def.attr()) {
539       if (attr.name() != attr.rename_to()) {
540         RenameInDocs(attr.name(), attr.rename_to(), &api_def);
541       }
542     }
543   }
544 }
545 
GetApiDef(const string & name) const546 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
547   return gtl::FindOrNull(map_, name);
548 }
549 }  // namespace tensorflow
550