1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16
17 #include <stdio.h>
18
19 #include <sstream>
20 #include <unordered_map>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/op_gen_lib.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/python/framework/python_op_gen_internal.h"
40
41 namespace tensorflow {
42 namespace {
43
44 const int kRightMargin = 78;
45
46 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
47
48 // Maps C++ dtype enum values to Python DType classes
49 const std::unordered_map<string, string> dtype_type{
50 {"_dtypes.float16", "_dtypes.Float16"},
51 {"_dtypes.half", "_dtypes.Half"},
52 {"_dtypes.float32", "_dtypes.Float32"},
53 {"_dtypes.float64", "_dtypes.Float64"},
54 {"_dtypes.bfloat16", "_dtypes.BFloat16"},
55 {"_dtypes.complex64", "_dtypes.Complex64"},
56 {"_dtypes.complex128", "_dtypes.Complex128"},
57 {"_dtypes.int8", "_dtypes.Int8"},
58 {"_dtypes.uint8", "_dtypes.UInt8"},
59 {"_dtypes.uint16", "_dtypes.UInt16"},
60 {"_dtypes.uint32", "_dtypes.UInt32"},
61 {"_dtypes.uint64", "_dtypes.UInt64"},
62 {"_dtypes.int16", "_dtypes.Int16"},
63 {"_dtypes.int32", "_dtypes.Int32"},
64 {"_dtypes.int64", "_dtypes.Int64"},
65 {"_dtypes.bool", "_dtypes.Bool"},
66 {"_dtypes.string", "_dtypes.String"},
67 {"_dtypes.qint8", "_dtypes.QInt8"},
68 {"_dtypes.quint8", "_dtypes.QUInt8"},
69 {"_dtypes.qint16", "_dtypes.QInt16"},
70 {"_dtypes.quint16", "_dtypes.QUInt16"},
71 {"_dtypes.qint32", "_dtypes.QInt32"},
72 {"_dtypes.resource", "_dtypes.Resource"},
73 {"_dtypes.variant", "_dtypes.Variant"}};
74
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)75 string AttrVarName(const string& attr_name,
76 std::unordered_map<string, string>* attr_expressions) {
77 const string var = strings::StrCat("_attr_", attr_name);
78 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
79 return var;
80 }
81
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)82 void AddInferredAttr(const string& indentation, const string& attr_name,
83 const string& value_expression, string* result,
84 std::unordered_map<string, string>* attr_expressions) {
85 strings::StrAppend(result, indentation,
86 AttrVarName(attr_name, attr_expressions), " = ",
87 value_expression, "\n");
88 }
89
VectorToTuple(const std::vector<string> & l)90 string VectorToTuple(const std::vector<string>& l) {
91 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
92 string ret = "(";
93 for (int i = 0, end = l.size(); i < end; ++i) {
94 if (i > 0) {
95 strings::StrAppend(&ret, ", ");
96 }
97 strings::StrAppend(&ret, l[i]);
98 }
99 strings::StrAppend(&ret, ")");
100 return ret;
101 }
102
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)103 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
104 const string& var, string* result) {
105 for (int i = 0, end = output_sizes.size(); i < end; ++i) {
106 if (!output_sizes[i].empty()) {
107 strings::StrAppend(result, prefix, var, " = ");
108 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
109 if (i + 1 < end) {
110 // Special case i == 0 to avoid "0 +" in the generated code.
111 if (i == 0) {
112 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
113 var, "[", output_sizes[i], ":]");
114 } else {
115 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
116 output_sizes[i], "]] + ", var, "[", i, " + ",
117 output_sizes[i], ":]");
118 }
119 } else {
120 strings::StrAppend(result, "[", var, "[", i, ":]]");
121 }
122 strings::StrAppend(result, "\n");
123 }
124 }
125 }
126
TensorPBString(const TensorProto & pb)127 string TensorPBString(const TensorProto& pb) {
128 // Note: This gets used in the argument list, and so must survive naive
129 // word wrapping.
130 return strings::StrCat("\"\"\"", pb.ShortDebugString(), "\"\"\"");
131 }
132
133 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
134 public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)135 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
136 const string& function_name, bool add_type_annotations)
137 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
138 add_type_annotations) {
139 op_name_ = function_name_;
140 absl::ConsumePrefix(&op_name_, "_");
141 }
~GenEagerPythonOp()142 ~GenEagerPythonOp() override {}
143
144 string Code() override;
145
146 protected:
147 void HandleGraphMode(const string& function_setup,
148 const std::vector<string>& output_sizes);
149
150 string GetEagerNotAllowedError();
151 void ExpectListArg(const string& indentation, const string& arg_name,
152 string* output);
153 bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
154 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
155 string* num_outputs_expr);
156
157 void AddEagerFunctionTeardown(const string& indentation,
158 const std::vector<string>& output_sizes,
159 bool execute_record_gradient);
160
161 bool AddEagerFastPathAndGraphCode(
162 const string& parameters, const std::vector<string>& output_sizes,
163 const string& eager_not_allowed_error,
164 const std::unordered_map<string, string>& type_annotations);
165 bool AddEagerFallbackCode(
166 const string& parameters, const std::vector<string>& output_sizes,
167 const string& num_outputs_expr, const string& eager_not_allowed_error,
168 const std::unordered_map<string, string>& type_annotations);
169 void AddEagerFastPathExecute();
170
171 void AddEagerInferredAttrs(const string& indentation);
172 void AddEagerInputCasts(const string& indentation);
173 void AddEagerAttrs(const string& indentation);
174 void AddEagerExecute(const string& indentation,
175 const string& num_outputs_expr);
176 void AddFallbackDispatch(const string& prefix);
177 void AddTypeBasedDispatch(const string& prefix);
178 void AddTypeBasedDispatcherAlias();
179
180 void AddRawOpExport(const string& parameters);
181
182 std::unordered_map<string, string> GetTypeAnnotations();
183
184 void GenerateTypeVars(
185 const std::unordered_map<string, string>& type_annotations);
186
187 void AddReturnTypeAnnotation(
188 const std::unordered_map<string, string>& type_annotations);
189
AddAttrForArg(const string & attr,int arg_index)190 void AddAttrForArg(const string& attr, int arg_index) {
191 gtl::InsertIfNotPresent(&inferred_attrs_, attr,
192 op_def_.input_arg(arg_index).name());
193 auto iter = attr_to_args_.find(attr);
194 if (iter == attr_to_args_.end()) {
195 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
196 } else {
197 iter->second.push_back(arg_index);
198 }
199 }
200
201 // Returns a string expression representing a flattened list of all
202 // the inputs given by `*input_indices` (or all inputs if
203 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
204 string FlattenInputs(const std::vector<int>* input_indices,
205 std::vector<string>* output_sizes) const;
206
207 StringPiece op_name_;
208 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
209 AttrToArgMap attr_to_args_;
210 std::unordered_map<string, string> attr_expressions_;
211 // This has all the input args followed by those attrs that don't have
212 // defaults.
213 std::vector<python_op_gen_internal::ParamNames> params_no_default_;
214 // The parameters with defaults (these have to be listed after those without).
215 // No input args are included, just attrs.
216 std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
217 params_with_default_;
218 };
219
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)220 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
221 const string& function_name,
222 bool add_type_annotations) {
223 return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
224 .Code();
225 }
226
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const227 string GenEagerPythonOp::FlattenInputs(
228 const std::vector<int>* input_indices,
229 std::vector<string>* output_sizes) const {
230 string inputs;
231 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
232 const int n = input_indices != nullptr ? input_indices->size()
233 : op_def_.input_arg_size();
234 for (int j = 0; j < n; ++j) {
235 const int i = input_indices ? (*input_indices)[j] : j;
236 const auto& arg(op_def_.input_arg(i));
237 const bool is_list =
238 !arg.type_list_attr().empty() || !arg.number_attr().empty();
239 if (is_list) {
240 if (inputs_state == WAS_SOLO_INPUT) {
241 strings::StrAppend(&inputs, "] + ");
242 } else if (inputs_state == WAS_LIST_INPUT) {
243 strings::StrAppend(&inputs, " + ");
244 }
245 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
246 inputs_state = WAS_LIST_INPUT;
247 if (output_sizes != nullptr) {
248 if (!arg.number_attr().empty()) {
249 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
250 } else {
251 output_sizes->emplace_back(
252 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
253 }
254 }
255 } else {
256 if (inputs_state == WAS_SOLO_INPUT) {
257 strings::StrAppend(&inputs, ", ");
258 } else if (inputs_state == WAS_LIST_INPUT) {
259 strings::StrAppend(&inputs, " + [");
260 } else {
261 strings::StrAppend(&inputs, "[");
262 }
263 strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
264 inputs_state = WAS_SOLO_INPUT;
265 if (output_sizes != nullptr) output_sizes->emplace_back();
266 }
267 }
268 if (inputs_state == STARTING) return "[]";
269 if (inputs_state == WAS_SOLO_INPUT) {
270 strings::StrAppend(&inputs, "]");
271 }
272 return inputs;
273 }
274
Code()275 string GenEagerPythonOp::Code() {
276 if (api_def_.visibility() == ApiDef::SKIP) {
277 return "";
278 }
279
280 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
281 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
282 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
283 params_no_default_.emplace_back(api_def_arg.name(),
284 api_def_arg.rename_to());
285 if (!arg.type_attr().empty()) {
286 AddAttrForArg(arg.type_attr(), i);
287 } else if (!arg.type_list_attr().empty()) {
288 AddAttrForArg(arg.type_list_attr(), i);
289 }
290 if (!arg.number_attr().empty()) {
291 AddAttrForArg(arg.number_attr(), i);
292 }
293 }
294 for (int i = 0; i < op_def_.attr_size(); ++i) {
295 const auto& attr(op_def_.attr(i));
296 const auto& api_def_attr(api_def_.attr(i));
297 // Do not add inferred attrs to the Python function signature.
298 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
299 if (api_def_attr.has_default_value()) {
300 if (attr.type() == "tensor") {
301 params_with_default_.emplace_back(
302 python_op_gen_internal::ParamNames(api_def_attr.name(),
303 api_def_attr.rename_to()),
304 strings::StrCat(
305 "_execute.make_tensor(",
306 TensorPBString(api_def_attr.default_value().tensor()), ", \"",
307 api_def_attr.rename_to(), "\")"));
308 } else if (attr.type() == "list(tensor)") {
309 std::vector<string> pbtxt;
310 for (const auto& pb : api_def_attr.default_value().list().tensor()) {
311 pbtxt.emplace_back(TensorPBString(pb));
312 }
313 params_with_default_.emplace_back(
314 python_op_gen_internal::ParamNames(api_def_attr.name(),
315 api_def_attr.rename_to()),
316 strings::StrCat("[_execute.make_tensor(_pb, \"",
317 api_def_attr.rename_to(), "\") for _pb in ",
318 VectorToTuple(pbtxt), "]"));
319 } else {
320 params_with_default_.emplace_back(
321 python_op_gen_internal::ParamNames(api_def_attr.name(),
322 api_def_attr.rename_to()),
323 python_op_gen_internal::AttrValueToPython(
324 attr.type(), api_def_attr.default_value(), "_dtypes."));
325 }
326 } else {
327 params_no_default_.emplace_back(api_def_attr.name(),
328 api_def_attr.rename_to());
329 }
330 }
331 }
332
333 // Save the list of attr parameters (attrs that won't be inferred),
334 // those with defaults go at the end.
335 // Get the attrs in the order we want by taking the attrs without defaults
336 // from the end of params_no_default_, and adding params_no_default_.
337 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
338 params_with_default_.size());
339 for (int i = op_def_.input_arg_size(), end = params_no_default_.size();
340 i < end; ++i) {
341 attrs_.push_back(params_no_default_[i].GetName());
342 }
343 for (const auto& p : params_with_default_) {
344 attrs_.push_back(p.first.GetName());
345 }
346
347 // TODO(slebedev): call AvoidPythonReserved on each param?
348 param_names_.reserve(params_no_default_.size() + params_with_default_.size());
349 param_names_.insert(param_names_.begin(), params_no_default_.begin(),
350 params_no_default_.end());
351 for (const auto& param_and_default : params_with_default_) {
352 param_names_.push_back(param_and_default.first);
353 }
354
355 std::unordered_map<string, string> type_annotations;
356 // Only populate map for allowlisted ops
357 if (add_type_annotations_) {
358 type_annotations = GetTypeAnnotations();
359 }
360
361 string parameters;
362 // Param can be an input or an attr
363 for (const auto& param : params_no_default_) {
364 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
365 strings::StrAppend(¶meters, param.GetRenameTo());
366
367 if (type_annotations.find(param.GetName()) != type_annotations.end()) {
368 strings::StrAppend(¶meters, ": ",
369 type_annotations.at(param.GetName()));
370 }
371 }
372
373 string parameters_with_defaults = parameters;
374 for (const auto& param_and_default : params_with_default_) {
375 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
376 if (!parameters_with_defaults.empty())
377 strings::StrAppend(¶meters_with_defaults, ", ");
378
379 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
380 strings::StrAppend(¶meters_with_defaults,
381 param_and_default.first.GetRenameTo());
382 if (type_annotations.find(param_and_default.first.GetName()) !=
383 type_annotations.end()) {
384 const string param_type =
385 type_annotations.at(param_and_default.first.GetName());
386 // Append to parameters and parameters_with_defaults because multiple
387 // functions are generated by AddEagerFastPathAndGraphCode() and
388 // AddEagerFallbackCode()
389 strings::StrAppend(¶meters, ": ", param_type);
390 strings::StrAppend(¶meters_with_defaults, ":", param_type);
391 }
392
393 strings::StrAppend(¶meters_with_defaults, "=",
394 param_and_default.second);
395 }
396
397 strings::StrAppend(¶meters, parameters.empty() ? "" : ", ", "name");
398 strings::StrAppend(¶meters_with_defaults,
399 parameters_with_defaults.empty() ? "" : ", ", "name=None");
400
401 // Add attr_expressions_ for attrs that are params.
402 for (int i = 0, end = attrs_.size(); i < end; ++i) {
403 const string& attr_name = attrs_[i];
404 const string& attr_api_name =
405 param_names_[i + op_def_.input_arg_size()].GetRenameTo();
406 attr_expressions_[attr_name] = attr_api_name;
407 }
408 // Add attr_expressions_ for attrs that are inferred.
409 for (int i = 0; i < op_def_.attr_size(); ++i) {
410 const auto& attr(op_def_.attr(i));
411 if (attr.type() == "int") {
412 auto arg_list = attr_to_args_.find(attr.name());
413 if (arg_list != attr_to_args_.end()) {
414 AttrVarName(attr.name(), &attr_expressions_);
415 }
416 }
417 }
418
419 string num_outputs_expr;
420 std::vector<string> output_sizes(num_outs_);
421 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
422
423 string eager_not_allowed_error = GetEagerNotAllowedError();
424
425 if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
426 eager_not_allowed_error,
427 type_annotations)) {
428 return result_;
429 }
430
431 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
432 eager_not_allowed_error, type_annotations)) {
433 return result_;
434 }
435
436 return prelude_ + result_;
437 }
438
GetTypeAnnotations()439 std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
440 std::unordered_map<string, string> type_annotations;
441 // Map attrs to TypeVars
442 for (const auto& attr : op_def_.attr()) {
443 if (attr.type() == "type") {
444 const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
445 type_annotations[attr.name()] = type_var_name;
446 } else if (attr.type() == "bool" || attr.type() == "float" ||
447 attr.type() == "int" || attr.type() == "bytes") {
448 type_annotations[attr.name()] = attr.type();
449 } else if (attr.type() == "string") {
450 type_annotations[attr.name()] = "str";
451 }
452 }
453
454 // Map input Tensors to their types
455 for (const auto& arg : op_def_.input_arg()) {
456 // TODO(rahulkamat): Add type annotations to args that accept a sequence of
457 // Tensors
458 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
459 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
460 }
461
462 // TODO(rahulkamat): Add type annotations to handle return types of a sequence
463 // of Tensors. Map output Tensor to its type
464 if (op_def_.output_arg_size() == 1) {
465 const auto& arg = op_def_.output_arg(0);
466 if (arg.number_attr().empty() && arg.type_list_attr().empty())
467 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
468 }
469
470 return type_annotations;
471 }
472
473 // Generate TypeVars using attrs
GenerateTypeVars(const std::unordered_map<string,string> & type_annotations)474 void GenEagerPythonOp::GenerateTypeVars(
475 const std::unordered_map<string, string>& type_annotations) {
476 bool added_typevar = false;
477 for (const auto& attr : op_def_.attr()) {
478 if (attr.type() == "type") {
479 std::vector<string> allowed_types;
480 for (int t : attr.allowed_values().list().type()) {
481 DataType dtype = static_cast<DataType>(t);
482 const string py_dtype =
483 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
484 allowed_types.emplace_back(dtype_type.at(py_dtype));
485 }
486
487 // When a Tensor does not have any dtypes specified, all dtypes are
488 // allowed
489 if (allowed_types.empty()) {
490 for (std::pair<string, string> map_dtype : dtype_type) {
491 allowed_types.emplace_back(map_dtype.second);
492 }
493 }
494
495 std::sort(allowed_types.begin(), allowed_types.end());
496
497 string typevar_dtypes;
498 for (std::vector<string>::iterator it = allowed_types.begin();
499 it != allowed_types.end(); ++it) {
500 if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
501 strings::StrAppend(&typevar_dtypes, *it);
502 }
503
504 const string type_var_name = type_annotations.at(attr.name());
505 strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
506 type_var_name, "\", ", typevar_dtypes, ")\n");
507 added_typevar = true;
508 }
509 }
510
511 if (added_typevar) strings::StrAppend(&result_, "\n");
512 }
513
AddReturnTypeAnnotation(const std::unordered_map<string,string> & type_annotations)514 void GenEagerPythonOp::AddReturnTypeAnnotation(
515 const std::unordered_map<string, string>& type_annotations) {
516 if (op_def_.output_arg_size() == 1) {
517 const auto& arg = op_def_.output_arg(0);
518 if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
519 const string return_type = type_annotations.at(arg.name());
520 // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
521 // avoid erasing ":\n" from the end of the def line
522 result_.erase(result_.length() - 2);
523 strings::StrAppend(&result_, " -> ", return_type, ":\n");
524 }
525 }
526 }
527
HandleGraphMode(const string & function_setup,const std::vector<string> & output_sizes)528 void GenEagerPythonOp::HandleGraphMode(
529 const string& function_setup, const std::vector<string>& output_sizes) {
530 if (api_def_.visibility() == ApiDef::VISIBLE) {
531 strings::StrAppend(&result_, " else:\n");
532 AddTypeBasedDispatch(" ");
533 }
534 strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
535 strings::StrAppend(&result_, function_setup);
536 if (api_def_.visibility() == ApiDef::VISIBLE) {
537 strings::StrAppend(&result_, " try:\n ");
538 }
539 strings::StrAppend(
540 &result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
541 AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
542 AddFallbackDispatch(" ");
543
544 if (num_outs_ > 0) {
545 strings::StrAppend(&result_, " _result = _outputs[:]\n");
546 // Special case handling for stateful op with single list output
547 // that might be empty.
548 if (num_outs_ == 1 && op_def_.is_stateful() &&
549 (!op_def_.output_arg(0).number_attr().empty() ||
550 !op_def_.output_arg(0).type_list_attr().empty())) {
551 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
552 // a constraint indicating that this can never be empty.
553 strings::StrAppend(&result_,
554 " if not _result:\n"
555 " return _op\n");
556 }
557
558 // Compute graph-mode attrs when we need to record a gradient.
559 strings::StrAppend(&result_, " if _execute.must_record_gradient():\n");
560 if (op_def_.attr_size() > 0) {
561 string attr_values;
562 for (int i = 0; i < op_def_.attr_size(); ++i) {
563 if (i > 0) strings::StrAppend(&attr_values, ", ");
564 const auto& attr_name(op_def_.attr(i).name());
565 if (op_def_.attr(i).type() == "type") {
566 strings::StrAppend(&attr_values, "\"", attr_name,
567 "\", _op._get_attr_type(\"", attr_name, "\")");
568 } else if (op_def_.attr(i).type() == "bool") {
569 strings::StrAppend(&attr_values, "\"", attr_name,
570 "\", _op._get_attr_bool(\"", attr_name, "\")");
571 } else if (op_def_.attr(i).type() == "int") {
572 strings::StrAppend(&attr_values, "\"", attr_name,
573 "\", _op._get_attr_int(\"", attr_name, "\")");
574 } else {
575 strings::StrAppend(&attr_values, "\"", attr_name,
576 "\", _op.get_attr(\"", attr_name, "\")");
577 }
578 }
579 strings::StrAppend(&attr_values, ")");
580 strings::StrAppend(&result_,
581 WordWrap(" _attrs = (", attr_values, kRightMargin),
582 "\n");
583
584 } else {
585 strings::StrAppend(&result_, " _attrs = ()\n");
586 }
587
588 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
589 strings::StrAppend(&result_, " _execute.record_gradient(\n",
590 " \"", op_def_.name(),
591 "\", _inputs_flat, _attrs, _result)\n");
592
593 if (num_outs_ == 1 && !output_sizes[0].empty()) {
594 // Single list result.
595 } else if (num_outs_ == 1) {
596 // Execute returns a single-element list which we need to destructure.
597 strings::StrAppend(&result_, " ", "_result, = _result\n");
598 } else {
599 // Have multiple outputs, so we will need to reformat the return
600 // value of execute() to be a list with one entry per op output
601 // (that entry will be a list of tensors if that output is of list
602 // type).
603 // For list outputs, convert the right subrange of _result into a list.
604 Unflatten(" ", output_sizes, "_result", &result_);
605 // Convert to a named tuple.
606 strings::StrAppend(
607 &result_, " _result = _",
608 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
609 "Output._make(_result)\n");
610 }
611 strings::StrAppend(&result_, " return _result\n\n");
612 } else {
613 strings::StrAppend(&result_, " return _op\n");
614 }
615 }
616
GetEagerNotAllowedError()617 string GenEagerPythonOp::GetEagerNotAllowedError() {
618 bool eager_allowed = true;
619 string ref_arg;
620 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
621 const auto& arg = op_def_.input_arg(i);
622 if (arg.is_ref()) {
623 eager_allowed = false;
624 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
625 ref_arg = api_def_.in_arg(i).rename_to();
626 }
627 }
628 for (int i = 0; i < op_def_.output_arg_size(); ++i) {
629 const auto& arg = op_def_.output_arg(i);
630 if (arg.is_ref()) {
631 eager_allowed = false;
632 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
633 ref_arg = api_def_.out_arg(i).rename_to();
634 }
635 }
636
637 if (eager_allowed) return "";
638
639 return strings::StrCat("raise RuntimeError(\"", op_name_,
640 " op does not support eager execution. ", "Arg '",
641 ref_arg, "' is a ref.\")\n");
642 }
643
ExpectListArg(const string & indentation,const string & arg_name,string * output)644 void GenEagerPythonOp::ExpectListArg(const string& indentation,
645 const string& arg_name, string* output) {
646 strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
647 ", (list, tuple)):\n", indentation, " raise TypeError(\n",
648 indentation, " \"Expected list for '", arg_name,
649 "' argument to \"\n", indentation, " \"'", op_name_,
650 "' Op, not %r.\" % ", arg_name, ")\n");
651 }
652
GetEagerFunctionSetup(const string & indentation,string * function_setup)653 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
654 string* function_setup) {
655 // Validate list inputs, infer length attrs.
656 for (int i = 0; i < op_def_.attr_size(); ++i) {
657 const auto& attr(op_def_.attr(i));
658 if (attr.type() == "int") {
659 auto arg_list = attr_to_args_.find(attr.name());
660 if (arg_list != attr_to_args_.end()) {
661 // Inferred int attrs are the lengths of inputs. Validate those
662 // inputs are lists and have the same length.
663 for (auto iter = arg_list->second.begin();
664 iter != arg_list->second.end(); ++iter) {
665 const string& arg_api_name = param_names_[*iter].GetRenameTo();
666 ExpectListArg(indentation, arg_api_name, function_setup);
667 if (iter == arg_list->second.begin()) {
668 AddInferredAttr(indentation, attr.name(),
669 strings::StrCat("len(", arg_api_name, ")"),
670 function_setup, &attr_expressions_);
671 } else {
672 const auto& attr_var = attr_expressions_[attr.name()];
673 strings::StrAppend(
674 function_setup, indentation, "if len(", arg_api_name,
675 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
676 indentation, " \"List argument '", arg_api_name, "' to '",
677 op_name_, "' Op with length %d \"\n", indentation,
678 " \"must match length %d of argument '",
679 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
680 " (len(", arg_api_name, "), ", attr_var, "))\n");
681 }
682 }
683 }
684 }
685 }
686
687 for (int i = 0, end = attrs_.size(); i < end; ++i) {
688 const string& attr_name = attrs_[i];
689 const auto& param = param_names_[i + op_def_.input_arg_size()];
690 const auto& attr = *FindAttr(attr_name, op_def_);
691 const string& attr_api_name = param.GetRenameTo();
692 StringPiece attr_type = attr.type();
693 attr_expressions_[attr_name] = attr_api_name;
694 const int default_index = i - (attrs_.size() - params_with_default_.size());
695 if (default_index >= 0) {
696 const string& default_value = params_with_default_[default_index].second;
697 strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
698 " is None:\n");
699 strings::StrAppend(function_setup, indentation, " ", attr_api_name,
700 " = ", default_value, "\n");
701 }
702 if (absl::StartsWith(attr_type, "list(")) {
703 ExpectListArg(indentation, attr_api_name, function_setup);
704 }
705
706 if (attr_type == "string") {
707 strings::StrAppend(function_setup, indentation, attr_api_name,
708 " = _execute.make_str(", attr_api_name, ", \"",
709 attr_api_name, "\")\n");
710 } else if (attr_type == "list(string)") {
711 strings::StrAppend(function_setup, indentation, attr_api_name,
712 " = [_execute.make_str(_s, \"", attr_api_name,
713 "\") for _s in ", attr_api_name, "]\n");
714 } else if (attr_type == "int") {
715 strings::StrAppend(function_setup, indentation, attr_api_name,
716 " = _execute.make_int(", attr_api_name, ", \"",
717 attr_api_name, "\")\n");
718 } else if (attr_type == "list(int)") {
719 strings::StrAppend(function_setup, indentation, attr_api_name,
720 " = [_execute.make_int(_i, \"", attr_api_name,
721 "\") for _i in ", attr_api_name, "]\n");
722 } else if (attr_type == "float") {
723 strings::StrAppend(function_setup, indentation, attr_api_name,
724 " = _execute.make_float(", attr_api_name, ", \"",
725 attr_api_name, "\")\n");
726 } else if (attr_type == "list(float)") {
727 strings::StrAppend(function_setup, indentation, attr_api_name,
728 " = [_execute.make_float(_f, \"", attr_api_name,
729 "\") for _f in ", attr_api_name, "]\n");
730 } else if (attr_type == "bool") {
731 strings::StrAppend(function_setup, indentation, attr_api_name,
732 " = _execute.make_bool(", attr_api_name, ", \"",
733 attr_api_name, "\")\n");
734 } else if (attr_type == "list(bool)") {
735 strings::StrAppend(function_setup, indentation, attr_api_name,
736 " = [_execute.make_bool(_b, \"", attr_api_name,
737 "\") for _b in ", attr_api_name, "]\n");
738 } else if (attr_type == "type") {
739 strings::StrAppend(function_setup, indentation, attr_api_name,
740 " = _execute.make_type(", attr_api_name, ", \"",
741 attr_api_name, "\")\n");
742 } else if (attr_type == "list(type)") {
743 strings::StrAppend(function_setup, indentation, attr_api_name,
744 " = [_execute.make_type(_t, \"", attr_api_name,
745 "\") for _t in ", attr_api_name, "]\n");
746 } else if (attr_type == "shape") {
747 strings::StrAppend(function_setup, indentation, attr_api_name,
748 " = _execute.make_shape(", attr_api_name, ", \"",
749 attr_api_name, "\")\n");
750 } else if (attr_type == "list(shape)") {
751 strings::StrAppend(function_setup, indentation, attr_api_name,
752 " = [_execute.make_shape(_s, \"", attr_api_name,
753 "\") for _s in ", attr_api_name, "]\n");
754 } else if (attr_type == "tensor") {
755 strings::StrAppend(function_setup, indentation, attr_api_name,
756 " = _execute.make_tensor(", attr_api_name, ", \"",
757 attr_api_name, "\")\n");
758 } else if (attr_type == "list(tensor)") {
759 strings::StrAppend(function_setup, indentation, attr_api_name,
760 " = [_execute.make_tensor(_t, \"", attr_api_name,
761 "\") for _t in ", attr_api_name, "]\n");
762 } else if (attr_type != "func" && attr_type != "list(func)") {
763 *function_setup =
764 strings::StrCat("# No definition for ", function_name_,
765 " since we don't support attrs with type\n"
766 "# '",
767 attr_type, "' right now.\n\n");
768 return false;
769 }
770 }
771 return true;
772 }
773
774 // If output i is list output, output_sizes[i] will be set to a
775 // string with the python expression that will evaluate to its
776 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)777 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
778 std::vector<string>* output_sizes, string* num_outputs_expr) {
779 // Expression representing the number of outputs.
780 int num_fixed_outputs = 0;
781 for (int i = 0; i < num_outs_; ++i) {
782 const auto& arg(op_def_.output_arg(i));
783 if (!arg.number_attr().empty()) {
784 if (!num_outputs_expr->empty()) {
785 strings::StrAppend(num_outputs_expr, " + ");
786 }
787 (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
788 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
789 } else if (!arg.type_list_attr().empty()) {
790 if (!num_outputs_expr->empty()) {
791 strings::StrAppend(num_outputs_expr, " + ");
792 }
793 // Have to be careful to use an expression that works in both
794 // graph and eager paths here.
795 const auto iter = inferred_attrs_.find(arg.type_list_attr());
796 if (iter == inferred_attrs_.end()) {
797 (*output_sizes)[i] = strings::StrCat(
798 "len(", attr_expressions_[arg.type_list_attr()], ")");
799 } else {
800 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
801 }
802 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
803 } else {
804 ++num_fixed_outputs;
805 }
806 }
807 if (num_fixed_outputs > 0) {
808 if (!num_outputs_expr->empty()) {
809 strings::StrAppend(num_outputs_expr, " + ");
810 }
811 strings::StrAppend(num_outputs_expr, num_fixed_outputs);
812 } else if (num_outputs_expr->empty()) {
813 *num_outputs_expr = "0";
814 }
815 }
816
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)817 void GenEagerPythonOp::AddEagerFunctionTeardown(
818 const string& indentation, const std::vector<string>& output_sizes,
819 bool execute_record_gradient) {
820 if (num_outs_ > 0) {
821 if (execute_record_gradient) {
822 strings::StrAppend(&result_, indentation,
823 "if _execute.must_record_gradient():\n");
824 strings::StrAppend(&result_, indentation, " _execute.record_gradient(\n",
825 " \"", op_def_.name(),
826 "\", _inputs_flat, _attrs, _result)\n");
827 }
828 if (num_outs_ == 1 && !output_sizes[0].empty()) {
829 // Single list result.
830 } else if (num_outs_ == 1) {
831 // Execute returns a single-element list which we need to destructure.
832 strings::StrAppend(&result_, indentation, "_result, = _result\n");
833 } else {
834 // Have multiple outputs, so we will need to reformat the return
835 // value of execute() to be a list with one entry per op output
836 // (that entry will be a list of tensors if that output is of list
837 // type).
838 // For list outputs, convert the right subrange of _result into a list.
839 Unflatten(indentation, output_sizes, "_result", &result_);
840 // Convert to a named tuple.
841 strings::StrAppend(
842 &result_, indentation, "_result = _",
843 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
844 "Output._make(_result)\n");
845 }
846 } else {
847 strings::StrAppend(&result_, indentation, "_result = None\n");
848 }
849 strings::StrAppend(&result_, indentation, "return _result\n\n");
850 }
851
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)852 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
853 const string& parameters, const std::vector<string>& output_sizes,
854 const string& eager_not_allowed_error,
855 const std::unordered_map<string, string>& type_annotations) {
856 if (add_type_annotations_) {
857 GenerateTypeVars(type_annotations);
858 }
859 if (api_def_.visibility() == ApiDef::VISIBLE) {
860 strings::StrAppend(&result_, "@_dispatch.add_fallback_dispatch_list\n");
861 strings::StrAppend(&result_, "@_dispatch.add_type_based_api_dispatcher\n");
862 }
863
864 AddExport();
865 AddDefLine(function_name_, parameters);
866 if (add_type_annotations_) {
867 AddReturnTypeAnnotation(type_annotations);
868 }
869 AddDocStringDescription();
870 AddDocStringArgs();
871 AddDocStringInputs();
872 AddDocStringAttrs();
873 AddDocStringNameArg();
874 AddOutputGlobals(); // Added to prelude_
875 AddDocStringOutputs();
876 strings::StrAppend(&result_, " \"\"\"\n");
877
878 strings::StrAppend(&result_,
879 " _ctx = _context._context or _context.context()\n"
880 " tld = _ctx._thread_local_data\n",
881 " if tld.is_eager:", "\n");
882 if (eager_not_allowed_error.empty()) {
883 AddEagerFastPathExecute();
884 } else {
885 strings::StrAppend(&result_, " ", eager_not_allowed_error);
886 }
887
888 // Handle graph-mode case
889 string function_setup;
890 if (!GetEagerFunctionSetup(" ", &function_setup)) {
891 result_ = function_setup;
892 return false;
893 }
894 HandleGraphMode(function_setup, output_sizes);
895
896 AddRawOpExport(parameters);
897 AddTypeBasedDispatcherAlias();
898 strings::StrAppend(&result_, "\n\n");
899 return true;
900 }
901
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)902 bool GenEagerPythonOp::AddEagerFallbackCode(
903 const string& parameters, const std::vector<string>& output_sizes,
904 const string& num_outputs_expr, const string& eager_not_allowed_error,
905 const std::unordered_map<string, string>& type_annotations) {
906 AddDefLine(
907 strings::StrCat(function_name_, kEagerFallbackSuffix),
908 strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
909 if (add_type_annotations_) {
910 AddReturnTypeAnnotation(type_annotations);
911 }
912 if (!eager_not_allowed_error.empty()) {
913 strings::StrAppend(&result_, " ", eager_not_allowed_error);
914 return true;
915 }
916
917 string function_setup;
918 if (!GetEagerFunctionSetup(" ", &function_setup)) {
919 result_ = function_setup;
920 return false;
921 }
922 strings::StrAppend(&result_, function_setup);
923
924 AddEagerInferredAttrs(" ");
925 AddEagerInputCasts(" ");
926 strings::StrAppend(
927 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
928 AddEagerAttrs(" ");
929 AddEagerExecute(" ", num_outputs_expr);
930
931 AddEagerFunctionTeardown(" ", output_sizes,
932 true /* execute_record_gradient */);
933
934 return true;
935 }
936
AddEagerFastPathExecute()937 void GenEagerPythonOp::AddEagerFastPathExecute() {
938 string fastpath_execute_params =
939 strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
940 string fallback_params;
941
942 for (int i = 0; i < api_def_.in_arg_size(); i++) {
943 const string param_name = param_names_[i].GetRenameTo();
944 strings::StrAppend(&fastpath_execute_params, ", ", param_name);
945 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
946 strings::StrAppend(&fallback_params, param_name);
947 }
948
949 for (const auto& attr : api_def_.attr()) {
950 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
951 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
952 attr.rename_to());
953
954 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
955 strings::StrAppend(&fallback_params, attr.rename_to(), "=",
956 attr.rename_to());
957 }
958 }
959
960 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
961 strings::StrAppend(&fallback_params, "name=name");
962
963 strings::StrAppend(&result_, " try:\n");
964 strings::StrAppend(
965 &result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
966 WordWrap(strings::StrCat(" "),
967 strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
968 "\n");
969
970 if (op_def_.output_arg_size() > 1) {
971 const string output_tuple_name = strings::StrCat(
972 "_", python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
973 "Output");
974 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
975 "._make(_result)\n");
976 }
977 strings::StrAppend(&result_, " ", "return _result\n");
978
979 // Handle fallback.
980 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
981 strings::StrAppend(&fallback_params, "ctx=_ctx");
982
983 // Any errors thrown from execute need to be unwrapped from
984 // _NotOkStatusException.
985 strings::StrAppend(&result_, " ",
986 "except _core._NotOkStatusException as e:\n");
987 strings::StrAppend(&result_, " ",
988 "_ops.raise_from_not_ok_status(e, name)\n");
989
990 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
991 strings::StrAppend(&result_, " pass\n");
992 strings::StrAppend(&result_, " try:\n");
993 AddTypeBasedDispatch(" ");
994 strings::StrAppend(
995 &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
996 "(\n",
997 WordWrap(strings::StrCat(" "),
998 strings::StrCat(fallback_params, ")"), kRightMargin),
999 "\n");
1000 strings::StrAppend(&result_, " except _core._SymbolicException:\n");
1001 strings::StrAppend(&result_,
1002 " pass # Add nodes to the TensorFlow graph.\n");
1003 AddFallbackDispatch(" ");
1004 }
1005
AddEagerInferredAttrs(const string & indentation)1006 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
1007 // Figure out values for inferred attrs, and cast to eager tensors.
1008 for (int i = 0; i < op_def_.attr_size(); ++i) {
1009 const auto& attr(op_def_.attr(i));
1010 const auto& api_def_attr(api_def_.attr(i));
1011 auto arg_list = attr_to_args_.find(attr.name());
1012 if (arg_list != attr_to_args_.end()) {
1013 if (attr.type() == "type") {
1014 std::vector<string> output_sizes;
1015 const string flattened =
1016 FlattenInputs(&arg_list->second, &output_sizes);
1017 string conversion = strings::StrCat("_execute.args_to_matching_eager(",
1018 flattened, ", ctx");
1019
1020 strings::StrAppend(&conversion, ", [");
1021 for (int t : attr.allowed_values().list().type()) {
1022 DataType dtype = static_cast<DataType>(t);
1023 const string py_dtype =
1024 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
1025 strings::StrAppend(&conversion, py_dtype, ", ");
1026 }
1027 strings::StrAppend(&conversion, "]");
1028
1029 if (attr.has_default_value()) {
1030 strings::StrAppend(
1031 &conversion, ", ",
1032 python_op_gen_internal::AttrValueToPython(
1033 attr.type(), api_def_attr.default_value(), "_dtypes."));
1034 }
1035 strings::StrAppend(&conversion, ")");
1036 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1037 if (output_sizes.size() == 1) {
1038 // Avoid creating a temporary variable in the case where
1039 // we can easily assign to the right value directly.
1040 const string inputs_var =
1041 param_names_[arg_list->second.front()].GetRenameTo();
1042 if (output_sizes.front().empty()) {
1043 strings::StrAppend(&result_, indentation, var_name, ", (",
1044 inputs_var, ",) = ", conversion, "\n");
1045 } else {
1046 strings::StrAppend(&result_, indentation, var_name, ", ",
1047 inputs_var, " = ", conversion, "\n");
1048 }
1049 } else {
1050 const string inputs_var = strings::StrCat("_inputs_", attr.name());
1051 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1052 " = ", conversion, "\n");
1053 // Convert from a flat list of eager tensors back to the
1054 // parameter variables.
1055 Unflatten(indentation, output_sizes, inputs_var, &result_);
1056 std::vector<string> p;
1057 for (int j : arg_list->second) {
1058 p.emplace_back(param_names_[j].GetRenameTo());
1059 }
1060 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
1061 inputs_var, "\n");
1062 }
1063 } else if (attr.type() == "list(type)") {
1064 // NOTE: We ignore default values for these attrs, since it is
1065 // unclear how you would use it, and the one use case is
1066 // parse_single_sequence_example which only needs it for
1067 // backwards compatibility.
1068 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1069 string inputs_var;
1070 string conversion;
1071 if (arg_list->second.size() > 1) {
1072 // If you have more than one list(tensor) argument, their types
1073 // have to match.
1074 std::vector<string> lists;
1075 for (auto iter = arg_list->second.begin();
1076 iter != arg_list->second.end(); ++iter) {
1077 lists.push_back(param_names_[*iter].GetRenameTo());
1078 }
1079 inputs_var = VectorToTuple(lists);
1080 conversion = "_execute.args_to_mixed_eager_tensors";
1081 } else {
1082 // For one list(tensor) argument, we just convert every
1083 // element of the list to an eager tensor.
1084 inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
1085 conversion = "_execute.convert_to_mixed_eager_tensors";
1086 }
1087 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1088 " = ", conversion, "(", inputs_var, ", ctx)\n");
1089 }
1090 }
1091 }
1092 }
1093
AddEagerInputCasts(const string & indentation)1094 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
1095 // Cast remaining args to eager tensors
1096 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
1097 const auto& arg(op_def_.input_arg(i));
1098 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
1099 const string& param = param_names_[i].GetRenameTo();
1100 const string fn = arg.number_attr().empty() ? "" : "n_";
1101 const string dtype =
1102 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1103 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
1104 "to_tensor(", param, ", ", dtype, ")\n");
1105 }
1106 }
1107
AddEagerAttrs(const string & indentation)1108 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
1109 // Compute eager attrs
1110 if (op_def_.attr_size() > 0) {
1111 string attr_values;
1112 for (int i = 0; i < op_def_.attr_size(); ++i) {
1113 if (i > 0) strings::StrAppend(&attr_values, ", ");
1114 const auto& attr_name(op_def_.attr(i).name());
1115 strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
1116 attr_expressions_[attr_name]);
1117 }
1118 strings::StrAppend(&attr_values, ")");
1119 strings::StrAppend(
1120 &result_,
1121 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
1122 kRightMargin),
1123 "\n");
1124 } else {
1125 strings::StrAppend(&result_, indentation, "_attrs = None\n");
1126 }
1127 }
1128
AddEagerExecute(const string & indentation,const string & num_outputs_expr)1129 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
1130 const string& num_outputs_expr) {
1131 const string return_prefix =
1132 strings::StrCat(indentation, "_result = _execute.execute(");
1133 const string return_args = strings::StrCat(
1134 "b\"", op_def_.name(), "\", ", num_outputs_expr,
1135 ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)");
1136 strings::StrAppend(&result_,
1137 // Wrap the arguments, and indent to the (.
1138 WordWrap(return_prefix, return_args, kRightMargin), "\n");
1139 }
1140
AddFallbackDispatch(const string & prefix)1141 void GenEagerPythonOp::AddFallbackDispatch(const string& prefix) {
1142 if (api_def_.visibility() != ApiDef::VISIBLE) return;
1143
1144 strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
1145 strings::StrAppend(&result_, prefix, " _result = _dispatch.dispatch(\n");
1146 AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_,
1147 ", "
1148 "(), dict("));
1149 strings::StrAppend(&result_, prefix, " )\n");
1150 strings::StrAppend(&result_, prefix,
1151 " if _result is not "
1152 "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
1153 strings::StrAppend(&result_, prefix, " return _result\n");
1154 strings::StrAppend(&result_, prefix, " raise\n");
1155 }
1156
AddTypeBasedDispatcherAlias()1157 void GenEagerPythonOp::AddTypeBasedDispatcherAlias() {
1158 // It's possible for the name of a parameter to be the same as the name of
1159 // an op, in which case the parameter shadows the op's function. To avoid
1160 // this, we add a private variable with the dispatcher, and access that
1161 // directly.
1162 if (api_def_.visibility() == ApiDef::VISIBLE) {
1163 strings::StrAppend(&result_, "_dispatcher_for_", function_name_,
1164 " = ", function_name_,
1165 "._tf_type_based_dispatcher.Dispatch\n");
1166 }
1167 }
AddTypeBasedDispatch(const string & prefix)1168 void GenEagerPythonOp::AddTypeBasedDispatch(const string& prefix) {
1169 if (api_def_.visibility() != ApiDef::VISIBLE) return;
1170 std::string args("(");
1171 for (const auto& name : param_names_) {
1172 strings::StrAppend(&args, name.GetRenameTo(), ", ");
1173 }
1174 strings::StrAppend(&args, "name,), None");
1175
1176 strings::StrAppend(
1177 &result_, prefix, "_result = ", "_dispatcher_for_", function_name_, "(\n",
1178 WordWrap(strings::StrCat(prefix, " "), args, kRightMargin), ")\n");
1179 strings::StrAppend(&result_, prefix, "if _result is not NotImplemented:\n",
1180 prefix, " return _result\n");
1181 }
1182
AddRawOpExport(const string & parameters)1183 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
1184 // Example:
1185 //
1186 // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity))
1187 const string raw_function_name =
1188 python_op_gen_internal::AvoidPythonReserved(op_def_.name());
1189 strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops.",
1190 raw_function_name, "\")", "(_ops.to_raw_op(",
1191 function_name_, "))\n");
1192 }
1193
GetPythonOpsImpl(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name="",const std::unordered_set<string> type_annotate_ops={})1194 string GetPythonOpsImpl(
1195 const OpList& ops, const ApiDefMap& api_defs,
1196 const std::vector<string>& hidden_ops, const string& source_file_name = "",
1197 const std::unordered_set<string> type_annotate_ops = {}) {
1198 string result;
1199 // Header
1200 // TODO(josh11b): Mention the library for which wrappers are being generated.
1201 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
1202
1203 This file is MACHINE GENERATED! Do not edit.
1204 )");
1205
1206 // Mention the original source file so someone tracing back through
1207 // generated Python code will know where to look next.
1208 if (!source_file_name.empty()) {
1209 strings::StrAppend(&result, "Original C++ source file: ");
1210 strings::StrAppend(&result, source_file_name);
1211 strings::StrAppend(&result, "\n");
1212 }
1213
1214 strings::StrAppend(&result, R"("""
1215
1216 import collections
1217
1218 from tensorflow.python import pywrap_tfe as pywrap_tfe
1219 from tensorflow.python.eager import context as _context
1220 from tensorflow.python.eager import core as _core
1221 from tensorflow.python.eager import execute as _execute
1222 from tensorflow.python.framework import dtypes as _dtypes
1223
1224 from tensorflow.python.framework import op_def_registry as _op_def_registry
1225 from tensorflow.python.framework import ops as _ops
1226 from tensorflow.python.framework import op_def_library as _op_def_library
1227 from tensorflow.python.util.deprecation import deprecated_endpoints
1228 from tensorflow.python.util import dispatch as _dispatch
1229 from tensorflow.python.util.tf_export import tf_export
1230
1231 from typing import TypeVar
1232 )");
1233
1234 for (const auto& op_def : ops.op()) {
1235 const auto* api_def = api_defs.GetApiDef(op_def.name());
1236
1237 if (api_def->visibility() == ApiDef::SKIP) {
1238 continue;
1239 }
1240 // An op is hidden if either its ApiDef visibility is HIDDEN
1241 // or it is in the hidden_ops list.
1242 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1243 bool hidden_by_api_def = is_hidden;
1244 if (!is_hidden) {
1245 for (const string& hidden : hidden_ops) {
1246 if (op_def.name() == hidden) {
1247 is_hidden = true;
1248 break;
1249 }
1250 }
1251 }
1252
1253 string function_name;
1254 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1255 &function_name);
1256 bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1257
1258 // Prefix an op with underscore if the op is listed in hidden_ops or
1259 // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1260 // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1261 // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1262 if (is_hidden) {
1263 if (!hidden_by_api_def || is_reserved ||
1264 python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1265 function_name = strings::StrCat("_", function_name);
1266 }
1267 } else if (is_reserved) {
1268 // When users create custom python wrappers, they may link in the
1269 // default op registry by accident, and because they can't
1270 // enumerate all 'hidden' symbols, this guard is to prevent
1271 // instantiating a python reserved word in their wrapper.
1272 continue;
1273 }
1274
1275 auto iter = type_annotate_ops.find(op_def.name());
1276 bool add_type_annotations = iter != type_annotate_ops.end();
1277
1278 strings::StrAppend(&result,
1279 GetEagerPythonOp(op_def, *api_def, function_name,
1280 add_type_annotations));
1281 }
1282
1283 return result;
1284 }
1285
1286 } // namespace
1287
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1288 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1289 const std::vector<string>& hidden_ops,
1290 const string& source_file_name,
1291 const std::unordered_set<string> type_annotate_ops) {
1292 return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1293 type_annotate_ops);
1294 }
1295
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1296 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1297 const std::vector<string>& hidden_ops,
1298 const string& source_file_name,
1299 const std::unordered_set<string> type_annotate_ops) {
1300 printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1301 type_annotate_ops)
1302 .c_str());
1303 }
1304
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1305 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1306 OpList ops;
1307 ops.ParseFromArray(op_list_buf, op_list_len);
1308
1309 ApiDefMap api_def_map(ops);
1310 return GetPythonOpsImpl(ops, api_def_map, {});
1311 }
1312
GetArgAnnotation(const OpDef::ArgDef & arg,const std::unordered_map<string,string> & type_annotations)1313 string GetArgAnnotation(
1314 const OpDef::ArgDef& arg,
1315 const std::unordered_map<string, string>& type_annotations) {
1316 if (!arg.type_attr().empty()) {
1317 // Get the correct TypeVar if arg maps to an attr
1318 return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
1319 } else {
1320 // Get the dtype of the Tensor
1321 const string py_dtype =
1322 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1323 return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
1324 }
1325
1326 return "Any";
1327 }
1328
1329 } // namespace tensorflow
1330