xref: /aosp_15_r20/external/executorch/backends/cadence/aot/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import enum
10import logging
11import operator
12import os
13from dataclasses import dataclass
14from typing import Dict, List, Optional, Tuple
15
16import torch
17
18from executorch.exir import ExecutorchProgramManager, memory
19from executorch.exir.dialects._ops import ops as exir_ops
20from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21from tabulate import tabulate
22
23from torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops
24
25
26# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
27def model_is_quantized(model: torch.nn.Module) -> bool:
28    # Quantized models have to be GraphModules already, from prepare/convert calls.
29    # Return false if the model is not a GraphModule.
30    if not isinstance(model, torch.fx.GraphModule):
31        return False
32
33    # Walk through the graph and look for quant/dequant ops
34    for op in quant_ops:
35        if model.graph.find_nodes(op="call_function", target=op):
36            return True
37    return False
38
39
40# Get the output size of a 1D convolution given the input size and parameters
41def get_conv1d_output_size(
42    in_size: torch.Size,
43    out_channels: int,
44    stride: int,
45    padding: int,
46    dilation: int,
47    kernel_size: int,
48    channel_last: bool,
49) -> torch.Size:
50    assert len(in_size) == 3
51    if channel_last:
52        N, L, C = in_size
53    else:
54        N, C, L = in_size
55
56    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
57    lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
58
59    if channel_last:
60        return torch.Size((N, lout, out_channels))
61    return torch.Size((N, out_channels, lout))
62
63
64# Get the output size of a 2D convolution given the input size and parameters
65def get_conv2d_output_size(
66    in_size: torch.Size,
67    out_channels: int,
68    stride: Tuple[int],
69    padding: Tuple[int],
70    dilation: Tuple[int],
71    kernel_size: List[int],
72    channel_last: bool,
73) -> torch.Size:
74    assert len(in_size) == 4
75    if channel_last:
76        N, H, W, C = in_size
77    else:
78        N, C, H, W = in_size
79
80    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
81    hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
82        0
83    ] + 1
84    wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
85        1
86    ] + 1
87    if channel_last:
88        return torch.Size((N, hout, wout, out_channels))
89    return torch.Size((in_size[0], out_channels, hout, wout))
90
91
92# Return the overload packet for the edge op
93def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
94    edge_op_namespace, edge_op_name = (
95        edge_op.namespace,
96        edge_op._schema.name.split("::")[1],
97    )
98    edge_op_overload_packet = getattr(
99        getattr(exir_ops.edge, edge_op_namespace), edge_op_name
100    )
101    return edge_op_overload_packet
102
103
104# Get the frequency list of ops in a graph module
105def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
106    freq = {}
107    # Loop over nodes to count the number of times each op occurs
108    for node in graph_module.graph.nodes:
109        if node.op == "call_function":
110            # Ignore getitem, alloc and view cases, we only want actual operations
111            if (
112                node.target == operator.getitem
113                or node.target.__name__ == "alloc"
114                or node.target == memory.view
115            ):
116                continue
117            # If the op is already present, increment the count
118            if node.target._name in freq:
119                freq[node.target._name] += 1
120            # else, add a new entry
121            else:
122                freq[node.target._name] = 1
123    return freq
124
125
126# Print the ops and how many times they occur multiple graph modules:
127# from export, from to_edge, and from final. Print the available
128# implementations for each op, and error out if the op is not supported.
129def print_ops_info(
130    to_edge_gm: torch.fx.GraphModule,
131    final_gm: torch.fx.GraphModule,
132) -> None:
133    to_edge_ops_count = get_ops_count(to_edge_gm)
134    final_ops_count = get_ops_count(final_gm)
135
136    removed_ops = []
137    # Get the counts of the ops that are removed from the final graph
138    for k in to_edge_ops_count:
139        if k not in final_ops_count:
140            removed_ops.append(k)
141
142    # Create a dict of ops and their counts to pass to tabulate
143    ops_count = [
144        [
145            op,
146            final_ops_count[op],
147            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
148        ]
149        for op in final_ops_count
150    ]
151    sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)
152
153    # Create a dict of deleted ops and their counts to pass to tabulate
154    removed_ops_count = [
155        [
156            op,
157            0,
158            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
159        ]
160        for op in removed_ops
161    ]
162
163    # Print the final ops and their counts in a tabular format
164    logging.info(
165        tabulate(
166            sorted_ops_count,
167            headers=[
168                "Final Operators                                    ",  # one character longer than the longest op name
169                "Final Graph",
170                "To_edge Graph",
171                "Export Graph",
172            ],
173            tablefmt="outline",
174        )
175    )
176
177    # Print the removed ops and their counts in a tabular format (if any)
178    if removed_ops != []:
179        logging.info(
180            tabulate(
181                removed_ops_count,
182                headers=[
183                    "Deleted Operators                                  ",  # one character longer than the longest op name
184                    "Final Graph",
185                    "To_edge Graph",
186                    "Export Graph",
187                ],
188                tablefmt="outline",
189            )
190        )
191
192
193def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
194    for node in model_gm.graph.nodes:
195        if node.op == "call_function":
196            if node.target == torch.ops.aten.scaled_dot_product_attention.default:
197                return True
198    return False
199
200
201def save_pte_program(
202    prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
203) -> None:
204    if model_name.endswith(".pte"):
205        filename = model_name
206    else:
207        filename = os.path.join(output_dir, f"{model_name}.pte")
208
209    try:
210        with open(filename, "wb") as file:
211            prog.write_to_file(file)
212            logging.info(f"Saved exported program to {filename}")
213    except Exception as e:
214        logging.error(f"Error while saving to {filename}: {e}")
215
216
217def save_bpte_program(
218    buffer: bytes,
219    model_name: str,
220    output_dir: str = "",
221) -> None:
222    if model_name.endswith(".bpte"):
223        filename = model_name
224    else:
225        filename = os.path.join(output_dir, f"{model_name}.bpte")
226    try:
227        with open(filename, "wb") as f:
228            f.write(buffer)
229        logging.info(f"Saved exported program to {filename}")
230    except Exception as e:
231        logging.error(f"Error while saving to {output_dir}: {e}")
232
233
234@dataclass
235class MemoryConfig:
236    memory_sizes: List[int]
237
238    # Optional fields for logs
239    memory_names: Optional[List[str]] = None
240    base_addrs: Optional[List[int]] = None
241    memory_xml_path: Optional[str] = None
242    MemorySpace: Optional[enum.Enum] = None
243
244    # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
245    def get_num_memories(self) -> int:
246        return len(self.memory_sizes) + 1
247
248    # memory_space module provides num_memories indexed 0..num_memories-1.
249    def get_size(self, exir_id: int) -> int:
250        return self.memory_sizes[exir_id - 1]
251
252
253# Return default memory config for the backend
254def get_default_memory_config() -> MemoryConfig:
255    return MemoryConfig(memory_sizes=[0x1000000000])
256