1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ 16 #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 // Place `<locale>` before <Python.h> to avoid build failures in macOS. 24 #include <locale> 25 26 // The empty line above is on purpose as otherwise clang-format will 27 // automatically move <Python.h> before <locale>. 28 #include <Python.h> 29 30 #include "tensorflow/lite/interpreter.h" 31 32 struct TfLiteDelegate; 33 34 // We forward declare TFLite classes here to avoid exposing them to SWIG. 35 namespace tflite { 36 class MutableOpResolver; 37 class FlatBufferModel; 38 39 namespace interpreter_wrapper { 40 41 class PythonErrorReporter; 42 43 class InterpreterWrapper { 44 public: 45 using Model = FlatBufferModel; 46 47 // SWIG caller takes ownership of pointer. 48 static InterpreterWrapper* CreateWrapperCPPFromFile( 49 const char* model_path, int op_resolver_id, 50 const std::vector<std::string>& registerers, std::string* error_msg, 51 bool preserve_all_tensors); 52 static InterpreterWrapper* CreateWrapperCPPFromFile( 53 const char* model_path, int op_resolver_id, 54 const std::vector<std::string>& registerers_by_name, 55 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, 56 std::string* error_msg, bool preserve_all_tensors); 57 58 // SWIG caller takes ownership of pointer. 59 static InterpreterWrapper* CreateWrapperCPPFromBuffer( 60 PyObject* data, int op_resolver_id, 61 const std::vector<std::string>& registerers, std::string* error_msg, 62 bool preserve_all_tensors); 63 static InterpreterWrapper* CreateWrapperCPPFromBuffer( 64 PyObject* data, int op_resolver_id, 65 const std::vector<std::string>& registerers_by_name, 66 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, 67 std::string* error_msg, bool preserve_all_tensors); 68 69 ~InterpreterWrapper(); 70 PyObject* AllocateTensors(int subgraph_index); 71 PyObject* Invoke(int subgraph_index); 72 73 PyObject* InputIndices() const; 74 PyObject* OutputIndices() const; 75 PyObject* ResizeInputTensor(int i, PyObject* value, bool strict, 76 int subgraph_index); 77 78 int NumTensors() const; 79 std::string TensorName(int i) const; 80 PyObject* TensorType(int i) const; 81 PyObject* TensorSize(int i) const; 82 PyObject* TensorSizeSignature(int i) const; 83 PyObject* TensorSparsityParameters(int i) const; 84 // Deprecated in favor of TensorQuantizationScales, below. 85 PyObject* TensorQuantization(int i) const; 86 PyObject* TensorQuantizationParameters(int i) const; 87 PyObject* SetTensor(int i, PyObject* value, int subgraph_index); 88 PyObject* GetTensor(int i, int subgraph_index) const; 89 PyObject* GetSubgraphIndexFromSignature(const char* signature_key); 90 PyObject* GetSignatureDefs() const; 91 PyObject* ResetVariableTensors(); 92 93 int NumNodes() const; 94 std::string NodeName(int i) const; 95 PyObject* NodeInputs(int i) const; 96 PyObject* NodeOutputs(int i) const; 97 98 // Returns a reference to tensor index as a numpy array from subgraph. The 99 // base_object should be the interpreter object providing the memory. 100 PyObject* tensor(PyObject* base_object, int tensor_index, int subgraph_index); 101 102 PyObject* SetNumThreads(int num_threads); 103 104 // Adds a delegate to the interpreter. 105 PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); 106 107 // Experimental and subject to change. 108 // 109 // Returns a pointer to the underlying interpreter. interpreter()110 Interpreter* interpreter() { return interpreter_.get(); } 111 112 private: 113 // Helper function to construct an `InterpreterWrapper` object. 114 // It only returns InterpreterWrapper if it can construct an `Interpreter`. 115 // Otherwise it returns `nullptr`. 116 static InterpreterWrapper* CreateInterpreterWrapper( 117 std::unique_ptr<Model> model, int op_resolver_id, 118 std::unique_ptr<PythonErrorReporter> error_reporter, 119 const std::vector<std::string>& registerers_by_name, 120 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, 121 std::string* error_msg, bool preserve_all_tensors); 122 123 InterpreterWrapper(std::unique_ptr<Model> model, 124 std::unique_ptr<PythonErrorReporter> error_reporter, 125 std::unique_ptr<tflite::MutableOpResolver> resolver, 126 std::unique_ptr<Interpreter> interpreter); 127 128 // InterpreterWrapper is not copyable or assignable. We avoid the use of 129 // InterpreterWrapper() = delete here for SWIG compatibility. 130 InterpreterWrapper(); 131 InterpreterWrapper(const InterpreterWrapper& rhs); 132 133 // Helper function to resize an input tensor. 134 PyObject* ResizeInputTensorImpl(int i, PyObject* value); 135 136 // The public functions which creates `InterpreterWrapper` should ensure all 137 // these member variables are initialized successfully. Otherwise it should 138 // report the error and return `nullptr`. 139 const std::unique_ptr<Model> model_; 140 const std::unique_ptr<PythonErrorReporter> error_reporter_; 141 const std::unique_ptr<tflite::MutableOpResolver> resolver_; 142 const std::unique_ptr<Interpreter> interpreter_; 143 }; 144 145 } // namespace interpreter_wrapper 146 } // namespace tflite 147 148 #endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ 149