xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/serialization/writer_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/serialization/writer_lib.h"
16 
17 #include <cstdlib>
18 #include <cstring>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/schema/reflection/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_conversion_utils.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/tools/versioning/op_version.h"
31 #include "tensorflow/lite/version.h"
32 
33 namespace tflite {
34 namespace {
35 
36 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<OpCode> * opcodes)37 CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
38                       std::vector<OpCode>* opcodes) {
39   std::vector<flatbuffers::Offset<OperatorCode>> codes;
40   for (const auto& it : *opcodes) {
41     const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
42     // Use version 0 for builtin op. This is a way to serialize version field to
43     // flatbuffer (since 0 is non default) and it will be corrected later.
44     int32_t op_version = it.builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
45     codes.push_back(
46         CreateOperatorCodeDirect(*fbb, static_cast<BuiltinOperator>(it.builtin),
47                                  custom_name, op_version));
48   }
49   return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
50 }
51 
52 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<std::pair<const uint8_t *,size_t>> * buffers)53 ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
54                   std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
55   std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
56   for (auto buffer : *buffers) {
57     auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
58     buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
59   }
60   return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
61 }
62 
WriteImpl(const std::string & filename,void * data,size_t size)63 TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
64   FILE* fp = fopen(filename.c_str(), "wb");
65   if (!fp) return kTfLiteError;
66 
67   const int result_size = fwrite(data, 1, size, fp);
68   fclose(fp);
69   if (result_size != size) return kTfLiteError;
70 
71   return kTfLiteOk;
72 }
73 
CreateBuiltinUnion(flatbuffers::FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data,const TfLiteNode & node)74 std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
75     flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
76     void* builtin_op_data, const TfLiteNode& node) {
77   switch (op) {
78 #include "tensorflow/lite/tools/serialization/option_writer_generated.h"
79   }
80   return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
81 }
82 
83 }  // namespace
84 
85 template <class T_OUTPUT, class T_INPUT>
ExportVector(flatbuffers::FlatBufferBuilder * fbb,const T_INPUT & v)86 flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
87     flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
88   std::vector<T_OUTPUT> inputs(v.begin(), v.end());
89   return fbb->template CreateVector<T_OUTPUT>(inputs);
90 }
91 
92 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Operator>>>
ExportOperators(flatbuffers::FlatBufferBuilder * fbb)93 SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
94   std::vector<flatbuffers::Offset<Operator>> operators;
95 
96   std::vector<int> operator_to_opcode;
97   // TODO(aselle): Augment this once we put execution plan in schema.
98   operator_to_opcode.resize(subgraph_->nodes_size(), -1);
99   for (int op_index : execution_plan_) {
100     const auto* node_and_registration =
101         subgraph_->node_and_registration(op_index);
102     const TfLiteRegistration* registration = &node_and_registration->second;
103     if (!registration->custom_name) {
104       operator_to_opcode[op_index] =
105           GetOpCodeForBuiltin(registration->builtin_code);
106     } else {
107       operator_to_opcode[op_index] =
108           GetOpCodeForCustom(registration->custom_name);
109     }
110   }
111   // second pass serialize operators
112   for (int op_index : execution_plan_) {
113     const auto* node_and_registration =
114         subgraph_->node_and_registration(op_index);
115     const TfLiteNode& node = node_and_registration->first;
116     const TfLiteRegistration& registration = node_and_registration->second;
117     flatbuffers::Offset<void> builtin_options;
118     BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
119     // Custom data
120     // TODO(aselle): Custom options format is not known by default. Just assume
121     // for now.
122     auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
123     flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0;
124 
125     if (!registration.custom_name) {
126       // builtin
127       auto builtin_options_and_type = CreateBuiltinUnion(
128           fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
129           node.builtin_data, node);
130       builtin_options = builtin_options_and_type.second;
131       builtin_options_type = builtin_options_and_type.first;
132     } else {
133       auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
134       if (custom_writer != custom_op_to_writer_.end() &&
135           custom_writer->second) {
136         // delegate to custom writer if it exists
137         custom_writer->second(fbb, subgraph_, op_index, &custom_options,
138                               &custom_options_format);
139       } else {
140         // use the custom data as fact
141         custom_options = fbb->CreateVector(
142             reinterpret_cast<const uint8_t*>(node.custom_initial_data),
143             node.custom_initial_data_size);
144       }
145     }
146 
147     int opcode_index = operator_to_opcode[op_index];
148     std::vector<int> written_inputs =
149         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
150     std::vector<int> written_outputs =
151         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
152     auto inputs = ExportVector<int32_t>(fbb, written_inputs);
153     auto outputs = ExportVector<int32_t>(fbb, written_outputs);
154     operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
155                                        builtin_options_type, builtin_options,
156                                        custom_options, custom_options_format));
157   }
158 
159   return fbb->template CreateVector<flatbuffers::Offset<Operator>>(operators);
160 }
161 
162 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Tensor>>>
ExportTensors(flatbuffers::FlatBufferBuilder * fbb)163 SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
164   // Initialized to -1.
165   // A value of -1 means this tensor will not be exported.
166   tensor_to_written_tensor_.resize(subgraph_->tensors_size(), -1);
167 
168   std::vector<flatbuffers::Offset<Tensor>> tensors;
169 
170   // Make a map from tensor index to whether the tensor is a temporary.
171   std::vector<bool> tensor_is_temporary(subgraph_->tensors_size(), false);
172   for (int op_index = 0; op_index < subgraph_->nodes_size(); ++op_index) {
173     const auto* node_and_registration =
174         subgraph_->node_and_registration(op_index);
175     for (auto tensor_index :
176          TfLiteIntArrayView(node_and_registration->first.temporaries))
177       tensor_is_temporary[tensor_index] = true;
178   }
179 
180   // Now we need to remap all used tensor indices
181   int curr_output_index = 0;
182   for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
183        tensor_index++) {
184     // Temporary tensors and unused tensors will not be written.
185     if (!tensor_is_temporary[tensor_index] &&
186         unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
187       tensor_to_written_tensor_[tensor_index] = curr_output_index++;
188     }
189   }
190 
191   for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
192        ++tensor_index) {
193     // Tensor not exported.
194     if (tensor_to_written_tensor_[tensor_index] == -1) continue;
195 
196     if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
197       // Allocate a buffer index
198       int buffer_index = 0;  // This is null
199       if (tensor->allocation_type == kTfLiteMmapRo) {
200         buffer_index = buffers_->size();
201         buffers_->push_back(std::make_pair(
202             reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
203       }
204       // Primitive type.
205       TensorType type = TfLiteTypeToSchemaType(tensor->type);
206       // Handle quantization
207       flatbuffers::Offset<QuantizationParameters> quantization_params;
208 
209       const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
210       flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
211       flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
212 
213       if (tensor->quantization.type == kTfLiteAffineQuantization) {
214         if (tensor->params.scale != 0.f) {
215           // Quantization with a single argument array.
216           scale_array = fbb->CreateVector<float>({tensor->params.scale});
217           zero_point_array =
218               fbb->CreateVector<int64_t>({tensor->params.zero_point});
219           quantization_params = CreateQuantizationParameters(
220               *fbb, null_array, null_array, scale_array, zero_point_array);
221         } else {  // Multi channel quantization.
222           const TfLiteAffineQuantization* params =
223               reinterpret_cast<TfLiteAffineQuantization*>(
224                   tensor->quantization.params);
225           const size_t num_scales = params->scale->size;
226 
227           std::vector<float> scale_vector(params->scale->data,
228                                           params->scale->data + num_scales);
229           std::vector<int64_t> zero_point_vector(
230               params->zero_point->data, params->zero_point->data + num_scales);
231           scale_array = fbb->CreateVector<float>(scale_vector);
232           zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
233           quantization_params = CreateQuantizationParameters(
234               *fbb, null_array, null_array, scale_array, zero_point_array,
235               QuantizationDetails_NONE, 0, params->quantized_dimension);
236         }
237       }
238 
239       // Shape
240       // Some tensors added during op init are not registered formally as
241       // node temporaries. Some didn't get memory allocated for them, and we
242       // should avoid serializing those tensors.
243       if (tensor->dims) {
244         TfLiteIntArrayView shape_view(tensor->dims);
245         std::vector<int> shape =
246             std::vector<int>(shape_view.begin(), shape_view.end());
247 
248         Offset<flatbuffers::String> tensor_name_offset = 0;
249         if (tensor->name != nullptr) {
250           tensor_name_offset = fbb->CreateString(tensor->name);
251         }
252 
253         flatbuffers::Offset<flatbuffers::Vector<int32_t>>
254             shape_signature_offset = 0;
255         if (tensor->dims_signature != nullptr) {
256           TfLiteIntArrayView shape_signature_view(tensor->dims_signature);
257           std::vector<int32_t> shape_signature(shape_signature_view.begin(),
258                                                shape_signature_view.end());
259           shape_signature_offset = ExportVector<int32_t>(fbb, shape_signature);
260         }
261 
262         tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
263                                        type, buffer_index, tensor_name_offset,
264                                        quantization_params, tensor->is_variable,
265                                        /*sparsity=*/0, shape_signature_offset));
266       }
267     }
268   }
269   return fbb->template CreateVector<flatbuffers::Offset<Tensor>>(tensors);
270 }
271 
272 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)273 SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
274   return ExportBuffersImpl(fbb, buffers_);
275 }
276 
277 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)278 SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
279   return CreateOpCodeTableImpl(fbb, opcodes_);
280 }
281 
282 template <class T>
RemapTensorIndicesToWritten(const T & input)283 std::vector<int> SubgraphWriter::RemapTensorIndicesToWritten(const T& input) {
284   std::vector<int> output;
285   output.reserve(input.size());
286   for (int x : input) {
287     // Special value representing an optional tensor which is not present.
288     if (x == -1) {
289       output.push_back(x);
290       continue;
291     }
292     if (tensor_to_written_tensor_[x] != -1) {
293       output.push_back(tensor_to_written_tensor_[x]);
294     }
295   }
296   return output;
297 }
298 
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)299 TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
300                                        size_t* size) {
301   if (!out || !size) return kTfLiteError;
302   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
303   std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
304   subgraphs_as_vector.push_back(
305       PopulateAndGetOffset(&builder, subgraph_->GetName()));
306 
307   flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
308       buffers = ExportBuffers(&builder);
309 
310   auto description = builder.CreateString("Exported from Subgraph.");
311 
312   auto op_codes = CreateOpCodeTable(&builder);
313   auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
314                            builder.CreateVector(subgraphs_as_vector),
315                            description, buffers);
316   ::tflite::FinishModelBuffer(builder, model);
317   ::tflite::UpdateOpVersion(builder.GetBufferPointer());
318   const uint8_t* buffer = builder.GetBufferPointer();
319   *size = builder.GetSize();
320   (*out).reset(new uint8_t[*size]);
321   memcpy(out->get(), buffer, *size);
322   return kTfLiteOk;
323 }
324 
PopulateAndGetOffset(flatbuffers::FlatBufferBuilder * builder,const std::string & subgraph_name)325 flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
326     flatbuffers::FlatBufferBuilder* builder, const std::string& subgraph_name) {
327   auto tensors = ExportTensors(builder);
328   std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
329   std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
330   auto inputs = ExportVector<int32_t>(builder, written_inputs);
331   auto outputs = ExportVector<int32_t>(builder, written_outputs);
332 
333   auto ops = ExportOperators(builder);
334   auto name = builder->CreateString(subgraph_name);
335   return CreateSubGraph(*builder, tensors, inputs, outputs, ops, name);
336 }
337 
Write(const std::string & filename)338 TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
339   std::unique_ptr<uint8_t[]> buffer;
340   size_t size;
341   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
342   return WriteImpl(filename, buffer.get(), size);
343 }
344 
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)345 TfLiteStatus SubgraphWriter::RegisterCustomWriter(
346     const std::string& custom_name, CustomWriter custom_writer) {
347   if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
348     return kTfLiteError;
349   }
350   custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
351   return kTfLiteOk;
352 }
353 
CheckInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)354 TfLiteStatus SubgraphWriter::CheckInputOutput(
355     const std::vector<int>& inputs, const std::vector<int>& outputs,
356     const std::vector<int>& execution_plan) {
357   std::unordered_set<int> known_tensors(inputs.begin(), inputs.end());
358   known_tensors.insert(subgraph_->variables().begin(),
359                        subgraph_->variables().end());
360   // Scan execution plan and confirm input tensors are known before each node
361   // executes. Then append output tensors to known tensors.
362   for (int op_index : execution_plan) {
363     const auto* node_and_registration =
364         subgraph_->node_and_registration(op_index);
365     const TfLiteNode& node = node_and_registration->first;
366     for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
367       if (tensor_index < 0) {
368         // Skip if optional input not present.
369         if (tensor_index == kTfLiteOptionalTensor) {
370           continue;
371         } else {
372           return kTfLiteError;
373         }
374       }
375       if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
376         // Skip constant tensors.
377         if (tensor->allocation_type == kTfLiteMmapRo) {
378           continue;
379         }
380       }
381 
382       if (known_tensors.find(tensor_index) == known_tensors.end()) {
383         subgraph_->context()->ReportError(
384             subgraph_->context(),
385             "Node (%d) uses an input (%d) that is not provided.", op_index,
386             tensor_index);
387         return kTfLiteError;
388       }
389     }
390     TfLiteIntArrayView outputs(node.outputs);
391     known_tensors.insert(outputs.begin(), outputs.end());
392   }
393 
394   // Check if outputs are known tensors or constants.
395   for (int tensor_index : outputs) {
396     if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
397       // Skip constant tensors.
398       if (tensor->allocation_type == kTfLiteMmapRo) {
399         continue;
400       }
401     }
402 
403     if (known_tensors.find(tensor_index) == known_tensors.end()) {
404       subgraph_->context()->ReportError(
405           subgraph_->context(),
406           "Output (%d) is not produced by the execution plan.", tensor_index);
407       return kTfLiteError;
408     }
409   }
410   return kTfLiteOk;
411 }
412 
SetCustomInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)413 TfLiteStatus SubgraphWriter::SetCustomInputOutput(
414     const std::vector<int>& inputs, const std::vector<int>& outputs,
415     const std::vector<int>& execution_plan) {
416   TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
417   inputs_ = inputs;
418   outputs_ = outputs;
419   execution_plan_ = execution_plan;
420   return kTfLiteOk;
421 }
422 
ModelWriter(Interpreter * interpreter)423 ModelWriter::ModelWriter(Interpreter* interpreter) {
424   std::vector<Subgraph*> subgraphs;
425 
426   // Retrieves the list of the subgraphs from the interpreter for constructing
427   // a list of SubgraphWriters.
428   subgraphs.reserve(interpreter->subgraphs_size());
429   for (int i = 0; i < interpreter->subgraphs_size(); ++i) {
430     subgraphs.push_back(interpreter->subgraph(i));
431   }
432 
433   Init(subgraphs);
434 }
435 
ModelWriter(const std::vector<Subgraph * > & subgraphs)436 ModelWriter::ModelWriter(const std::vector<Subgraph*>& subgraphs) {
437   Init(subgraphs);
438 }
439 
Init(const std::vector<Subgraph * > & subgraphs)440 void ModelWriter::Init(const std::vector<Subgraph*>& subgraphs) {
441   buffers_.push_back(std::make_pair(nullptr, 0));
442   subgraph_writers_.reserve(subgraphs.size());
443   for (auto* subgraph : subgraphs) {
444     SubgraphWriter writer(subgraph, &buffers_, &opcodes_,
445                           &builtin_op_to_opcode_);
446     subgraph_writers_.push_back(writer);
447   }
448 }
449 
450 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)451 ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
452   return ExportBuffersImpl(fbb, &buffers_);
453 }
454 
455 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)456 ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
457   return CreateOpCodeTableImpl(fbb, &opcodes_);
458 }
459 
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)460 TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
461                                     size_t* size) {
462   if (!out || !size) return kTfLiteError;
463   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
464 
465   std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
466   subgraphs_as_vector.reserve(subgraph_writers_.size());
467   for (auto& subgraph_writer : subgraph_writers_) {
468     subgraphs_as_vector.push_back(subgraph_writer.PopulateAndGetOffset(
469         &builder, subgraph_writer.subgraph_->GetName()));
470   }
471 
472   flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
473       buffers = ExportBuffers(&builder);
474 
475   auto description = builder.CreateString("Exported from Subgraph.");
476 
477   auto op_codes = CreateOpCodeTable(&builder);
478   auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
479                            builder.CreateVector(subgraphs_as_vector),
480                            description, buffers);
481   ::tflite::FinishModelBuffer(builder, model);
482   ::tflite::UpdateOpVersion(builder.GetBufferPointer());
483   const uint8_t* buffer = builder.GetBufferPointer();
484   *size = builder.GetSize();
485   (*out).reset(new uint8_t[*size]);
486   memcpy(out->get(), buffer, *size);
487   return kTfLiteOk;
488 }
489 
Write(const std::string & filename)490 TfLiteStatus ModelWriter::Write(const std::string& filename) {
491   std::unique_ptr<uint8_t[]> buffer;
492   size_t size;
493   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
494   return WriteImpl(filename, buffer.get(), size);
495 }
496 
SetUnusedTensors(int subgraph_index,const std::set<int> & unused_tensors)497 void ModelWriter::SetUnusedTensors(int subgraph_index,
498                                    const std::set<int>& unused_tensors) {
499   subgraph_writers_[subgraph_index].SetUnusedTensors(unused_tensors);
500 }
501 
SetCustomInputOutput(int subgraph_index,const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)502 TfLiteStatus ModelWriter::SetCustomInputOutput(
503     int subgraph_index, const std::vector<int>& inputs,
504     const std::vector<int>& outputs, const std::vector<int>& execution_plan) {
505   return subgraph_writers_[subgraph_index].SetCustomInputOutput(inputs, outputs,
506                                                                 execution_plan);
507 }
508 
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)509 TfLiteStatus ModelWriter::RegisterCustomWriter(const std::string& custom_name,
510                                                CustomWriter custom_writer) {
511   for (auto& subgraph_writer : subgraph_writers_) {
512     subgraph_writer.RegisterCustomWriter(custom_name, custom_writer);
513   }
514   return kTfLiteOk;
515 }
516 
517 }  // namespace tflite
518