1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Note: This library is only used by python_tensor_converter_test. It is
16 // not meant to be used in other circumstances.
17
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "pybind11/stl.h"
21 #include "tensorflow/python/eager/pywrap_tfe.h"
22 #include "tensorflow/python/framework/python_tensor_converter.h"
23
24 #if PY_MAJOR_VERSION < 3
25 // Python 2.x:
26 #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
27 #define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
28 #define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
29 #else
30 // Python 3.x:
31 #define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
32 #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
33 #define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
34 #endif
35
36 namespace py = pybind11;
37
38 namespace tensorflow {
39 namespace {
40
GetAttr_ThreadLocalData(PyObject * eager_context)41 Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) {
42 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data");
43 return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
44 }
45
GetAttr_ContextHandle(PyObject * eager_context)46 Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) {
47 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle");
48 return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
49 }
50
GetAttr_IsEager(PyObject * tld)51 Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) {
52 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager");
53 return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
54 }
55
GetAttr_DeviceName(PyObject * tld)56 Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) {
57 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name");
58 return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
59 }
60
GetAttr_TypeEnum(PyObject * dtype)61 Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
62 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
63 return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
64 }
65
MakePythonTensorConverter(py::handle py_eager_context)66 PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) {
67 Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr());
68 if (!tld) throw py::error_already_set();
69
70 Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get());
71 if (!py_is_eager) throw py::error_already_set();
72 bool is_eager = PyObject_IsTrue(py_is_eager.get());
73
74 // Initialize the eager context, if necessary.
75 TFE_Context* ctx = nullptr;
76 const char* device_name = nullptr;
77 if (is_eager) {
78 Safe_PyObjectPtr context_handle =
79 GetAttr_ContextHandle(py_eager_context.ptr());
80 if (!context_handle) throw py::error_already_set();
81 if (context_handle.get() == Py_None) {
82 throw std::runtime_error("Error retrieving context handle.");
83 }
84 Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get());
85 if (!py_device_name) {
86 throw std::runtime_error("Error retrieving device name.");
87 }
88 device_name = TFE_GetPythonString(py_device_name.get());
89 ctx = reinterpret_cast<TFE_Context*>(
90 PyCapsule_GetPointer(context_handle.get(), nullptr));
91 }
92
93 return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name);
94 }
95
Convert(tensorflow::PythonTensorConverter * self,py::handle obj,py::handle dtype)96 py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj,
97 py::handle dtype) {
98 DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr()));
99 bool used_fallback = false;
100 Safe_PyObjectPtr converted =
101 self->Convert(obj.ptr(), dtype_enum, &used_fallback);
102 if (!converted) throw py::error_already_set();
103
104 PyObject* result = PyTuple_New(3);
105 PyTuple_SET_ITEM(result, 0, converted.release());
106 PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum));
107 PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False);
108 Py_INCREF(PyTuple_GET_ITEM(result, 1));
109 Py_INCREF(PyTuple_GET_ITEM(result, 2));
110 return result;
111 }
112
113 } // namespace
114 } // namespace tensorflow
115
PYBIND11_MODULE(_pywrap_python_tensor_converter,m)116 PYBIND11_MODULE(_pywrap_python_tensor_converter, m) {
117 py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter")
118 .def(py::init(&tensorflow::MakePythonTensorConverter))
119 .def("Convert", tensorflow::Convert);
120 }
121