xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/op_def_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/op_def_util.h"
16 
17 #include <map>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
24 #include "tensorflow/python/util/util.h"
25 
26 using ::tensorflow::swig::GetRegisteredPyObject;
27 
28 #if PY_MAJOR_VERSION < 3
29 // Python 2.x:
30 #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
31 #define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
32 #define PY_INT_CHECK(x) (PyInt_Check(x))
33 #define PY_INT_TYPE PyInt_Type
34 #define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
35 #else
36 // Python 3.x:
37 #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
38 #define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
39 #define PY_INT_CHECK(x) (PyLong_Check(x))
40 #define PY_INT_TYPE PyLong_Type
41 #define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
42 #endif
43 
44 namespace tensorflow {
45 
46 namespace {
47 
AttributeTypeNameMap()48 const std::map<std::string, AttributeType>* AttributeTypeNameMap() {
49   static auto* type_map = new std::map<std::string, AttributeType>(
50       {{"any", AttributeType::ANY},
51        {"float", AttributeType::FLOAT},
52        {"int", AttributeType::INT},
53        {"string", AttributeType::STRING},
54        {"bool", AttributeType::BOOL},
55        {"shape", AttributeType::SHAPE},
56        {"type", AttributeType::DTYPE},
57        {"tensor", AttributeType::TENSOR},
58        {"list(any)", AttributeType::LIST_ANY},
59        {"list(float)", AttributeType::LIST_FLOAT},
60        {"list(int)", AttributeType::LIST_INT},
61        {"list(string)", AttributeType::LIST_STRING},
62        {"list(bool)", AttributeType::LIST_BOOL},
63        {"list(type)", AttributeType::LIST_DTYPE},
64        {"list(shape)", AttributeType::LIST_SHAPE},
65        {"list(tensor)", AttributeType::LIST_TENSOR}});
66   return type_map;
67 }
68 
69 // Note: we define functors for converting value types (rather than simple
70 // functions) so we can define a generic ConvertListAttr method.  These
71 // functors all return a new reference on success, or nullptr on failure.
72 // They do not (necessarily) call PyErr_SetString.
73 
74 struct ConvertAnyFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertAnyFunctor75   Safe_PyObjectPtr operator()(PyObject* value) {
76     Py_INCREF(value);
77     return Safe_PyObjectPtr(value);
78   }
79 };
80 
81 struct ConvertFloatFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertFloatFunctor82   Safe_PyObjectPtr operator()(PyObject* value) {
83     Safe_PyObjectPtr result;
84     if (PyFloat_Check(value)) {
85       Py_INCREF(value);
86       result.reset(value);
87     } else if (!PY_STRING_CHECK(value)) {
88       result.reset(PyObject_CallFunctionObjArgs(
89           reinterpret_cast<PyObject*>(&PyFloat_Type), value, nullptr));
90     }
91     return result;
92   }
93 };
94 
95 struct ConvertIntFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertIntFunctor96   Safe_PyObjectPtr operator()(PyObject* value) {
97     Safe_PyObjectPtr result;
98     if (PY_INT_CHECK(value)) {
99       Py_INCREF(value);
100       result.reset(value);
101     } else if (!PY_STRING_CHECK(value)) {
102       result.reset(PyObject_CallFunctionObjArgs(
103           reinterpret_cast<PyObject*>(&PY_INT_TYPE), value, nullptr));
104     }
105     return result;
106   }
107 };
108 
109 struct ConvertStringFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertStringFunctor110   Safe_PyObjectPtr operator()(PyObject* value) {
111     Safe_PyObjectPtr result;
112     if (PY_STRING_CHECK(value)) {
113       Py_INCREF(value);
114       result.reset(value);
115     }
116     return result;
117   }
118 };
119 
120 // TODO(edloper): Should we allow ints (or any other values) to be converted
121 // to booleans?  Currently, TensorFlow does not do this conversion for attribute
122 // values in _MakeBool or make_bool.
123 struct ConvertBoolFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertBoolFunctor124   Safe_PyObjectPtr operator()(PyObject* value) {
125     Safe_PyObjectPtr result;
126     if (PyBool_Check(value)) {
127       Py_INCREF(value);
128       result.reset(value);
129     }
130     return result;
131   }
132 };
133 
134 struct ConvertDTypeFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertDTypeFunctor135   Safe_PyObjectPtr operator()(PyObject* value) {
136     Safe_PyObjectPtr result;
137     // The following symbols are registered in op_def_library.py
138     static PyObject* dtype = GetRegisteredPyObject("tf.dtypes.DType");
139     static PyObject* as_dtype = GetRegisteredPyObject("tf.dtypes.as_dtype");
140     if (reinterpret_cast<PyObject*>(value->ob_type) == dtype) {
141       Py_INCREF(value);
142       result.reset(value);
143     } else {
144       result.reset(PyObject_CallFunctionObjArgs(as_dtype, value, nullptr));
145     }
146     return result;
147   }
148 };
149 
150 struct ConvertTensorShapeFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertTensorShapeFunctor151   Safe_PyObjectPtr operator()(PyObject* value) {
152     Safe_PyObjectPtr result;
153     // The following symbols are registered in op_def_library.py
154     static PyObject* shape = GetRegisteredPyObject("tf.TensorShape");
155     static PyObject* as_shape = GetRegisteredPyObject("tf.as_shape");
156     if (reinterpret_cast<PyObject*>(value->ob_type) == shape) {
157       Py_INCREF(value);
158       result.reset(value);
159     } else {
160       result.reset(PyObject_CallFunctionObjArgs(as_shape, value, nullptr));
161     }
162     return result;
163   }
164 };
165 
166 struct ConvertTensorProtoFunctor {
operator ()tensorflow::__anon22fe53b70111::ConvertTensorProtoFunctor167   Safe_PyObjectPtr operator()(PyObject* value) {
168     Safe_PyObjectPtr result;
169     // The following symbols are registered in op_def_library.py
170     static PyObject* tensor_proto = GetRegisteredPyObject("tf.TensorProto");
171     static PyObject* text_format_parse =
172         GetRegisteredPyObject("text_format.Parse");
173     if (reinterpret_cast<PyObject*>(value->ob_type) == tensor_proto) {
174       Py_INCREF(value);
175       result.reset(value);
176     } else if (PY_STRING_CHECK(value)) {
177       result.reset(PyObject_CallObject(tensor_proto, nullptr));
178       if (result) {
179         if (!PyObject_CallFunctionObjArgs(text_format_parse, value,
180                                           result.get(), nullptr)) {
181           return nullptr;
182         }
183       }
184     }
185     return result;
186   }
187 };
188 
189 // Converts `value` to a list of elements with the same type, using
190 // `convert_functor` to convert each element.
191 template <typename T>
ConvertListAttr(PyObject * value,T convert_functor)192 Safe_PyObjectPtr ConvertListAttr(PyObject* value, T convert_functor) {
193   // Copy the list.
194   Safe_PyObjectPtr result(PySequence_List(value));
195   if (!result) return nullptr;
196 
197   // Check the type of each item in the list.
198   Py_ssize_t len = PySequence_Fast_GET_SIZE(result.get());
199   PyObject** items = PySequence_Fast_ITEMS(result.get());
200   for (Py_ssize_t i = 0; i < len; ++i) {
201     if (!PyFloat_Check(value)) {
202       Safe_PyObjectPtr item = convert_functor(items[i]);
203       if (!item) return nullptr;
204       PySequence_SetItem(result.get(), i, item.get());
205     }
206   }
207   return result;
208 }
209 
210 // Returns the given `value` value, converted to the indicated type.
211 // Returns nullptr if `value` is not convertible.
ConvertAttrOrNull(PyObject * value,AttributeType attr_type)212 Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
213   switch (attr_type) {
214     case AttributeType::ANY:
215       return ConvertAnyFunctor()(value);
216     case AttributeType::FLOAT:
217       return ConvertFloatFunctor()(value);
218     case AttributeType::INT:
219       return ConvertIntFunctor()(value);
220     case AttributeType::STRING:
221       return ConvertStringFunctor()(value);
222     case AttributeType::BOOL:
223       return ConvertBoolFunctor()(value);
224     case AttributeType::DTYPE:
225       return ConvertDTypeFunctor()(value);
226     case AttributeType::SHAPE:
227       return ConvertTensorShapeFunctor()(value);
228     case AttributeType::TENSOR:
229       return ConvertTensorProtoFunctor()(value);
230     case AttributeType::LIST_ANY:
231       return ConvertListAttr(value, ConvertAnyFunctor());
232     case AttributeType::LIST_FLOAT:
233       return ConvertListAttr(value, ConvertFloatFunctor());
234     case AttributeType::LIST_INT:
235       return ConvertListAttr(value, ConvertIntFunctor());
236     case AttributeType::LIST_STRING:
237       return ConvertListAttr(value, ConvertStringFunctor());
238     case AttributeType::LIST_BOOL:
239       return ConvertListAttr(value, ConvertBoolFunctor());
240     case AttributeType::LIST_DTYPE:
241       return ConvertListAttr(value, ConvertDTypeFunctor());
242     case AttributeType::LIST_SHAPE:
243       return ConvertListAttr(value, ConvertTensorShapeFunctor());
244     case AttributeType::LIST_TENSOR:
245       return ConvertListAttr(value, ConvertTensorProtoFunctor());
246     default:
247       return nullptr;
248   }
249 }
250 
251 // Returns a new reference to Py_True or Py_False depending on b.
PyBool_FromBool(bool b)252 PyObject* PyBool_FromBool(bool b) {
253   PyObject* result = b ? Py_True : Py_False;
254   Py_INCREF(result);
255   return result;
256 }
257 
AttrValueListToPyObject(AttrValue::ListValue list)258 Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
259   if (list.s_size()) {
260     Safe_PyObjectPtr result(PyList_New(list.s_size()));
261     for (int i = 0; i < list.s_size(); ++i) {
262       PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
263     }
264     return result;
265   } else if (list.i_size()) {
266     Safe_PyObjectPtr result(PyList_New(list.i_size()));
267     for (int i = 0; i < list.i_size(); ++i) {
268       PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
269     }
270     return result;
271   } else if (list.f_size()) {
272     Safe_PyObjectPtr result(PyList_New(list.f_size()));
273     for (int i = 0; i < list.f_size(); ++i) {
274       PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
275     }
276     return result;
277   } else if (list.b_size()) {
278     Safe_PyObjectPtr result(PyList_New(list.b_size()));
279     for (int i = 0; i < list.b_size(); ++i) {
280       PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
281     }
282     return result;
283   } else if (list.type_size()) {
284     Safe_PyObjectPtr result(PyList_New(list.type_size()));
285     for (int i = 0; i < list.type_size(); ++i) {
286       Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
287       Py_INCREF(item.get());
288       PyList_SET_ITEM(result.get(), i, item.get());
289     }
290     return result;
291   } else if (list.shape_size()) {
292     Safe_PyObjectPtr result(PyList_New(list.shape_size()));
293     for (int i = 0; i < list.shape_size(); ++i) {
294       Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
295       Py_INCREF(item.get());
296       PyList_SET_ITEM(result.get(), i, item.get());
297     }
298     return result;
299   } else if (list.tensor_size() || list.func_size()) {
300     // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
301     PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
302     return nullptr;
303   } else {
304     // Empty list
305     return Safe_PyObjectPtr(PyList_New(0));
306   }
307 }
308 
309 }  // namespace
310 
AttributeTypeFromName(const std::string & type_name)311 AttributeType AttributeTypeFromName(const std::string& type_name) {
312   const auto* type_map = AttributeTypeNameMap();
313   auto it = type_map->find(type_name);
314   return it != type_map->end() ? it->second : AttributeType::UNKNOWN;
315 }
316 
AttributeTypeToName(AttributeType attr_type)317 std::string AttributeTypeToName(AttributeType attr_type) {
318   for (const auto& pair : *AttributeTypeNameMap()) {
319     if (pair.second == attr_type) {
320       return pair.first;
321     }
322   }
323   return "<unknown>";
324 }
325 
ConvertPyObjectToAttributeType(PyObject * value,AttributeType type)326 Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
327                                                 AttributeType type) {
328   Safe_PyObjectPtr result = ConvertAttrOrNull(value, type);
329   if (!result) {
330     auto err = absl::StrCat("Failed to convert value of type '",
331                             value->ob_type->tp_name, "' to type '",
332                             AttributeTypeToName(type), "'.");
333     PyErr_SetString(PyExc_TypeError, err.c_str());
334   }
335 
336   return result;
337 }
338 
AttrValueToPyObject(const AttrValue & attr_value)339 Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
340   switch (attr_value.value_case()) {
341     case tensorflow::AttrValue::kS:
342       return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
343     case tensorflow::AttrValue::kI:
344       return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
345     case tensorflow::AttrValue::kF:
346       return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
347     case tensorflow::AttrValue::kB:
348       return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
349     case tensorflow::AttrValue::kType:
350       return DataTypeToPyObject(attr_value.type());
351     case tensorflow::AttrValue::kShape:
352       return TensorShapeProtoToPyObject(attr_value.shape());
353     case tensorflow::AttrValue::kList:
354       return AttrValueListToPyObject(attr_value.list());
355     default:
356       // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
357       PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
358       return nullptr;
359   }
360 }
361 
DataTypeToPyObject(const DataType & data_type)362 Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
363   Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
364   return ConvertDTypeFunctor()(enum_value.get());
365 }
366 
TensorShapeProtoToPyObject(const TensorShapeProto & tensor_shape)367 Safe_PyObjectPtr TensorShapeProtoToPyObject(
368     const TensorShapeProto& tensor_shape) {
369   if (tensor_shape.unknown_rank()) {
370     return ConvertTensorShapeFunctor()(Py_None);
371   } else {
372     Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
373     for (int i = 0; i < tensor_shape.dim_size(); ++i) {
374       PyTuple_SET_ITEM(dims.get(), i,
375                        PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
376     }
377     return ConvertTensorShapeFunctor()(dims.get());
378   }
379 }
380 
381 }  // namespace tensorflow
382