1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <ctype.h>
16
17 #include <iostream>
18 #include <string>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include "flatbuffers/minireflect.h" // from @flatbuffers
23 #include "tensorflow/lite/schema/reflection/schema_generated.h"
24
25 namespace tflite {
26 namespace {
27 // This is generated by grepping
28 // cat third_party/tensorflow/lite/c/builtin_op_data.h | grep "^} TfLite" |
29 // sed 's/^} \(TfLite.*\)Params;/\1Params/g' | grep -v "^}" | sed
30 // 's/\(.*\)/"\1",/g' | sort
31 static const char* param_structs[] = {"TfLiteAddParams",
32 "TfLiteArgMaxParams",
33 "TfLiteArgMinParams",
34 "TfLiteBatchMatMulParams",
35 "TfLiteBatchToSpaceNDParams",
36 "TfLiteBidirectionalSequenceLSTMParams",
37 "TfLiteBidirectionalSequenceRNNParams",
38 "TfLiteBucketizeParams",
39 "TfLiteCastParams",
40 "TfLiteConcatenationParams",
41 "TfLiteConvParams",
42 "TfLiteDepthwiseConvParams",
43 "TfLiteDivParams",
44 "TfLiteDynamicUpdateSliceParams",
45 "TfLiteEmbeddingLookupSparseParams",
46 "TfLiteFakeQuantParams",
47 "TfLiteFullyConnectedParams",
48 "TfLiteGatherParams",
49 "TfLiteGeluParams",
50 "TfLiteIfParams",
51 "TfLiteL2NormParams",
52 "TfLiteLeakyReluParams",
53 "TfLiteLocalResponseNormParams",
54 "TfLiteLSHProjectionParams",
55 "TfLiteLSTMParams",
56 "TfLiteMirrorPaddingParams",
57 "TfLiteMulParams",
58 "TfLiteOneHotParams",
59 "TfLitePackParams",
60 "TfLitePadParams",
61 "TfLitePadV2Params",
62 "TfLitePoolParams",
63 "TfLiteRandomParams",
64 "TfLiteReducerParams",
65 "TfLiteReshapeParams",
66 "TfLiteResizeBilinearParams",
67 "TfLiteResizeNearestNeighborParams",
68 "TfLiteRNNParams",
69 "TfLiteSequenceRNNParams",
70 "TfLiteShapeParams",
71 "TfLiteSkipGramParams",
72 "TfLiteSoftmaxParams",
73 "TfLiteSpaceToBatchNDParams",
74 "TfLiteSpaceToDepthParams",
75 "TfLiteDepthToSpaceParams",
76 "TfLiteSparseToDenseParams",
77 "TfLiteSplitParams",
78 "TfLiteSplitVParams",
79 "TfLiteSqueezeParams",
80 "TfLiteStridedSliceParams",
81 "TfLiteSubParams",
82 "TfLiteSVDFParams",
83 "TfLiteTransposeConvParams",
84 "TfLiteTransposeParams",
85 "TfLiteUnidirectionalSequenceLSTMParams",
86 "TfLiteUniqueParams",
87 "TfLiteUnpackParams",
88 "TfLiteReverseSequenceParams",
89 "TfLiteWhileParams",
90 "TfLiteCumsumParams",
91 "TfLiteCallOnceParams",
92 "TfLiteConv3DParams",
93 "TfLiteHashtableParams",
94 "TfLiteHashtableFindParams",
95 "TfLiteHashtableImportParams",
96 "TfLiteHashtableSizeParams",
97 "TfLiteConv3DTransposeParams",
98 "TfLiteVarHandleParams",
99 "TfLiteUnsortedSegmentSumParams",
100 "TfLiteUnsortedSegmentMinParams",
101 nullptr};
102 } // namespace
103
104 // Get rid of all underscores and make everything lower case to make name
105 // matching work for stuff like 3D vs 3d or RNN vs Rnn.
ToCollapsed(const std::string & in)106 std::string ToCollapsed(const std::string& in) {
107 const char* s = in.c_str();
108 bool first = true;
109 std::string out;
110 while (*s != '\0') {
111 if (*s == '_') {
112 first = true;
113 } else if (first) {
114 out.push_back(tolower(*s));
115 first = false;
116 } else {
117 out.push_back(tolower(*s));
118 }
119 s++;
120 }
121 return out;
122 }
123
124 // A collection of information about builtin ops.
125 class OpOptionData {
126 public:
OpOptionData()127 OpOptionData() {
128 BuildOpList();
129 BuildOptionToTypeFunctionMap();
130 BuildOpToOptionMap();
131 }
132
133 // A list of builtin operations
ops() const134 const std::vector<std::string>& ops() const { return ops_; }
135 // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
op_to_option()136 const std::unordered_map<std::string, std::string>& op_to_option() {
137 return op_to_option_;
138 }
139 // Maps from option to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
option_to_struct()140 const std::unordered_map<std::string, std::string>& option_to_struct() {
141 return option_to_struct_;
142 }
143 // Maps from option to a flatbuffer type function that describes that option.
144 const std::unordered_map<std::string, flatbuffers::TypeFunction>&
option_to_type_function()145 option_to_type_function() {
146 return option_to_type_function_;
147 }
148
149 private:
BuildOpList()150 void BuildOpList() {
151 for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
152 ++curr) {
153 if (strlen(*curr) != 0) ops_.push_back(*curr);
154 }
155 }
156
BuildOptionToTypeFunctionMap()157 void BuildOptionToTypeFunctionMap() {
158 auto d = tflite::BuiltinOptionsTypeTable();
159 for (int i = 0; i < d->num_elems; i++) {
160 flatbuffers::TypeCode code = d->type_codes[i];
161 if (code.sequence_ref != -1) {
162 option_to_type_function_.insert(
163 std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
164 }
165 }
166 }
167
BuildOpToOptionMap()168 void BuildOpToOptionMap() {
169 // Manually specified mappings between ops and options
170 op_to_option_["REDUCE_MAX"] = "ReducerOptions";
171 op_to_option_["REDUCE_MIN"] = "ReducerOptions";
172 op_to_option_["REDUCE_ANY"] = "ReducerOptions";
173 op_to_option_["REDUCE_ALL"] = "ReducerOptions";
174 op_to_option_["SUM"] = "ReducerOptions";
175 op_to_option_["REDUCE_MAX"] = "ReducerOptions";
176 op_to_option_["REDUCE_PROD"] = "ReducerOptions";
177 op_to_option_["MEAN"] = "ReducerOptions";
178 op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
179 op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
180 op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
181 op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
182 op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
183 op_to_option_["MAXIMUM"] = "MaximumMinimumOptions";
184 op_to_option_["MINIMUM"] = "MaximumMinimumOptions";
185 op_to_option_["CONV_3D_TRANSPOSE"] = "Conv3DOptions";
186 op_to_option_["RANDOM_STANDARD_NORMAL"] = "RandomOptions";
187 op_to_option_["RANDOM_UNIFORM"] = "RandomOptions";
188 op_to_option_["MULTINOMIAL"] = "RandomOptions";
189
190 // These operators are not real ones.
191 op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
192 op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
193 op_to_option_["PLACEHOLDER_FOR_GREATER_OP_CODES"] = "";
194
195 // Manually specified mappings between ops to "none" options -- these are
196 // ops without a corresponding Options message in schema as yet. If these
197 // options do get assigned an Options message in future, they need to be
198 // updated here as well.
199 op_to_option_["EMBEDDING_LOOKUP"] = "";
200 op_to_option_["FLOOR"] = "";
201 op_to_option_["CEIL"] = "";
202 op_to_option_["HASHTABLE_LOOKUP"] = "";
203 op_to_option_["LOGISTIC"] = "";
204 op_to_option_["RELU"] = "";
205 op_to_option_["RELU_N1_TO_1"] = "";
206 op_to_option_["RELU_0_TO_1"] = "";
207 op_to_option_["RELU6"] = "";
208 op_to_option_["ROUND"] = "";
209 op_to_option_["TANH"] = "";
210 op_to_option_["PRELU"] = "";
211 op_to_option_["SIN"] = "";
212 op_to_option_["LOG"] = "";
213 op_to_option_["SQRT"] = "";
214 op_to_option_["RSQRT"] = "";
215 op_to_option_["ELU"] = "";
216 op_to_option_["REVERSE_SEQUENCE"] = "";
217 op_to_option_["REAL"] = "";
218 op_to_option_["IMAG"] = "";
219 op_to_option_["COMPLEX_ABS"] = "";
220 op_to_option_["BROADCAST_ARGS"] = "";
221 op_to_option_["GELU"] = "";
222 op_to_option_["DYNAMIC_UPDATE_SLICE"] = "";
223
224 // TODO(aselle): These are undesirable hacks. Consider changing C structs
225 option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
226 option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
227 option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
228 option_to_struct_["LocalResponseNormalizationOptions"] =
229 "TfLiteLocalResponseNormParams";
230 option_to_struct_["MirrorPadOptions"] = "TfLiteMirrorPaddingParams";
231 // Now for every op, try to find an option.
232 bool fatal = false;
233 for (const auto& op_name : ops_) {
234 auto d = tflite::BuiltinOptionsTypeTable();
235 std::string collapsed_option_name_guess =
236 ToCollapsed(op_name) + "options";
237 // O(n^2) but not that big of n.
238 for (int i = 0; i < d->num_elems; i++) {
239 std::string option_name = d->names[i];
240 std::string collapsed_option_name = ToCollapsed(option_name);
241 if (collapsed_option_name_guess == collapsed_option_name) {
242 op_to_option_.insert(std::make_pair(op_name, option_name));
243 break;
244 }
245 }
246 auto it = op_to_option_.find(op_name);
247 if (it == op_to_option_.end()) {
248 std::cerr << "Didn't find option for " << op_name << std::endl;
249 fatal = true;
250 } else if (!it->second.empty()) {
251 std::string option_name = it->second;
252
253 if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
254 bool param_struct_found = false;
255 std::string params_guess = std::string("TfLite") + option_name;
256 size_t start = params_guess.find("Options");
257 size_t len = strlen("Options");
258 params_guess.replace(start, len, "Params");
259 for (auto* param = param_structs; *param != nullptr; param++) {
260 if (*param == params_guess) {
261 param_struct_found = true;
262 break;
263 }
264 }
265 if (!param_struct_found) {
266 std::cerr << "Failed to get param struct for option " << option_name
267 << std::endl;
268 } else {
269 option_to_struct_.insert(std::make_pair(option_name, params_guess));
270 }
271 }
272 }
273 }
274 if (fatal) {
275 exit(1);
276 }
277 }
278
279 private:
280 std::vector<std::string> ops_;
281 std::unordered_map<std::string, std::string> op_to_option_;
282 std::unordered_map<std::string, std::string> option_to_struct_;
283 std::unordered_map<std::string, flatbuffers::TypeFunction>
284 option_to_type_function_;
285 };
286
GenerateImportForResizeBilinearOp(FILE * fp)287 void GenerateImportForResizeBilinearOp(FILE* fp) {
288 fprintf(fp,
289 " case BuiltinOperator_RESIZE_BILINEAR: {\n"
290 " const auto* params = reinterpret_cast<const "
291 "TfLiteResizeBilinearParams*>(builtin_op_data);\n"
292 " auto union_type = CreateResizeBilinearOptions(*fbb, "
293 "params->align_corners, params->half_pixel_centers).Union();\n"
294 " return std::make_pair(BuiltinOptions_ResizeBilinearOptions, "
295 "union_type);\n"
296 " }\n break;\n");
297 }
298
GenerateImportForVarHandleOp(FILE * fp)299 void GenerateImportForVarHandleOp(FILE* fp) {
300 fprintf(fp,
301 " case BuiltinOperator_VAR_HANDLE: {\n"
302 " const auto* params = reinterpret_cast<const "
303 "TfLiteVarHandleParams*>(builtin_op_data);\n"
304 " auto union_type = CreateVarHandleOptions(*fbb, "
305 "fbb->CreateString(params->container), "
306 "fbb->CreateString(params->shared_name)).Union();\n"
307 " return std::make_pair(BuiltinOptions_VarHandleOptions, "
308 "union_type);\n"
309 " }\n break;\n");
310 }
311
312 // Reshape Op infers output shape either from Parameter or from shape tensor
313 // that's is an additional input. When we have this additional shape tensor as
314 // input we don't have the parameter present in this layer. In case of more than
315 // one input and the shape parameter does not have a valid value, we import an
316 // empty vector for the parameters.
GenerateImportForReshapeOp(FILE * fp)317 void GenerateImportForReshapeOp(FILE* fp) {
318 fprintf(fp,
319 " case BuiltinOperator_RESHAPE: {\n"
320 " const auto* params = reinterpret_cast<const "
321 "TfLiteReshapeParams*>(builtin_op_data);\n"
322 " flatbuffers::Offset<void> union_type;\n"
323 " if (node.inputs->size > 1 && (params->num_dimensions <= 0 || "
324 "params->num_dimensions > TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT))"
325 " {\n"
326 " union_type = CreateReshapeOptions(*fbb).Union();\n"
327 " } else {\n"
328 " auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
329 "params->shape + params->num_dimensions));\n"
330 " union_type = CreateReshapeOptions(*fbb, "
331 "val0).Union();\n"
332 " }\n"
333 " return std::make_pair(BuiltinOptions_ReshapeOptions, "
334 "union_type);\n"
335 " }\n break;\n");
336 }
337
GenerateImportForOp(FILE * fp,const std::string & op_name,const std::string & option_name,const std::string & option_type,const flatbuffers::TypeTable * options,const std::string & struct_name)338 void GenerateImportForOp(FILE* fp, const std::string& op_name,
339 const std::string& option_name,
340 const std::string& option_type,
341 const flatbuffers::TypeTable* options,
342 const std::string& struct_name) {
343 // Special-case ResizeBilinear which has some deprecated fields.
344 if (struct_name == "TfLiteResizeBilinearParams") {
345 GenerateImportForResizeBilinearOp(fp);
346 return;
347 }
348
349 if (struct_name == "TfLiteVarHandleParams") {
350 GenerateImportForVarHandleOp(fp);
351 return;
352 }
353
354 // Special case Reshape that may have 'new_shape' field missing from the
355 // parameters.
356 if (struct_name == "TfLiteReshapeParams") {
357 GenerateImportForReshapeOp(fp);
358 return;
359 }
360
361 fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
362 if (options->num_elems != 0) {
363 fprintf(fp,
364 " const auto* params = reinterpret_cast<const "
365 "%s*>(builtin_op_data);\n",
366 struct_name.c_str());
367 }
368
369 for (size_t i = 0; i < options->num_elems; i++) {
370 std::string elem_name = options->names[i];
371 bool is_int_vector = false;
372 bool is_float_vector = false;
373 std::string vector_name = elem_name;
374 std::string vector_size;
375 // TODO(aselle): Irregular naming in builtins
376 if (elem_name == "fused_activation_function")
377 elem_name = "activation";
378 else if (elem_name == "stride_w")
379 elem_name = "stride_width";
380 else if (elem_name == "stride_h")
381 elem_name = "stride_height";
382 else if (elem_name == "stride_d")
383 elem_name = "stride_depth";
384 else if (elem_name == "dilation_h_factor")
385 elem_name = "dilation_height_factor";
386 else if (elem_name == "dilation_w_factor")
387 elem_name = "dilation_width_factor";
388 else if (elem_name == "dilation_d_factor")
389 elem_name = "dilation_depth_factor";
390 else if (elem_name == "idx_out_type")
391 elem_name = "index_out_type";
392
393 // Vector fields treated specially.
394 if (elem_name == "new_shape") {
395 is_int_vector = true;
396 vector_name = "shape";
397 vector_size = "num_dimensions";
398 } else if (elem_name == "squeeze_dims") {
399 is_int_vector = true;
400 vector_size = "num_squeeze_dims";
401 } else if (elem_name == "boundaries") {
402 is_float_vector = true;
403 vector_size = "num_boundaries";
404 }
405
406 if (is_int_vector) {
407 fprintf(fp,
408 " auto val%zu = fbb->CreateVector("
409 "std::vector<int>(params->%s, params->%s + params->%s));\n",
410 i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
411 continue;
412 }
413
414 if (is_float_vector) {
415 fprintf(fp,
416 " auto val%zu = fbb->CreateVector("
417 "std::vector<float>(params->%s, params->%s + params->%s));\n",
418 i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
419 continue;
420 }
421
422 flatbuffers::TypeCode code = options->type_codes[i];
423 auto contained_type = code.sequence_ref != -1
424 ? options->type_refs[code.sequence_ref]
425 : nullptr;
426 std::string mapper = "";
427 if (contained_type == TensorTypeTypeTable) {
428 mapper = "TfLiteTypeToSchemaType";
429 } else if (contained_type == ActivationFunctionTypeTypeTable) {
430 mapper = "TfLiteActivationToSchemaActivation";
431 } else if (contained_type == PaddingTypeTable) {
432 mapper = "TfLitePaddingToSchemaPadding";
433 } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
434 mapper = "FullyConnectedOptionsWeightsFormatToSchema";
435 } else if (contained_type == LSTMKernelTypeTypeTable) {
436 mapper = "LSTMKernelTypeToSchema";
437 } else if (contained_type == LSHProjectionTypeTypeTable) {
438 mapper = "LSHProjectionTypeToSchema";
439 } else if (contained_type == MirrorPadModeTypeTable) {
440 mapper = "MirrorPaddingModeToSchema";
441 } else if (contained_type == CombinerTypeTypeTable) {
442 mapper = "CombinerTypeToSchema";
443 }
444
445 fprintf(fp,
446 " auto val%zu = "
447 "%s(params->%s);\n",
448 i, mapper.c_str(), elem_name.c_str());
449 }
450 fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
451 for (size_t i = 0; i < options->num_elems; i++) {
452 fprintf(fp, ", val%zu", i);
453 }
454 fprintf(fp, ").Union();\n");
455 fprintf(fp, " return std::make_pair(%s, union_type);\n",
456 option_type.c_str());
457 fprintf(fp, " }\n break;\n");
458 }
459
GenerateImport(OpOptionData * option,FILE * fp)460 void GenerateImport(OpOptionData* option, FILE* fp) {
461 std::unordered_set<std::string> ignores;
462 ignores.insert("CONCAT_EMBEDDINGS");
463 ignores.insert("CALL");
464
465 // Allow any op that doesn't have an options struct to be blocked
466 // together
467 for (const auto& op_name : option->ops()) {
468 auto option_it = option->op_to_option().find(op_name);
469 if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
470 continue;
471 fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
472 }
473 fprintf(fp,
474 " return std::make_pair(BuiltinOptions_NONE, "
475 "flatbuffers::Offset<void>());\n break;\n");
476
477 // Iterate over each ops
478 for (const auto& op_name : option->ops()) {
479 if (ignores.find(op_name) != ignores.end()) continue;
480 // Get to the option and struct names, continuing if not found.
481 auto option_it = option->op_to_option().find(op_name);
482 if (option_it->second.empty()) continue;
483 std::string option_name = option_it->second;
484 std::string option_type = "BuiltinOptions_" + option_name;
485 auto option_func_it = option->option_to_type_function().find(option_name);
486 if (option_func_it == option->option_to_type_function().end()) continue;
487 auto struct_name_it = option->option_to_struct().find(option_name);
488 if (struct_name_it == option->option_to_struct().end()) {
489 // If no C struct, then it better have no arguments.
490 auto type_info = option_func_it->second();
491 if (type_info->num_elems != 0) {
492 // We have non-zero arguments in the schema, this means there
493 // should be a struct.
494 fprintf(stderr,
495 "Op %s uses option struct %s which has no builtin struct\n",
496 op_name.c_str(), option_name.c_str());
497 exit(1);
498 }
499 fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
500 fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
501 option_type.c_str(), option_name.c_str());
502 } else {
503 // If C struct, then we need to assign all properties
504 auto struct_name = struct_name_it->second;
505 GenerateImportForOp(fp, op_name, option_name, option_type,
506 option_func_it->second(), struct_name);
507 }
508 }
509 // TODO(aselle): Handle unhandled cases more gracefully.
510 fprintf(fp,
511 "default: return std::make_pair(BuiltinOptions_NONE, "
512 "flatbuffers::Offset<void>());\n break;\n");
513 }
514
515 } // namespace tflite
516
main(int argc,char * argv[])517 int main(int argc, char* argv[]) {
518 tflite::OpOptionData option;
519 if (argc != 2) {
520 fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
521 return 1;
522 }
523 FILE* fp = fopen(argv[1], "w");
524 tflite::GenerateImport(&option, fp);
525 fclose(fp);
526 }
527