1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ 16 #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ 17 18 #include <Python.h> 19 20 #include "tensorflow/c/eager/c_api.h" 21 #include "tensorflow/core/framework/types.pb.h" 22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" 23 24 namespace tensorflow { 25 26 // Converts PyObject* values to Tensors. 27 // 28 // This converter attempts to convert values as efficiently as possible; but 29 // it has fallback paths to handle any PyObject* value for which tensor 30 // conversion is defined. 31 class PythonTensorConverter { 32 public: 33 // Constructs a new PythonTensorConverter. 34 // 35 // Note: the arguments to this constructor may change in the future, as 36 // we move more of python tensor conversion from the Python layer to the 37 // c++ layer. 38 // 39 // Args: 40 // py_eager_context: the value of context.context() from eager/context.py. 41 // ctx: The c++ eager context, or nullptr in graph mode. 42 // device_name: The current device name. 43 // 44 // All three argument values must remain alive until `this` is deleted. PythonTensorConverter(PyObject * py_eager_context,TFE_Context * ctx,const char * device_name)45 PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx, 46 const char* device_name) 47 : py_eager_context_(py_eager_context), 48 ctx_(ctx), 49 device_name_(device_name) {} 50 51 // Converts `src` to a tensor (if it's not already one), and returns a new 52 // reference to the converted value. 53 // 54 // Args: 55 // src: The object that should be converted to a Tensor. 56 // dtype: The requested dtype. Use `DT_INVALID` if the dtype should be 57 // inferred from the `src` value (in which case `dtype` will be updated 58 // in-place to be the actual dtype of the converted value). 59 // used_fallback: Output parameter used to record whether the conversion 60 // was done by falling back to the Python `tf.convert_to_tensor()` 61 // function. This is for testing/logging purposes only. May be null. 62 // 63 // If `src` can't be converted to a tensor with the requested dtype, sets a 64 // Python exception and returns nullptr. 65 Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype, 66 bool* used_fallback = nullptr) const; 67 68 private: 69 PyObject* py_eager_context_; 70 TFE_Context* ctx_; 71 const char* device_name_; 72 }; 73 74 } // namespace tensorflow 75 76 #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ 77