1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Library to write a flatbuffer of a currently loaded TFLite model/subgraph. 16 17 #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 18 #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 19 #include <iostream> 20 #include <string> 21 #include <unordered_map> 22 23 #include "tensorflow/lite/builtin_op_data.h" 24 #include "tensorflow/lite/c/common.h" 25 #include "tensorflow/lite/context_util.h" 26 #include "tensorflow/lite/core/subgraph.h" 27 #include "tensorflow/lite/interpreter.h" 28 #include "tensorflow/lite/schema/reflection/schema_generated.h" 29 #include "tensorflow/lite/tools/serialization/enum_mapping.h" 30 #include "tensorflow/lite/version.h" 31 32 namespace tflite { 33 34 struct OpCode { 35 int builtin; 36 std::string custom; 37 }; 38 39 // Forward declaration. 40 class SubgraphWriter; 41 42 // Handles writing a full TFLite model (with 1 or more subgraphs) to a 43 // serialized TF lite file format. 44 // TODO(b/174708523): Support custom I/O or unused tensors later. 45 class ModelWriter { 46 public: 47 // CustomWriter allows the delegate to customize the write to the flatbuffer. 48 typedef flatbuffers::Offset<Operator> (*CustomWriter)( 49 flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index, 50 flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, 51 CustomOptionsFormat* custom_options_format); 52 53 // Construct a writer for the specified `interpreter`. Then, use 54 // .Write() or .GetBuffer(...) to extract the data. 55 explicit ModelWriter(Interpreter* interpreter); 56 57 // Same as above, except takes subgraphs as input. 58 explicit ModelWriter(const std::vector<Subgraph*>& subgraphs); 59 60 // For initializing the ModelWriter internal data. 61 void Init(const std::vector<Subgraph*>& subgraphs); 62 63 // Get a buffer and size of a serialized flatbuffer. 64 TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); 65 // Write the serialized flatbuffer to the prescribed `filename`. 66 TfLiteStatus Write(const std::string& filename); 67 68 // Specifies unused tensors on the target subgraph. 69 void SetUnusedTensors(int subgraph_index, 70 const std::set<int>& unused_tensors); 71 72 // Specifies custom inputs, outputs, and execution_plan to target subgraph. 73 TfLiteStatus SetCustomInputOutput(int subgraph_index, 74 const std::vector<int>& inputs, 75 const std::vector<int>& outputs, 76 const std::vector<int>& execution_plan); 77 78 // Registers a custom writer for a custom op. The customization allows the 79 // caller to change the custom data. 80 TfLiteStatus RegisterCustomWriter(const std::string& custom_name, 81 CustomWriter custom_writer); 82 83 private: 84 template <class T> 85 using Offset = flatbuffers::Offset<T>; 86 Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( 87 flatbuffers::FlatBufferBuilder* fbb); 88 Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( 89 flatbuffers::FlatBufferBuilder* fbb); 90 91 // List of subgraph writers owned by this model writer. 92 // There is one subgraph writer for each subgraph in the model. 93 std::vector<SubgraphWriter> subgraph_writers_; 94 95 // This data corresponds to the overall model (rather than individual 96 // subgraphs), so we define common fields. Keep track of byte buffers 97 std::vector<std::pair<const uint8_t*, size_t>> buffers_; 98 // List of used opcodes 99 std::vector<OpCode> opcodes_; 100 std::unordered_map<int, int> builtin_op_to_opcode_; 101 }; 102 103 // Handles writing TensorFlow Lite running subgraph to a serialized TF lite 104 // file format. 105 // TODO(b/174708523): Reconcile into ModelWriter? 106 class SubgraphWriter { 107 public: 108 friend class ModelWriter; 109 110 typedef flatbuffers::Offset<Operator> (*CustomWriter)( 111 flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index, 112 flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, 113 CustomOptionsFormat* custom_options_format); 114 115 // Construct a subgraph writer for the specified `subgraph`. Then, use 116 // .Write() or .GetBuffer(...) to extract the data. SubgraphWriter(Subgraph * subgraph)117 explicit SubgraphWriter(Subgraph* subgraph) 118 : subgraph_(subgraph), 119 inputs_(subgraph->inputs()), 120 outputs_(subgraph->outputs()), 121 execution_plan_(subgraph->execution_plan()) { 122 buffers_ = &buffers_data_; 123 opcodes_ = &opcodes_data_; 124 builtin_op_to_opcode_ = &builtin_op_to_opcode_data_; 125 buffers_->push_back(std::make_pair(nullptr, 0)); 126 } 127 128 // Get a buffer and size of a serialized flatbuffer. 129 TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); 130 // Write the serialized flatbuffer to the prescribed `filename`. 131 TfLiteStatus Write(const std::string& filename); 132 // Registers a custom writer for a custom op. The customization allows the 133 // caller to change the custom data. 134 TfLiteStatus RegisterCustomWriter(const std::string& custom_name, 135 CustomWriter custom_writer); 136 // Tensors that are unused and shouldn't be written. SetUnusedTensors(const std::set<int> & unused_tensors)137 void SetUnusedTensors(const std::set<int>& unused_tensors) { 138 unused_tensors_ = unused_tensors; 139 } 140 // Sets custom inputs, outputs, and execution_plan so that a portion of the 141 // subgraph is written to the buffer instead of the whole subgraph. 142 TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs, 143 const std::vector<int>& outputs, 144 const std::vector<int>& execution_plan); 145 146 private: 147 // Used by ModelWriter. SubgraphWriter(Subgraph * subgraph,std::vector<std::pair<const uint8_t *,size_t>> * external_buffers,std::vector<OpCode> * external_opcodes,std::unordered_map<int,int> * external_builtin_op_to_opcode)148 explicit SubgraphWriter( 149 Subgraph* subgraph, 150 std::vector<std::pair<const uint8_t*, size_t>>* external_buffers, 151 std::vector<OpCode>* external_opcodes, 152 std::unordered_map<int, int>* external_builtin_op_to_opcode) 153 : subgraph_(subgraph), 154 inputs_(subgraph->inputs()), 155 outputs_(subgraph->outputs()), 156 execution_plan_(subgraph->execution_plan()) { 157 buffers_ = external_buffers; 158 opcodes_ = external_opcodes; 159 builtin_op_to_opcode_ = external_builtin_op_to_opcode; 160 buffers_->push_back(std::make_pair(nullptr, 0)); 161 } 162 163 // Used by ModelWriter to populate data specific to this subgraph. 164 // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_, 165 // etc. & populated in the Flatbuffer by ModelWriter. 166 flatbuffers::Offset<SubGraph> PopulateAndGetOffset( 167 flatbuffers::FlatBufferBuilder* builder, 168 const std::string& subgraph_name); 169 170 template <class T> 171 using Offset = flatbuffers::Offset<T>; 172 template <class T_OUTPUT, class T_INPUT> 173 Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector( 174 flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); 175 Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors( 176 flatbuffers::FlatBufferBuilder* fbb); 177 Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators( 178 flatbuffers::FlatBufferBuilder* fbb); 179 Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( 180 flatbuffers::FlatBufferBuilder* fbb); 181 Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( 182 flatbuffers::FlatBufferBuilder* fbb); 183 184 template <class T> 185 std::vector<int> RemapTensorIndicesToWritten(const T& input); 186 187 // Checks if given `input`, `output`, and `execution_plan` represents a valid 188 // model within the Subgraph. 189 TfLiteStatus CheckInputOutput(const std::vector<int>& inputs, 190 const std::vector<int>& outputs, 191 const std::vector<int>& execution_plan); 192 GetOpCodeForBuiltin(int builtin_op_index)193 int GetOpCodeForBuiltin(int builtin_op_index) { 194 // auto it = builtin_op_to_opcode_.find(builtin_op_index); 195 std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result = 196 builtin_op_to_opcode_->insert( 197 std::make_pair(builtin_op_index, opcodes_->size())); 198 if (result.second) { 199 opcodes_->push_back({builtin_op_index, ""}); 200 } 201 return result.first->second; 202 } 203 GetOpCodeForCustom(const std::string & custom_name)204 int GetOpCodeForCustom(const std::string& custom_name) { 205 std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result = 206 custom_op_to_opcode_.insert( 207 std::make_pair(custom_name, opcodes_->size())); 208 if (result.second) { 209 opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name}); 210 } 211 return result.first->second; 212 } 213 214 // The subgraph we are writing 215 Subgraph* subgraph_; 216 // Input tensor indices to be written. 217 std::vector<int> inputs_; 218 // Output tensor indices to be written. 219 std::vector<int> outputs_; 220 // Order of nodes to be written. 221 std::vector<int> execution_plan_; 222 // List of op codes and mappings from builtin or custom op to opcode 223 std::set<int> unused_tensors_; 224 // For every tensor index in the subgraph, the index in the written. 225 // This is different due to temporary and unused tensors not being written. 226 std::vector<int> tensor_to_written_tensor_; 227 std::unordered_map<std::string, int> custom_op_to_opcode_; 228 std::unordered_map<std::string, CustomWriter> custom_op_to_writer_; 229 230 // We use pointers for these, since they may be provided by ModelWriter. 231 // Keep track of byte buffers 232 std::vector<std::pair<const uint8_t*, size_t>>* buffers_; 233 // List of used opcodes 234 std::vector<OpCode>* opcodes_; 235 std::unordered_map<int, int>* builtin_op_to_opcode_; 236 237 // These are used if SubgraphWriter is being used directly. 238 std::vector<std::pair<const uint8_t*, size_t>> buffers_data_; 239 // List of used opcodes 240 std::vector<OpCode> opcodes_data_; 241 std::unordered_map<int, int> builtin_op_to_opcode_data_; 242 }; 243 244 } // namespace tflite 245 246 #endif // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 247