xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/serialization/option_writer_generator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <ctype.h>
16 
17 #include <iostream>
18 #include <string>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "flatbuffers/minireflect.h"  // from @flatbuffers
23 #include "tensorflow/lite/schema/reflection/schema_generated.h"
24 
25 namespace tflite {
26 namespace {
27 // This is generated by grepping
28 //  cat  third_party/tensorflow/lite/c/builtin_op_data.h | grep "^} TfLite" |
29 //  sed 's/^} \(TfLite.*\)Params;/\1Params/g' | grep -v "^}" | sed
30 //  's/\(.*\)/"\1",/g' | sort
31 static const char* param_structs[] = {"TfLiteAddParams",
32                                       "TfLiteArgMaxParams",
33                                       "TfLiteArgMinParams",
34                                       "TfLiteBatchMatMulParams",
35                                       "TfLiteBatchToSpaceNDParams",
36                                       "TfLiteBidirectionalSequenceLSTMParams",
37                                       "TfLiteBidirectionalSequenceRNNParams",
38                                       "TfLiteBucketizeParams",
39                                       "TfLiteCastParams",
40                                       "TfLiteConcatenationParams",
41                                       "TfLiteConvParams",
42                                       "TfLiteDepthwiseConvParams",
43                                       "TfLiteDivParams",
44                                       "TfLiteDynamicUpdateSliceParams",
45                                       "TfLiteEmbeddingLookupSparseParams",
46                                       "TfLiteFakeQuantParams",
47                                       "TfLiteFullyConnectedParams",
48                                       "TfLiteGatherParams",
49                                       "TfLiteGeluParams",
50                                       "TfLiteIfParams",
51                                       "TfLiteL2NormParams",
52                                       "TfLiteLeakyReluParams",
53                                       "TfLiteLocalResponseNormParams",
54                                       "TfLiteLSHProjectionParams",
55                                       "TfLiteLSTMParams",
56                                       "TfLiteMirrorPaddingParams",
57                                       "TfLiteMulParams",
58                                       "TfLiteOneHotParams",
59                                       "TfLitePackParams",
60                                       "TfLitePadParams",
61                                       "TfLitePadV2Params",
62                                       "TfLitePoolParams",
63                                       "TfLiteRandomParams",
64                                       "TfLiteReducerParams",
65                                       "TfLiteReshapeParams",
66                                       "TfLiteResizeBilinearParams",
67                                       "TfLiteResizeNearestNeighborParams",
68                                       "TfLiteRNNParams",
69                                       "TfLiteSequenceRNNParams",
70                                       "TfLiteShapeParams",
71                                       "TfLiteSkipGramParams",
72                                       "TfLiteSoftmaxParams",
73                                       "TfLiteSpaceToBatchNDParams",
74                                       "TfLiteSpaceToDepthParams",
75                                       "TfLiteDepthToSpaceParams",
76                                       "TfLiteSparseToDenseParams",
77                                       "TfLiteSplitParams",
78                                       "TfLiteSplitVParams",
79                                       "TfLiteSqueezeParams",
80                                       "TfLiteStridedSliceParams",
81                                       "TfLiteSubParams",
82                                       "TfLiteSVDFParams",
83                                       "TfLiteTransposeConvParams",
84                                       "TfLiteTransposeParams",
85                                       "TfLiteUnidirectionalSequenceLSTMParams",
86                                       "TfLiteUniqueParams",
87                                       "TfLiteUnpackParams",
88                                       "TfLiteReverseSequenceParams",
89                                       "TfLiteWhileParams",
90                                       "TfLiteCumsumParams",
91                                       "TfLiteCallOnceParams",
92                                       "TfLiteConv3DParams",
93                                       "TfLiteHashtableParams",
94                                       "TfLiteHashtableFindParams",
95                                       "TfLiteHashtableImportParams",
96                                       "TfLiteHashtableSizeParams",
97                                       "TfLiteConv3DTransposeParams",
98                                       "TfLiteVarHandleParams",
99                                       "TfLiteUnsortedSegmentSumParams",
100                                       "TfLiteUnsortedSegmentMinParams",
101                                       nullptr};
102 }  // namespace
103 
104 // Get rid of all underscores and make everything lower case to make name
105 // matching work for stuff like 3D vs 3d or RNN vs Rnn.
ToCollapsed(const std::string & in)106 std::string ToCollapsed(const std::string& in) {
107   const char* s = in.c_str();
108   bool first = true;
109   std::string out;
110   while (*s != '\0') {
111     if (*s == '_') {
112       first = true;
113     } else if (first) {
114       out.push_back(tolower(*s));
115       first = false;
116     } else {
117       out.push_back(tolower(*s));
118     }
119     s++;
120   }
121   return out;
122 }
123 
124 // A collection of information about builtin ops.
125 class OpOptionData {
126  public:
OpOptionData()127   OpOptionData() {
128     BuildOpList();
129     BuildOptionToTypeFunctionMap();
130     BuildOpToOptionMap();
131   }
132 
133   // A list of builtin operations
ops() const134   const std::vector<std::string>& ops() const { return ops_; }
135   // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
op_to_option()136   const std::unordered_map<std::string, std::string>& op_to_option() {
137     return op_to_option_;
138   }
139   // Maps from option to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
option_to_struct()140   const std::unordered_map<std::string, std::string>& option_to_struct() {
141     return option_to_struct_;
142   }
143   // Maps from option to a flatbuffer type function that describes that option.
144   const std::unordered_map<std::string, flatbuffers::TypeFunction>&
option_to_type_function()145   option_to_type_function() {
146     return option_to_type_function_;
147   }
148 
149  private:
BuildOpList()150   void BuildOpList() {
151     for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
152          ++curr) {
153       if (strlen(*curr) != 0) ops_.push_back(*curr);
154     }
155   }
156 
BuildOptionToTypeFunctionMap()157   void BuildOptionToTypeFunctionMap() {
158     auto d = tflite::BuiltinOptionsTypeTable();
159     for (int i = 0; i < d->num_elems; i++) {
160       flatbuffers::TypeCode code = d->type_codes[i];
161       if (code.sequence_ref != -1) {
162         option_to_type_function_.insert(
163             std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
164       }
165     }
166   }
167 
BuildOpToOptionMap()168   void BuildOpToOptionMap() {
169     // Manually specified mappings between ops and options
170     op_to_option_["REDUCE_MAX"] = "ReducerOptions";
171     op_to_option_["REDUCE_MIN"] = "ReducerOptions";
172     op_to_option_["REDUCE_ANY"] = "ReducerOptions";
173     op_to_option_["REDUCE_ALL"] = "ReducerOptions";
174     op_to_option_["SUM"] = "ReducerOptions";
175     op_to_option_["REDUCE_MAX"] = "ReducerOptions";
176     op_to_option_["REDUCE_PROD"] = "ReducerOptions";
177     op_to_option_["MEAN"] = "ReducerOptions";
178     op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
179     op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
180     op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
181     op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
182     op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
183     op_to_option_["MAXIMUM"] = "MaximumMinimumOptions";
184     op_to_option_["MINIMUM"] = "MaximumMinimumOptions";
185     op_to_option_["CONV_3D_TRANSPOSE"] = "Conv3DOptions";
186     op_to_option_["RANDOM_STANDARD_NORMAL"] = "RandomOptions";
187     op_to_option_["RANDOM_UNIFORM"] = "RandomOptions";
188     op_to_option_["MULTINOMIAL"] = "RandomOptions";
189 
190     // These operators are not real ones.
191     op_to_option_["CUSTOM"] = "";    // TODO(aselle): maybe something else.
192     op_to_option_["DELEGATE"] = "";  // TODO(aselle): maybe something else.
193     op_to_option_["PLACEHOLDER_FOR_GREATER_OP_CODES"] = "";
194 
195     // Manually specified mappings between ops to "none" options -- these are
196     // ops without a corresponding Options message in schema as yet. If these
197     // options do get assigned an Options message in future, they need to be
198     // updated here as well.
199     op_to_option_["EMBEDDING_LOOKUP"] = "";
200     op_to_option_["FLOOR"] = "";
201     op_to_option_["CEIL"] = "";
202     op_to_option_["HASHTABLE_LOOKUP"] = "";
203     op_to_option_["LOGISTIC"] = "";
204     op_to_option_["RELU"] = "";
205     op_to_option_["RELU_N1_TO_1"] = "";
206     op_to_option_["RELU_0_TO_1"] = "";
207     op_to_option_["RELU6"] = "";
208     op_to_option_["ROUND"] = "";
209     op_to_option_["TANH"] = "";
210     op_to_option_["PRELU"] = "";
211     op_to_option_["SIN"] = "";
212     op_to_option_["LOG"] = "";
213     op_to_option_["SQRT"] = "";
214     op_to_option_["RSQRT"] = "";
215     op_to_option_["ELU"] = "";
216     op_to_option_["REVERSE_SEQUENCE"] = "";
217     op_to_option_["REAL"] = "";
218     op_to_option_["IMAG"] = "";
219     op_to_option_["COMPLEX_ABS"] = "";
220     op_to_option_["BROADCAST_ARGS"] = "";
221     op_to_option_["GELU"] = "";
222     op_to_option_["DYNAMIC_UPDATE_SLICE"] = "";
223 
224     // TODO(aselle): These are undesirable hacks. Consider changing C structs
225     option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
226     option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
227     option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
228     option_to_struct_["LocalResponseNormalizationOptions"] =
229         "TfLiteLocalResponseNormParams";
230     option_to_struct_["MirrorPadOptions"] = "TfLiteMirrorPaddingParams";
231     // Now for every op, try to find an option.
232     bool fatal = false;
233     for (const auto& op_name : ops_) {
234       auto d = tflite::BuiltinOptionsTypeTable();
235       std::string collapsed_option_name_guess =
236           ToCollapsed(op_name) + "options";
237       // O(n^2) but not that big of n.
238       for (int i = 0; i < d->num_elems; i++) {
239         std::string option_name = d->names[i];
240         std::string collapsed_option_name = ToCollapsed(option_name);
241         if (collapsed_option_name_guess == collapsed_option_name) {
242           op_to_option_.insert(std::make_pair(op_name, option_name));
243           break;
244         }
245       }
246       auto it = op_to_option_.find(op_name);
247       if (it == op_to_option_.end()) {
248         std::cerr << "Didn't find option for  " << op_name << std::endl;
249         fatal = true;
250       } else if (!it->second.empty()) {
251         std::string option_name = it->second;
252 
253         if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
254           bool param_struct_found = false;
255           std::string params_guess = std::string("TfLite") + option_name;
256           size_t start = params_guess.find("Options");
257           size_t len = strlen("Options");
258           params_guess.replace(start, len, "Params");
259           for (auto* param = param_structs; *param != nullptr; param++) {
260             if (*param == params_guess) {
261               param_struct_found = true;
262               break;
263             }
264           }
265           if (!param_struct_found) {
266             std::cerr << "Failed to get param struct for option " << option_name
267                       << std::endl;
268           } else {
269             option_to_struct_.insert(std::make_pair(option_name, params_guess));
270           }
271         }
272       }
273     }
274     if (fatal) {
275       exit(1);
276     }
277   }
278 
279  private:
280   std::vector<std::string> ops_;
281   std::unordered_map<std::string, std::string> op_to_option_;
282   std::unordered_map<std::string, std::string> option_to_struct_;
283   std::unordered_map<std::string, flatbuffers::TypeFunction>
284       option_to_type_function_;
285 };
286 
GenerateImportForResizeBilinearOp(FILE * fp)287 void GenerateImportForResizeBilinearOp(FILE* fp) {
288   fprintf(fp,
289           "  case BuiltinOperator_RESIZE_BILINEAR:  {\n"
290           "    const auto* params = reinterpret_cast<const "
291           "TfLiteResizeBilinearParams*>(builtin_op_data);\n"
292           "    auto union_type = CreateResizeBilinearOptions(*fbb, "
293           "params->align_corners, params->half_pixel_centers).Union();\n"
294           "    return std::make_pair(BuiltinOptions_ResizeBilinearOptions, "
295           "union_type);\n"
296           "  }\n  break;\n");
297 }
298 
GenerateImportForVarHandleOp(FILE * fp)299 void GenerateImportForVarHandleOp(FILE* fp) {
300   fprintf(fp,
301           "  case BuiltinOperator_VAR_HANDLE:  {\n"
302           "    const auto* params = reinterpret_cast<const "
303           "TfLiteVarHandleParams*>(builtin_op_data);\n"
304           "    auto union_type = CreateVarHandleOptions(*fbb, "
305           "fbb->CreateString(params->container), "
306           "fbb->CreateString(params->shared_name)).Union();\n"
307           "    return std::make_pair(BuiltinOptions_VarHandleOptions, "
308           "union_type);\n"
309           "  }\n  break;\n");
310 }
311 
312 // Reshape Op infers output shape either from Parameter or from shape tensor
313 // that's is an additional input. When we have this additional shape tensor as
314 // input we don't have the parameter present in this layer. In case of more than
315 // one input and the shape parameter does not have a valid value, we import an
316 // empty vector for the parameters.
GenerateImportForReshapeOp(FILE * fp)317 void GenerateImportForReshapeOp(FILE* fp) {
318   fprintf(fp,
319           "  case BuiltinOperator_RESHAPE:  {\n"
320           "    const auto* params = reinterpret_cast<const "
321           "TfLiteReshapeParams*>(builtin_op_data);\n"
322           "    flatbuffers::Offset<void> union_type;\n"
323           "    if (node.inputs->size > 1 && (params->num_dimensions <= 0 || "
324           "params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))"
325           " {\n"
326           "      union_type = CreateReshapeOptions(*fbb).Union();\n"
327           "    } else {\n"
328           "      auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
329           "params->shape + params->num_dimensions));\n"
330           "      union_type = CreateReshapeOptions(*fbb, "
331           "val0).Union();\n"
332           "    }\n"
333           "    return std::make_pair(BuiltinOptions_ReshapeOptions, "
334           "union_type);\n"
335           "  }\n  break;\n");
336 }
337 
GenerateImportForOp(FILE * fp,const std::string & op_name,const std::string & option_name,const std::string & option_type,const flatbuffers::TypeTable * options,const std::string & struct_name)338 void GenerateImportForOp(FILE* fp, const std::string& op_name,
339                          const std::string& option_name,
340                          const std::string& option_type,
341                          const flatbuffers::TypeTable* options,
342                          const std::string& struct_name) {
343   // Special-case ResizeBilinear which has some deprecated fields.
344   if (struct_name == "TfLiteResizeBilinearParams") {
345     GenerateImportForResizeBilinearOp(fp);
346     return;
347   }
348 
349   if (struct_name == "TfLiteVarHandleParams") {
350     GenerateImportForVarHandleOp(fp);
351     return;
352   }
353 
354   // Special case Reshape that may have 'new_shape' field missing from the
355   // parameters.
356   if (struct_name == "TfLiteReshapeParams") {
357     GenerateImportForReshapeOp(fp);
358     return;
359   }
360 
361   fprintf(fp, "  case BuiltinOperator_%s:  {\n", op_name.c_str());
362   if (options->num_elems != 0) {
363     fprintf(fp,
364             "    const auto* params = reinterpret_cast<const "
365             "%s*>(builtin_op_data);\n",
366             struct_name.c_str());
367   }
368 
369   for (size_t i = 0; i < options->num_elems; i++) {
370     std::string elem_name = options->names[i];
371     bool is_int_vector = false;
372     bool is_float_vector = false;
373     std::string vector_name = elem_name;
374     std::string vector_size;
375     // TODO(aselle): Irregular naming in builtins
376     if (elem_name == "fused_activation_function")
377       elem_name = "activation";
378     else if (elem_name == "stride_w")
379       elem_name = "stride_width";
380     else if (elem_name == "stride_h")
381       elem_name = "stride_height";
382     else if (elem_name == "stride_d")
383       elem_name = "stride_depth";
384     else if (elem_name == "dilation_h_factor")
385       elem_name = "dilation_height_factor";
386     else if (elem_name == "dilation_w_factor")
387       elem_name = "dilation_width_factor";
388     else if (elem_name == "dilation_d_factor")
389       elem_name = "dilation_depth_factor";
390     else if (elem_name == "idx_out_type")
391       elem_name = "index_out_type";
392 
393     // Vector fields treated specially.
394     if (elem_name == "new_shape") {
395       is_int_vector = true;
396       vector_name = "shape";
397       vector_size = "num_dimensions";
398     } else if (elem_name == "squeeze_dims") {
399       is_int_vector = true;
400       vector_size = "num_squeeze_dims";
401     } else if (elem_name == "boundaries") {
402       is_float_vector = true;
403       vector_size = "num_boundaries";
404     }
405 
406     if (is_int_vector) {
407       fprintf(fp,
408               "    auto val%zu = fbb->CreateVector("
409               "std::vector<int>(params->%s, params->%s + params->%s));\n",
410               i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
411       continue;
412     }
413 
414     if (is_float_vector) {
415       fprintf(fp,
416               "    auto val%zu = fbb->CreateVector("
417               "std::vector<float>(params->%s, params->%s + params->%s));\n",
418               i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
419       continue;
420     }
421 
422     flatbuffers::TypeCode code = options->type_codes[i];
423     auto contained_type = code.sequence_ref != -1
424                               ? options->type_refs[code.sequence_ref]
425                               : nullptr;
426     std::string mapper = "";
427     if (contained_type == TensorTypeTypeTable) {
428       mapper = "TfLiteTypeToSchemaType";
429     } else if (contained_type == ActivationFunctionTypeTypeTable) {
430       mapper = "TfLiteActivationToSchemaActivation";
431     } else if (contained_type == PaddingTypeTable) {
432       mapper = "TfLitePaddingToSchemaPadding";
433     } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
434       mapper = "FullyConnectedOptionsWeightsFormatToSchema";
435     } else if (contained_type == LSTMKernelTypeTypeTable) {
436       mapper = "LSTMKernelTypeToSchema";
437     } else if (contained_type == LSHProjectionTypeTypeTable) {
438       mapper = "LSHProjectionTypeToSchema";
439     } else if (contained_type == MirrorPadModeTypeTable) {
440       mapper = "MirrorPaddingModeToSchema";
441     } else if (contained_type == CombinerTypeTypeTable) {
442       mapper = "CombinerTypeToSchema";
443     }
444 
445     fprintf(fp,
446             "    auto val%zu = "
447             "%s(params->%s);\n",
448             i, mapper.c_str(), elem_name.c_str());
449   }
450   fprintf(fp, "    auto union_type = Create%s(*fbb", option_name.c_str());
451   for (size_t i = 0; i < options->num_elems; i++) {
452     fprintf(fp, ", val%zu", i);
453   }
454   fprintf(fp, ").Union();\n");
455   fprintf(fp, "    return std::make_pair(%s, union_type);\n",
456           option_type.c_str());
457   fprintf(fp, "  }\n  break;\n");
458 }
459 
GenerateImport(OpOptionData * option,FILE * fp)460 void GenerateImport(OpOptionData* option, FILE* fp) {
461   std::unordered_set<std::string> ignores;
462   ignores.insert("CONCAT_EMBEDDINGS");
463   ignores.insert("CALL");
464 
465   // Allow any op that doesn't have an options struct to be blocked
466   // together
467   for (const auto& op_name : option->ops()) {
468     auto option_it = option->op_to_option().find(op_name);
469     if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
470       continue;
471     fprintf(fp, "  case BuiltinOperator_%s:\n", op_name.c_str());
472   }
473   fprintf(fp,
474           "    return std::make_pair(BuiltinOptions_NONE, "
475           "flatbuffers::Offset<void>());\n    break;\n");
476 
477   // Iterate over each ops
478   for (const auto& op_name : option->ops()) {
479     if (ignores.find(op_name) != ignores.end()) continue;
480     // Get to the option and struct names, continuing if not found.
481     auto option_it = option->op_to_option().find(op_name);
482     if (option_it->second.empty()) continue;
483     std::string option_name = option_it->second;
484     std::string option_type = "BuiltinOptions_" + option_name;
485     auto option_func_it = option->option_to_type_function().find(option_name);
486     if (option_func_it == option->option_to_type_function().end()) continue;
487     auto struct_name_it = option->option_to_struct().find(option_name);
488     if (struct_name_it == option->option_to_struct().end()) {
489       // If no C struct, then it better have no arguments.
490       auto type_info = option_func_it->second();
491       if (type_info->num_elems != 0) {
492         // We have non-zero arguments in the schema, this means there
493         // should be a struct.
494         fprintf(stderr,
495                 "Op %s uses option struct %s which has no builtin struct\n",
496                 op_name.c_str(), option_name.c_str());
497         exit(1);
498       }
499       fprintf(fp, "  case BuiltinOperator_%s:\n", op_name.c_str());
500       fprintf(fp, "    return std::make_pair(%s, Create%s(*fbb).Union());",
501               option_type.c_str(), option_name.c_str());
502     } else {
503       // If C struct, then we need to assign all properties
504       auto struct_name = struct_name_it->second;
505       GenerateImportForOp(fp, op_name, option_name, option_type,
506                           option_func_it->second(), struct_name);
507     }
508   }
509   // TODO(aselle): Handle unhandled cases more gracefully.
510   fprintf(fp,
511           "default:    return std::make_pair(BuiltinOptions_NONE, "
512           "flatbuffers::Offset<void>());\n    break;\n");
513 }
514 
515 }  // namespace tflite
516 
main(int argc,char * argv[])517 int main(int argc, char* argv[]) {
518   tflite::OpOptionData option;
519   if (argc != 2) {
520     fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
521     return 1;
522   }
523   FILE* fp = fopen(argv[1], "w");
524   tflite::GenerateImport(&option, fp);
525   fclose(fp);
526 }
527