xref: /aosp_15_r20/external/ComputeLibrary/python/scripts/report-model-ops/report_model_ops.py (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1#!/usr/bin/env python3
2# Copyright (c) 2021 Arm Limited.
3#
4# SPDX-License-Identifier: MIT
5#
6# Permission is hereby granted, free of charge, to any person obtaining a copy
7# of this software and associated documentation files (the "Software"), to
8# deal in the Software without restriction, including without limitation the
9# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10# sell copies of the Software, and to permit persons to whom the Software is
11# furnished to do so, subject to the following conditions:
12#
13# The above copyright notice and this permission notice shall be included in all
14# copies or substantial portions of the Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22# SOFTWARE.
23import json
24import logging
25import os
26import sys
27from argparse import ArgumentParser
28
29import tflite
30
31sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
32
33from utils.model_identification import identify_model_type
34from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name, tflite_typecode2aclname
35
36SUPPORTED_MODEL_TYPES = ["tflite"]
37logger = logging.getLogger("report_model_ops")
38
39
40def get_ops_types_from_tflite_graph(model):
41    """
42    Helper function that extract operator related meta-data from a TFLite model
43
44    Parameters
45        ----------
46    model: str
47        Respective TFLite model to analyse
48
49    Returns
50    ----------
51    supported_ops, unsupported_ops, data_types: tuple
52        A tuple with the sets of unique operator types and data-types that are present in the model
53    """
54
55    logger.debug(f"Analysing TFLite mode '{model}'!")
56
57    with open(model, "rb") as f:
58        buf = f.read()
59        model = tflite.Model.GetRootAsModel(buf, 0)
60
61    # Extract unique operators
62    nr_unique_ops = model.OperatorCodesLength()
63    unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)}
64
65    # Extract IO data-types
66    supported_data_types = set()
67    unsupported_data_types = set()
68    for subgraph_id in range(0, model.SubgraphsLength()):
69        subgraph = model.Subgraphs(subgraph_id)
70        for tensor_id in range(0, subgraph.TensorsLength()):
71            try:
72                supported_data_types.add(tflite_typecode2aclname(subgraph.Tensors(tensor_id).Type()))
73            except ValueError:
74                unsupported_data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type()))
75                logger.warning(f"Data type {tflite_typecode2name(subgraph.Tensors(tensor_id).Type())} is not supported by ComputeLibrary")
76
77    # Perform mapping between TfLite ops to ComputeLibrary ones
78    supported_ops = set()
79    unsupported_ops = set()
80    for top in unique_ops:
81        try:
82            supported_ops.add(tflite_op2acl(top))
83        except ValueError:
84            unsupported_ops.add(top)
85            logger.warning(f"Operator {top} does not have ComputeLibrary mapping")
86
87    return (supported_ops, unsupported_ops, supported_data_types, unsupported_data_types)
88
89
90def extract_model_meta(model, model_type):
91    """
92    Function that calls the appropriate model parser to extract model related meta-data
93    Supported parsers: TFLite
94
95    Parameters
96        ----------
97    model: str
98        Path to model that we want to analyze
99    model_type:
100        type of the model
101
102    Returns
103    ----------
104    ops, data_types: (tuple)
105        A tuple with the list of unique operator types and data-types that are present in the model
106    """
107
108    if model_type == "tflite":
109        return get_ops_types_from_tflite_graph(model)
110    else:
111        logger.warning(f"Model type '{model_type}' is unsupported!")
112        return ()
113
114
115def generate_build_config(ops, data_types, data_layouts):
116    """
117    Function that generates a compatible ComputeLibrary operator-based build configuration
118
119    Parameters
120        ----------
121    ops: set
122        Set with the operators to add in the build configuration
123    data_types:
124        Set with the data types to add in the build configuration
125    data_layouts:
126        Set with the data layouts to add in the build configuration
127
128    Returns
129    ----------
130    config_data: dict
131        Dictionary compatible with ComputeLibrary
132    """
133    config_data = {}
134    config_data["operators"] = list(ops)
135    config_data["data_types"] = list(data_types)
136    config_data["data_layouts"] = list(data_layouts)
137
138    return config_data
139
140
141if __name__ == "__main__":
142    parser = ArgumentParser(
143        description="""Report map of operations in a list of models.
144            The script consumes deep learning models and reports the type of operations and data-types used
145            Supported model types: TFLite """
146    )
147
148    parser.add_argument(
149        "-m",
150        "--models",
151        nargs="+",
152        required=True,
153        type=str,
154        help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}",
155    )
156    parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output")
157    parser.add_argument(
158        "-c",
159        "--config",
160        type=str,
161        help="JSON configuration file used that can be used for custom ComputeLibrary builds",
162    )
163    args = parser.parse_args()
164
165    # Setup Logger
166    logging_level = logging.INFO
167    if args.debug:
168        logging_level = logging.DEBUG
169    logging.basicConfig(level=logging_level)
170
171    # Extract operator mapping
172    final_supported_ops = set()
173    final_unsupported_ops = set()
174    final_supported_dts = set()
175    final_unsupported_dts = set()
176    final_layouts = {"nhwc"} # Data layout for TFLite is always NHWC
177    for model in args.models:
178        logger.debug(f"Starting analyzing {model} model")
179
180        model_type = identify_model_type(model)
181        supported_model_ops, unsupported_mode_ops, supported_model_dts, unsupported_model_dts = extract_model_meta(model, model_type)
182        final_supported_ops.update(supported_model_ops)
183        final_unsupported_ops.update(unsupported_mode_ops)
184        final_supported_dts.update(supported_model_dts)
185        final_unsupported_dts.update(unsupported_model_dts)
186
187    logger.info("=== Supported Operators")
188    logger.info(final_supported_ops)
189    if(len(final_unsupported_ops)):
190        logger.info("=== Unsupported Operators")
191        logger.info(final_unsupported_ops)
192    logger.info("=== Data Types")
193    logger.info(final_supported_dts)
194    if(len(final_unsupported_dts)):
195        logger.info("=== Unsupported Data Types")
196        logger.info(final_unsupported_dts)
197    logger.info("=== Data Layouts")
198    logger.info(final_layouts)
199
200    # Generate JSON file
201    if args.config:
202        logger.debug("Generating JSON build configuration file")
203        config_data = generate_build_config(final_supported_ops, final_supported_dts, final_layouts)
204        with open(args.config, "w") as f:
205            json.dump(config_data, f)
206