xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/op_def_library_pybind.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <stdexcept>
18 #include <string>
19 #include <utility>
20 
21 #include "Python.h"
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "pybind11/cast.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/pytypes.h"
31 #include "tensorflow/core/framework/attr_value.pb.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_def.pb.h"
34 #include "tensorflow/core/framework/op_def_util.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/python/framework/op_def_util.h"
38 #include "tensorflow/python/lib/core/pybind11_status.h"
39 
40 namespace py = pybind11;
41 
42 namespace {
43 
44 using ::tensorflow::AttributeType;
45 using ::tensorflow::AttributeTypeFromName;
46 using ::tensorflow::AttrValue;
47 using ::tensorflow::CheckOpDeprecation;
48 using ::tensorflow::ConvertPyObjectToAttributeType;
49 using ::tensorflow::DataType;
50 using ::tensorflow::DataTypeToPyObject;
51 using ::tensorflow::MaybeRaiseFromStatus;
52 using ::tensorflow::OpDef;
53 using ::tensorflow::OpRegistry;
54 using ::tensorflow::protobuf::RepeatedField;
55 using ::tensorflow::protobuf::RepeatedPtrField;
56 using AttrDef = ::tensorflow::OpDef::AttrDef;
57 using ArgDef = ::tensorflow::OpDef::ArgDef;
58 // Keys: attr.name(); Values: attr_def.allowed_values().list().type()
59 using AllowedAttrMap =
60     absl::flat_hash_map<std::string, absl::flat_hash_set<int>>;
61 // Keys: attr.name(); Values; attr_def.default_value().type()
62 using DefaultAttrMap = absl::flat_hash_map<std::string, py::object>;
63 // Keys: attr.name(); Values: corresponding attr serialized as an AttrValue
64 using AttrProtosMap = absl::flat_hash_map<std::string, AttrValue>;
65 
66 constexpr char kType[] = "type";
67 constexpr char kTypeEnum[] = "_type_enum";
68 constexpr char kDType[] = "dtype";
69 constexpr char kBaseDType[] = "base_dtype";
70 constexpr char kAsProto[] = "as_proto";
71 constexpr char kSerialize[] = "SerializeToString";
72 constexpr char kListPrefix[] = "list(";
73 constexpr char kPop[] = "pop";
74 
PyTypeError(const std::string & error_msg)75 inline py::error_already_set PyTypeError(const std::string& error_msg) {
76   PyErr_SetString(PyExc_TypeError, error_msg.c_str());
77   return pybind11::error_already_set();
78 }
79 
PyValueError(const std::string & error_msg)80 inline py::error_already_set PyValueError(const std::string& error_msg) {
81   PyErr_SetString(PyExc_ValueError, error_msg.c_str());
82   return pybind11::error_already_set();
83 }
84 
PyRepr(const py::handle & value)85 inline std::string PyRepr(const py::handle& value) {
86   return value.attr("__repr__")().cast<std::string>();
87 }
88 
DataTypeToPybindObject(const DataType & data_type)89 py::object DataTypeToPybindObject(const DataType& data_type) {
90   return py::reinterpret_borrow<py::object>(
91       DataTypeToPyObject(data_type).release());
92 }
93 
94 // Converts the py:object to the AttributeType.
95 // ToAttributeType corrupts the value's representation when it fails. So this
96 // should be stored before hand if it is needed for error msgs.
ToAttributeType(const py::handle & value,const AttributeType type)97 py::object ToAttributeType(const py::handle& value, const AttributeType type) {
98   auto result = ConvertPyObjectToAttributeType(value.ptr(), type);
99   if (result == nullptr) {
100     throw std::runtime_error("Failed to perform conversion.");
101   }
102   return py::reinterpret_borrow<py::object>(result.release());
103 }
104 
MakeBool(const py::handle & value,const std::string & arg_name)105 inline bool MakeBool(const py::handle& value, const std::string& arg_name) {
106   if (!py::isinstance<py::bool_>(value)) {
107     throw PyTypeError(
108         absl::StrCat("Expected bool for argument '", arg_name, "' not ",
109                      value.attr("__repr__")().cast<std::string>(), "."));
110   }
111   return value.cast<py::bool_>();
112 }
113 
MakeInt(const py::handle & value)114 inline int MakeInt(const py::handle& value) {
115   try {
116     // Needed for TF1 compatibility where a tf.Dimension may be passed in.
117     return value.attr("value").cast<float>();
118   } catch (...) {
119     return value.cast<float>();  // Cast to float to match Python's behaviour.
120   }
121 }
122 
MakeType(const py::handle & value,const std::string & arg_name)123 inline DataType MakeType(const py::handle& value, const std::string& arg_name) {
124   std::string repr_v = PyRepr(value);
125   try {
126     return ToAttributeType(value, AttributeType::DTYPE)
127         .attr(kBaseDType)
128         .cast<DataType>();
129   } catch (...) {
130     throw PyTypeError(absl::StrCat("Expected DataType for argument '", arg_name,
131                                    "' not ", repr_v, "."));
132   }
133 }
134 
MakeShape(const py::handle & value,const std::string & arg_name)135 inline std::string MakeShape(const py::handle& value,
136                              const std::string& arg_name) {
137   std::string repr_v = PyRepr(value);
138   try {
139     return ToAttributeType(value, AttributeType::SHAPE)
140         .attr(kAsProto)()
141         .attr(kSerialize)()
142         .cast<std::string>();
143   } catch (...) {
144     throw PyTypeError(absl::StrCat("Error converting ", repr_v, " (arg name = ",
145                                    arg_name, ") to a TensorShape"));
146   }
147 }
148 
ValueToAttrValue(const py::object & value,const std::string & attr_type,const std::string & arg_name)149 AttrValue ValueToAttrValue(const py::object& value,
150                            const std::string& attr_type,
151                            const std::string& arg_name) {
152   AttrValue attr_value;
153   if (absl::StartsWith(attr_type, kListPrefix)) {
154     if (!py::isinstance<py::list>(value) && !py::isinstance<py::tuple>(value)) {
155       throw PyTypeError(absl::StrCat(
156           "Expected list for attr ", arg_name, ", obtained ",
157           py::type::handle_of(value).attr("__name__").cast<std::string>(),
158           " instead."));
159     }
160   }
161 
162   try {
163     const AttributeType type_enum = AttributeTypeFromName(attr_type);
164     switch (type_enum) {
165       case AttributeType::STRING:
166         attr_value.set_s(value.cast<std::string>());
167         break;
168       case AttributeType::LIST_STRING: {
169         auto* list = attr_value.mutable_list();
170         for (const auto& v : value) {
171           list->add_s(v.cast<std::string>());
172         }
173         break;
174       }
175       case AttributeType::INT:
176         attr_value.set_i(MakeInt(value));
177         break;
178       case AttributeType::LIST_INT: {
179         auto* list = attr_value.mutable_list();
180         for (const auto& v : value) {
181           list->add_i(MakeInt(v));
182         }
183         break;
184       }
185       case AttributeType::FLOAT:
186         attr_value.set_f(value.cast<float>());
187         break;
188       case AttributeType::LIST_FLOAT: {
189         auto* list = attr_value.mutable_list();
190         for (const auto& v : value) {
191           list->add_f(v.cast<float>());
192         }
193         break;
194       }
195       case AttributeType::BOOL:
196         attr_value.set_b(MakeBool(value, arg_name));
197         break;
198       case AttributeType::LIST_BOOL: {
199         auto* list = attr_value.mutable_list();
200         for (const auto& v : value) {
201           list->add_b(MakeBool(v, arg_name));
202         }
203         break;
204       }
205       case AttributeType::DTYPE: {
206         attr_value.set_type(MakeType(value, arg_name));
207         break;
208       }
209       case AttributeType::LIST_DTYPE: {
210         auto* list = attr_value.mutable_list();
211         for (const auto& v : value) {
212           list->add_type(MakeType(v, arg_name));
213         }
214         break;
215       }
216       case AttributeType::SHAPE:
217         attr_value.mutable_shape()->ParseFromString(MakeShape(value, arg_name));
218         break;
219       case AttributeType::LIST_SHAPE: {
220         auto* list = attr_value.mutable_list();
221         for (const auto& v : value) {
222           list->add_shape()->ParseFromString(MakeShape(v, arg_name));
223         }
224         break;
225       }
226       case AttributeType::TENSOR:
227         attr_value.mutable_tensor()->ParseFromString(
228             ToAttributeType(value, type_enum)
229                 .attr(kSerialize)()
230                 .cast<std::string>());
231         break;
232       case AttributeType::LIST_TENSOR: {
233         auto* list = attr_value.mutable_list();
234         for (const auto& v : value) {
235           list->add_tensor()->ParseFromString(
236               ToAttributeType(v, AttributeType::TENSOR)
237                   .attr(kSerialize)()
238                   .cast<std::string>());
239         }
240         break;
241       }
242       default:
243         throw PyTypeError(absl::StrCat("Unrecognized Attr type ", attr_type,
244                                        " for ", arg_name, "."));
245     }
246   } catch (const py::error_already_set& e) {
247     throw e;
248   } catch (...) {
249     throw PyTypeError(absl::StrCat(
250         "Expected ", attr_type, " for argument '", arg_name, "' not ",
251         value.attr("__repr__")().cast<std::string>(), "."));
252   }
253 
254   return attr_value;
255 }
256 
AttrValueToSerializedBytesPyObject(const AttrValue & attr_value)257 py::object AttrValueToSerializedBytesPyObject(const AttrValue& attr_value) {
258   std::string serialized_attr_value;
259   if (!attr_value.SerializeToString(&serialized_attr_value)) {
260     throw std::runtime_error("Failed to serialized AttrValue to string");
261   }
262   return py::reinterpret_borrow<py::object>(py::bytes(serialized_attr_value));
263 }
264 
AssertSatisfiesLengthConstraint(const py::object & attr,const AttrDef & attr_def,const std::string & attr_name,const std::string & op_type_name)265 void AssertSatisfiesLengthConstraint(const py::object& attr,
266                                      const AttrDef& attr_def,
267                                      const std::string& attr_name,
268                                      const std::string& op_type_name) {
269   if (!absl::StartsWith(attr_def.type(), kListPrefix)) return;
270   int attr_size = attr.cast<py::list>().size();
271   if (attr_def.has_minimum() && attr_size < attr_def.minimum()) {
272     throw PyValueError(absl::StrCat("Attr '", attr_name, "' of '", op_type_name,
273                                     "' Op passed list of length ", attr_size,
274                                     " less than minimum ", attr_def.minimum(),
275                                     "."));
276   }
277 }
278 
AssertSatisfiesAllowedStringConstraint(const std::string & attr,const RepeatedPtrField<std::string> & allowed_values,const std::string & attr_name,const std::string & op_type_name)279 void AssertSatisfiesAllowedStringConstraint(
280     const std::string& attr,
281     const RepeatedPtrField<std::string>& allowed_values,
282     const std::string& attr_name, const std::string& op_type_name) {
283   if (!absl::c_linear_search(allowed_values, attr)) {
284     const std::string allowed_values_str =
285         absl::StrJoin(allowed_values, "\", \"");
286     throw PyValueError(absl::StrCat("Attr '", attr_name, "' of '", op_type_name,
287                                     "' Op passed string '", attr,
288                                     "' not in: \"", allowed_values_str, "\"."));
289   }
290 }
291 
AssertSatisfiesAllowedStringsConstraint(const AttrValue & attr,const AttrDef & attr_def,const std::string & attr_name,const AttributeType attr_type,const std::string & op_type_name)292 void AssertSatisfiesAllowedStringsConstraint(const AttrValue& attr,
293                                              const AttrDef& attr_def,
294                                              const std::string& attr_name,
295                                              const AttributeType attr_type,
296                                              const std::string& op_type_name) {
297   if (!attr_def.has_allowed_values()) return;
298   const auto& allowed_values = attr_def.allowed_values().list().s();
299   if (attr_type == AttributeType::STRING) {
300     AssertSatisfiesAllowedStringConstraint(attr.s(), allowed_values, attr_name,
301                                            op_type_name);
302   } else if (attr_type == AttributeType::LIST_STRING) {
303     for (const std::string& v : attr.list().s()) {
304       AssertSatisfiesAllowedStringConstraint(v, allowed_values, attr_name,
305                                              op_type_name);
306     }
307   }
308 }
309 
AssertSatisfiesIntMinimumConstraint(const AttrValue & attr,const AttrDef & attr_def,const std::string & attr_name,const AttributeType attr_type,const std::string & op_type_name)310 void AssertSatisfiesIntMinimumConstraint(const AttrValue& attr,
311                                          const AttrDef& attr_def,
312                                          const std::string& attr_name,
313                                          const AttributeType attr_type,
314                                          const std::string& op_type_name) {
315   if (attr_def.has_minimum() && attr_type == AttributeType::INT &&
316       attr.i() < attr_def.minimum()) {
317     throw PyValueError(absl::StrCat(
318         "Attr '", attr_name, "' of '", op_type_name, "' Op passed ", attr.i(),
319         " less than minimum ", attr_def.minimum(), "."));
320   }
321 }
322 
AssertSatisfiesAllowedListAttrTypeConstraint(const std::string & type_attr,const AllowedAttrMap & allowed_list_attr_map,const py::object & dtype,const std::string & input_name)323 void AssertSatisfiesAllowedListAttrTypeConstraint(
324     const std::string& type_attr, const AllowedAttrMap& allowed_list_attr_map,
325     const py::object& dtype, const std::string& input_name) {
326   auto it = allowed_list_attr_map.find(type_attr);
327   if (it != allowed_list_attr_map.end() &&
328       !it->second.contains(dtype.cast<DataType>())) {
329     std::vector<std::string> allowed_values;
330     for (const auto& allowed_value : it->second) {
331       allowed_values.emplace_back(
332           DataTypeToPybindObject(static_cast<DataType>(allowed_value))
333               .attr("name")
334               .cast<std::string>());
335     }
336     throw PyTypeError(absl::StrCat("Value passed to parameter '", input_name,
337                                    "' has DataType ",
338                                    dtype.attr("name").cast<std::string>(),
339                                    " not in list of allowed values: ",
340                                    absl::StrJoin(allowed_values, ", ")));
341   }
342 }
343 
AssertSatisfiesDTypeConstraint(const int attr,const RepeatedField<int> & allowed_values,const std::string & attr_name,const std::string & op_type_name)344 void AssertSatisfiesDTypeConstraint(const int attr,
345                                     const RepeatedField<int>& allowed_values,
346                                     const std::string& attr_name,
347                                     const std::string& op_type_name) {
348   if (!absl::c_linear_search(allowed_values, attr)) {
349     std::string allowed_vals_str;
350     for (const auto& v : allowed_values) {
351       if (!allowed_vals_str.empty()) absl::StrAppend(&allowed_vals_str, ", ");
352       absl::StrAppend(&allowed_vals_str,
353                       DataTypeToPybindObject(static_cast<DataType>(v))
354                           .attr("name")
355                           .cast<std::string>());
356     }
357     throw PyTypeError(absl::StrCat(
358         "Value passed to parameter '", attr_name, "' has DataType ",
359         DataTypeToPybindObject(static_cast<DataType>(attr))
360             .attr("name")
361             .cast<std::string>(),
362         " not in list of allowed values: ", allowed_vals_str));
363   }
364 }
365 
AssertSatisfiesTypeConstraint(const AttrValue & attr,const AttrDef & attr_def,const std::string & attr_name,const AttributeType attr_type,const std::string & op_type_name)366 void AssertSatisfiesTypeConstraint(const AttrValue& attr,
367                                    const AttrDef& attr_def,
368                                    const std::string& attr_name,
369                                    const AttributeType attr_type,
370                                    const std::string& op_type_name) {
371   if (!attr_def.has_allowed_values()) return;
372   const auto& allowed_values = attr_def.allowed_values().list().type();
373   if (attr_type == AttributeType::DTYPE) {
374     AssertSatisfiesDTypeConstraint(attr.type(), allowed_values, attr_name,
375                                    op_type_name);
376   } else if (attr_type == AttributeType::LIST_DTYPE) {
377     for (const auto& v : attr.list().type()) {
378       AssertSatisfiesDTypeConstraint(v, allowed_values, attr_name,
379                                      op_type_name);
380     }
381   }
382 }
383 
384 // Returns the OpDef from the global registry. Raises runtime_error if the
385 // OpDef is not found.
GetOpDef(const std::string & op_type_name,int producer_version)386 const OpDef* GetOpDef(const std::string& op_type_name, int producer_version) {
387   const OpDef* op_def = nullptr;
388   auto status = OpRegistry::Global()->LookUpOpDef(op_type_name, &op_def);
389   if (!status.ok() || op_def == nullptr) {
390     throw std::runtime_error(
391         absl::StrCat("Unrecognized Op name ", op_type_name));
392   }
393   return op_def;
394 }
395 
396 // Extracts the default_type_attr_map and the allowed_list_attr_map from the
397 // OpDef.
ExtractDefaultTypesAndAllowedTypes(const OpDef & op_def,DefaultAttrMap & default_type_attr_map,AllowedAttrMap & allowed_list_attr_map)398 void ExtractDefaultTypesAndAllowedTypes(const OpDef& op_def,
399                                         DefaultAttrMap& default_type_attr_map,
400                                         AllowedAttrMap& allowed_list_attr_map) {
401   for (const AttrDef& attr_def : op_def.attr()) {
402     if (attr_def.type() != kType) continue;
403     const std::string& attr_name = attr_def.name();
404     if (attr_def.has_default_value()) {
405       default_type_attr_map[attr_name] =
406           DataTypeToPybindObject(attr_def.default_value().type());
407     }
408     if (attr_def.has_allowed_values()) {
409       const auto& types = attr_def.allowed_values().list().type();
410       absl::flat_hash_set<int> allowed_values(types.begin(), types.end());
411       allowed_list_attr_map[attr_name] = std::move(allowed_values);
412     }
413   }
414 }
415 
416 // Returns the input Tensor corresponding to `input_name` from `keywords`.
417 // Updates `input_name` if it is a Python keyword or built-in.
GetInputTensor(std::string & input_name,const py::dict & keywords,const OpDef & op_def)418 py::object GetInputTensor(std::string& input_name, const py::dict& keywords,
419                           const OpDef& op_def) {
420   if (keywords.contains(input_name)) {
421     return py::reinterpret_borrow<py::object>(
422         keywords.attr(kPop)(input_name.c_str()));
423   } else if (keywords.contains(absl::StrCat(input_name, "_"))) {
424     absl::StrAppend(&input_name, "_");
425     return py::reinterpret_borrow<py::object>(
426         keywords.attr(kPop)(input_name.c_str()));
427   } else {
428     throw PyTypeError(absl::StrCat("No argument for input ", input_name,
429                                    " found in ", op_def.DebugString()));
430   }
431 }
432 
433 // Returns the input Tensor's DType.
GetInputType(const py::object & input_tensor,const ArgDef & input_arg,const AllowedAttrMap & allowed_list_attr_map,const std::string & op_type_name,const std::string & input_name,py::dict & attrs,absl::flat_hash_map<std::string,std::string> & inferred_from)434 py::object GetInputType(
435     const py::object& input_tensor, const ArgDef& input_arg,
436     const AllowedAttrMap& allowed_list_attr_map,
437     const std::string& op_type_name, const std::string& input_name,
438     py::dict& attrs,
439     absl::flat_hash_map<std::string, std::string>& inferred_from) {
440   py::object dtype = input_tensor.attr(kDType);
441   py::object base_type = dtype.attr(kBaseDType);
442 
443   // Check that the input_arg and the input are compatible.
444   if (input_arg.type() != DataType::DT_INVALID &&
445       input_arg.type() != dtype.cast<DataType>() &&
446       input_arg.type() != base_type.cast<DataType>()) {
447     throw PyTypeError(absl::StrCat("Input '", input_name, "' of '",
448                                    op_type_name, "' Op has type ",
449                                    base_type.attr("name").cast<std::string>(),
450                                    " that does not match expected type of ",
451                                    DataTypeToPybindObject(input_arg.type())
452                                        .attr("name")
453                                        .cast<std::string>(),
454                                    "."));
455   }
456 
457   const std::string& type_attr = input_arg.type_attr();
458   if (!type_attr.empty()) {
459     if (attrs.contains(type_attr) &&
460         attrs[type_attr.c_str()].cast<py::object>() != base_type) {
461       throw PyTypeError(absl::StrCat(
462           "Input '", input_name, "' of '", op_type_name, "' Op has type ",
463           base_type.attr("name").cast<std::string>(),
464           " that does not match type ",
465           attrs[type_attr.c_str()].attr("name").cast<std::string>(),
466           " of argument '", inferred_from.at(type_attr), "'."));
467     } else {
468       AssertSatisfiesAllowedListAttrTypeConstraint(
469           type_attr, allowed_list_attr_map, base_type, input_name);
470       attrs[type_attr.c_str()] = base_type;
471       inferred_from[input_arg.type_attr()] = input_name;
472     }
473   } else if (base_type.cast<DataType>() != input_arg.type()) {
474     // Added to match the python behaviour.
475     throw PyTypeError("Unreachable");
476   }
477   if (input_arg.is_ref()) return dtype;
478   return base_type;
479 }
480 
481 // Extracts `inputs`, `input_types` and `attrs`.
ExtractInputsAndAttrs(const std::string & op_type_name,const OpDef & op_def,const AllowedAttrMap & allowed_list_attr_map,py::dict & keywords,py::dict & attrs,py::list & inputs,py::list & input_types)482 void ExtractInputsAndAttrs(const std::string& op_type_name, const OpDef& op_def,
483                            const AllowedAttrMap& allowed_list_attr_map,
484                            py::dict& keywords, py::dict& attrs,
485                            py::list& inputs, py::list& input_types) {
486   absl::flat_hash_map<std::string, std::string> inferred_from;
487   for (const ArgDef& input_arg : op_def.input_arg()) {
488     std::string input_name = input_arg.name();
489     py::object input_tensor = GetInputTensor(input_name, keywords, op_def);
490     inputs.append(input_tensor);
491     py::object dtype =
492         GetInputType(input_tensor, input_arg, allowed_list_attr_map,
493                      op_type_name, input_name, attrs, inferred_from);
494     input_types.append(dtype);
495   }
496 }
497 
498 // Extracts the remaining attributes from the OpDef to `attrs`.
ExtractRemainingAttrs(const std::string & op_type_name,const OpDef & op_def,const py::dict & keywords,const DefaultAttrMap & default_type_attr_map,py::dict & attrs)499 void ExtractRemainingAttrs(const std::string& op_type_name, const OpDef& op_def,
500                            const py::dict& keywords,
501                            const DefaultAttrMap& default_type_attr_map,
502                            py::dict& attrs) {
503   for (const AttrDef& attr : op_def.attr()) {
504     const std::string& attr_name = attr.name();
505     if (attrs.contains(attr_name)) {
506       if (keywords.contains(attr_name)) {
507         throw PyTypeError(
508             absl::StrCat("Should not specify value for inferred attr '",
509                          attr_name, "' for ", op_type_name, "."));
510       }
511       continue;
512     }
513     if (keywords.contains(attr_name)) {
514       attrs[attr_name.c_str()] =
515           keywords.attr(kPop)(attr_name.c_str()).cast<py::object>();
516     } else if (keywords.contains(absl::StrCat(attr_name, "_"))) {
517       attrs[attr_name.c_str()] =
518           keywords.attr(kPop)(absl::StrCat(attr_name, "_").c_str())
519               .cast<py::object>();
520     } else if (default_type_attr_map.contains(attr_name)) {
521       attrs[attr_name.c_str()] = default_type_attr_map.at(attr_name);
522     } else {
523       throw PyTypeError(absl::StrCat("No argument found for attr ", attr_name,
524                                      " for ", op_type_name));
525     }
526   }
527 }
528 
SetAttrProto(const std::string & key,const AttrValue & value,py::dict & attr_protos,AttrProtosMap & attr_protos_map)529 void SetAttrProto(const std::string& key, const AttrValue& value,
530                   py::dict& attr_protos, AttrProtosMap& attr_protos_map) {
531   attr_protos_map[key] = value;
532   attr_protos[key.c_str()] = AttrValueToSerializedBytesPyObject(value);
533 }
534 
535 // Converts attr values to AttrValues.
ExtractAttrProto(const std::string & op_type_name,const OpDef & op_def,const py::dict & attrs,py::dict & attr_protos,AttrProtosMap & attr_protos_map)536 void ExtractAttrProto(const std::string& op_type_name, const OpDef& op_def,
537                       const py::dict& attrs, py::dict& attr_protos,
538                       AttrProtosMap& attr_protos_map) {
539   for (const AttrDef& attr_def : op_def.attr()) {
540     const std::string& attr_name = attr_def.name();
541     const py::object attr = attrs[attr_name.c_str()].cast<py::object>();
542 
543     if (attr_def.has_default_value() && attr.is_none()) {
544       SetAttrProto(attr_name, attr_def.default_value(), attr_protos,
545                    attr_protos_map);
546       continue;
547     }
548 
549     const AttrValue attr_value =
550         ValueToAttrValue(attr, attr_def.type(), attr_name);
551     const AttributeType attr_type = AttributeTypeFromName(attr_def.type());
552     AssertSatisfiesLengthConstraint(attr, attr_def, attr_name, op_type_name);
553     AssertSatisfiesAllowedStringsConstraint(attr_value, attr_def, attr_name,
554                                             attr_type, op_type_name);
555     AssertSatisfiesIntMinimumConstraint(attr_value, attr_def, attr_name,
556                                         attr_type, op_type_name);
557     AssertSatisfiesTypeConstraint(attr_value, attr_def, attr_name, attr_type,
558                                   op_type_name);
559     SetAttrProto(attr_name, attr_value, attr_protos, attr_protos_map);
560   }
561 }
562 
MaybeGetAttrValue(const py::dict & attr_protos,const AttrProtosMap & attr_protos_map,const std::string & attr_name,const std::string & op_type_name)563 inline const AttrValue& MaybeGetAttrValue(const py::dict& attr_protos,
564                                           const AttrProtosMap& attr_protos_map,
565                                           const std::string& attr_name,
566                                           const std::string& op_type_name) {
567   auto it = attr_protos_map.find(attr_name);
568   if (it != attr_protos_map.end()) return it->second;
569   throw PyTypeError(absl::StrCat(
570       "Inconsistent OpDef for '", op_type_name, "', missing attr '", attr_name,
571       "' from '", attr_protos.attr("__repr__")().cast<std::string>(), "'."));
572 }
573 
ExtractOutputStructure(const std::string & op_type_name,const OpDef & op_def,const py::dict & attr_protos,const AttrProtosMap & attr_protos_map,py::list & output_structure)574 void ExtractOutputStructure(const std::string& op_type_name,
575                             const OpDef& op_def, const py::dict& attr_protos,
576                             const AttrProtosMap& attr_protos_map,
577                             py::list& output_structure) {
578   for (const ArgDef& arg : op_def.output_arg()) {
579     if (!arg.number_attr().empty()) {
580       const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map,
581                                             arg.number_attr(), op_type_name);
582       output_structure.append(value.i());
583     } else if (!arg.type_attr().empty()) {
584       const auto& _ = MaybeGetAttrValue(attr_protos, attr_protos_map,
585                                         arg.type_attr(), op_type_name);
586       output_structure.append(py::none());
587     } else if (!arg.type_list_attr().empty()) {
588       const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map,
589                                             arg.type_list_attr(), op_type_name);
590       output_structure.append(value.list().type_size());
591     } else {
592       output_structure.append(py::none());
593     }
594   }
595 }
596 
CheckAllInputsUsed(const std::string & op_type_name,const py::dict & keywords)597 void CheckAllInputsUsed(const std::string& op_type_name,
598                         const py::dict& keywords) {
599   if (!keywords.empty()) {
600     std::string all_keywords;
601     for (const auto& item : keywords) {
602       if (!all_keywords.empty()) absl::StrAppend(&all_keywords, ", ");
603       absl::StrAppend(&all_keywords, item.first.cast<std::string>());
604     }
605     throw PyTypeError(absl::StrCat(
606         op_type_name, " got unexpected keyword arguments: ", all_keywords));
607   }
608 }
609 
610 }  // namespace
611 
612 // This module provides a subset of the functionality from op_def_library.py
613 // and relies on op_def_library_test.py for test coverage.
PYBIND11_MODULE(_op_def_library_pybind,m)614 PYBIND11_MODULE(_op_def_library_pybind, m) {
615   // Method assumes all inputs in `keywords` are of type tf.Tensor.
616   m.def("process_inputs", [](std::string& op_type_name, int producer_version,
617                              py::dict& keywords) {
618     const OpDef* op_def = GetOpDef(op_type_name, producer_version);
619     MaybeRaiseFromStatus(CheckOpDeprecation(*op_def, producer_version));
620 
621     DefaultAttrMap default_type_attr_map;
622     AllowedAttrMap allowed_list_attr_map;
623     AttrProtosMap attr_protos_map;
624     py::dict attrs, attr_protos;
625     py::list inputs, input_types, output_structure;
626 
627     ExtractDefaultTypesAndAllowedTypes(*op_def, default_type_attr_map,
628                                        allowed_list_attr_map);
629     ExtractInputsAndAttrs(op_type_name, *op_def, allowed_list_attr_map,
630                           keywords, attrs, inputs, input_types);
631     ExtractRemainingAttrs(op_type_name, *op_def, keywords,
632                           default_type_attr_map, attrs);
633     ExtractAttrProto(op_type_name, *op_def, attrs, attr_protos,
634                      attr_protos_map);
635     ExtractOutputStructure(op_type_name, *op_def, attr_protos, attr_protos_map,
636                            output_structure);
637     CheckAllInputsUsed(op_type_name, keywords);
638 
639     return py::make_tuple(attr_protos, inputs, input_types, output_structure);
640   });
641 };
642