xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
16 #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 // Place `<locale>` before <Python.h> to avoid build failures in macOS.
24 #include <locale>
25 
26 // The empty line above is on purpose as otherwise clang-format will
27 // automatically move <Python.h> before <locale>.
28 #include <Python.h>
29 
30 #include "tensorflow/lite/interpreter.h"
31 
32 struct TfLiteDelegate;
33 
34 // We forward declare TFLite classes here to avoid exposing them to SWIG.
35 namespace tflite {
36 class MutableOpResolver;
37 class FlatBufferModel;
38 
39 namespace interpreter_wrapper {
40 
41 class PythonErrorReporter;
42 
43 class InterpreterWrapper {
44  public:
45   using Model = FlatBufferModel;
46 
47   // SWIG caller takes ownership of pointer.
48   static InterpreterWrapper* CreateWrapperCPPFromFile(
49       const char* model_path, int op_resolver_id,
50       const std::vector<std::string>& registerers, std::string* error_msg,
51       bool preserve_all_tensors);
52   static InterpreterWrapper* CreateWrapperCPPFromFile(
53       const char* model_path, int op_resolver_id,
54       const std::vector<std::string>& registerers_by_name,
55       const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
56       std::string* error_msg, bool preserve_all_tensors);
57 
58   // SWIG caller takes ownership of pointer.
59   static InterpreterWrapper* CreateWrapperCPPFromBuffer(
60       PyObject* data, int op_resolver_id,
61       const std::vector<std::string>& registerers, std::string* error_msg,
62       bool preserve_all_tensors);
63   static InterpreterWrapper* CreateWrapperCPPFromBuffer(
64       PyObject* data, int op_resolver_id,
65       const std::vector<std::string>& registerers_by_name,
66       const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
67       std::string* error_msg, bool preserve_all_tensors);
68 
69   ~InterpreterWrapper();
70   PyObject* AllocateTensors(int subgraph_index);
71   PyObject* Invoke(int subgraph_index);
72 
73   PyObject* InputIndices() const;
74   PyObject* OutputIndices() const;
75   PyObject* ResizeInputTensor(int i, PyObject* value, bool strict,
76                               int subgraph_index);
77 
78   int NumTensors() const;
79   std::string TensorName(int i) const;
80   PyObject* TensorType(int i) const;
81   PyObject* TensorSize(int i) const;
82   PyObject* TensorSizeSignature(int i) const;
83   PyObject* TensorSparsityParameters(int i) const;
84   // Deprecated in favor of TensorQuantizationScales, below.
85   PyObject* TensorQuantization(int i) const;
86   PyObject* TensorQuantizationParameters(int i) const;
87   PyObject* SetTensor(int i, PyObject* value, int subgraph_index);
88   PyObject* GetTensor(int i, int subgraph_index) const;
89   PyObject* GetSubgraphIndexFromSignature(const char* signature_key);
90   PyObject* GetSignatureDefs() const;
91   PyObject* ResetVariableTensors();
92 
93   int NumNodes() const;
94   std::string NodeName(int i) const;
95   PyObject* NodeInputs(int i) const;
96   PyObject* NodeOutputs(int i) const;
97 
98   // Returns a reference to tensor index as a numpy array from subgraph. The
99   // base_object should be the interpreter object providing the memory.
100   PyObject* tensor(PyObject* base_object, int tensor_index, int subgraph_index);
101 
102   PyObject* SetNumThreads(int num_threads);
103 
104   // Adds a delegate to the interpreter.
105   PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
106 
107   // Experimental and subject to change.
108   //
109   // Returns a pointer to the underlying interpreter.
interpreter()110   Interpreter* interpreter() { return interpreter_.get(); }
111 
112  private:
113   // Helper function to construct an `InterpreterWrapper` object.
114   // It only returns InterpreterWrapper if it can construct an `Interpreter`.
115   // Otherwise it returns `nullptr`.
116   static InterpreterWrapper* CreateInterpreterWrapper(
117       std::unique_ptr<Model> model, int op_resolver_id,
118       std::unique_ptr<PythonErrorReporter> error_reporter,
119       const std::vector<std::string>& registerers_by_name,
120       const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
121       std::string* error_msg, bool preserve_all_tensors);
122 
123   InterpreterWrapper(std::unique_ptr<Model> model,
124                      std::unique_ptr<PythonErrorReporter> error_reporter,
125                      std::unique_ptr<tflite::MutableOpResolver> resolver,
126                      std::unique_ptr<Interpreter> interpreter);
127 
128   // InterpreterWrapper is not copyable or assignable. We avoid the use of
129   // InterpreterWrapper() = delete here for SWIG compatibility.
130   InterpreterWrapper();
131   InterpreterWrapper(const InterpreterWrapper& rhs);
132 
133   // Helper function to resize an input tensor.
134   PyObject* ResizeInputTensorImpl(int i, PyObject* value);
135 
136   // The public functions which creates `InterpreterWrapper` should ensure all
137   // these member variables are initialized successfully. Otherwise it should
138   // report the error and return `nullptr`.
139   const std::unique_ptr<Model> model_;
140   const std::unique_ptr<PythonErrorReporter> error_reporter_;
141   const std::unique_ptr<tflite::MutableOpResolver> resolver_;
142   const std::unique_ptr<Interpreter> interpreter_;
143 };
144 
145 }  // namespace interpreter_wrapper
146 }  // namespace tflite
147 
148 #endif  // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
149