1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <string>
18 
19 #include "pybind11/functional.h"
20 #include "pybind11/pybind11.h"
21 #include "pybind11/pytypes.h"
22 #include "pybind11/stl.h"
23 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
24 #include "tensorflow/python/lib/core/pybind11_lib.h"
25 
26 namespace py = pybind11;
27 using tflite::interpreter_wrapper::InterpreterWrapper;
28 
PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper,m)29 PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
30   m.doc() = R"pbdoc(
31     _pywrap_tensorflow_interpreter_wrapper
32     -----
33   )pbdoc";
34 
35   // pybind11 suggests to convert factory functions into constructors, but
36   // when bytes are provided the wrapper will be confused which
37   // constructor to call.
38   m.def("CreateWrapperFromFile",
39         [](const std::string& model_path, int op_resolver_id,
40            const std::vector<std::string>& registerers,
41            bool preserve_all_tensors) {
42           std::string error;
43           auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
44               model_path.c_str(), op_resolver_id, registerers, &error,
45               preserve_all_tensors);
46           if (!wrapper) {
47             throw std::invalid_argument(error);
48           }
49           return wrapper;
50         });
51   m.def(
52       "CreateWrapperFromFile",
53       [](const std::string& model_path, int op_resolver_id,
54          const std::vector<std::string>& registerers_by_name,
55          const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
56          bool preserve_all_tensors) {
57         std::string error;
58         auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
59             model_path.c_str(), op_resolver_id, registerers_by_name,
60             registerers_by_func, &error, preserve_all_tensors);
61         if (!wrapper) {
62           throw std::invalid_argument(error);
63         }
64         return wrapper;
65       });
66   m.def("CreateWrapperFromBuffer",
67         [](const py::bytes& data, int op_resolver_id,
68            const std::vector<std::string>& registerers,
69            bool preserve_all_tensors) {
70           std::string error;
71           auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
72               data.ptr(), op_resolver_id, registerers, &error,
73               preserve_all_tensors);
74           if (!wrapper) {
75             throw std::invalid_argument(error);
76           }
77           return wrapper;
78         });
79   m.def(
80       "CreateWrapperFromBuffer",
81       [](const py::bytes& data, int op_resolver_id,
82          const std::vector<std::string>& registerers_by_name,
83          const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
84          bool preserve_all_tensors) {
85         std::string error;
86         auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
87             data.ptr(), op_resolver_id, registerers_by_name,
88             registerers_by_func, &error, preserve_all_tensors);
89         if (!wrapper) {
90           throw std::invalid_argument(error);
91         }
92         return wrapper;
93       });
94   py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
95       .def(
96           "AllocateTensors",
97           [](InterpreterWrapper& self, int subgraph_index) {
98             return tensorflow::PyoOrThrow(self.AllocateTensors(subgraph_index));
99           },
100           // LINT.IfChange
101           py::arg("subgraph_index") = -1)
102           // LINT.ThenChange(//tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc)
103       .def(
104           "Invoke",
105           [](InterpreterWrapper& self, int subgraph_index) {
106             return tensorflow::PyoOrThrow(self.Invoke(subgraph_index));
107           },
108           py::arg("subgraph_index") = 0)
109       .def("InputIndices",
110            [](const InterpreterWrapper& self) {
111              return tensorflow::PyoOrThrow(self.InputIndices());
112            })
113       .def("OutputIndices",
114            [](InterpreterWrapper& self) {
115              return tensorflow::PyoOrThrow(self.OutputIndices());
116            })
117       .def(
118           "ResizeInputTensor",
119           [](InterpreterWrapper& self, int i, py::handle& value, bool strict,
120              int subgraph_index) {
121             return tensorflow::PyoOrThrow(
122                 self.ResizeInputTensor(i, value.ptr(), strict, subgraph_index));
123           },
124           py::arg("i"), py::arg("value"), py::arg("strict"),
125           py::arg("subgraph_index") = 0)
126       .def("NumTensors", &InterpreterWrapper::NumTensors)
127       .def("TensorName", &InterpreterWrapper::TensorName)
128       .def("TensorType",
129            [](const InterpreterWrapper& self, int i) {
130              return tensorflow::PyoOrThrow(self.TensorType(i));
131            })
132       .def("TensorSize",
133            [](const InterpreterWrapper& self, int i) {
134              return tensorflow::PyoOrThrow(self.TensorSize(i));
135            })
136       .def("TensorSizeSignature",
137            [](const InterpreterWrapper& self, int i) {
138              return tensorflow::PyoOrThrow(self.TensorSizeSignature(i));
139            })
140       .def("TensorSparsityParameters",
141            [](const InterpreterWrapper& self, int i) {
142              return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i));
143            })
144       .def(
145           "TensorQuantization",
146           [](const InterpreterWrapper& self, int i) {
147             return tensorflow::PyoOrThrow(self.TensorQuantization(i));
148           },
149           R"pbdoc(
150             Deprecated in favor of TensorQuantizationParameters.
151           )pbdoc")
152       .def(
153           "TensorQuantizationParameters",
154           [](InterpreterWrapper& self, int i) {
155             return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i));
156           })
157       .def(
158           "SetTensor",
159           [](InterpreterWrapper& self, int i, py::handle& value,
160              int subgraph_index) {
161             return tensorflow::PyoOrThrow(
162                 self.SetTensor(i, value.ptr(), subgraph_index));
163           },
164           py::arg("i"), py::arg("value"), py::arg("subgraph_index") = 0)
165       .def(
166           "GetTensor",
167           [](const InterpreterWrapper& self, int tensor_index,
168              int subgraph_index) {
169             return tensorflow::PyoOrThrow(
170                 self.GetTensor(tensor_index, subgraph_index));
171           },
172           py::arg("tensor_index"), py::arg("subgraph_index") = 0)
173       .def("GetSubgraphIndexFromSignature",
174            [](InterpreterWrapper& self, const char* signature_key) {
175              return tensorflow::PyoOrThrow(
176                  self.GetSubgraphIndexFromSignature(signature_key));
177            })
178       .def("GetSignatureDefs",
179            [](InterpreterWrapper& self) {
180              return tensorflow::PyoOrThrow(self.GetSignatureDefs());
181            })
182       .def("ResetVariableTensors",
183            [](InterpreterWrapper& self) {
184              return tensorflow::PyoOrThrow(self.ResetVariableTensors());
185            })
186       .def("NumNodes", &InterpreterWrapper::NumNodes)
187       .def("NodeName", &InterpreterWrapper::NodeName)
188       .def("NodeInputs",
189            [](const InterpreterWrapper& self, int i) {
190              return tensorflow::PyoOrThrow(self.NodeInputs(i));
191            })
192       .def("NodeOutputs",
193            [](const InterpreterWrapper& self, int i) {
194              return tensorflow::PyoOrThrow(self.NodeOutputs(i));
195            })
196       .def(
197           "tensor",
198           [](InterpreterWrapper& self, py::handle& base_object,
199              int tensor_index, int subgraph_index) {
200             return tensorflow::PyoOrThrow(
201                 self.tensor(base_object.ptr(), tensor_index, subgraph_index));
202           },
203           R"pbdoc(
204             Returns a reference to tensor index as a numpy array from subgraph.
205             The base_object should be the interpreter object providing the
206             memory.
207           )pbdoc",
208           py::arg("base_object"), py::arg("tensor_index"),
209           py::arg("subgraph_index") = 0)
210       .def(
211           "ModifyGraphWithDelegate",
212           // Address of the delegate is passed as an argument.
213           [](InterpreterWrapper& self, uintptr_t delegate_ptr) {
214             return tensorflow::PyoOrThrow(self.ModifyGraphWithDelegate(
215                 reinterpret_cast<TfLiteDelegate*>(delegate_ptr)));
216           },
217           R"pbdoc(
218             Adds a delegate to the interpreter.
219           )pbdoc")
220       .def(
221           "SetNumThreads",
222           [](InterpreterWrapper& self, int num_threads) {
223             return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
224           },
225           R"pbdoc(
226              ask the interpreter to set the number of threads to use.
227           )pbdoc")
228       .def("interpreter", [](InterpreterWrapper& self) {
229         return reinterpret_cast<intptr_t>(self.interpreter());
230       });
231 }
232