1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/serialization/writer_lib.h"
16
17 #include <cstdlib>
18 #include <cstring>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/schema/reflection/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_conversion_utils.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/tools/versioning/op_version.h"
31 #include "tensorflow/lite/version.h"
32
33 namespace tflite {
34 namespace {
35
36 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<OpCode> * opcodes)37 CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
38 std::vector<OpCode>* opcodes) {
39 std::vector<flatbuffers::Offset<OperatorCode>> codes;
40 for (const auto& it : *opcodes) {
41 const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
42 // Use version 0 for builtin op. This is a way to serialize version field to
43 // flatbuffer (since 0 is non default) and it will be corrected later.
44 int32_t op_version = it.builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
45 codes.push_back(
46 CreateOperatorCodeDirect(*fbb, static_cast<BuiltinOperator>(it.builtin),
47 custom_name, op_version));
48 }
49 return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
50 }
51
52 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<std::pair<const uint8_t *,size_t>> * buffers)53 ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
54 std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
55 std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
56 for (auto buffer : *buffers) {
57 auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
58 buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
59 }
60 return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
61 }
62
WriteImpl(const std::string & filename,void * data,size_t size)63 TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
64 FILE* fp = fopen(filename.c_str(), "wb");
65 if (!fp) return kTfLiteError;
66
67 const int result_size = fwrite(data, 1, size, fp);
68 fclose(fp);
69 if (result_size != size) return kTfLiteError;
70
71 return kTfLiteOk;
72 }
73
CreateBuiltinUnion(flatbuffers::FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data,const TfLiteNode & node)74 std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
75 flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
76 void* builtin_op_data, const TfLiteNode& node) {
77 switch (op) {
78 #include "tensorflow/lite/tools/serialization/option_writer_generated.h"
79 }
80 return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
81 }
82
83 } // namespace
84
85 template <class T_OUTPUT, class T_INPUT>
ExportVector(flatbuffers::FlatBufferBuilder * fbb,const T_INPUT & v)86 flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
87 flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
88 std::vector<T_OUTPUT> inputs(v.begin(), v.end());
89 return fbb->template CreateVector<T_OUTPUT>(inputs);
90 }
91
92 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Operator>>>
ExportOperators(flatbuffers::FlatBufferBuilder * fbb)93 SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
94 std::vector<flatbuffers::Offset<Operator>> operators;
95
96 std::vector<int> operator_to_opcode;
97 // TODO(aselle): Augment this once we put execution plan in schema.
98 operator_to_opcode.resize(subgraph_->nodes_size(), -1);
99 for (int op_index : execution_plan_) {
100 const auto* node_and_registration =
101 subgraph_->node_and_registration(op_index);
102 const TfLiteRegistration* registration = &node_and_registration->second;
103 if (!registration->custom_name) {
104 operator_to_opcode[op_index] =
105 GetOpCodeForBuiltin(registration->builtin_code);
106 } else {
107 operator_to_opcode[op_index] =
108 GetOpCodeForCustom(registration->custom_name);
109 }
110 }
111 // second pass serialize operators
112 for (int op_index : execution_plan_) {
113 const auto* node_and_registration =
114 subgraph_->node_and_registration(op_index);
115 const TfLiteNode& node = node_and_registration->first;
116 const TfLiteRegistration& registration = node_and_registration->second;
117 flatbuffers::Offset<void> builtin_options;
118 BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
119 // Custom data
120 // TODO(aselle): Custom options format is not known by default. Just assume
121 // for now.
122 auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
123 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0;
124
125 if (!registration.custom_name) {
126 // builtin
127 auto builtin_options_and_type = CreateBuiltinUnion(
128 fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
129 node.builtin_data, node);
130 builtin_options = builtin_options_and_type.second;
131 builtin_options_type = builtin_options_and_type.first;
132 } else {
133 auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
134 if (custom_writer != custom_op_to_writer_.end() &&
135 custom_writer->second) {
136 // delegate to custom writer if it exists
137 custom_writer->second(fbb, subgraph_, op_index, &custom_options,
138 &custom_options_format);
139 } else {
140 // use the custom data as fact
141 custom_options = fbb->CreateVector(
142 reinterpret_cast<const uint8_t*>(node.custom_initial_data),
143 node.custom_initial_data_size);
144 }
145 }
146
147 int opcode_index = operator_to_opcode[op_index];
148 std::vector<int> written_inputs =
149 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
150 std::vector<int> written_outputs =
151 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
152 auto inputs = ExportVector<int32_t>(fbb, written_inputs);
153 auto outputs = ExportVector<int32_t>(fbb, written_outputs);
154 operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
155 builtin_options_type, builtin_options,
156 custom_options, custom_options_format));
157 }
158
159 return fbb->template CreateVector<flatbuffers::Offset<Operator>>(operators);
160 }
161
162 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Tensor>>>
ExportTensors(flatbuffers::FlatBufferBuilder * fbb)163 SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
164 // Initialized to -1.
165 // A value of -1 means this tensor will not be exported.
166 tensor_to_written_tensor_.resize(subgraph_->tensors_size(), -1);
167
168 std::vector<flatbuffers::Offset<Tensor>> tensors;
169
170 // Make a map from tensor index to whether the tensor is a temporary.
171 std::vector<bool> tensor_is_temporary(subgraph_->tensors_size(), false);
172 for (int op_index = 0; op_index < subgraph_->nodes_size(); ++op_index) {
173 const auto* node_and_registration =
174 subgraph_->node_and_registration(op_index);
175 for (auto tensor_index :
176 TfLiteIntArrayView(node_and_registration->first.temporaries))
177 tensor_is_temporary[tensor_index] = true;
178 }
179
180 // Now we need to remap all used tensor indices
181 int curr_output_index = 0;
182 for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
183 tensor_index++) {
184 // Temporary tensors and unused tensors will not be written.
185 if (!tensor_is_temporary[tensor_index] &&
186 unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
187 tensor_to_written_tensor_[tensor_index] = curr_output_index++;
188 }
189 }
190
191 for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
192 ++tensor_index) {
193 // Tensor not exported.
194 if (tensor_to_written_tensor_[tensor_index] == -1) continue;
195
196 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
197 // Allocate a buffer index
198 int buffer_index = 0; // This is null
199 if (tensor->allocation_type == kTfLiteMmapRo) {
200 buffer_index = buffers_->size();
201 buffers_->push_back(std::make_pair(
202 reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
203 }
204 // Primitive type.
205 TensorType type = TfLiteTypeToSchemaType(tensor->type);
206 // Handle quantization
207 flatbuffers::Offset<QuantizationParameters> quantization_params;
208
209 const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
210 flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
211 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
212
213 if (tensor->quantization.type == kTfLiteAffineQuantization) {
214 if (tensor->params.scale != 0.f) {
215 // Quantization with a single argument array.
216 scale_array = fbb->CreateVector<float>({tensor->params.scale});
217 zero_point_array =
218 fbb->CreateVector<int64_t>({tensor->params.zero_point});
219 quantization_params = CreateQuantizationParameters(
220 *fbb, null_array, null_array, scale_array, zero_point_array);
221 } else { // Multi channel quantization.
222 const TfLiteAffineQuantization* params =
223 reinterpret_cast<TfLiteAffineQuantization*>(
224 tensor->quantization.params);
225 const size_t num_scales = params->scale->size;
226
227 std::vector<float> scale_vector(params->scale->data,
228 params->scale->data + num_scales);
229 std::vector<int64_t> zero_point_vector(
230 params->zero_point->data, params->zero_point->data + num_scales);
231 scale_array = fbb->CreateVector<float>(scale_vector);
232 zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
233 quantization_params = CreateQuantizationParameters(
234 *fbb, null_array, null_array, scale_array, zero_point_array,
235 QuantizationDetails_NONE, 0, params->quantized_dimension);
236 }
237 }
238
239 // Shape
240 // Some tensors added during op init are not registered formally as
241 // node temporaries. Some didn't get memory allocated for them, and we
242 // should avoid serializing those tensors.
243 if (tensor->dims) {
244 TfLiteIntArrayView shape_view(tensor->dims);
245 std::vector<int> shape =
246 std::vector<int>(shape_view.begin(), shape_view.end());
247
248 Offset<flatbuffers::String> tensor_name_offset = 0;
249 if (tensor->name != nullptr) {
250 tensor_name_offset = fbb->CreateString(tensor->name);
251 }
252
253 flatbuffers::Offset<flatbuffers::Vector<int32_t>>
254 shape_signature_offset = 0;
255 if (tensor->dims_signature != nullptr) {
256 TfLiteIntArrayView shape_signature_view(tensor->dims_signature);
257 std::vector<int32_t> shape_signature(shape_signature_view.begin(),
258 shape_signature_view.end());
259 shape_signature_offset = ExportVector<int32_t>(fbb, shape_signature);
260 }
261
262 tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
263 type, buffer_index, tensor_name_offset,
264 quantization_params, tensor->is_variable,
265 /*sparsity=*/0, shape_signature_offset));
266 }
267 }
268 }
269 return fbb->template CreateVector<flatbuffers::Offset<Tensor>>(tensors);
270 }
271
272 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)273 SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
274 return ExportBuffersImpl(fbb, buffers_);
275 }
276
277 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)278 SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
279 return CreateOpCodeTableImpl(fbb, opcodes_);
280 }
281
282 template <class T>
RemapTensorIndicesToWritten(const T & input)283 std::vector<int> SubgraphWriter::RemapTensorIndicesToWritten(const T& input) {
284 std::vector<int> output;
285 output.reserve(input.size());
286 for (int x : input) {
287 // Special value representing an optional tensor which is not present.
288 if (x == -1) {
289 output.push_back(x);
290 continue;
291 }
292 if (tensor_to_written_tensor_[x] != -1) {
293 output.push_back(tensor_to_written_tensor_[x]);
294 }
295 }
296 return output;
297 }
298
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)299 TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
300 size_t* size) {
301 if (!out || !size) return kTfLiteError;
302 flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
303 std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
304 subgraphs_as_vector.push_back(
305 PopulateAndGetOffset(&builder, subgraph_->GetName()));
306
307 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
308 buffers = ExportBuffers(&builder);
309
310 auto description = builder.CreateString("Exported from Subgraph.");
311
312 auto op_codes = CreateOpCodeTable(&builder);
313 auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
314 builder.CreateVector(subgraphs_as_vector),
315 description, buffers);
316 ::tflite::FinishModelBuffer(builder, model);
317 ::tflite::UpdateOpVersion(builder.GetBufferPointer());
318 const uint8_t* buffer = builder.GetBufferPointer();
319 *size = builder.GetSize();
320 (*out).reset(new uint8_t[*size]);
321 memcpy(out->get(), buffer, *size);
322 return kTfLiteOk;
323 }
324
PopulateAndGetOffset(flatbuffers::FlatBufferBuilder * builder,const std::string & subgraph_name)325 flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
326 flatbuffers::FlatBufferBuilder* builder, const std::string& subgraph_name) {
327 auto tensors = ExportTensors(builder);
328 std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
329 std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
330 auto inputs = ExportVector<int32_t>(builder, written_inputs);
331 auto outputs = ExportVector<int32_t>(builder, written_outputs);
332
333 auto ops = ExportOperators(builder);
334 auto name = builder->CreateString(subgraph_name);
335 return CreateSubGraph(*builder, tensors, inputs, outputs, ops, name);
336 }
337
Write(const std::string & filename)338 TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
339 std::unique_ptr<uint8_t[]> buffer;
340 size_t size;
341 TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
342 return WriteImpl(filename, buffer.get(), size);
343 }
344
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)345 TfLiteStatus SubgraphWriter::RegisterCustomWriter(
346 const std::string& custom_name, CustomWriter custom_writer) {
347 if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
348 return kTfLiteError;
349 }
350 custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
351 return kTfLiteOk;
352 }
353
CheckInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)354 TfLiteStatus SubgraphWriter::CheckInputOutput(
355 const std::vector<int>& inputs, const std::vector<int>& outputs,
356 const std::vector<int>& execution_plan) {
357 std::unordered_set<int> known_tensors(inputs.begin(), inputs.end());
358 known_tensors.insert(subgraph_->variables().begin(),
359 subgraph_->variables().end());
360 // Scan execution plan and confirm input tensors are known before each node
361 // executes. Then append output tensors to known tensors.
362 for (int op_index : execution_plan) {
363 const auto* node_and_registration =
364 subgraph_->node_and_registration(op_index);
365 const TfLiteNode& node = node_and_registration->first;
366 for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
367 if (tensor_index < 0) {
368 // Skip if optional input not present.
369 if (tensor_index == kTfLiteOptionalTensor) {
370 continue;
371 } else {
372 return kTfLiteError;
373 }
374 }
375 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
376 // Skip constant tensors.
377 if (tensor->allocation_type == kTfLiteMmapRo) {
378 continue;
379 }
380 }
381
382 if (known_tensors.find(tensor_index) == known_tensors.end()) {
383 subgraph_->context()->ReportError(
384 subgraph_->context(),
385 "Node (%d) uses an input (%d) that is not provided.", op_index,
386 tensor_index);
387 return kTfLiteError;
388 }
389 }
390 TfLiteIntArrayView outputs(node.outputs);
391 known_tensors.insert(outputs.begin(), outputs.end());
392 }
393
394 // Check if outputs are known tensors or constants.
395 for (int tensor_index : outputs) {
396 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
397 // Skip constant tensors.
398 if (tensor->allocation_type == kTfLiteMmapRo) {
399 continue;
400 }
401 }
402
403 if (known_tensors.find(tensor_index) == known_tensors.end()) {
404 subgraph_->context()->ReportError(
405 subgraph_->context(),
406 "Output (%d) is not produced by the execution plan.", tensor_index);
407 return kTfLiteError;
408 }
409 }
410 return kTfLiteOk;
411 }
412
SetCustomInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)413 TfLiteStatus SubgraphWriter::SetCustomInputOutput(
414 const std::vector<int>& inputs, const std::vector<int>& outputs,
415 const std::vector<int>& execution_plan) {
416 TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
417 inputs_ = inputs;
418 outputs_ = outputs;
419 execution_plan_ = execution_plan;
420 return kTfLiteOk;
421 }
422
ModelWriter(Interpreter * interpreter)423 ModelWriter::ModelWriter(Interpreter* interpreter) {
424 std::vector<Subgraph*> subgraphs;
425
426 // Retrieves the list of the subgraphs from the interpreter for constructing
427 // a list of SubgraphWriters.
428 subgraphs.reserve(interpreter->subgraphs_size());
429 for (int i = 0; i < interpreter->subgraphs_size(); ++i) {
430 subgraphs.push_back(interpreter->subgraph(i));
431 }
432
433 Init(subgraphs);
434 }
435
ModelWriter(const std::vector<Subgraph * > & subgraphs)436 ModelWriter::ModelWriter(const std::vector<Subgraph*>& subgraphs) {
437 Init(subgraphs);
438 }
439
Init(const std::vector<Subgraph * > & subgraphs)440 void ModelWriter::Init(const std::vector<Subgraph*>& subgraphs) {
441 buffers_.push_back(std::make_pair(nullptr, 0));
442 subgraph_writers_.reserve(subgraphs.size());
443 for (auto* subgraph : subgraphs) {
444 SubgraphWriter writer(subgraph, &buffers_, &opcodes_,
445 &builtin_op_to_opcode_);
446 subgraph_writers_.push_back(writer);
447 }
448 }
449
450 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)451 ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
452 return ExportBuffersImpl(fbb, &buffers_);
453 }
454
455 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)456 ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
457 return CreateOpCodeTableImpl(fbb, &opcodes_);
458 }
459
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)460 TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
461 size_t* size) {
462 if (!out || !size) return kTfLiteError;
463 flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
464
465 std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
466 subgraphs_as_vector.reserve(subgraph_writers_.size());
467 for (auto& subgraph_writer : subgraph_writers_) {
468 subgraphs_as_vector.push_back(subgraph_writer.PopulateAndGetOffset(
469 &builder, subgraph_writer.subgraph_->GetName()));
470 }
471
472 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
473 buffers = ExportBuffers(&builder);
474
475 auto description = builder.CreateString("Exported from Subgraph.");
476
477 auto op_codes = CreateOpCodeTable(&builder);
478 auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
479 builder.CreateVector(subgraphs_as_vector),
480 description, buffers);
481 ::tflite::FinishModelBuffer(builder, model);
482 ::tflite::UpdateOpVersion(builder.GetBufferPointer());
483 const uint8_t* buffer = builder.GetBufferPointer();
484 *size = builder.GetSize();
485 (*out).reset(new uint8_t[*size]);
486 memcpy(out->get(), buffer, *size);
487 return kTfLiteOk;
488 }
489
Write(const std::string & filename)490 TfLiteStatus ModelWriter::Write(const std::string& filename) {
491 std::unique_ptr<uint8_t[]> buffer;
492 size_t size;
493 TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
494 return WriteImpl(filename, buffer.get(), size);
495 }
496
SetUnusedTensors(int subgraph_index,const std::set<int> & unused_tensors)497 void ModelWriter::SetUnusedTensors(int subgraph_index,
498 const std::set<int>& unused_tensors) {
499 subgraph_writers_[subgraph_index].SetUnusedTensors(unused_tensors);
500 }
501
SetCustomInputOutput(int subgraph_index,const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)502 TfLiteStatus ModelWriter::SetCustomInputOutput(
503 int subgraph_index, const std::vector<int>& inputs,
504 const std::vector<int>& outputs, const std::vector<int>& execution_plan) {
505 return subgraph_writers_[subgraph_index].SetCustomInputOutput(inputs, outputs,
506 execution_plan);
507 }
508
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)509 TfLiteStatus ModelWriter::RegisterCustomWriter(const std::string& custom_name,
510 CustomWriter custom_writer) {
511 for (auto& subgraph_writer : subgraph_writers_) {
512 subgraph_writer.RegisterCustomWriter(custom_name, custom_writer);
513 }
514 return kTfLiteOk;
515 }
516
517 } // namespace tflite
518