xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_api_parameter_converter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_api_parameter_converter.h"
16 
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/python/eager/pywrap_tensor.h"
21 #include "tensorflow/python/framework/op_def_util.h"
22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23 #include "tensorflow/python/util/util.h"
24 
25 #if PY_MAJOR_VERSION < 3
26 // Python 2.x:
27 #define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
28 #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
29 #else
30 // Python 3.x:
31 #define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
32 #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
33 #endif
34 
35 // Evaluate `condition`, and if it returns false then return false.
36 #define RETURN_IF_FALSE(condition)  \
37   do {                              \
38     if (!(condition)) return false; \
39   } while (0)
40 
41 #define PyList_ITEMS(o) (((PyListObject*)(o))->ob_item)
42 
43 namespace tensorflow {
44 
45 using InferredAttributes = PythonAPIInfo::InferredAttributes;
46 using ParamIndex = PythonAPIInfo::ParamIndex;
47 using Attribute = PythonAPIInfo::Attribute;
48 using InputWithFixedDType = PythonAPIInfo::InputWithFixedDType;
49 using InputsWithTypeAttr = PythonAPIInfo::InputsWithTypeAttr;
50 using InputsWithTypeListAttr = PythonAPIInfo::InputsWithTypeListAttr;
51 
52 namespace {
53 
54 // Returns `dtype._type_enum`.
GetAttr_TypeEnum(PyObject * dtype)55 Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
56   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
57   return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
58 }
59 
60 // Returns `tensor.dtype`.
GetAttr_DType(PyObject * tensor)61 Safe_PyObjectPtr GetAttr_DType(PyObject* tensor) {
62   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("dtype");
63   return Safe_PyObjectPtr(PyObject_GetAttr(tensor, attr));
64 }
65 
66 // Raises a TypeError with a message constructed by applying StrCat to the
67 // specified strings.  If an exception has already been set when this function
68 // is called, then add its message as a suffix to the message string.
69 template <typename... Args>
RaiseTypeError(Args...args)70 void RaiseTypeError(Args... args) {
71   string message = absl::StrCat(args...);
72   if (!PyErr_Occurred()) {
73     PyErr_SetString(PyExc_TypeError, message.c_str());
74   } else {
75     PyObject* exc_type;
76     PyObject* exc_value;
77     PyObject* exc_traceback;
78     PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
79     PyErr_Format(PyExc_TypeError, "%s: %S", message.c_str(), exc_value);
80     Py_XDECREF(exc_type);
81     Py_XDECREF(exc_value);
82     Py_XDECREF(exc_traceback);
83   }
84 }
85 
86 // Returns the DataType for a `tf.dtypes.DType` object (or DT_INVALID if it
87 // is not a valid DType object).
88 ABSL_MUST_USE_RESULT
DataTypeFromPyDType(PyObject * dtype)89 DataType DataTypeFromPyDType(PyObject* dtype) {
90   if (!dtype) {
91     return DT_INVALID;
92   }
93   Safe_PyObjectPtr enum_field = GetAttr_TypeEnum(dtype);
94   if (!enum_field) {
95     return DT_INVALID;
96   }
97   DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
98   return result;
99 }
100 
101 // Update `dtype` with an inferred dtype from `value`.  In particular, if
102 // `dtype == DT_INVALID` and `value` is a `Tensor`, then set `dtype` to
103 // `value.dtype`.  (If `dtype` is not `DT_INVALID`, or `value` is not a
104 // tensor, then do nothing.)  Returns false on exception.
105 ABSL_MUST_USE_RESULT
InferDType(PyObject * value,DataType & dtype)106 bool InferDType(PyObject* value, DataType& dtype) {
107   if (dtype != DT_INVALID) return true;  // Already have dtype.
108 
109   if (EagerTensor_CheckExact(value)) {
110     dtype = PyEagerTensor_Dtype(value);
111     return true;
112   }
113 
114   if (swig::IsTensor(value)) {
115     Safe_PyObjectPtr py_dtype = GetAttr_DType(value);
116     if (!py_dtype) return false;
117     dtype = DataTypeFromPyDType(py_dtype.get());  // set output parameter
118     return true;
119   }
120   return true;
121 }
122 
123 // Returns true if `dtype` is in `ok_dtypes`, or `ok_dtypes` is null or empty.
124 ABSL_MUST_USE_RESULT
IsOkDType(DataType dtype,const std::vector<DataType> * ok_dtypes)125 bool IsOkDType(DataType dtype, const std::vector<DataType>* ok_dtypes) {
126   return (ok_dtypes == nullptr || ok_dtypes->empty() ||
127           std::find(ok_dtypes->begin(), ok_dtypes->end(), dtype) !=
128               ok_dtypes->end());
129 }
130 
131 // Formatter for DataTypes for absl::StrJoin.
132 struct DataTypeFormatter {
operator ()tensorflow::__anon0a01641f0111::DataTypeFormatter133   void operator()(std::string* out, DataType dtype) const {
134     out->append(DataType_Name(dtype));
135   }
136 };
137 
138 // Converts `src` to a tensor using `tensor_converter.Convert`.  If `src` is
139 // replaced by a new value then decref the replaced value.  If an error
140 // occurs, then re-raise it as a TypeError with a prefix indicating the API
141 // name and the parameter name.
142 //
143 // Args:
144 //   src: The value that should be converted (in-place).
145 //   dtype: The dtype to convert `src` to, or DT_INVALID for unconstraned.
146 //     If DT_INVALID, then `dtype` will be set to the actual dtype of the
147 //     converted value.
148 //   tensor_converter: Class used to convert python values to tensors.
149 //   api_info: Information about the API we're converting this value for
150 //     (for error messages).
151 //   param_index: Index of the parameter we're converting (for error messages).
152 //   ok_dtypes: List of valid dtypes for conversion (optional).
153 //   default_dtype: Default dtype -- used if converting the value to a tensor
154 //     with unconstrained dtype returns a value not in ok_dtypes.
155 ABSL_MUST_USE_RESULT
ConvertToTensorInPlace(PyObject * & src,DataType & dtype,const PythonTensorConverter & tensor_converter,const PythonAPIInfo & api_info,int param_index,const std::vector<DataType> * ok_dtypes=nullptr,DataType default_dtype=DT_INVALID)156 bool ConvertToTensorInPlace(PyObject*& src, DataType& dtype,
157                             const PythonTensorConverter& tensor_converter,
158                             const PythonAPIInfo& api_info, int param_index,
159                             const std::vector<DataType>* ok_dtypes = nullptr,
160                             DataType default_dtype = DT_INVALID) {
161   bool inferred_dtype = (dtype == DT_INVALID);
162   Safe_PyObjectPtr converted = tensor_converter.Convert(src, dtype);
163   if (!converted) {
164     RaiseTypeError(api_info.api_name(), " argument ",
165                    api_info.param_names()[param_index]);
166     return false;
167   }
168 
169   if (inferred_dtype && !IsOkDType(dtype, ok_dtypes)) {
170     // Converting `src` to a tensor gave us a disallowed dtype; try again
171     // with `default_dtype`.
172     if (default_dtype == DT_INVALID) {
173       RaiseTypeError(api_info.api_name(), " argument ",
174                      api_info.param_names()[param_index], ": Expected one of {",
175                      absl::StrJoin(*ok_dtypes, ", ", DataTypeFormatter()),
176                      "}, but got ", DataType_Name(dtype));
177       return false;
178     } else {
179       dtype = default_dtype;
180       converted = tensor_converter.Convert(src, dtype);
181       if (!converted) {
182         RaiseTypeError(api_info.api_name(), " argument ",
183                        api_info.param_names()[param_index]);
184         return false;
185       }
186     }
187   }
188 
189   Py_DECREF(src);
190   src = converted.release();
191   return true;
192 }
193 
194 // Converts the specified attribute parameter to the expected type.  Modifies
195 // `params` in-place.  Returns true on success, or sets an exception and
196 // returns false on failure.
197 ABSL_MUST_USE_RESULT
ConvertAttribute(const Attribute & attr,const PythonAPIInfo & api_info,absl::Span<PyObject * > params)198 bool ConvertAttribute(const Attribute& attr, const PythonAPIInfo& api_info,
199                       absl::Span<PyObject*> params) {
200   if (attr.index == -1) return true;  // Inferred attribute.
201   PyObject* src = params[attr.index];
202   Safe_PyObjectPtr converted = ConvertPyObjectToAttributeType(src, attr.type);
203   if (!converted) {
204     RaiseTypeError(api_info.api_name(), " argument ",
205                    api_info.param_names()[attr.index]);
206     return false;
207   }
208   if (converted.get() != src) {
209     Py_DECREF(src);
210     params[attr.index] = converted.release();
211   }
212   return true;
213 }
214 
215 // Converts the specified fixed-dtype input parameter to a Tensor with the
216 // expected dtype.  Modifies `params` in-place.  Returns true on success, or
217 // sets an exception and returns false on failure.
218 ABSL_MUST_USE_RESULT
ConvertInputWithFixedDType(const InputWithFixedDType & input,const PythonTensorConverter & tensor_converter,const PythonAPIInfo & api_info,absl::Span<PyObject * > params)219 bool ConvertInputWithFixedDType(const InputWithFixedDType& input,
220                                 const PythonTensorConverter& tensor_converter,
221                                 const PythonAPIInfo& api_info,
222                                 absl::Span<PyObject*> params) {
223   DataType dtype = input.dtype;
224   PyObject*& src = params[input.index];
225   if (!input.is_list) {
226     RETURN_IF_FALSE(ConvertToTensorInPlace(src, dtype, tensor_converter,
227                                            api_info, input.index));
228   } else {
229     DCHECK(PyList_CheckExact(src));
230     PyObject** items = PyList_ITEMS(src);
231     Py_ssize_t len = PyList_GET_SIZE(src);
232     for (Py_ssize_t i = 0; i < len; ++i) {
233       RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter,
234                                              api_info, input.index));
235     }
236   }
237   return true;
238 }
239 
240 // Infers a consistent dtype for the specified collection of homogeneous-dtype
241 // input parameters, and converts those parameters to Tensors (or lists of
242 // Tensors) with that dtype. Modifies `params` in-place, and updates
243 // `inferred_attrs` with the inferred dtype (if it's not null).  Returns true
244 // on success, or sets an exception and returns false on failure.
245 ABSL_MUST_USE_RESULT
ConvertInputsWithTypeAttr(const InputsWithTypeAttr & input,const PythonTensorConverter & tensor_converter,const PythonAPIInfo & api_info,absl::Span<PyObject * > params,InferredAttributes * inferred_attrs)246 bool ConvertInputsWithTypeAttr(const InputsWithTypeAttr& input,
247                                const PythonTensorConverter& tensor_converter,
248                                const PythonAPIInfo& api_info,
249                                absl::Span<PyObject*> params,
250                                InferredAttributes* inferred_attrs) {
251   DataType dtype = DT_INVALID;
252   if (input.type_attr->index != -1) {
253     // explicit type attribute
254     PyObject* py_dtype = params[input.type_attr->index];
255     dtype = DataTypeFromPyDType(py_dtype);
256   } else {
257     // implicit type attribute: infer the dtype.
258     // First, check the single-tensor inputs.
259     for (ParamIndex index : input.tensor_params) {
260       RETURN_IF_FALSE(InferDType(params[index], dtype));
261       if (dtype != DT_INVALID) break;
262     }
263     // Next, check the list-of-tensor inputs.
264     if (dtype == DT_INVALID) {
265       for (ParamIndex index : input.tensor_list_params) {
266         PyObject* tensor_list = params[index];
267         DCHECK(PyList_CheckExact(tensor_list));
268         Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list);
269         PyObject** tensors = PyList_ITEMS(tensor_list);
270         for (Py_ssize_t i = 0; i < num_tensors; ++i) {
271           RETURN_IF_FALSE(InferDType(tensors[i], dtype));
272           if (dtype != DT_INVALID) break;
273         }
274         if (dtype != DT_INVALID) break;
275       }
276     }
277   }
278 
279   // Convert the single-tensor inputs to tensors.
280   for (ParamIndex index : input.tensor_params) {
281     RETURN_IF_FALSE(
282         ConvertToTensorInPlace(params[index], dtype, tensor_converter, api_info,
283                                index, &input.ok_dtypes, input.default_dtype));
284   }
285 
286   // Convert the list-of-tensor inputs to tensors.
287   for (ParamIndex index : input.tensor_list_params) {
288     PyObject* tensor_list = params[index];
289     DCHECK(PyList_CheckExact(tensor_list));
290     Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list);
291     PyObject** items = PyList_ITEMS(tensor_list);
292     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
293       RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter,
294                                              api_info, index, &input.ok_dtypes,
295                                              input.default_dtype));
296     }
297   }
298 
299   if (inferred_attrs) {
300     if (dtype == DT_INVALID) {
301       dtype = input.default_dtype;
302     }
303     // TODO(b/164980194) Should we raise an exception here if we didn't manage
304     // to infer a dtype?  (I.e., if there were no single-tensor inputs and all
305     // list-of-tensor inputs were empty, and there's no default dtype.)
306     int inferred_index = input.type_attr->inferred_index;
307     if (inferred_index != -1) {
308       inferred_attrs->types[inferred_index] = dtype;
309     }
310   }
311 
312   return true;
313 }
314 
315 // Infers a consistent list of dtypes for the specified collection of
316 // heterogeneous-dtype input parameters, and converts those parameters to lists
317 // of Tensors with those dtypes. Modifies `params` in-place, and updates
318 // `inferred_attrs` with the inferred dtypes (if it's not null).  Returns true
319 // on success, or sets an exception and returns false on failure.
320 ABSL_MUST_USE_RESULT
ConvertInputsWithTypeListAttr(const InputsWithTypeListAttr & input,const PythonTensorConverter & tensor_converter,const PythonAPIInfo & api_info,absl::Span<PyObject * > params,InferredAttributes * inferred_attrs)321 bool ConvertInputsWithTypeListAttr(
322     const InputsWithTypeListAttr& input,
323     const PythonTensorConverter& tensor_converter,
324     const PythonAPIInfo& api_info, absl::Span<PyObject*> params,
325     InferredAttributes* inferred_attrs) {
326   DCHECK(!input.tensor_list_params.empty());
327 
328   // Get the number of tensors from the first input list; and check that the
329   // remaining lists have the same size.
330   DCHECK(PyList_CheckExact(params[input.tensor_list_params[0]]));
331   Py_ssize_t num_tensors = PyList_GET_SIZE(params[input.tensor_list_params[0]]);
332   for (int i = 1; i < input.tensor_list_params.size(); ++i) {
333     DCHECK(PyList_CheckExact(params[input.tensor_list_params[i]]));
334     if (num_tensors != PyList_GET_SIZE(params[input.tensor_list_params[i]])) {
335       RaiseTypeError(api_info.api_name(), " expected parameters ",
336                      api_info.param_names()[0], " and ",
337                      api_info.param_names()[i],
338                      " to be lists of the same length.");
339       return false;
340     }
341   }
342 
343   // Get the list of dtypes.
344   std::vector<DataType> dtypes(num_tensors, DT_INVALID);
345   if (input.type_list_attr->index != -1) {
346     // Dtypes are specified by an explicit attribute.
347     PyObject* py_dtypes = params[input.type_list_attr->index];
348     if (PyList_GET_SIZE(py_dtypes) != num_tensors) {
349       RaiseTypeError(api_info.api_name(), " expected parameters ",
350                      api_info.param_names()[0], " and ",
351                      api_info.param_names()[input.type_list_attr->index],
352                      "to be lists of the same length.");
353       return false;
354     }
355     for (Py_ssize_t i = 0; i < PyList_GET_SIZE(py_dtypes); ++i) {
356       dtypes[i] = DataTypeFromPyDType(PyList_GetItem(py_dtypes, i));
357     }
358   } else {
359     // Dtypes are implicit: infer them.
360     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
361       for (ParamIndex index : input.tensor_list_params) {
362         PyObject* tensor_list = params[index];
363         DCHECK(PyList_CheckExact(tensor_list));
364         PyObject* item = PyList_GET_ITEM(tensor_list, i);
365         RETURN_IF_FALSE(InferDType(item, dtypes[i]));
366         if (dtypes[i] != DT_INVALID) break;
367       }
368     }
369   }
370 
371   // Convert tensors.
372   for (ParamIndex index : input.tensor_list_params) {
373     PyObject* tensor_list = params[index];
374     PyObject** items = PyList_ITEMS(tensor_list);
375     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
376       DataType default_dtype = i < input.default_dtypes.size()
377                                    ? input.default_dtypes[i]
378                                    : DT_INVALID;
379       RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtypes[i],
380                                              tensor_converter, api_info, index,
381                                              &input.ok_dtypes, default_dtype));
382     }
383   }
384 
385   if (inferred_attrs) {
386     int inferred_index = input.type_list_attr->inferred_index;
387     if (inferred_index != -1) {
388       inferred_attrs->type_lists[inferred_index].swap(dtypes);
389     }
390   }
391 
392   return true;
393 }
394 
395 // Infers length attributes for Tensor-list parameters from their values, and
396 // updates `inferred_length_attrs` with the inferred length.  Sets an exception
397 // if multiple Tensor-list parameters have the same length attribute but
398 // different lengths. Returns true on success, or sets an exception and returns
399 // false on failure.
400 ABSL_MUST_USE_RESULT
InferLengthAttributes(const absl::Span<PyObject * > params,const PythonAPIInfo & api_info,std::vector<int64_t> & inferred_length_attrs)401 bool InferLengthAttributes(const absl::Span<PyObject*> params,
402                            const PythonAPIInfo& api_info,
403                            std::vector<int64_t>& inferred_length_attrs) {
404   for (int i = 0; i < api_info.inputs_with_number_attrs().size(); ++i) {
405     const auto& inputs = api_info.inputs_with_number_attrs()[i];
406     DCHECK(!inputs.tensor_list_params.empty());
407 
408     // Use the first tensor_list parameter to infer the length attribute.
409     PyObject* tensors = params[inputs.tensor_list_params[0]];
410     DCHECK(PyList_CheckExact(tensors));
411     int inferred_length = PyList_GET_SIZE(tensors);
412 
413     // Check that any other tensor_list parameters have matching length.
414     for (int j = 1; j < inputs.tensor_list_params.size(); ++j) {
415       int num_tensors = PyList_GET_SIZE(params[inputs.tensor_list_params[j]]);
416       if (num_tensors != inferred_length) {
417         RaiseTypeError(api_info.api_name(), " expected parameters ",
418                        api_info.param_names()[inputs.tensor_list_params[0]],
419                        " and ",
420                        api_info.param_names()[inputs.tensor_list_params[j]],
421                        " to be lists with the same length.");
422       }
423     }
424 
425     int inferred_index = inputs.number_attr->inferred_index;
426     if (inferred_index != -1) {
427       inferred_length_attrs[inferred_index] = inferred_length;
428     }
429   }
430   return true;
431 }
432 
433 }  // namespace
434 
ConvertPythonAPIParameters(const PythonAPIInfo & api_info,const PythonTensorConverter & tensor_converter,absl::Span<PyObject * > params,InferredAttributes * inferred_attrs)435 bool ConvertPythonAPIParameters(const PythonAPIInfo& api_info,
436                                 const PythonTensorConverter& tensor_converter,
437                                 absl::Span<PyObject*> params,
438                                 InferredAttributes* inferred_attrs) {
439   // Make room for inferred attributes.
440   if (inferred_attrs) {
441     inferred_attrs->types.resize(api_info.inferred_type_attrs().size());
442     inferred_attrs->type_lists.resize(
443         api_info.inferred_type_list_attrs().size());
444     inferred_attrs->lengths.resize(api_info.inferred_length_attrs().size());
445   }
446 
447   for (const auto& attr : api_info.attributes()) {
448     RETURN_IF_FALSE(ConvertAttribute(attr, api_info, params));
449   }
450 
451   for (const auto& input : api_info.inputs_with_fixed_dtype()) {
452     RETURN_IF_FALSE(
453         ConvertInputWithFixedDType(input, tensor_converter, api_info, params));
454   }
455 
456   for (int i = 0; i < api_info.inputs_with_type_attrs().size(); ++i) {
457     RETURN_IF_FALSE(ConvertInputsWithTypeAttr(
458         api_info.inputs_with_type_attrs()[i], tensor_converter, api_info,
459         params, inferred_attrs));
460   }
461 
462   for (int i = 0; i < api_info.inputs_with_type_list_attrs().size(); ++i) {
463     RETURN_IF_FALSE(ConvertInputsWithTypeListAttr(
464         api_info.inputs_with_type_list_attrs()[i], tensor_converter, api_info,
465         params, inferred_attrs));
466   }
467 
468   if (inferred_attrs) {
469     RETURN_IF_FALSE(
470         InferLengthAttributes(params, api_info, inferred_attrs->lengths));
471   }
472 
473   return true;
474 }
475 
CopyPythonAPITensorLists(const PythonAPIInfo & api_info,absl::Span<PyObject * > params)476 bool CopyPythonAPITensorLists(const PythonAPIInfo& api_info,
477                               absl::Span<PyObject*> params) {
478   for (const auto& input : api_info.inputs()) {
479     if (input.is_list) {
480       PyObject* src = params[input.index];
481       PyObject* copy = PySequence_List(src);
482       if (!copy) {
483         RaiseTypeError(api_info.api_name(), " expected a list of Tensors for '",
484                        api_info.param_names()[input.index], "'; got ",
485                        src->ob_type->tp_name, ".");
486         return false;
487       }
488       Py_DECREF(params[input.index]);
489       params[input.index] = copy;
490     }
491   }
492   return true;
493 }
494 
495 }  // namespace tensorflow
496