1#!/usr/bin/env python3 2# Copyright (c) 2021 Arm Limited. 3# 4# SPDX-License-Identifier: MIT 5# 6# Permission is hereby granted, free of charge, to any person obtaining a copy 7# of this software and associated documentation files (the "Software"), to 8# deal in the Software without restriction, including without limitation the 9# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10# sell copies of the Software, and to permit persons to whom the Software is 11# furnished to do so, subject to the following conditions: 12# 13# The above copyright notice and this permission notice shall be included in all 14# copies or substantial portions of the Software. 15# 16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22# SOFTWARE. 23import json 24import logging 25import os 26import sys 27from argparse import ArgumentParser 28 29import tflite 30 31sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 32 33from utils.model_identification import identify_model_type 34from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name, tflite_typecode2aclname 35 36SUPPORTED_MODEL_TYPES = ["tflite"] 37logger = logging.getLogger("report_model_ops") 38 39 40def get_ops_types_from_tflite_graph(model): 41 """ 42 Helper function that extract operator related meta-data from a TFLite model 43 44 Parameters 45 ---------- 46 model: str 47 Respective TFLite model to analyse 48 49 Returns 50 ---------- 51 supported_ops, unsupported_ops, data_types: tuple 52 A tuple with the sets of unique operator types and data-types that are present in the model 53 """ 54 55 logger.debug(f"Analysing TFLite mode '{model}'!") 56 57 with open(model, "rb") as f: 58 buf = f.read() 59 model = tflite.Model.GetRootAsModel(buf, 0) 60 61 # Extract unique operators 62 nr_unique_ops = model.OperatorCodesLength() 63 unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)} 64 65 # Extract IO data-types 66 supported_data_types = set() 67 unsupported_data_types = set() 68 for subgraph_id in range(0, model.SubgraphsLength()): 69 subgraph = model.Subgraphs(subgraph_id) 70 for tensor_id in range(0, subgraph.TensorsLength()): 71 try: 72 supported_data_types.add(tflite_typecode2aclname(subgraph.Tensors(tensor_id).Type())) 73 except ValueError: 74 unsupported_data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type())) 75 logger.warning(f"Data type {tflite_typecode2name(subgraph.Tensors(tensor_id).Type())} is not supported by ComputeLibrary") 76 77 # Perform mapping between TfLite ops to ComputeLibrary ones 78 supported_ops = set() 79 unsupported_ops = set() 80 for top in unique_ops: 81 try: 82 supported_ops.add(tflite_op2acl(top)) 83 except ValueError: 84 unsupported_ops.add(top) 85 logger.warning(f"Operator {top} does not have ComputeLibrary mapping") 86 87 return (supported_ops, unsupported_ops, supported_data_types, unsupported_data_types) 88 89 90def extract_model_meta(model, model_type): 91 """ 92 Function that calls the appropriate model parser to extract model related meta-data 93 Supported parsers: TFLite 94 95 Parameters 96 ---------- 97 model: str 98 Path to model that we want to analyze 99 model_type: 100 type of the model 101 102 Returns 103 ---------- 104 ops, data_types: (tuple) 105 A tuple with the list of unique operator types and data-types that are present in the model 106 """ 107 108 if model_type == "tflite": 109 return get_ops_types_from_tflite_graph(model) 110 else: 111 logger.warning(f"Model type '{model_type}' is unsupported!") 112 return () 113 114 115def generate_build_config(ops, data_types, data_layouts): 116 """ 117 Function that generates a compatible ComputeLibrary operator-based build configuration 118 119 Parameters 120 ---------- 121 ops: set 122 Set with the operators to add in the build configuration 123 data_types: 124 Set with the data types to add in the build configuration 125 data_layouts: 126 Set with the data layouts to add in the build configuration 127 128 Returns 129 ---------- 130 config_data: dict 131 Dictionary compatible with ComputeLibrary 132 """ 133 config_data = {} 134 config_data["operators"] = list(ops) 135 config_data["data_types"] = list(data_types) 136 config_data["data_layouts"] = list(data_layouts) 137 138 return config_data 139 140 141if __name__ == "__main__": 142 parser = ArgumentParser( 143 description="""Report map of operations in a list of models. 144 The script consumes deep learning models and reports the type of operations and data-types used 145 Supported model types: TFLite """ 146 ) 147 148 parser.add_argument( 149 "-m", 150 "--models", 151 nargs="+", 152 required=True, 153 type=str, 154 help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}", 155 ) 156 parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output") 157 parser.add_argument( 158 "-c", 159 "--config", 160 type=str, 161 help="JSON configuration file used that can be used for custom ComputeLibrary builds", 162 ) 163 args = parser.parse_args() 164 165 # Setup Logger 166 logging_level = logging.INFO 167 if args.debug: 168 logging_level = logging.DEBUG 169 logging.basicConfig(level=logging_level) 170 171 # Extract operator mapping 172 final_supported_ops = set() 173 final_unsupported_ops = set() 174 final_supported_dts = set() 175 final_unsupported_dts = set() 176 final_layouts = {"nhwc"} # Data layout for TFLite is always NHWC 177 for model in args.models: 178 logger.debug(f"Starting analyzing {model} model") 179 180 model_type = identify_model_type(model) 181 supported_model_ops, unsupported_mode_ops, supported_model_dts, unsupported_model_dts = extract_model_meta(model, model_type) 182 final_supported_ops.update(supported_model_ops) 183 final_unsupported_ops.update(unsupported_mode_ops) 184 final_supported_dts.update(supported_model_dts) 185 final_unsupported_dts.update(unsupported_model_dts) 186 187 logger.info("=== Supported Operators") 188 logger.info(final_supported_ops) 189 if(len(final_unsupported_ops)): 190 logger.info("=== Unsupported Operators") 191 logger.info(final_unsupported_ops) 192 logger.info("=== Data Types") 193 logger.info(final_supported_dts) 194 if(len(final_unsupported_dts)): 195 logger.info("=== Unsupported Data Types") 196 logger.info(final_unsupported_dts) 197 logger.info("=== Data Layouts") 198 logger.info(final_layouts) 199 200 # Generate JSON file 201 if args.config: 202 logger.debug("Generating JSON build configuration file") 203 config_data = generate_build_config(final_supported_ops, final_supported_dts, final_layouts) 204 with open(args.config, "w") as f: 205 json.dump(config_data, f) 206