xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_api_info.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_api_info.h"
16 
17 #include <Python.h>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/python/eager/pywrap_tensor.h"
23 #include "tensorflow/python/eager/pywrap_tfe.h"
24 #include "tensorflow/python/framework/op_def_util.h"
25 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
26 #include "tensorflow/python/util/util.h"
27 
28 namespace tensorflow {
29 
30 #if PY_MAJOR_VERSION < 3
31 // Python 2.x:
32 #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
33 #define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
34 #define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
35 #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
36 #define PY_STRING_AS_CSTR(x) (PyString_AsString(x))
37 #else
38 // Python 3.x:
39 #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
40 #define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
41 #define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
42 #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
43 #define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr))
44 #endif
45 
46 namespace {
47 
48 // Converts the given object to an interned Python string, and returns its
49 // data pointer.  (This means we don't need to worry about ownership for
50 // this string.)
InternPyString(const std::string & s)51 const char* InternPyString(const std::string& s) {
52   Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str()));
53   return PY_STRING_AS_CSTR(interned.get());
54 }
55 
56 template <typename T, typename UnaryPredicate>
RemoveIf(UnaryPredicate p,std::vector<T> * vec)57 void RemoveIf(UnaryPredicate p, std::vector<T>* vec) {
58   vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end());
59 }
60 
61 struct DataTypeFormatter {
operator ()tensorflow::__anon2a9b44f30111::DataTypeFormatter62   void operator()(std::string* out, DataType dtype) const {
63     out->append(DataType_Name(dtype));
64   }
65 };
66 
67 // Populates `param_names` and `defaults_tuple` based on the given OpDef.
GetOpDefNamesAndDefaults(const tensorflow::OpDef & op_def,std::vector<string> & param_names,Safe_PyObjectPtr & defaults_tuple)68 void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def,
69                               std::vector<string>& param_names,
70                               Safe_PyObjectPtr& defaults_tuple) {
71   param_names.reserve(op_def.input_arg_size() + op_def.attr_size());
72   std::set<std::string> inferred_attrs;
73 
74   // Input parameters come first, in the order they occur in the OpDef.
75   for (const auto& input : op_def.input_arg()) {
76     param_names.push_back(input.name());
77     if (!input.type_attr().empty()) {
78       inferred_attrs.insert(input.type_attr());
79     }
80     if (!input.type_list_attr().empty()) {
81       inferred_attrs.insert(input.type_list_attr());
82     }
83     if (!input.number_attr().empty()) {
84       inferred_attrs.insert(input.number_attr());
85     }
86   }
87 
88   // Next come attribute params without defaults, followed by attributes with
89   // defaults (but inferred attributes are not included).
90   std::vector<std::string> param_names_with_default;
91   std::vector<Safe_PyObjectPtr> defaults;
92   for (const auto& attr : op_def.attr()) {
93     if (inferred_attrs.count(attr.name()) == 0) {
94       if (attr.has_default_value()) {
95         param_names_with_default.push_back(attr.name());
96         defaults.push_back(AttrValueToPyObject(attr.default_value()));
97       } else {
98         param_names.push_back(attr.name());
99       }
100     }
101   }
102   param_names.insert(param_names.end(), param_names_with_default.begin(),
103                      param_names_with_default.end());
104 
105   // Finally, the 'name' parameter comes at the end, and its default value
106   // is the operation's name.
107   param_names.push_back("name");
108   defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str()));
109 
110   defaults_tuple.reset(PyTuple_New(defaults.size()));
111   for (int i = 0; i < defaults.size(); ++i) {
112     PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release());
113   }
114 }
115 
116 }  // namespace
117 
PythonAPIInfo(const std::string & api_name)118 PythonAPIInfo::PythonAPIInfo(const std::string& api_name)
119     : api_name_(InternPyString(api_name)) {}
120 
Initialize(const OpDef & op_def,const std::vector<string> param_names,PyObject * defaults_tuple)121 Status PythonAPIInfo::Initialize(const OpDef& op_def,
122                                  const std::vector<string> param_names,
123                                  PyObject* defaults_tuple) {
124   // Intern the parameter names.
125   param_names_.reserve(param_names.size());
126   for (const auto& param_name : param_names) {
127     param_names_.push_back(InternPyString(param_name));
128   }
129 
130   Py_INCREF(defaults_tuple);
131   defaults_tuple_.reset(defaults_tuple);
132 
133   // Build an index to look up parameter index by name.  (Does not include
134   // inferred attributes.)
135   std::map<std::string, int> param_name_to_index;
136   for (int i = 0; i < param_names_.size(); ++i) {
137     param_name_to_index[param_names_[i]] = i;
138   }
139 
140   // Initialize each attribute & input parameter.
141   attributes_.reserve(op_def.attr_size());
142   for (const auto& attr_def : op_def.attr()) {
143     TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index));
144   }
145 
146   inputs_.reserve(op_def.input_arg_size());
147   for (const auto& arg_def : op_def.input_arg()) {
148     TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index));
149   }
150 
151   TF_RETURN_IF_ERROR(CheckParamNames());
152 
153   // Filter out any unused entries from inputs_with_*_attrs_.
154   RemoveIf(
155       [](const InputsWithTypeAttr& input) {
156         return input.tensor_params.empty() && input.tensor_list_params.empty();
157       },
158       &inputs_with_type_attrs_);
159   RemoveIf(
160       [](const InputsWithTypeListAttr& input) {
161         return input.tensor_list_params.empty();
162       },
163       &inputs_with_type_list_attrs_);
164   RemoveIf(
165       [](const InputsWithNumberAttr& input) {
166         return input.tensor_list_params.empty();
167       },
168       &inputs_with_number_attrs_);
169 
170   return OkStatus();
171 }
172 
CheckParamNames() const173 Status PythonAPIInfo::CheckParamNames() const {
174   std::vector<bool> param_found(param_names_.size());
175   for (const auto& attr : attributes_) {
176     if (attr.index != -1) {
177       param_found[attr.index] = true;
178     }
179   }
180   for (const auto& input : inputs_) {
181     param_found[input.index] = true;
182   }
183 
184   for (int i = 0; i < param_names_.size(); ++i) {
185     if (param_names_[i] == std::string("name")) {
186       continue;
187     }
188     if (!param_found[i]) {
189       return errors::InvalidArgument(
190           api_name_, ": missing specification for parameter ", param_names_[i]);
191     }
192   }
193   return OkStatus();
194 }
195 
InitializeFromRegisteredOp(const std::string & op_name)196 Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) {
197   const tensorflow::OpDef* op_def = nullptr;
198   TF_RETURN_IF_ERROR(
199       tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def));
200   std::vector<std::string> param_names;
201   Safe_PyObjectPtr defaults_tuple;
202   GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple);
203   TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get()));
204   return OkStatus();
205 }
206 
InitializeFromParamSpecs(const std::map<std::string,std::string> & input_specs,const std::map<std::string,std::string> & attr_specs,const std::vector<string> param_names,PyObject * defaults_tuple)207 Status PythonAPIInfo::InitializeFromParamSpecs(
208     const std::map<std::string, std::string>& input_specs,
209     const std::map<std::string, std::string>& attr_specs,
210     const std::vector<string> param_names, PyObject* defaults_tuple) {
211   OpDefBuilder op_def_builder(api_name_);
212   op_def_builder.AllowAttrTypeAny();
213   for (const auto& attr_spec : attr_specs) {
214     op_def_builder.Attr(absl::StrCat(attr_spec.first, ": ", attr_spec.second));
215   }
216   for (const auto& input_spec : input_specs) {
217     op_def_builder.Input(
218         absl::StrCat(input_spec.first, ": ", input_spec.second));
219   }
220   OpRegistrationData op_reg_data;
221   TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data));
222 
223   TF_RETURN_IF_ERROR(
224       Initialize(op_reg_data.op_def, param_names, defaults_tuple));
225 
226   return OkStatus();
227 }
228 
InitializeAttribute(const OpDef::AttrDef & attr_def,const std::map<std::string,int> & param_name_to_index)229 Status PythonAPIInfo::InitializeAttribute(
230     const OpDef::AttrDef& attr_def,
231     const std::map<std::string, int>& param_name_to_index) {
232   if (attr_def.name() == "name") {
233     return errors::InvalidArgument(
234         api_name_, ": Reserved parameter `name` was used as an attribute.");
235   }
236   const char* name = InternPyString(attr_def.name());
237 
238   const int param_index =
239       gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1);
240   const AttributeType dtype = AttributeTypeFromName(attr_def.type());
241   const int inferred_index = -1;
242   attributes_.push_back({param_index, dtype, name, inferred_index});
243   Attribute& attr = attributes_.back();
244   if (attr.type == AttributeType::UNKNOWN) {
245     return errors::InvalidArgument(api_name_, ": Bad attribute type for ",
246                                    attr_def.name(), ": '", attr_def.type(),
247                                    "'");
248   }
249   std::vector<DataType>* ok_dtypes = nullptr;
250 
251   if (attr.type == AttributeType::DTYPE) {
252     DataType default_dtype = attr_def.has_default_value()
253                                  ? attr_def.default_value().type()
254                                  : DT_INVALID;
255     inputs_with_type_attrs_.push_back({&attr, default_dtype});
256     ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes;
257 
258   } else if (attr.type == AttributeType::LIST_DTYPE) {
259     inputs_with_type_list_attrs_.push_back({&attr});
260     for (int d : attr_def.default_value().list().type()) {
261       inputs_with_type_list_attrs_.back().default_dtypes.push_back(
262           static_cast<DataType>(d));
263     }
264     ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes;
265   }
266 
267   if (attr_def.has_allowed_values() && ok_dtypes) {
268     const auto& dtypes = attr_def.allowed_values().list();
269     for (int i = 0; i < dtypes.type_size(); ++i) {
270       ok_dtypes->push_back(dtypes.type(i));
271     }
272   }
273 
274   if (attr.type == AttributeType::INT) {
275     int64_t default_len =
276         attr_def.has_default_value() ? attr_def.default_value().i() : -1;
277     inputs_with_number_attrs_.push_back({&attr, default_len});
278   }
279 
280   // If this is an inferred attribute, then record its name and index.
281   if (attr.index == -1) {
282     std::vector<const char*>* inferred_attr_names =
283         attr.type == AttributeType::DTYPE        ? &inferred_type_attrs_
284         : attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_
285         : attr.type == AttributeType::INT        ? &inferred_length_attrs_
286                                                  : nullptr;
287     if (inferred_attr_names == nullptr) {
288       return errors::InvalidArgument(
289           api_name_, ": Missing specification for parameter ", attr_def.name());
290     } else {
291       attr.inferred_index = inferred_attr_names->size();
292       inferred_attr_names->push_back(attr.name);
293     }
294   }
295 
296   return OkStatus();
297 }
298 
InitializeInput(const OpDef::ArgDef & arg_def,const std::map<std::string,ParamIndex> & param_name_to_index)299 Status PythonAPIInfo::InitializeInput(
300     const OpDef::ArgDef& arg_def,
301     const std::map<std::string, ParamIndex>& param_name_to_index) {
302   if (arg_def.name() == "name") {
303     return errors::InvalidArgument(
304         api_name_, ": Reserved parameter `name` was used as a tensor input.");
305   }
306   const ParamIndex param_index =
307       gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1);
308   if (param_index == -1) {
309     return errors::InvalidArgument(
310         api_name_, ": Missing specification for parameter ", arg_def.name());
311   }
312   if (arg_def.is_ref()) {
313     // TODO(b/164980194): Support reference parameters.
314     //   - Pass as_ref to convert_to_tensor
315     //   - Check that values for ref inputs have ref types.
316     return errors::InvalidArgument(api_name_,
317                                    ": PythonAPIInfo doesn't support reference "
318                                    "parameters yet.");
319   }
320   bool is_list =
321       !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
322   inputs_.push_back({param_index, is_list});
323 
324   if (!arg_def.type_list_attr().empty()) {
325     // list(input) with dtypes specified by a `list(type)` attribute.
326     InputsWithTypeListAttr* input =
327         FindInputsWithTypeListAttr(arg_def.type_list_attr());
328     if (!input) {
329       return errors::InvalidArgument(
330           api_name_, ": Type attribute ", arg_def.type_list_attr(),
331           " for parameter ", arg_def.name(), " not found.");
332     }
333     input->tensor_list_params.push_back(param_index);
334   } else if (!arg_def.type_attr().empty()) {
335     InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr());
336     // input or list(input) with dtype specified by a `type` attribute.
337     if (!input) {
338       return errors::InvalidArgument(api_name_, ": Type attribute ",
339                                      arg_def.type_attr(), " for parameter ",
340                                      arg_def.name(), " not found.");
341     }
342     if (arg_def.number_attr().empty()) {
343       input->tensor_params.push_back(param_index);
344     } else {
345       input->tensor_list_params.push_back(param_index);
346     }
347   } else {
348     // input or list(input) with fixed dtype
349     inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list});
350   }
351 
352   if (!arg_def.number_attr().empty()) {
353     InputsWithNumberAttr* input =
354         FindInputsWithNumberAttr(arg_def.number_attr());
355     if (!input) {
356       return errors::InvalidArgument(api_name_, ": Length attribute ",
357                                      arg_def.number_attr(), " for parameter ",
358                                      arg_def.name(), " not found.");
359     }
360     input->tensor_list_params.push_back(param_index);
361   }
362 
363   return OkStatus();
364 }
365 
FindInputsWithTypeAttr(const string & name)366 PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr(
367     const string& name) {
368   for (auto& input : inputs_with_type_attrs_) {
369     if (name == input.type_attr->name) {
370       return &input;
371     }
372   }
373   return nullptr;
374 }
375 
376 PythonAPIInfo::InputsWithTypeListAttr*
FindInputsWithTypeListAttr(const string & name)377 PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) {
378   for (auto& input : inputs_with_type_list_attrs_) {
379     if (name == input.type_list_attr->name) {
380       return &input;
381     }
382   }
383   return nullptr;
384 }
385 
FindInputsWithNumberAttr(const string & name)386 PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr(
387     const string& name) {
388   for (auto& input : inputs_with_number_attrs_) {
389     if (name == input.number_attr->name) {
390       return &input;
391     }
392   }
393   return nullptr;
394 }
395 
DebugInfo() const396 string PythonAPIInfo::DebugInfo() const {
397   string s = absl::StrCat("DebugInfo for ", api_name_, ":\n");
398   absl::StrAppend(&s, "  param_names=[", absl::StrJoin(param_names_, ", "),
399                   "]\n");
400   Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get()));
401   absl::StrAppend(
402       &s, "  defaults_tuple=", TFE_GetPythonString(defaults_repr.get()), "\n");
403   if (!attributes_.empty()) {
404     absl::StrAppend(&s, "  attributes=[");
405     for (const auto& attrib : attributes_) {
406       if (attrib.index != -1) {
407         absl::StrAppend(&s, "\n    {index=", attrib.index);
408         DCHECK_EQ(attrib.inferred_index, -1);
409       } else {
410         absl::StrAppend(&s, "\n    {inferred_index=", attrib.inferred_index);
411       }
412       absl::StrAppend(&s, ", name=", attrib.name,
413                       ", type=", AttributeTypeToName(attrib.type), "},");
414     }
415     absl::StrAppend(&s, "]\n");
416   }
417   if (!inputs_.empty()) {
418     absl::StrAppend(&s, "  inputs=[");
419     for (const auto& input : inputs_) {
420       absl::StrAppend(&s, "\n    {index=", input.index,
421                       ", name=", param_names_[input.index],
422                       ", is_list=", input.is_list, "},");
423     }
424     absl::StrAppend(&s, "]\n");
425   }
426   if (!inputs_with_fixed_dtype_.empty()) {
427     absl::StrAppend(&s, "  inputs_with_fixed_dtype=[");
428     for (const auto& input : inputs_with_fixed_dtype_) {
429       absl::StrAppend(&s, "\n    {index=", input.index,
430                       ", dtype=", DataType_Name(input.dtype),
431                       ", is_list=", input.is_list, "},");
432     }
433     absl::StrAppend(&s, "]\n");
434   }
435   if (!inputs_with_type_attrs_.empty()) {
436     absl::StrAppend(&s, "  inputs_with_type_attr=[");
437     for (const auto& input : inputs_with_type_attrs_) {
438       absl::StrAppend(&s, "\n    {type_attr=", input.type_attr->name);
439       if (input.default_dtype != DT_INVALID) {
440         absl::StrAppend(&s,
441                         ", default_dtype=", DataType_Name(input.default_dtype));
442       }
443       if (!input.tensor_params.empty()) {
444         absl::StrAppend(&s, ", tensor_params=[",
445                         absl::StrJoin(input.tensor_params, ", "), "]");
446       }
447       if (!input.tensor_list_params.empty()) {
448         absl::StrAppend(&s, ", tensor_list_params=[",
449                         absl::StrJoin(input.tensor_list_params, ", "), "]");
450       }
451       if (!input.ok_dtypes.empty()) {
452         absl::StrAppend(
453             &s, ", ok_dtypes=[",
454             absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
455       }
456       absl::StrAppend(&s, "},");
457     }
458     absl::StrAppend(&s, "]\n");
459   }
460   if (!inputs_with_type_list_attrs_.empty()) {
461     absl::StrAppend(&s, "  inputs_with_type_list_attrs=[");
462     for (const auto& input : inputs_with_type_list_attrs_) {
463       absl::StrAppend(&s, "\n    {type_list_attr=", input.type_list_attr->name);
464       if (!input.default_dtypes.empty()) {
465         absl::StrAppend(
466             &s, ", default_dtypes=[",
467             absl::StrJoin(input.default_dtypes, ", ", DataTypeFormatter()),
468             "]");
469       }
470       if (!input.tensor_list_params.empty()) {
471         absl::StrAppend(&s, ", tensor_list_params=[",
472                         absl::StrJoin(input.tensor_list_params, ", "), "]");
473       }
474       if (!input.ok_dtypes.empty()) {
475         absl::StrAppend(
476             &s, ", ok_dtypes=[",
477             absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
478       }
479       absl::StrAppend(&s, "},");
480     }
481     absl::StrAppend(&s, "]\n");
482   }
483   if (!inputs_with_number_attrs_.empty()) {
484     absl::StrAppend(&s, "  inputs_with_number_attrs=[");
485     for (const auto& input : inputs_with_number_attrs_) {
486       absl::StrAppend(&s, "\n    {number_attr=", input.number_attr->name,
487                       ", default_length=", input.default_length,
488                       ", tensor_list_params=[",
489                       absl::StrJoin(input.tensor_list_params, ", "), "],\n");
490     }
491     absl::StrAppend(&s, "]\n");
492   }
493   if (!inferred_type_attrs_.empty()) {
494     absl::StrAppend(&s, "  inferred_type_attrs=[",
495                     absl::StrJoin(inferred_type_attrs_, ", "), "]\n");
496   }
497   if (!inferred_type_list_attrs_.empty()) {
498     absl::StrAppend(&s, "  inferred_type_list_attrs=[",
499                     absl::StrJoin(inferred_type_list_attrs_, ", "), "]\n");
500   }
501   if (!inferred_length_attrs_.empty()) {
502     absl::StrAppend(&s, "  inferred_length_attrs=[",
503                     absl::StrJoin(inferred_length_attrs_, ", "), "]\n");
504   }
505   return s;
506 }
507 
508 }  // namespace tensorflow
509