xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_tensor_converter_wrapper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Note: This library is only used by python_tensor_converter_test.  It is
16 // not meant to be used in other circumstances.
17 
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "pybind11/stl.h"
21 #include "tensorflow/python/eager/pywrap_tfe.h"
22 #include "tensorflow/python/framework/python_tensor_converter.h"
23 
24 #if PY_MAJOR_VERSION < 3
25 // Python 2.x:
26 #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
27 #define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
28 #define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
29 #else
30 // Python 3.x:
31 #define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
32 #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
33 #define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
34 #endif
35 
36 namespace py = pybind11;
37 
38 namespace tensorflow {
39 namespace {
40 
GetAttr_ThreadLocalData(PyObject * eager_context)41 Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) {
42   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data");
43   return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
44 }
45 
GetAttr_ContextHandle(PyObject * eager_context)46 Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) {
47   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle");
48   return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
49 }
50 
GetAttr_IsEager(PyObject * tld)51 Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) {
52   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager");
53   return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
54 }
55 
GetAttr_DeviceName(PyObject * tld)56 Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) {
57   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name");
58   return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
59 }
60 
GetAttr_TypeEnum(PyObject * dtype)61 Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
62   static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
63   return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
64 }
65 
MakePythonTensorConverter(py::handle py_eager_context)66 PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) {
67   Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr());
68   if (!tld) throw py::error_already_set();
69 
70   Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get());
71   if (!py_is_eager) throw py::error_already_set();
72   bool is_eager = PyObject_IsTrue(py_is_eager.get());
73 
74   // Initialize the eager context, if necessary.
75   TFE_Context* ctx = nullptr;
76   const char* device_name = nullptr;
77   if (is_eager) {
78     Safe_PyObjectPtr context_handle =
79         GetAttr_ContextHandle(py_eager_context.ptr());
80     if (!context_handle) throw py::error_already_set();
81     if (context_handle.get() == Py_None) {
82       throw std::runtime_error("Error retrieving context handle.");
83     }
84     Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get());
85     if (!py_device_name) {
86       throw std::runtime_error("Error retrieving device name.");
87     }
88     device_name = TFE_GetPythonString(py_device_name.get());
89     ctx = reinterpret_cast<TFE_Context*>(
90         PyCapsule_GetPointer(context_handle.get(), nullptr));
91   }
92 
93   return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name);
94 }
95 
Convert(tensorflow::PythonTensorConverter * self,py::handle obj,py::handle dtype)96 py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj,
97                    py::handle dtype) {
98   DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr()));
99   bool used_fallback = false;
100   Safe_PyObjectPtr converted =
101       self->Convert(obj.ptr(), dtype_enum, &used_fallback);
102   if (!converted) throw py::error_already_set();
103 
104   PyObject* result = PyTuple_New(3);
105   PyTuple_SET_ITEM(result, 0, converted.release());
106   PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum));
107   PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False);
108   Py_INCREF(PyTuple_GET_ITEM(result, 1));
109   Py_INCREF(PyTuple_GET_ITEM(result, 2));
110   return result;
111 }
112 
113 }  // namespace
114 }  // namespace tensorflow
115 
PYBIND11_MODULE(_pywrap_python_tensor_converter,m)116 PYBIND11_MODULE(_pywrap_python_tensor_converter, m) {
117   py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter")
118       .def(py::init(&tensorflow::MakePythonTensorConverter))
119       .def("Convert", tensorflow::Convert);
120 }
121