xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/serialization/writer_lib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Library to write a flatbuffer of a currently loaded TFLite model/subgraph.
16 
17 #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
18 #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
19 #include <iostream>
20 #include <string>
21 #include <unordered_map>
22 
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/schema/reflection/schema_generated.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/version.h"
31 
32 namespace tflite {
33 
34 struct OpCode {
35   int builtin;
36   std::string custom;
37 };
38 
39 // Forward declaration.
40 class SubgraphWriter;
41 
42 // Handles writing a full TFLite model (with 1 or more subgraphs) to a
43 // serialized TF lite file format.
44 // TODO(b/174708523): Support custom I/O or unused tensors later.
45 class ModelWriter {
46  public:
47   // CustomWriter allows the delegate to customize the write to the flatbuffer.
48   typedef flatbuffers::Offset<Operator> (*CustomWriter)(
49       flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
50       flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
51       CustomOptionsFormat* custom_options_format);
52 
53   // Construct a writer for the specified `interpreter`. Then, use
54   // .Write() or .GetBuffer(...) to extract the data.
55   explicit ModelWriter(Interpreter* interpreter);
56 
57   // Same as above, except takes subgraphs as input.
58   explicit ModelWriter(const std::vector<Subgraph*>& subgraphs);
59 
60   // For initializing the ModelWriter internal data.
61   void Init(const std::vector<Subgraph*>& subgraphs);
62 
63   // Get a buffer and size of a serialized flatbuffer.
64   TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
65   // Write the serialized flatbuffer to the prescribed `filename`.
66   TfLiteStatus Write(const std::string& filename);
67 
68   // Specifies unused tensors on the target subgraph.
69   void SetUnusedTensors(int subgraph_index,
70                         const std::set<int>& unused_tensors);
71 
72   // Specifies custom inputs, outputs, and execution_plan to target subgraph.
73   TfLiteStatus SetCustomInputOutput(int subgraph_index,
74                                     const std::vector<int>& inputs,
75                                     const std::vector<int>& outputs,
76                                     const std::vector<int>& execution_plan);
77 
78   // Registers a custom writer for a custom op. The customization allows the
79   // caller to change the custom data.
80   TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
81                                     CustomWriter custom_writer);
82 
83  private:
84   template <class T>
85   using Offset = flatbuffers::Offset<T>;
86   Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
87       flatbuffers::FlatBufferBuilder* fbb);
88   Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
89       flatbuffers::FlatBufferBuilder* fbb);
90 
91   // List of subgraph writers owned by this model writer.
92   // There is one subgraph writer for each subgraph in the model.
93   std::vector<SubgraphWriter> subgraph_writers_;
94 
95   // This data corresponds to the overall model (rather than individual
96   // subgraphs), so we define common fields. Keep track of byte buffers
97   std::vector<std::pair<const uint8_t*, size_t>> buffers_;
98   // List of used opcodes
99   std::vector<OpCode> opcodes_;
100   std::unordered_map<int, int> builtin_op_to_opcode_;
101 };
102 
103 // Handles writing TensorFlow Lite running subgraph to a serialized TF lite
104 // file format.
105 // TODO(b/174708523): Reconcile into ModelWriter?
106 class SubgraphWriter {
107  public:
108   friend class ModelWriter;
109 
110   typedef flatbuffers::Offset<Operator> (*CustomWriter)(
111       flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
112       flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
113       CustomOptionsFormat* custom_options_format);
114 
115   // Construct a subgraph writer for the specified `subgraph`. Then, use
116   // .Write() or .GetBuffer(...) to extract the data.
SubgraphWriter(Subgraph * subgraph)117   explicit SubgraphWriter(Subgraph* subgraph)
118       : subgraph_(subgraph),
119         inputs_(subgraph->inputs()),
120         outputs_(subgraph->outputs()),
121         execution_plan_(subgraph->execution_plan()) {
122     buffers_ = &buffers_data_;
123     opcodes_ = &opcodes_data_;
124     builtin_op_to_opcode_ = &builtin_op_to_opcode_data_;
125     buffers_->push_back(std::make_pair(nullptr, 0));
126   }
127 
128   // Get a buffer and size of a serialized flatbuffer.
129   TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
130   // Write the serialized flatbuffer to the prescribed `filename`.
131   TfLiteStatus Write(const std::string& filename);
132   // Registers a custom writer for a custom op. The customization allows the
133   // caller to change the custom data.
134   TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
135                                     CustomWriter custom_writer);
136   // Tensors that are unused and shouldn't be written.
SetUnusedTensors(const std::set<int> & unused_tensors)137   void SetUnusedTensors(const std::set<int>& unused_tensors) {
138     unused_tensors_ = unused_tensors;
139   }
140   // Sets custom inputs, outputs, and execution_plan so that a portion of the
141   // subgraph is written to the buffer instead of the whole subgraph.
142   TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs,
143                                     const std::vector<int>& outputs,
144                                     const std::vector<int>& execution_plan);
145 
146  private:
147   // Used by ModelWriter.
SubgraphWriter(Subgraph * subgraph,std::vector<std::pair<const uint8_t *,size_t>> * external_buffers,std::vector<OpCode> * external_opcodes,std::unordered_map<int,int> * external_builtin_op_to_opcode)148   explicit SubgraphWriter(
149       Subgraph* subgraph,
150       std::vector<std::pair<const uint8_t*, size_t>>* external_buffers,
151       std::vector<OpCode>* external_opcodes,
152       std::unordered_map<int, int>* external_builtin_op_to_opcode)
153       : subgraph_(subgraph),
154         inputs_(subgraph->inputs()),
155         outputs_(subgraph->outputs()),
156         execution_plan_(subgraph->execution_plan()) {
157     buffers_ = external_buffers;
158     opcodes_ = external_opcodes;
159     builtin_op_to_opcode_ = external_builtin_op_to_opcode;
160     buffers_->push_back(std::make_pair(nullptr, 0));
161   }
162 
163   // Used by ModelWriter to populate data specific to this subgraph.
164   // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_,
165   // etc. & populated in the Flatbuffer by ModelWriter.
166   flatbuffers::Offset<SubGraph> PopulateAndGetOffset(
167       flatbuffers::FlatBufferBuilder* builder,
168       const std::string& subgraph_name);
169 
170   template <class T>
171   using Offset = flatbuffers::Offset<T>;
172   template <class T_OUTPUT, class T_INPUT>
173   Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
174       flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
175   Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
176       flatbuffers::FlatBufferBuilder* fbb);
177   Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
178       flatbuffers::FlatBufferBuilder* fbb);
179   Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
180       flatbuffers::FlatBufferBuilder* fbb);
181   Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
182       flatbuffers::FlatBufferBuilder* fbb);
183 
184   template <class T>
185   std::vector<int> RemapTensorIndicesToWritten(const T& input);
186 
187   // Checks if given `input`, `output`, and `execution_plan` represents a valid
188   // model within the Subgraph.
189   TfLiteStatus CheckInputOutput(const std::vector<int>& inputs,
190                                 const std::vector<int>& outputs,
191                                 const std::vector<int>& execution_plan);
192 
GetOpCodeForBuiltin(int builtin_op_index)193   int GetOpCodeForBuiltin(int builtin_op_index) {
194     // auto it = builtin_op_to_opcode_.find(builtin_op_index);
195     std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result =
196         builtin_op_to_opcode_->insert(
197             std::make_pair(builtin_op_index, opcodes_->size()));
198     if (result.second) {
199       opcodes_->push_back({builtin_op_index, ""});
200     }
201     return result.first->second;
202   }
203 
GetOpCodeForCustom(const std::string & custom_name)204   int GetOpCodeForCustom(const std::string& custom_name) {
205     std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
206         custom_op_to_opcode_.insert(
207             std::make_pair(custom_name, opcodes_->size()));
208     if (result.second) {
209       opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name});
210     }
211     return result.first->second;
212   }
213 
214   // The subgraph we are writing
215   Subgraph* subgraph_;
216   // Input tensor indices to be written.
217   std::vector<int> inputs_;
218   // Output tensor indices to be written.
219   std::vector<int> outputs_;
220   // Order of nodes to be written.
221   std::vector<int> execution_plan_;
222   // List of op codes and mappings from builtin or custom op to opcode
223   std::set<int> unused_tensors_;
224   // For every tensor index in the subgraph, the index in the written.
225   // This is different due to temporary and unused tensors not being written.
226   std::vector<int> tensor_to_written_tensor_;
227   std::unordered_map<std::string, int> custom_op_to_opcode_;
228   std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
229 
230   // We use pointers for these, since they may be provided by ModelWriter.
231   // Keep track of byte buffers
232   std::vector<std::pair<const uint8_t*, size_t>>* buffers_;
233   // List of used opcodes
234   std::vector<OpCode>* opcodes_;
235   std::unordered_map<int, int>* builtin_op_to_opcode_;
236 
237   // These are used if SubgraphWriter is being used directly.
238   std::vector<std::pair<const uint8_t*, size_t>> buffers_data_;
239   // List of used opcodes
240   std::vector<OpCode> opcodes_data_;
241   std::unordered_map<int, int> builtin_op_to_opcode_data_;
242 };
243 
244 }  // namespace tflite
245 
246 #endif  // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
247