xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_tensor_converter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
16 #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
17 
18 #include <Python.h>
19 
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23 
24 namespace tensorflow {
25 
26 // Converts PyObject* values to Tensors.
27 //
28 // This converter attempts to convert values as efficiently as possible; but
29 // it has fallback paths to handle any PyObject* value for which tensor
30 // conversion is defined.
31 class PythonTensorConverter {
32  public:
33   // Constructs a new PythonTensorConverter.
34   //
35   // Note: the arguments to this constructor may change in the future, as
36   // we move more of python tensor conversion from the Python layer to the
37   // c++ layer.
38   //
39   // Args:
40   //   py_eager_context: the value of context.context() from eager/context.py.
41   //   ctx: The c++ eager context, or nullptr in graph mode.
42   //   device_name: The current device name.
43   //
44   // All three argument values must remain alive until `this` is deleted.
PythonTensorConverter(PyObject * py_eager_context,TFE_Context * ctx,const char * device_name)45   PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx,
46                         const char* device_name)
47       : py_eager_context_(py_eager_context),
48         ctx_(ctx),
49         device_name_(device_name) {}
50 
51   // Converts `src` to a tensor (if it's not already one), and returns a new
52   // reference to the converted value.
53   //
54   // Args:
55   //   src: The object that should be converted to a Tensor.
56   //   dtype: The requested dtype.  Use `DT_INVALID` if the dtype should be
57   //     inferred from the `src` value (in which case `dtype` will be updated
58   //     in-place to be the actual dtype of the converted value).
59   //   used_fallback: Output parameter used to record whether the conversion
60   //     was done by falling back to the Python `tf.convert_to_tensor()`
61   //     function.  This is for testing/logging purposes only.  May be null.
62   //
63   // If `src` can't be converted to a tensor with the requested dtype, sets a
64   // Python exception and returns nullptr.
65   Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype,
66                            bool* used_fallback = nullptr) const;
67 
68  private:
69   PyObject* py_eager_context_;
70   TFE_Context* ctx_;
71   const char* device_name_;
72 };
73 
74 }  // namespace tensorflow
75 
76 #endif  // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
77