1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions used by multiple tflite test files.""" 16 17from tensorflow.lite.python import schema_py_generated as schema_fb 18from tensorflow.lite.python import schema_util 19from tensorflow.lite.tools import visualize 20 21 22def get_ops_list(model_data): 23 """Returns a set of ops in the tflite model data.""" 24 model = schema_fb.Model.GetRootAsModel(model_data, 0) 25 op_set = set() 26 27 for subgraph_idx in range(model.SubgraphsLength()): 28 subgraph = model.Subgraphs(subgraph_idx) 29 for op_idx in range(subgraph.OperatorsLength()): 30 op = subgraph.Operators(op_idx) 31 opcode = model.OperatorCodes(op.OpcodeIndex()) 32 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 33 if builtin_code == schema_fb.BuiltinOperator.CUSTOM: 34 opname = opcode.CustomCode().decode("utf-8") 35 op_set.add(opname) 36 else: 37 op_set.add(visualize.BuiltinCodeToName(builtin_code)) 38 return op_set 39 40 41def get_output_shapes(model_data): 42 """Returns a list of output shapes in the tflite model data.""" 43 model = schema_fb.Model.GetRootAsModel(model_data, 0) 44 45 output_shapes = [] 46 for subgraph_idx in range(model.SubgraphsLength()): 47 subgraph = model.Subgraphs(subgraph_idx) 48 for output_idx in range(subgraph.OutputsLength()): 49 output_tensor_idx = subgraph.Outputs(output_idx) 50 output_tensor = subgraph.Tensors(output_tensor_idx) 51 output_shapes.append(output_tensor.ShapeAsNumpy().tolist()) 52 53 return output_shapes 54