1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <string>
18
19 #include "pybind11/functional.h"
20 #include "pybind11/pybind11.h"
21 #include "pybind11/pytypes.h"
22 #include "pybind11/stl.h"
23 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
24 #include "tensorflow/python/lib/core/pybind11_lib.h"
25
26 namespace py = pybind11;
27 using tflite::interpreter_wrapper::InterpreterWrapper;
28
PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper,m)29 PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
30 m.doc() = R"pbdoc(
31 _pywrap_tensorflow_interpreter_wrapper
32 -----
33 )pbdoc";
34
35 // pybind11 suggests to convert factory functions into constructors, but
36 // when bytes are provided the wrapper will be confused which
37 // constructor to call.
38 m.def("CreateWrapperFromFile",
39 [](const std::string& model_path, int op_resolver_id,
40 const std::vector<std::string>& registerers,
41 bool preserve_all_tensors) {
42 std::string error;
43 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
44 model_path.c_str(), op_resolver_id, registerers, &error,
45 preserve_all_tensors);
46 if (!wrapper) {
47 throw std::invalid_argument(error);
48 }
49 return wrapper;
50 });
51 m.def(
52 "CreateWrapperFromFile",
53 [](const std::string& model_path, int op_resolver_id,
54 const std::vector<std::string>& registerers_by_name,
55 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
56 bool preserve_all_tensors) {
57 std::string error;
58 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
59 model_path.c_str(), op_resolver_id, registerers_by_name,
60 registerers_by_func, &error, preserve_all_tensors);
61 if (!wrapper) {
62 throw std::invalid_argument(error);
63 }
64 return wrapper;
65 });
66 m.def("CreateWrapperFromBuffer",
67 [](const py::bytes& data, int op_resolver_id,
68 const std::vector<std::string>& registerers,
69 bool preserve_all_tensors) {
70 std::string error;
71 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
72 data.ptr(), op_resolver_id, registerers, &error,
73 preserve_all_tensors);
74 if (!wrapper) {
75 throw std::invalid_argument(error);
76 }
77 return wrapper;
78 });
79 m.def(
80 "CreateWrapperFromBuffer",
81 [](const py::bytes& data, int op_resolver_id,
82 const std::vector<std::string>& registerers_by_name,
83 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
84 bool preserve_all_tensors) {
85 std::string error;
86 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
87 data.ptr(), op_resolver_id, registerers_by_name,
88 registerers_by_func, &error, preserve_all_tensors);
89 if (!wrapper) {
90 throw std::invalid_argument(error);
91 }
92 return wrapper;
93 });
94 py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
95 .def(
96 "AllocateTensors",
97 [](InterpreterWrapper& self, int subgraph_index) {
98 return tensorflow::PyoOrThrow(self.AllocateTensors(subgraph_index));
99 },
100 // LINT.IfChange
101 py::arg("subgraph_index") = -1)
102 // LINT.ThenChange(//tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc)
103 .def(
104 "Invoke",
105 [](InterpreterWrapper& self, int subgraph_index) {
106 return tensorflow::PyoOrThrow(self.Invoke(subgraph_index));
107 },
108 py::arg("subgraph_index") = 0)
109 .def("InputIndices",
110 [](const InterpreterWrapper& self) {
111 return tensorflow::PyoOrThrow(self.InputIndices());
112 })
113 .def("OutputIndices",
114 [](InterpreterWrapper& self) {
115 return tensorflow::PyoOrThrow(self.OutputIndices());
116 })
117 .def(
118 "ResizeInputTensor",
119 [](InterpreterWrapper& self, int i, py::handle& value, bool strict,
120 int subgraph_index) {
121 return tensorflow::PyoOrThrow(
122 self.ResizeInputTensor(i, value.ptr(), strict, subgraph_index));
123 },
124 py::arg("i"), py::arg("value"), py::arg("strict"),
125 py::arg("subgraph_index") = 0)
126 .def("NumTensors", &InterpreterWrapper::NumTensors)
127 .def("TensorName", &InterpreterWrapper::TensorName)
128 .def("TensorType",
129 [](const InterpreterWrapper& self, int i) {
130 return tensorflow::PyoOrThrow(self.TensorType(i));
131 })
132 .def("TensorSize",
133 [](const InterpreterWrapper& self, int i) {
134 return tensorflow::PyoOrThrow(self.TensorSize(i));
135 })
136 .def("TensorSizeSignature",
137 [](const InterpreterWrapper& self, int i) {
138 return tensorflow::PyoOrThrow(self.TensorSizeSignature(i));
139 })
140 .def("TensorSparsityParameters",
141 [](const InterpreterWrapper& self, int i) {
142 return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i));
143 })
144 .def(
145 "TensorQuantization",
146 [](const InterpreterWrapper& self, int i) {
147 return tensorflow::PyoOrThrow(self.TensorQuantization(i));
148 },
149 R"pbdoc(
150 Deprecated in favor of TensorQuantizationParameters.
151 )pbdoc")
152 .def(
153 "TensorQuantizationParameters",
154 [](InterpreterWrapper& self, int i) {
155 return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i));
156 })
157 .def(
158 "SetTensor",
159 [](InterpreterWrapper& self, int i, py::handle& value,
160 int subgraph_index) {
161 return tensorflow::PyoOrThrow(
162 self.SetTensor(i, value.ptr(), subgraph_index));
163 },
164 py::arg("i"), py::arg("value"), py::arg("subgraph_index") = 0)
165 .def(
166 "GetTensor",
167 [](const InterpreterWrapper& self, int tensor_index,
168 int subgraph_index) {
169 return tensorflow::PyoOrThrow(
170 self.GetTensor(tensor_index, subgraph_index));
171 },
172 py::arg("tensor_index"), py::arg("subgraph_index") = 0)
173 .def("GetSubgraphIndexFromSignature",
174 [](InterpreterWrapper& self, const char* signature_key) {
175 return tensorflow::PyoOrThrow(
176 self.GetSubgraphIndexFromSignature(signature_key));
177 })
178 .def("GetSignatureDefs",
179 [](InterpreterWrapper& self) {
180 return tensorflow::PyoOrThrow(self.GetSignatureDefs());
181 })
182 .def("ResetVariableTensors",
183 [](InterpreterWrapper& self) {
184 return tensorflow::PyoOrThrow(self.ResetVariableTensors());
185 })
186 .def("NumNodes", &InterpreterWrapper::NumNodes)
187 .def("NodeName", &InterpreterWrapper::NodeName)
188 .def("NodeInputs",
189 [](const InterpreterWrapper& self, int i) {
190 return tensorflow::PyoOrThrow(self.NodeInputs(i));
191 })
192 .def("NodeOutputs",
193 [](const InterpreterWrapper& self, int i) {
194 return tensorflow::PyoOrThrow(self.NodeOutputs(i));
195 })
196 .def(
197 "tensor",
198 [](InterpreterWrapper& self, py::handle& base_object,
199 int tensor_index, int subgraph_index) {
200 return tensorflow::PyoOrThrow(
201 self.tensor(base_object.ptr(), tensor_index, subgraph_index));
202 },
203 R"pbdoc(
204 Returns a reference to tensor index as a numpy array from subgraph.
205 The base_object should be the interpreter object providing the
206 memory.
207 )pbdoc",
208 py::arg("base_object"), py::arg("tensor_index"),
209 py::arg("subgraph_index") = 0)
210 .def(
211 "ModifyGraphWithDelegate",
212 // Address of the delegate is passed as an argument.
213 [](InterpreterWrapper& self, uintptr_t delegate_ptr) {
214 return tensorflow::PyoOrThrow(self.ModifyGraphWithDelegate(
215 reinterpret_cast<TfLiteDelegate*>(delegate_ptr)));
216 },
217 R"pbdoc(
218 Adds a delegate to the interpreter.
219 )pbdoc")
220 .def(
221 "SetNumThreads",
222 [](InterpreterWrapper& self, int num_threads) {
223 return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
224 },
225 R"pbdoc(
226 ask the interpreter to set the number of threads to use.
227 )pbdoc")
228 .def("interpreter", [](InterpreterWrapper& self) {
229 return reinterpret_cast<intptr_t>(self.interpreter());
230 });
231 }
232