xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_op_gen.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16 
17 #include <stdio.h>
18 
19 #include <sstream>
20 #include <unordered_map>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/op_gen_lib.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/python/framework/python_op_gen_internal.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 const int kRightMargin = 78;
45 
46 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
47 
48 // Maps C++ dtype enum values to Python DType classes
49 const std::unordered_map<string, string> dtype_type{
50     {"_dtypes.float16", "_dtypes.Float16"},
51     {"_dtypes.half", "_dtypes.Half"},
52     {"_dtypes.float32", "_dtypes.Float32"},
53     {"_dtypes.float64", "_dtypes.Float64"},
54     {"_dtypes.bfloat16", "_dtypes.BFloat16"},
55     {"_dtypes.complex64", "_dtypes.Complex64"},
56     {"_dtypes.complex128", "_dtypes.Complex128"},
57     {"_dtypes.int8", "_dtypes.Int8"},
58     {"_dtypes.uint8", "_dtypes.UInt8"},
59     {"_dtypes.uint16", "_dtypes.UInt16"},
60     {"_dtypes.uint32", "_dtypes.UInt32"},
61     {"_dtypes.uint64", "_dtypes.UInt64"},
62     {"_dtypes.int16", "_dtypes.Int16"},
63     {"_dtypes.int32", "_dtypes.Int32"},
64     {"_dtypes.int64", "_dtypes.Int64"},
65     {"_dtypes.bool", "_dtypes.Bool"},
66     {"_dtypes.string", "_dtypes.String"},
67     {"_dtypes.qint8", "_dtypes.QInt8"},
68     {"_dtypes.quint8", "_dtypes.QUInt8"},
69     {"_dtypes.qint16", "_dtypes.QInt16"},
70     {"_dtypes.quint16", "_dtypes.QUInt16"},
71     {"_dtypes.qint32", "_dtypes.QInt32"},
72     {"_dtypes.resource", "_dtypes.Resource"},
73     {"_dtypes.variant", "_dtypes.Variant"}};
74 
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)75 string AttrVarName(const string& attr_name,
76                    std::unordered_map<string, string>* attr_expressions) {
77   const string var = strings::StrCat("_attr_", attr_name);
78   if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
79   return var;
80 }
81 
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)82 void AddInferredAttr(const string& indentation, const string& attr_name,
83                      const string& value_expression, string* result,
84                      std::unordered_map<string, string>* attr_expressions) {
85   strings::StrAppend(result, indentation,
86                      AttrVarName(attr_name, attr_expressions), " = ",
87                      value_expression, "\n");
88 }
89 
VectorToTuple(const std::vector<string> & l)90 string VectorToTuple(const std::vector<string>& l) {
91   if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
92   string ret = "(";
93   for (int i = 0, end = l.size(); i < end; ++i) {
94     if (i > 0) {
95       strings::StrAppend(&ret, ", ");
96     }
97     strings::StrAppend(&ret, l[i]);
98   }
99   strings::StrAppend(&ret, ")");
100   return ret;
101 }
102 
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)103 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
104                const string& var, string* result) {
105   for (int i = 0, end = output_sizes.size(); i < end; ++i) {
106     if (!output_sizes[i].empty()) {
107       strings::StrAppend(result, prefix, var, " = ");
108       if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
109       if (i + 1 < end) {
110         // Special case i == 0 to avoid "0 +" in the generated code.
111         if (i == 0) {
112           strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
113                              var, "[", output_sizes[i], ":]");
114         } else {
115           strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
116                              output_sizes[i], "]] + ", var, "[", i, " + ",
117                              output_sizes[i], ":]");
118         }
119       } else {
120         strings::StrAppend(result, "[", var, "[", i, ":]]");
121       }
122       strings::StrAppend(result, "\n");
123     }
124   }
125 }
126 
TensorPBString(const TensorProto & pb)127 string TensorPBString(const TensorProto& pb) {
128   // Note: This gets used in the argument list, and so must survive naive
129   // word wrapping.
130   return strings::StrCat("\"\"\"", pb.ShortDebugString(), "\"\"\"");
131 }
132 
133 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
134  public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)135   GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
136                    const string& function_name, bool add_type_annotations)
137       : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
138                                             add_type_annotations) {
139     op_name_ = function_name_;
140     absl::ConsumePrefix(&op_name_, "_");
141   }
~GenEagerPythonOp()142   ~GenEagerPythonOp() override {}
143 
144   string Code() override;
145 
146  protected:
147   void HandleGraphMode(const string& function_setup,
148                        const std::vector<string>& output_sizes);
149 
150   string GetEagerNotAllowedError();
151   void ExpectListArg(const string& indentation, const string& arg_name,
152                      string* output);
153   bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
154   void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
155                                        string* num_outputs_expr);
156 
157   void AddEagerFunctionTeardown(const string& indentation,
158                                 const std::vector<string>& output_sizes,
159                                 bool execute_record_gradient);
160 
161   bool AddEagerFastPathAndGraphCode(
162       const string& parameters, const std::vector<string>& output_sizes,
163       const string& eager_not_allowed_error,
164       const std::unordered_map<string, string>& type_annotations);
165   bool AddEagerFallbackCode(
166       const string& parameters, const std::vector<string>& output_sizes,
167       const string& num_outputs_expr, const string& eager_not_allowed_error,
168       const std::unordered_map<string, string>& type_annotations);
169   void AddEagerFastPathExecute();
170 
171   void AddEagerInferredAttrs(const string& indentation);
172   void AddEagerInputCasts(const string& indentation);
173   void AddEagerAttrs(const string& indentation);
174   void AddEagerExecute(const string& indentation,
175                        const string& num_outputs_expr);
176   void AddFallbackDispatch(const string& prefix);
177   void AddTypeBasedDispatch(const string& prefix);
178   void AddTypeBasedDispatcherAlias();
179 
180   void AddRawOpExport(const string& parameters);
181 
182   std::unordered_map<string, string> GetTypeAnnotations();
183 
184   void GenerateTypeVars(
185       const std::unordered_map<string, string>& type_annotations);
186 
187   void AddReturnTypeAnnotation(
188       const std::unordered_map<string, string>& type_annotations);
189 
AddAttrForArg(const string & attr,int arg_index)190   void AddAttrForArg(const string& attr, int arg_index) {
191     gtl::InsertIfNotPresent(&inferred_attrs_, attr,
192                             op_def_.input_arg(arg_index).name());
193     auto iter = attr_to_args_.find(attr);
194     if (iter == attr_to_args_.end()) {
195       attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
196     } else {
197       iter->second.push_back(arg_index);
198     }
199   }
200 
201   // Returns a string expression representing a flattened list of all
202   // the inputs given by `*input_indices` (or all inputs if
203   // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
204   string FlattenInputs(const std::vector<int>* input_indices,
205                        std::vector<string>* output_sizes) const;
206 
207   StringPiece op_name_;
208   typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
209   AttrToArgMap attr_to_args_;
210   std::unordered_map<string, string> attr_expressions_;
211   // This has all the input args followed by those attrs that don't have
212   // defaults.
213   std::vector<python_op_gen_internal::ParamNames> params_no_default_;
214   // The parameters with defaults (these have to be listed after those without).
215   // No input args are included, just attrs.
216   std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
217       params_with_default_;
218 };
219 
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)220 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
221                         const string& function_name,
222                         bool add_type_annotations) {
223   return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
224       .Code();
225 }
226 
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const227 string GenEagerPythonOp::FlattenInputs(
228     const std::vector<int>* input_indices,
229     std::vector<string>* output_sizes) const {
230   string inputs;
231   enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
232   const int n = input_indices != nullptr ? input_indices->size()
233                                          : op_def_.input_arg_size();
234   for (int j = 0; j < n; ++j) {
235     const int i = input_indices ? (*input_indices)[j] : j;
236     const auto& arg(op_def_.input_arg(i));
237     const bool is_list =
238         !arg.type_list_attr().empty() || !arg.number_attr().empty();
239     if (is_list) {
240       if (inputs_state == WAS_SOLO_INPUT) {
241         strings::StrAppend(&inputs, "] + ");
242       } else if (inputs_state == WAS_LIST_INPUT) {
243         strings::StrAppend(&inputs, " + ");
244       }
245       strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
246       inputs_state = WAS_LIST_INPUT;
247       if (output_sizes != nullptr) {
248         if (!arg.number_attr().empty()) {
249           output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
250         } else {
251           output_sizes->emplace_back(
252               strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
253         }
254       }
255     } else {
256       if (inputs_state == WAS_SOLO_INPUT) {
257         strings::StrAppend(&inputs, ", ");
258       } else if (inputs_state == WAS_LIST_INPUT) {
259         strings::StrAppend(&inputs, " + [");
260       } else {
261         strings::StrAppend(&inputs, "[");
262       }
263       strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
264       inputs_state = WAS_SOLO_INPUT;
265       if (output_sizes != nullptr) output_sizes->emplace_back();
266     }
267   }
268   if (inputs_state == STARTING) return "[]";
269   if (inputs_state == WAS_SOLO_INPUT) {
270     strings::StrAppend(&inputs, "]");
271   }
272   return inputs;
273 }
274 
Code()275 string GenEagerPythonOp::Code() {
276   if (api_def_.visibility() == ApiDef::SKIP) {
277     return "";
278   }
279 
280   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
281     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
282     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
283     params_no_default_.emplace_back(api_def_arg.name(),
284                                     api_def_arg.rename_to());
285     if (!arg.type_attr().empty()) {
286       AddAttrForArg(arg.type_attr(), i);
287     } else if (!arg.type_list_attr().empty()) {
288       AddAttrForArg(arg.type_list_attr(), i);
289     }
290     if (!arg.number_attr().empty()) {
291       AddAttrForArg(arg.number_attr(), i);
292     }
293   }
294   for (int i = 0; i < op_def_.attr_size(); ++i) {
295     const auto& attr(op_def_.attr(i));
296     const auto& api_def_attr(api_def_.attr(i));
297     // Do not add inferred attrs to the Python function signature.
298     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
299       if (api_def_attr.has_default_value()) {
300         if (attr.type() == "tensor") {
301           params_with_default_.emplace_back(
302               python_op_gen_internal::ParamNames(api_def_attr.name(),
303                                                  api_def_attr.rename_to()),
304               strings::StrCat(
305                   "_execute.make_tensor(",
306                   TensorPBString(api_def_attr.default_value().tensor()), ", \"",
307                   api_def_attr.rename_to(), "\")"));
308         } else if (attr.type() == "list(tensor)") {
309           std::vector<string> pbtxt;
310           for (const auto& pb : api_def_attr.default_value().list().tensor()) {
311             pbtxt.emplace_back(TensorPBString(pb));
312           }
313           params_with_default_.emplace_back(
314               python_op_gen_internal::ParamNames(api_def_attr.name(),
315                                                  api_def_attr.rename_to()),
316               strings::StrCat("[_execute.make_tensor(_pb, \"",
317                               api_def_attr.rename_to(), "\") for _pb in ",
318                               VectorToTuple(pbtxt), "]"));
319         } else {
320           params_with_default_.emplace_back(
321               python_op_gen_internal::ParamNames(api_def_attr.name(),
322                                                  api_def_attr.rename_to()),
323               python_op_gen_internal::AttrValueToPython(
324                   attr.type(), api_def_attr.default_value(), "_dtypes."));
325         }
326       } else {
327         params_no_default_.emplace_back(api_def_attr.name(),
328                                         api_def_attr.rename_to());
329       }
330     }
331   }
332 
333   // Save the list of attr parameters (attrs that won't be inferred),
334   // those with defaults go at the end.
335   // Get the attrs in the order we want by taking the attrs without defaults
336   // from the end of params_no_default_, and adding params_no_default_.
337   attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
338                  params_with_default_.size());
339   for (int i = op_def_.input_arg_size(), end = params_no_default_.size();
340        i < end; ++i) {
341     attrs_.push_back(params_no_default_[i].GetName());
342   }
343   for (const auto& p : params_with_default_) {
344     attrs_.push_back(p.first.GetName());
345   }
346 
347   // TODO(slebedev): call AvoidPythonReserved on each param?
348   param_names_.reserve(params_no_default_.size() + params_with_default_.size());
349   param_names_.insert(param_names_.begin(), params_no_default_.begin(),
350                       params_no_default_.end());
351   for (const auto& param_and_default : params_with_default_) {
352     param_names_.push_back(param_and_default.first);
353   }
354 
355   std::unordered_map<string, string> type_annotations;
356   // Only populate map for allowlisted ops
357   if (add_type_annotations_) {
358     type_annotations = GetTypeAnnotations();
359   }
360 
361   string parameters;
362   // Param can be an input or an attr
363   for (const auto& param : params_no_default_) {
364     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
365     strings::StrAppend(&parameters, param.GetRenameTo());
366 
367     if (type_annotations.find(param.GetName()) != type_annotations.end()) {
368       strings::StrAppend(&parameters, ": ",
369                          type_annotations.at(param.GetName()));
370     }
371   }
372 
373   string parameters_with_defaults = parameters;
374   for (const auto& param_and_default : params_with_default_) {
375     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
376     if (!parameters_with_defaults.empty())
377       strings::StrAppend(&parameters_with_defaults, ", ");
378 
379     strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
380     strings::StrAppend(&parameters_with_defaults,
381                        param_and_default.first.GetRenameTo());
382     if (type_annotations.find(param_and_default.first.GetName()) !=
383         type_annotations.end()) {
384       const string param_type =
385           type_annotations.at(param_and_default.first.GetName());
386       // Append to parameters and parameters_with_defaults because multiple
387       // functions are generated by AddEagerFastPathAndGraphCode() and
388       // AddEagerFallbackCode()
389       strings::StrAppend(&parameters, ": ", param_type);
390       strings::StrAppend(&parameters_with_defaults, ":", param_type);
391     }
392 
393     strings::StrAppend(&parameters_with_defaults, "=",
394                        param_and_default.second);
395   }
396 
397   strings::StrAppend(&parameters, parameters.empty() ? "" : ", ", "name");
398   strings::StrAppend(&parameters_with_defaults,
399                      parameters_with_defaults.empty() ? "" : ", ", "name=None");
400 
401   // Add attr_expressions_ for attrs that are params.
402   for (int i = 0, end = attrs_.size(); i < end; ++i) {
403     const string& attr_name = attrs_[i];
404     const string& attr_api_name =
405         param_names_[i + op_def_.input_arg_size()].GetRenameTo();
406     attr_expressions_[attr_name] = attr_api_name;
407   }
408   // Add attr_expressions_ for attrs that are inferred.
409   for (int i = 0; i < op_def_.attr_size(); ++i) {
410     const auto& attr(op_def_.attr(i));
411     if (attr.type() == "int") {
412       auto arg_list = attr_to_args_.find(attr.name());
413       if (arg_list != attr_to_args_.end()) {
414         AttrVarName(attr.name(), &attr_expressions_);
415       }
416     }
417   }
418 
419   string num_outputs_expr;
420   std::vector<string> output_sizes(num_outs_);
421   GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
422 
423   string eager_not_allowed_error = GetEagerNotAllowedError();
424 
425   if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
426                                     eager_not_allowed_error,
427                                     type_annotations)) {
428     return result_;
429   }
430 
431   if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
432                             eager_not_allowed_error, type_annotations)) {
433     return result_;
434   }
435 
436   return prelude_ + result_;
437 }
438 
GetTypeAnnotations()439 std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
440   std::unordered_map<string, string> type_annotations;
441   // Map attrs to TypeVars
442   for (const auto& attr : op_def_.attr()) {
443     if (attr.type() == "type") {
444       const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
445       type_annotations[attr.name()] = type_var_name;
446     } else if (attr.type() == "bool" || attr.type() == "float" ||
447                attr.type() == "int" || attr.type() == "bytes") {
448       type_annotations[attr.name()] = attr.type();
449     } else if (attr.type() == "string") {
450       type_annotations[attr.name()] = "str";
451     }
452   }
453 
454   // Map input Tensors to their types
455   for (const auto& arg : op_def_.input_arg()) {
456     // TODO(rahulkamat): Add type annotations to args that accept a sequence of
457     // Tensors
458     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
459     type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
460   }
461 
462   // TODO(rahulkamat): Add type annotations to handle return types of a sequence
463   // of Tensors. Map output Tensor to its type
464   if (op_def_.output_arg_size() == 1) {
465     const auto& arg = op_def_.output_arg(0);
466     if (arg.number_attr().empty() && arg.type_list_attr().empty())
467       type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
468   }
469 
470   return type_annotations;
471 }
472 
473 // Generate TypeVars using attrs
GenerateTypeVars(const std::unordered_map<string,string> & type_annotations)474 void GenEagerPythonOp::GenerateTypeVars(
475     const std::unordered_map<string, string>& type_annotations) {
476   bool added_typevar = false;
477   for (const auto& attr : op_def_.attr()) {
478     if (attr.type() == "type") {
479       std::vector<string> allowed_types;
480       for (int t : attr.allowed_values().list().type()) {
481         DataType dtype = static_cast<DataType>(t);
482         const string py_dtype =
483             python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
484         allowed_types.emplace_back(dtype_type.at(py_dtype));
485       }
486 
487       // When a Tensor does not have any dtypes specified, all dtypes are
488       // allowed
489       if (allowed_types.empty()) {
490         for (std::pair<string, string> map_dtype : dtype_type) {
491           allowed_types.emplace_back(map_dtype.second);
492         }
493       }
494 
495       std::sort(allowed_types.begin(), allowed_types.end());
496 
497       string typevar_dtypes;
498       for (std::vector<string>::iterator it = allowed_types.begin();
499            it != allowed_types.end(); ++it) {
500         if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
501         strings::StrAppend(&typevar_dtypes, *it);
502       }
503 
504       const string type_var_name = type_annotations.at(attr.name());
505       strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
506                          type_var_name, "\", ", typevar_dtypes, ")\n");
507       added_typevar = true;
508     }
509   }
510 
511   if (added_typevar) strings::StrAppend(&result_, "\n");
512 }
513 
AddReturnTypeAnnotation(const std::unordered_map<string,string> & type_annotations)514 void GenEagerPythonOp::AddReturnTypeAnnotation(
515     const std::unordered_map<string, string>& type_annotations) {
516   if (op_def_.output_arg_size() == 1) {
517     const auto& arg = op_def_.output_arg(0);
518     if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
519       const string return_type = type_annotations.at(arg.name());
520       // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
521       // avoid erasing ":\n" from the end of the def line
522       result_.erase(result_.length() - 2);
523       strings::StrAppend(&result_, " -> ", return_type, ":\n");
524     }
525   }
526 }
527 
HandleGraphMode(const string & function_setup,const std::vector<string> & output_sizes)528 void GenEagerPythonOp::HandleGraphMode(
529     const string& function_setup, const std::vector<string>& output_sizes) {
530   if (api_def_.visibility() == ApiDef::VISIBLE) {
531     strings::StrAppend(&result_, "  else:\n");
532     AddTypeBasedDispatch("    ");
533   }
534   strings::StrAppend(&result_, "  # Add nodes to the TensorFlow graph.\n");
535   strings::StrAppend(&result_, function_setup);
536   if (api_def_.visibility() == ApiDef::VISIBLE) {
537     strings::StrAppend(&result_, "  try:\n  ");
538   }
539   strings::StrAppend(
540       &result_, "  _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
541   AddBodyNoReturn(strings::StrCat("        \"", op_def_.name(), "\", "));
542   AddFallbackDispatch("  ");
543 
544   if (num_outs_ > 0) {
545     strings::StrAppend(&result_, "  _result = _outputs[:]\n");
546     // Special case handling for stateful op with single list output
547     // that might be empty.
548     if (num_outs_ == 1 && op_def_.is_stateful() &&
549         (!op_def_.output_arg(0).number_attr().empty() ||
550          !op_def_.output_arg(0).type_list_attr().empty())) {
551       // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
552       // a constraint indicating that this can never be empty.
553       strings::StrAppend(&result_,
554                          "  if not _result:\n"
555                          "    return _op\n");
556     }
557 
558     // Compute graph-mode attrs when we need to record a gradient.
559     strings::StrAppend(&result_, "  if _execute.must_record_gradient():\n");
560     if (op_def_.attr_size() > 0) {
561       string attr_values;
562       for (int i = 0; i < op_def_.attr_size(); ++i) {
563         if (i > 0) strings::StrAppend(&attr_values, ", ");
564         const auto& attr_name(op_def_.attr(i).name());
565         if (op_def_.attr(i).type() == "type") {
566           strings::StrAppend(&attr_values, "\"", attr_name,
567                              "\", _op._get_attr_type(\"", attr_name, "\")");
568         } else if (op_def_.attr(i).type() == "bool") {
569           strings::StrAppend(&attr_values, "\"", attr_name,
570                              "\", _op._get_attr_bool(\"", attr_name, "\")");
571         } else if (op_def_.attr(i).type() == "int") {
572           strings::StrAppend(&attr_values, "\"", attr_name,
573                              "\", _op._get_attr_int(\"", attr_name, "\")");
574         } else {
575           strings::StrAppend(&attr_values, "\"", attr_name,
576                              "\", _op.get_attr(\"", attr_name, "\")");
577         }
578       }
579       strings::StrAppend(&attr_values, ")");
580       strings::StrAppend(&result_,
581                          WordWrap("    _attrs = (", attr_values, kRightMargin),
582                          "\n");
583 
584     } else {
585       strings::StrAppend(&result_, "    _attrs = ()\n");
586     }
587 
588     strings::StrAppend(&result_, "    _inputs_flat = _op.inputs\n");
589     strings::StrAppend(&result_, "    _execute.record_gradient(\n",
590                        "        \"", op_def_.name(),
591                        "\", _inputs_flat, _attrs, _result)\n");
592 
593     if (num_outs_ == 1 && !output_sizes[0].empty()) {
594       // Single list result.
595     } else if (num_outs_ == 1) {
596       // Execute returns a single-element list which we need to destructure.
597       strings::StrAppend(&result_, "  ", "_result, = _result\n");
598     } else {
599       // Have multiple outputs, so we will need to reformat the return
600       // value of execute() to be a list with one entry per op output
601       // (that entry will be a list of tensors if that output is of list
602       // type).
603       // For list outputs, convert the right subrange of _result into a list.
604       Unflatten("  ", output_sizes, "_result", &result_);
605       // Convert to a named tuple.
606       strings::StrAppend(
607           &result_, "  _result = _",
608           python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
609           "Output._make(_result)\n");
610     }
611     strings::StrAppend(&result_, "  return _result\n\n");
612   } else {
613     strings::StrAppend(&result_, "  return _op\n");
614   }
615 }
616 
GetEagerNotAllowedError()617 string GenEagerPythonOp::GetEagerNotAllowedError() {
618   bool eager_allowed = true;
619   string ref_arg;
620   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
621     const auto& arg = op_def_.input_arg(i);
622     if (arg.is_ref()) {
623       eager_allowed = false;
624       DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
625       ref_arg = api_def_.in_arg(i).rename_to();
626     }
627   }
628   for (int i = 0; i < op_def_.output_arg_size(); ++i) {
629     const auto& arg = op_def_.output_arg(i);
630     if (arg.is_ref()) {
631       eager_allowed = false;
632       DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
633       ref_arg = api_def_.out_arg(i).rename_to();
634     }
635   }
636 
637   if (eager_allowed) return "";
638 
639   return strings::StrCat("raise RuntimeError(\"", op_name_,
640                          " op does not support eager execution. ", "Arg '",
641                          ref_arg, "' is a ref.\")\n");
642 }
643 
ExpectListArg(const string & indentation,const string & arg_name,string * output)644 void GenEagerPythonOp::ExpectListArg(const string& indentation,
645                                      const string& arg_name, string* output) {
646   strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
647                      ", (list, tuple)):\n", indentation, "  raise TypeError(\n",
648                      indentation, "      \"Expected list for '", arg_name,
649                      "' argument to \"\n", indentation, "      \"'", op_name_,
650                      "' Op, not %r.\" % ", arg_name, ")\n");
651 }
652 
GetEagerFunctionSetup(const string & indentation,string * function_setup)653 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
654                                              string* function_setup) {
655   // Validate list inputs, infer length attrs.
656   for (int i = 0; i < op_def_.attr_size(); ++i) {
657     const auto& attr(op_def_.attr(i));
658     if (attr.type() == "int") {
659       auto arg_list = attr_to_args_.find(attr.name());
660       if (arg_list != attr_to_args_.end()) {
661         // Inferred int attrs are the lengths of inputs. Validate those
662         // inputs are lists and have the same length.
663         for (auto iter = arg_list->second.begin();
664              iter != arg_list->second.end(); ++iter) {
665           const string& arg_api_name = param_names_[*iter].GetRenameTo();
666           ExpectListArg(indentation, arg_api_name, function_setup);
667           if (iter == arg_list->second.begin()) {
668             AddInferredAttr(indentation, attr.name(),
669                             strings::StrCat("len(", arg_api_name, ")"),
670                             function_setup, &attr_expressions_);
671           } else {
672             const auto& attr_var = attr_expressions_[attr.name()];
673             strings::StrAppend(
674                 function_setup, indentation, "if len(", arg_api_name,
675                 ") != ", attr_var, ":\n", indentation, "  raise ValueError(\n",
676                 indentation, "      \"List argument '", arg_api_name, "' to '",
677                 op_name_, "' Op with length %d \"\n", indentation,
678                 "      \"must match length %d of argument '",
679                 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
680                 "      (len(", arg_api_name, "), ", attr_var, "))\n");
681           }
682         }
683       }
684     }
685   }
686 
687   for (int i = 0, end = attrs_.size(); i < end; ++i) {
688     const string& attr_name = attrs_[i];
689     const auto& param = param_names_[i + op_def_.input_arg_size()];
690     const auto& attr = *FindAttr(attr_name, op_def_);
691     const string& attr_api_name = param.GetRenameTo();
692     StringPiece attr_type = attr.type();
693     attr_expressions_[attr_name] = attr_api_name;
694     const int default_index = i - (attrs_.size() - params_with_default_.size());
695     if (default_index >= 0) {
696       const string& default_value = params_with_default_[default_index].second;
697       strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
698                          " is None:\n");
699       strings::StrAppend(function_setup, indentation, "  ", attr_api_name,
700                          " = ", default_value, "\n");
701     }
702     if (absl::StartsWith(attr_type, "list(")) {
703       ExpectListArg(indentation, attr_api_name, function_setup);
704     }
705 
706     if (attr_type == "string") {
707       strings::StrAppend(function_setup, indentation, attr_api_name,
708                          " = _execute.make_str(", attr_api_name, ", \"",
709                          attr_api_name, "\")\n");
710     } else if (attr_type == "list(string)") {
711       strings::StrAppend(function_setup, indentation, attr_api_name,
712                          " = [_execute.make_str(_s, \"", attr_api_name,
713                          "\") for _s in ", attr_api_name, "]\n");
714     } else if (attr_type == "int") {
715       strings::StrAppend(function_setup, indentation, attr_api_name,
716                          " = _execute.make_int(", attr_api_name, ", \"",
717                          attr_api_name, "\")\n");
718     } else if (attr_type == "list(int)") {
719       strings::StrAppend(function_setup, indentation, attr_api_name,
720                          " = [_execute.make_int(_i, \"", attr_api_name,
721                          "\") for _i in ", attr_api_name, "]\n");
722     } else if (attr_type == "float") {
723       strings::StrAppend(function_setup, indentation, attr_api_name,
724                          " = _execute.make_float(", attr_api_name, ", \"",
725                          attr_api_name, "\")\n");
726     } else if (attr_type == "list(float)") {
727       strings::StrAppend(function_setup, indentation, attr_api_name,
728                          " = [_execute.make_float(_f, \"", attr_api_name,
729                          "\") for _f in ", attr_api_name, "]\n");
730     } else if (attr_type == "bool") {
731       strings::StrAppend(function_setup, indentation, attr_api_name,
732                          " = _execute.make_bool(", attr_api_name, ", \"",
733                          attr_api_name, "\")\n");
734     } else if (attr_type == "list(bool)") {
735       strings::StrAppend(function_setup, indentation, attr_api_name,
736                          " = [_execute.make_bool(_b, \"", attr_api_name,
737                          "\") for _b in ", attr_api_name, "]\n");
738     } else if (attr_type == "type") {
739       strings::StrAppend(function_setup, indentation, attr_api_name,
740                          " = _execute.make_type(", attr_api_name, ", \"",
741                          attr_api_name, "\")\n");
742     } else if (attr_type == "list(type)") {
743       strings::StrAppend(function_setup, indentation, attr_api_name,
744                          " = [_execute.make_type(_t, \"", attr_api_name,
745                          "\") for _t in ", attr_api_name, "]\n");
746     } else if (attr_type == "shape") {
747       strings::StrAppend(function_setup, indentation, attr_api_name,
748                          " = _execute.make_shape(", attr_api_name, ", \"",
749                          attr_api_name, "\")\n");
750     } else if (attr_type == "list(shape)") {
751       strings::StrAppend(function_setup, indentation, attr_api_name,
752                          " = [_execute.make_shape(_s, \"", attr_api_name,
753                          "\") for _s in ", attr_api_name, "]\n");
754     } else if (attr_type == "tensor") {
755       strings::StrAppend(function_setup, indentation, attr_api_name,
756                          " = _execute.make_tensor(", attr_api_name, ", \"",
757                          attr_api_name, "\")\n");
758     } else if (attr_type == "list(tensor)") {
759       strings::StrAppend(function_setup, indentation, attr_api_name,
760                          " = [_execute.make_tensor(_t, \"", attr_api_name,
761                          "\") for _t in ", attr_api_name, "]\n");
762     } else if (attr_type != "func" && attr_type != "list(func)") {
763       *function_setup =
764           strings::StrCat("# No definition for ", function_name_,
765                           " since we don't support attrs with type\n"
766                           "# '",
767                           attr_type, "' right now.\n\n");
768       return false;
769     }
770   }
771   return true;
772 }
773 
774 // If output i is list output, output_sizes[i] will be set to a
775 // string with the python expression that will evaluate to its
776 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)777 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
778     std::vector<string>* output_sizes, string* num_outputs_expr) {
779   // Expression representing the number of outputs.
780   int num_fixed_outputs = 0;
781   for (int i = 0; i < num_outs_; ++i) {
782     const auto& arg(op_def_.output_arg(i));
783     if (!arg.number_attr().empty()) {
784       if (!num_outputs_expr->empty()) {
785         strings::StrAppend(num_outputs_expr, " + ");
786       }
787       (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
788       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
789     } else if (!arg.type_list_attr().empty()) {
790       if (!num_outputs_expr->empty()) {
791         strings::StrAppend(num_outputs_expr, " + ");
792       }
793       // Have to be careful to use an expression that works in both
794       // graph and eager paths here.
795       const auto iter = inferred_attrs_.find(arg.type_list_attr());
796       if (iter == inferred_attrs_.end()) {
797         (*output_sizes)[i] = strings::StrCat(
798             "len(", attr_expressions_[arg.type_list_attr()], ")");
799       } else {
800         (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
801       }
802       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
803     } else {
804       ++num_fixed_outputs;
805     }
806   }
807   if (num_fixed_outputs > 0) {
808     if (!num_outputs_expr->empty()) {
809       strings::StrAppend(num_outputs_expr, " + ");
810     }
811     strings::StrAppend(num_outputs_expr, num_fixed_outputs);
812   } else if (num_outputs_expr->empty()) {
813     *num_outputs_expr = "0";
814   }
815 }
816 
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)817 void GenEagerPythonOp::AddEagerFunctionTeardown(
818     const string& indentation, const std::vector<string>& output_sizes,
819     bool execute_record_gradient) {
820   if (num_outs_ > 0) {
821     if (execute_record_gradient) {
822       strings::StrAppend(&result_, indentation,
823                          "if _execute.must_record_gradient():\n");
824       strings::StrAppend(&result_, indentation, "  _execute.record_gradient(\n",
825                          "        \"", op_def_.name(),
826                          "\", _inputs_flat, _attrs, _result)\n");
827     }
828     if (num_outs_ == 1 && !output_sizes[0].empty()) {
829       // Single list result.
830     } else if (num_outs_ == 1) {
831       // Execute returns a single-element list which we need to destructure.
832       strings::StrAppend(&result_, indentation, "_result, = _result\n");
833     } else {
834       // Have multiple outputs, so we will need to reformat the return
835       // value of execute() to be a list with one entry per op output
836       // (that entry will be a list of tensors if that output is of list
837       // type).
838       // For list outputs, convert the right subrange of _result into a list.
839       Unflatten(indentation, output_sizes, "_result", &result_);
840       // Convert to a named tuple.
841       strings::StrAppend(
842           &result_, indentation, "_result = _",
843           python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
844           "Output._make(_result)\n");
845     }
846   } else {
847     strings::StrAppend(&result_, indentation, "_result = None\n");
848   }
849   strings::StrAppend(&result_, indentation, "return _result\n\n");
850 }
851 
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)852 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
853     const string& parameters, const std::vector<string>& output_sizes,
854     const string& eager_not_allowed_error,
855     const std::unordered_map<string, string>& type_annotations) {
856   if (add_type_annotations_) {
857     GenerateTypeVars(type_annotations);
858   }
859   if (api_def_.visibility() == ApiDef::VISIBLE) {
860     strings::StrAppend(&result_, "@_dispatch.add_fallback_dispatch_list\n");
861     strings::StrAppend(&result_, "@_dispatch.add_type_based_api_dispatcher\n");
862   }
863 
864   AddExport();
865   AddDefLine(function_name_, parameters);
866   if (add_type_annotations_) {
867     AddReturnTypeAnnotation(type_annotations);
868   }
869   AddDocStringDescription();
870   AddDocStringArgs();
871   AddDocStringInputs();
872   AddDocStringAttrs();
873   AddDocStringNameArg();
874   AddOutputGlobals();  // Added to prelude_
875   AddDocStringOutputs();
876   strings::StrAppend(&result_, "  \"\"\"\n");
877 
878   strings::StrAppend(&result_,
879                      "  _ctx = _context._context or _context.context()\n"
880                      "  tld = _ctx._thread_local_data\n",
881                      "  if tld.is_eager:", "\n");
882   if (eager_not_allowed_error.empty()) {
883     AddEagerFastPathExecute();
884   } else {
885     strings::StrAppend(&result_, "    ", eager_not_allowed_error);
886   }
887 
888   // Handle graph-mode case
889   string function_setup;
890   if (!GetEagerFunctionSetup("  ", &function_setup)) {
891     result_ = function_setup;
892     return false;
893   }
894   HandleGraphMode(function_setup, output_sizes);
895 
896   AddRawOpExport(parameters);
897   AddTypeBasedDispatcherAlias();
898   strings::StrAppend(&result_, "\n\n");
899   return true;
900 }
901 
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)902 bool GenEagerPythonOp::AddEagerFallbackCode(
903     const string& parameters, const std::vector<string>& output_sizes,
904     const string& num_outputs_expr, const string& eager_not_allowed_error,
905     const std::unordered_map<string, string>& type_annotations) {
906   AddDefLine(
907       strings::StrCat(function_name_, kEagerFallbackSuffix),
908       strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
909   if (add_type_annotations_) {
910     AddReturnTypeAnnotation(type_annotations);
911   }
912   if (!eager_not_allowed_error.empty()) {
913     strings::StrAppend(&result_, "  ", eager_not_allowed_error);
914     return true;
915   }
916 
917   string function_setup;
918   if (!GetEagerFunctionSetup("  ", &function_setup)) {
919     result_ = function_setup;
920     return false;
921   }
922   strings::StrAppend(&result_, function_setup);
923 
924   AddEagerInferredAttrs("  ");
925   AddEagerInputCasts("  ");
926   strings::StrAppend(
927       &result_, "  _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
928   AddEagerAttrs("  ");
929   AddEagerExecute("  ", num_outputs_expr);
930 
931   AddEagerFunctionTeardown("  ", output_sizes,
932                            true /* execute_record_gradient */);
933 
934   return true;
935 }
936 
AddEagerFastPathExecute()937 void GenEagerPythonOp::AddEagerFastPathExecute() {
938   string fastpath_execute_params =
939       strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
940   string fallback_params;
941 
942   for (int i = 0; i < api_def_.in_arg_size(); i++) {
943     const string param_name = param_names_[i].GetRenameTo();
944     strings::StrAppend(&fastpath_execute_params, ", ", param_name);
945     if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
946     strings::StrAppend(&fallback_params, param_name);
947   }
948 
949   for (const auto& attr : api_def_.attr()) {
950     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
951       strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
952                          attr.rename_to());
953 
954       if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
955       strings::StrAppend(&fallback_params, attr.rename_to(), "=",
956                          attr.rename_to());
957     }
958   }
959 
960   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
961   strings::StrAppend(&fallback_params, "name=name");
962 
963   strings::StrAppend(&result_, "    try:\n");
964   strings::StrAppend(
965       &result_, "      ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
966       WordWrap(strings::StrCat("        "),
967                strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
968       "\n");
969 
970   if (op_def_.output_arg_size() > 1) {
971     const string output_tuple_name = strings::StrCat(
972         "_", python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
973         "Output");
974     strings::StrAppend(&result_, "      ", "_result = ", output_tuple_name,
975                        "._make(_result)\n");
976   }
977   strings::StrAppend(&result_, "      ", "return _result\n");
978 
979   // Handle fallback.
980   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
981   strings::StrAppend(&fallback_params, "ctx=_ctx");
982 
983   // Any errors thrown from execute need to be unwrapped from
984   // _NotOkStatusException.
985   strings::StrAppend(&result_, "    ",
986                      "except _core._NotOkStatusException as e:\n");
987   strings::StrAppend(&result_, "      ",
988                      "_ops.raise_from_not_ok_status(e, name)\n");
989 
990   strings::StrAppend(&result_, "    ", "except _core._FallbackException:\n");
991   strings::StrAppend(&result_, "      pass\n");
992   strings::StrAppend(&result_, "    try:\n");
993   AddTypeBasedDispatch("      ");
994   strings::StrAppend(
995       &result_, "      ", "return ", function_name_, kEagerFallbackSuffix,
996       "(\n",
997       WordWrap(strings::StrCat("          "),
998                strings::StrCat(fallback_params, ")"), kRightMargin),
999       "\n");
1000   strings::StrAppend(&result_, "    except _core._SymbolicException:\n");
1001   strings::StrAppend(&result_,
1002                      "      pass  # Add nodes to the TensorFlow graph.\n");
1003   AddFallbackDispatch("    ");
1004 }
1005 
AddEagerInferredAttrs(const string & indentation)1006 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
1007   // Figure out values for inferred attrs, and cast to eager tensors.
1008   for (int i = 0; i < op_def_.attr_size(); ++i) {
1009     const auto& attr(op_def_.attr(i));
1010     const auto& api_def_attr(api_def_.attr(i));
1011     auto arg_list = attr_to_args_.find(attr.name());
1012     if (arg_list != attr_to_args_.end()) {
1013       if (attr.type() == "type") {
1014         std::vector<string> output_sizes;
1015         const string flattened =
1016             FlattenInputs(&arg_list->second, &output_sizes);
1017         string conversion = strings::StrCat("_execute.args_to_matching_eager(",
1018                                             flattened, ", ctx");
1019 
1020         strings::StrAppend(&conversion, ", [");
1021         for (int t : attr.allowed_values().list().type()) {
1022           DataType dtype = static_cast<DataType>(t);
1023           const string py_dtype =
1024               python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
1025           strings::StrAppend(&conversion, py_dtype, ", ");
1026         }
1027         strings::StrAppend(&conversion, "]");
1028 
1029         if (attr.has_default_value()) {
1030           strings::StrAppend(
1031               &conversion, ", ",
1032               python_op_gen_internal::AttrValueToPython(
1033                   attr.type(), api_def_attr.default_value(), "_dtypes."));
1034         }
1035         strings::StrAppend(&conversion, ")");
1036         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1037         if (output_sizes.size() == 1) {
1038           // Avoid creating a temporary variable in the case where
1039           // we can easily assign to the right value directly.
1040           const string inputs_var =
1041               param_names_[arg_list->second.front()].GetRenameTo();
1042           if (output_sizes.front().empty()) {
1043             strings::StrAppend(&result_, indentation, var_name, ", (",
1044                                inputs_var, ",) = ", conversion, "\n");
1045           } else {
1046             strings::StrAppend(&result_, indentation, var_name, ", ",
1047                                inputs_var, " = ", conversion, "\n");
1048           }
1049         } else {
1050           const string inputs_var = strings::StrCat("_inputs_", attr.name());
1051           strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1052                              " = ", conversion, "\n");
1053           // Convert from a flat list of eager tensors back to the
1054           // parameter variables.
1055           Unflatten(indentation, output_sizes, inputs_var, &result_);
1056           std::vector<string> p;
1057           for (int j : arg_list->second) {
1058             p.emplace_back(param_names_[j].GetRenameTo());
1059           }
1060           strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
1061                              inputs_var, "\n");
1062         }
1063       } else if (attr.type() == "list(type)") {
1064         // NOTE: We ignore default values for these attrs, since it is
1065         // unclear how you would use it, and the one use case is
1066         // parse_single_sequence_example which only needs it for
1067         // backwards compatibility.
1068         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1069         string inputs_var;
1070         string conversion;
1071         if (arg_list->second.size() > 1) {
1072           // If you have more than one list(tensor) argument, their types
1073           // have to match.
1074           std::vector<string> lists;
1075           for (auto iter = arg_list->second.begin();
1076                iter != arg_list->second.end(); ++iter) {
1077             lists.push_back(param_names_[*iter].GetRenameTo());
1078           }
1079           inputs_var = VectorToTuple(lists);
1080           conversion = "_execute.args_to_mixed_eager_tensors";
1081         } else {
1082           // For one list(tensor) argument, we just convert every
1083           // element of the list to an eager tensor.
1084           inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
1085           conversion = "_execute.convert_to_mixed_eager_tensors";
1086         }
1087         strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1088                            " = ", conversion, "(", inputs_var, ", ctx)\n");
1089       }
1090     }
1091   }
1092 }
1093 
AddEagerInputCasts(const string & indentation)1094 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
1095   // Cast remaining args to eager tensors
1096   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
1097     const auto& arg(op_def_.input_arg(i));
1098     if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
1099     const string& param = param_names_[i].GetRenameTo();
1100     const string fn = arg.number_attr().empty() ? "" : "n_";
1101     const string dtype =
1102         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1103     strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
1104                        "to_tensor(", param, ", ", dtype, ")\n");
1105   }
1106 }
1107 
AddEagerAttrs(const string & indentation)1108 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
1109   // Compute eager attrs
1110   if (op_def_.attr_size() > 0) {
1111     string attr_values;
1112     for (int i = 0; i < op_def_.attr_size(); ++i) {
1113       if (i > 0) strings::StrAppend(&attr_values, ", ");
1114       const auto& attr_name(op_def_.attr(i).name());
1115       strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
1116                          attr_expressions_[attr_name]);
1117     }
1118     strings::StrAppend(&attr_values, ")");
1119     strings::StrAppend(
1120         &result_,
1121         WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
1122                  kRightMargin),
1123         "\n");
1124   } else {
1125     strings::StrAppend(&result_, indentation, "_attrs = None\n");
1126   }
1127 }
1128 
AddEagerExecute(const string & indentation,const string & num_outputs_expr)1129 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
1130                                        const string& num_outputs_expr) {
1131   const string return_prefix =
1132       strings::StrCat(indentation, "_result = _execute.execute(");
1133   const string return_args = strings::StrCat(
1134       "b\"", op_def_.name(), "\", ", num_outputs_expr,
1135       ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)");
1136   strings::StrAppend(&result_,
1137                      // Wrap the arguments, and indent to the (.
1138                      WordWrap(return_prefix, return_args, kRightMargin), "\n");
1139 }
1140 
AddFallbackDispatch(const string & prefix)1141 void GenEagerPythonOp::AddFallbackDispatch(const string& prefix) {
1142   if (api_def_.visibility() != ApiDef::VISIBLE) return;
1143 
1144   strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
1145   strings::StrAppend(&result_, prefix, "  _result = _dispatch.dispatch(\n");
1146   AddBodyNoReturn(strings::StrCat(prefix, "        ", function_name_,
1147                                   ", "
1148                                   "(), dict("));
1149   strings::StrAppend(&result_, prefix, "      )\n");
1150   strings::StrAppend(&result_, prefix,
1151                      "  if _result is not "
1152                      "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
1153   strings::StrAppend(&result_, prefix, "    return _result\n");
1154   strings::StrAppend(&result_, prefix, "  raise\n");
1155 }
1156 
AddTypeBasedDispatcherAlias()1157 void GenEagerPythonOp::AddTypeBasedDispatcherAlias() {
1158   // It's possible for the name of a parameter to be the same as the name of
1159   // an op, in which case the parameter shadows the op's function.  To avoid
1160   // this, we add a private variable with the dispatcher, and access that
1161   // directly.
1162   if (api_def_.visibility() == ApiDef::VISIBLE) {
1163     strings::StrAppend(&result_, "_dispatcher_for_", function_name_,
1164                        " = ", function_name_,
1165                        "._tf_type_based_dispatcher.Dispatch\n");
1166   }
1167 }
AddTypeBasedDispatch(const string & prefix)1168 void GenEagerPythonOp::AddTypeBasedDispatch(const string& prefix) {
1169   if (api_def_.visibility() != ApiDef::VISIBLE) return;
1170   std::string args("(");
1171   for (const auto& name : param_names_) {
1172     strings::StrAppend(&args, name.GetRenameTo(), ", ");
1173   }
1174   strings::StrAppend(&args, "name,), None");
1175 
1176   strings::StrAppend(
1177       &result_, prefix, "_result = ", "_dispatcher_for_", function_name_, "(\n",
1178       WordWrap(strings::StrCat(prefix, "    "), args, kRightMargin), ")\n");
1179   strings::StrAppend(&result_, prefix, "if _result is not NotImplemented:\n",
1180                      prefix, "  return _result\n");
1181 }
1182 
AddRawOpExport(const string & parameters)1183 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
1184   // Example:
1185   //
1186   // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity))
1187   const string raw_function_name =
1188       python_op_gen_internal::AvoidPythonReserved(op_def_.name());
1189   strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops.",
1190                      raw_function_name, "\")", "(_ops.to_raw_op(",
1191                      function_name_, "))\n");
1192 }
1193 
GetPythonOpsImpl(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name="",const std::unordered_set<string> type_annotate_ops={})1194 string GetPythonOpsImpl(
1195     const OpList& ops, const ApiDefMap& api_defs,
1196     const std::vector<string>& hidden_ops, const string& source_file_name = "",
1197     const std::unordered_set<string> type_annotate_ops = {}) {
1198   string result;
1199   // Header
1200   // TODO(josh11b): Mention the library for which wrappers are being generated.
1201   strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
1202 
1203 This file is MACHINE GENERATED! Do not edit.
1204 )");
1205 
1206   // Mention the original source file so someone tracing back through
1207   // generated Python code will know where to look next.
1208   if (!source_file_name.empty()) {
1209     strings::StrAppend(&result, "Original C++ source file: ");
1210     strings::StrAppend(&result, source_file_name);
1211     strings::StrAppend(&result, "\n");
1212   }
1213 
1214   strings::StrAppend(&result, R"("""
1215 
1216 import collections
1217 
1218 from tensorflow.python import pywrap_tfe as pywrap_tfe
1219 from tensorflow.python.eager import context as _context
1220 from tensorflow.python.eager import core as _core
1221 from tensorflow.python.eager import execute as _execute
1222 from tensorflow.python.framework import dtypes as _dtypes
1223 
1224 from tensorflow.python.framework import op_def_registry as _op_def_registry
1225 from tensorflow.python.framework import ops as _ops
1226 from tensorflow.python.framework import op_def_library as _op_def_library
1227 from tensorflow.python.util.deprecation import deprecated_endpoints
1228 from tensorflow.python.util import dispatch as _dispatch
1229 from tensorflow.python.util.tf_export import tf_export
1230 
1231 from typing import TypeVar
1232 )");
1233 
1234   for (const auto& op_def : ops.op()) {
1235     const auto* api_def = api_defs.GetApiDef(op_def.name());
1236 
1237     if (api_def->visibility() == ApiDef::SKIP) {
1238       continue;
1239     }
1240     // An op is hidden if either its ApiDef visibility is HIDDEN
1241     // or it is in the hidden_ops list.
1242     bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1243     bool hidden_by_api_def = is_hidden;
1244     if (!is_hidden) {
1245       for (const string& hidden : hidden_ops) {
1246         if (op_def.name() == hidden) {
1247           is_hidden = true;
1248           break;
1249         }
1250       }
1251     }
1252 
1253     string function_name;
1254     python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1255                                                     &function_name);
1256     bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1257 
1258     // Prefix an op with underscore if the op is listed in hidden_ops or
1259     // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1260     // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1261     // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1262     if (is_hidden) {
1263       if (!hidden_by_api_def || is_reserved ||
1264           python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1265         function_name = strings::StrCat("_", function_name);
1266       }
1267     } else if (is_reserved) {
1268       // When users create custom python wrappers, they may link in the
1269       // default op registry by accident, and because they can't
1270       // enumerate all 'hidden' symbols, this guard is to prevent
1271       // instantiating a python reserved word in their wrapper.
1272       continue;
1273     }
1274 
1275     auto iter = type_annotate_ops.find(op_def.name());
1276     bool add_type_annotations = iter != type_annotate_ops.end();
1277 
1278     strings::StrAppend(&result,
1279                        GetEagerPythonOp(op_def, *api_def, function_name,
1280                                         add_type_annotations));
1281   }
1282 
1283   return result;
1284 }
1285 
1286 }  // namespace
1287 
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1288 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1289                     const std::vector<string>& hidden_ops,
1290                     const string& source_file_name,
1291                     const std::unordered_set<string> type_annotate_ops) {
1292   return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1293                           type_annotate_ops);
1294 }
1295 
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1296 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1297                     const std::vector<string>& hidden_ops,
1298                     const string& source_file_name,
1299                     const std::unordered_set<string> type_annotate_ops) {
1300   printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1301                                 type_annotate_ops)
1302                    .c_str());
1303 }
1304 
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1305 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1306   OpList ops;
1307   ops.ParseFromArray(op_list_buf, op_list_len);
1308 
1309   ApiDefMap api_def_map(ops);
1310   return GetPythonOpsImpl(ops, api_def_map, {});
1311 }
1312 
GetArgAnnotation(const OpDef::ArgDef & arg,const std::unordered_map<string,string> & type_annotations)1313 string GetArgAnnotation(
1314     const OpDef::ArgDef& arg,
1315     const std::unordered_map<string, string>& type_annotations) {
1316   if (!arg.type_attr().empty()) {
1317     // Get the correct TypeVar if arg maps to an attr
1318     return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
1319   } else {
1320     // Get the dtype of the Tensor
1321     const string py_dtype =
1322         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1323     return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
1324   }
1325 
1326   return "Any";
1327 }
1328 
1329 }  // namespace tensorflow
1330