xref: /aosp_15_r20/external/executorch/backends/cadence/aot/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport enum
10*523fa7a6SAndroid Build Coastguard Workerimport logging
11*523fa7a6SAndroid Build Coastguard Workerimport operator
12*523fa7a6SAndroid Build Coastguard Workerimport os
13*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
14*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport torch
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExecutorchProgramManager, memory
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21*523fa7a6SAndroid Build Coastguard Workerfrom tabulate import tabulate
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker# Check if the model is quantized, by looking at the graph and finding quant/dequant ops
27*523fa7a6SAndroid Build Coastguard Workerdef model_is_quantized(model: torch.nn.Module) -> bool:
28*523fa7a6SAndroid Build Coastguard Worker    # Quantized models have to be GraphModules already, from prepare/convert calls.
29*523fa7a6SAndroid Build Coastguard Worker    # Return false if the model is not a GraphModule.
30*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(model, torch.fx.GraphModule):
31*523fa7a6SAndroid Build Coastguard Worker        return False
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    # Walk through the graph and look for quant/dequant ops
34*523fa7a6SAndroid Build Coastguard Worker    for op in quant_ops:
35*523fa7a6SAndroid Build Coastguard Worker        if model.graph.find_nodes(op="call_function", target=op):
36*523fa7a6SAndroid Build Coastguard Worker            return True
37*523fa7a6SAndroid Build Coastguard Worker    return False
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker# Get the output size of a 1D convolution given the input size and parameters
41*523fa7a6SAndroid Build Coastguard Workerdef get_conv1d_output_size(
42*523fa7a6SAndroid Build Coastguard Worker    in_size: torch.Size,
43*523fa7a6SAndroid Build Coastguard Worker    out_channels: int,
44*523fa7a6SAndroid Build Coastguard Worker    stride: int,
45*523fa7a6SAndroid Build Coastguard Worker    padding: int,
46*523fa7a6SAndroid Build Coastguard Worker    dilation: int,
47*523fa7a6SAndroid Build Coastguard Worker    kernel_size: int,
48*523fa7a6SAndroid Build Coastguard Worker    channel_last: bool,
49*523fa7a6SAndroid Build Coastguard Worker) -> torch.Size:
50*523fa7a6SAndroid Build Coastguard Worker    assert len(in_size) == 3
51*523fa7a6SAndroid Build Coastguard Worker    if channel_last:
52*523fa7a6SAndroid Build Coastguard Worker        N, L, C = in_size
53*523fa7a6SAndroid Build Coastguard Worker    else:
54*523fa7a6SAndroid Build Coastguard Worker        N, C, L = in_size
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
57*523fa7a6SAndroid Build Coastguard Worker    lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker    if channel_last:
60*523fa7a6SAndroid Build Coastguard Worker        return torch.Size((N, lout, out_channels))
61*523fa7a6SAndroid Build Coastguard Worker    return torch.Size((N, out_channels, lout))
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker# Get the output size of a 2D convolution given the input size and parameters
65*523fa7a6SAndroid Build Coastguard Workerdef get_conv2d_output_size(
66*523fa7a6SAndroid Build Coastguard Worker    in_size: torch.Size,
67*523fa7a6SAndroid Build Coastguard Worker    out_channels: int,
68*523fa7a6SAndroid Build Coastguard Worker    stride: Tuple[int],
69*523fa7a6SAndroid Build Coastguard Worker    padding: Tuple[int],
70*523fa7a6SAndroid Build Coastguard Worker    dilation: Tuple[int],
71*523fa7a6SAndroid Build Coastguard Worker    kernel_size: List[int],
72*523fa7a6SAndroid Build Coastguard Worker    channel_last: bool,
73*523fa7a6SAndroid Build Coastguard Worker) -> torch.Size:
74*523fa7a6SAndroid Build Coastguard Worker    assert len(in_size) == 4
75*523fa7a6SAndroid Build Coastguard Worker    if channel_last:
76*523fa7a6SAndroid Build Coastguard Worker        N, H, W, C = in_size
77*523fa7a6SAndroid Build Coastguard Worker    else:
78*523fa7a6SAndroid Build Coastguard Worker        N, C, H, W = in_size
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
81*523fa7a6SAndroid Build Coastguard Worker    hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
82*523fa7a6SAndroid Build Coastguard Worker        0
83*523fa7a6SAndroid Build Coastguard Worker    ] + 1
84*523fa7a6SAndroid Build Coastguard Worker    wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
85*523fa7a6SAndroid Build Coastguard Worker        1
86*523fa7a6SAndroid Build Coastguard Worker    ] + 1
87*523fa7a6SAndroid Build Coastguard Worker    if channel_last:
88*523fa7a6SAndroid Build Coastguard Worker        return torch.Size((N, hout, wout, out_channels))
89*523fa7a6SAndroid Build Coastguard Worker    return torch.Size((in_size[0], out_channels, hout, wout))
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker# Return the overload packet for the edge op
93*523fa7a6SAndroid Build Coastguard Workerdef get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
94*523fa7a6SAndroid Build Coastguard Worker    edge_op_namespace, edge_op_name = (
95*523fa7a6SAndroid Build Coastguard Worker        edge_op.namespace,
96*523fa7a6SAndroid Build Coastguard Worker        edge_op._schema.name.split("::")[1],
97*523fa7a6SAndroid Build Coastguard Worker    )
98*523fa7a6SAndroid Build Coastguard Worker    edge_op_overload_packet = getattr(
99*523fa7a6SAndroid Build Coastguard Worker        getattr(exir_ops.edge, edge_op_namespace), edge_op_name
100*523fa7a6SAndroid Build Coastguard Worker    )
101*523fa7a6SAndroid Build Coastguard Worker    return edge_op_overload_packet
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker# Get the frequency list of ops in a graph module
105*523fa7a6SAndroid Build Coastguard Workerdef get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
106*523fa7a6SAndroid Build Coastguard Worker    freq = {}
107*523fa7a6SAndroid Build Coastguard Worker    # Loop over nodes to count the number of times each op occurs
108*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
109*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function":
110*523fa7a6SAndroid Build Coastguard Worker            # Ignore getitem, alloc and view cases, we only want actual operations
111*523fa7a6SAndroid Build Coastguard Worker            if (
112*523fa7a6SAndroid Build Coastguard Worker                node.target == operator.getitem
113*523fa7a6SAndroid Build Coastguard Worker                or node.target.__name__ == "alloc"
114*523fa7a6SAndroid Build Coastguard Worker                or node.target == memory.view
115*523fa7a6SAndroid Build Coastguard Worker            ):
116*523fa7a6SAndroid Build Coastguard Worker                continue
117*523fa7a6SAndroid Build Coastguard Worker            # If the op is already present, increment the count
118*523fa7a6SAndroid Build Coastguard Worker            if node.target._name in freq:
119*523fa7a6SAndroid Build Coastguard Worker                freq[node.target._name] += 1
120*523fa7a6SAndroid Build Coastguard Worker            # else, add a new entry
121*523fa7a6SAndroid Build Coastguard Worker            else:
122*523fa7a6SAndroid Build Coastguard Worker                freq[node.target._name] = 1
123*523fa7a6SAndroid Build Coastguard Worker    return freq
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker
126*523fa7a6SAndroid Build Coastguard Worker# Print the ops and how many times they occur multiple graph modules:
127*523fa7a6SAndroid Build Coastguard Worker# from export, from to_edge, and from final. Print the available
128*523fa7a6SAndroid Build Coastguard Worker# implementations for each op, and error out if the op is not supported.
129*523fa7a6SAndroid Build Coastguard Workerdef print_ops_info(
130*523fa7a6SAndroid Build Coastguard Worker    to_edge_gm: torch.fx.GraphModule,
131*523fa7a6SAndroid Build Coastguard Worker    final_gm: torch.fx.GraphModule,
132*523fa7a6SAndroid Build Coastguard Worker) -> None:
133*523fa7a6SAndroid Build Coastguard Worker    to_edge_ops_count = get_ops_count(to_edge_gm)
134*523fa7a6SAndroid Build Coastguard Worker    final_ops_count = get_ops_count(final_gm)
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker    removed_ops = []
137*523fa7a6SAndroid Build Coastguard Worker    # Get the counts of the ops that are removed from the final graph
138*523fa7a6SAndroid Build Coastguard Worker    for k in to_edge_ops_count:
139*523fa7a6SAndroid Build Coastguard Worker        if k not in final_ops_count:
140*523fa7a6SAndroid Build Coastguard Worker            removed_ops.append(k)
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker    # Create a dict of ops and their counts to pass to tabulate
143*523fa7a6SAndroid Build Coastguard Worker    ops_count = [
144*523fa7a6SAndroid Build Coastguard Worker        [
145*523fa7a6SAndroid Build Coastguard Worker            op,
146*523fa7a6SAndroid Build Coastguard Worker            final_ops_count[op],
147*523fa7a6SAndroid Build Coastguard Worker            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
148*523fa7a6SAndroid Build Coastguard Worker        ]
149*523fa7a6SAndroid Build Coastguard Worker        for op in final_ops_count
150*523fa7a6SAndroid Build Coastguard Worker    ]
151*523fa7a6SAndroid Build Coastguard Worker    sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker    # Create a dict of deleted ops and their counts to pass to tabulate
154*523fa7a6SAndroid Build Coastguard Worker    removed_ops_count = [
155*523fa7a6SAndroid Build Coastguard Worker        [
156*523fa7a6SAndroid Build Coastguard Worker            op,
157*523fa7a6SAndroid Build Coastguard Worker            0,
158*523fa7a6SAndroid Build Coastguard Worker            to_edge_ops_count[op] if op in to_edge_ops_count else 0,
159*523fa7a6SAndroid Build Coastguard Worker        ]
160*523fa7a6SAndroid Build Coastguard Worker        for op in removed_ops
161*523fa7a6SAndroid Build Coastguard Worker    ]
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker    # Print the final ops and their counts in a tabular format
164*523fa7a6SAndroid Build Coastguard Worker    logging.info(
165*523fa7a6SAndroid Build Coastguard Worker        tabulate(
166*523fa7a6SAndroid Build Coastguard Worker            sorted_ops_count,
167*523fa7a6SAndroid Build Coastguard Worker            headers=[
168*523fa7a6SAndroid Build Coastguard Worker                "Final Operators                                    ",  # one character longer than the longest op name
169*523fa7a6SAndroid Build Coastguard Worker                "Final Graph",
170*523fa7a6SAndroid Build Coastguard Worker                "To_edge Graph",
171*523fa7a6SAndroid Build Coastguard Worker                "Export Graph",
172*523fa7a6SAndroid Build Coastguard Worker            ],
173*523fa7a6SAndroid Build Coastguard Worker            tablefmt="outline",
174*523fa7a6SAndroid Build Coastguard Worker        )
175*523fa7a6SAndroid Build Coastguard Worker    )
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker    # Print the removed ops and their counts in a tabular format (if any)
178*523fa7a6SAndroid Build Coastguard Worker    if removed_ops != []:
179*523fa7a6SAndroid Build Coastguard Worker        logging.info(
180*523fa7a6SAndroid Build Coastguard Worker            tabulate(
181*523fa7a6SAndroid Build Coastguard Worker                removed_ops_count,
182*523fa7a6SAndroid Build Coastguard Worker                headers=[
183*523fa7a6SAndroid Build Coastguard Worker                    "Deleted Operators                                  ",  # one character longer than the longest op name
184*523fa7a6SAndroid Build Coastguard Worker                    "Final Graph",
185*523fa7a6SAndroid Build Coastguard Worker                    "To_edge Graph",
186*523fa7a6SAndroid Build Coastguard Worker                    "Export Graph",
187*523fa7a6SAndroid Build Coastguard Worker                ],
188*523fa7a6SAndroid Build Coastguard Worker                tablefmt="outline",
189*523fa7a6SAndroid Build Coastguard Worker            )
190*523fa7a6SAndroid Build Coastguard Worker        )
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Workerdef model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
194*523fa7a6SAndroid Build Coastguard Worker    for node in model_gm.graph.nodes:
195*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function":
196*523fa7a6SAndroid Build Coastguard Worker            if node.target == torch.ops.aten.scaled_dot_product_attention.default:
197*523fa7a6SAndroid Build Coastguard Worker                return True
198*523fa7a6SAndroid Build Coastguard Worker    return False
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Workerdef save_pte_program(
202*523fa7a6SAndroid Build Coastguard Worker    prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
203*523fa7a6SAndroid Build Coastguard Worker) -> None:
204*523fa7a6SAndroid Build Coastguard Worker    if model_name.endswith(".pte"):
205*523fa7a6SAndroid Build Coastguard Worker        filename = model_name
206*523fa7a6SAndroid Build Coastguard Worker    else:
207*523fa7a6SAndroid Build Coastguard Worker        filename = os.path.join(output_dir, f"{model_name}.pte")
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker    try:
210*523fa7a6SAndroid Build Coastguard Worker        with open(filename, "wb") as file:
211*523fa7a6SAndroid Build Coastguard Worker            prog.write_to_file(file)
212*523fa7a6SAndroid Build Coastguard Worker            logging.info(f"Saved exported program to {filename}")
213*523fa7a6SAndroid Build Coastguard Worker    except Exception as e:
214*523fa7a6SAndroid Build Coastguard Worker        logging.error(f"Error while saving to {filename}: {e}")
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Workerdef save_bpte_program(
218*523fa7a6SAndroid Build Coastguard Worker    buffer: bytes,
219*523fa7a6SAndroid Build Coastguard Worker    model_name: str,
220*523fa7a6SAndroid Build Coastguard Worker    output_dir: str = "",
221*523fa7a6SAndroid Build Coastguard Worker) -> None:
222*523fa7a6SAndroid Build Coastguard Worker    if model_name.endswith(".bpte"):
223*523fa7a6SAndroid Build Coastguard Worker        filename = model_name
224*523fa7a6SAndroid Build Coastguard Worker    else:
225*523fa7a6SAndroid Build Coastguard Worker        filename = os.path.join(output_dir, f"{model_name}.bpte")
226*523fa7a6SAndroid Build Coastguard Worker    try:
227*523fa7a6SAndroid Build Coastguard Worker        with open(filename, "wb") as f:
228*523fa7a6SAndroid Build Coastguard Worker            f.write(buffer)
229*523fa7a6SAndroid Build Coastguard Worker        logging.info(f"Saved exported program to {filename}")
230*523fa7a6SAndroid Build Coastguard Worker    except Exception as e:
231*523fa7a6SAndroid Build Coastguard Worker        logging.error(f"Error while saving to {output_dir}: {e}")
232*523fa7a6SAndroid Build Coastguard Worker
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker@dataclass
235*523fa7a6SAndroid Build Coastguard Workerclass MemoryConfig:
236*523fa7a6SAndroid Build Coastguard Worker    memory_sizes: List[int]
237*523fa7a6SAndroid Build Coastguard Worker
238*523fa7a6SAndroid Build Coastguard Worker    # Optional fields for logs
239*523fa7a6SAndroid Build Coastguard Worker    memory_names: Optional[List[str]] = None
240*523fa7a6SAndroid Build Coastguard Worker    base_addrs: Optional[List[int]] = None
241*523fa7a6SAndroid Build Coastguard Worker    memory_xml_path: Optional[str] = None
242*523fa7a6SAndroid Build Coastguard Worker    MemorySpace: Optional[enum.Enum] = None
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Worker    # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
245*523fa7a6SAndroid Build Coastguard Worker    def get_num_memories(self) -> int:
246*523fa7a6SAndroid Build Coastguard Worker        return len(self.memory_sizes) + 1
247*523fa7a6SAndroid Build Coastguard Worker
248*523fa7a6SAndroid Build Coastguard Worker    # memory_space module provides num_memories indexed 0..num_memories-1.
249*523fa7a6SAndroid Build Coastguard Worker    def get_size(self, exir_id: int) -> int:
250*523fa7a6SAndroid Build Coastguard Worker        return self.memory_sizes[exir_id - 1]
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker# Return default memory config for the backend
254*523fa7a6SAndroid Build Coastguard Workerdef get_default_memory_config() -> MemoryConfig:
255*523fa7a6SAndroid Build Coastguard Worker    return MemoryConfig(memory_sizes=[0x1000000000])
256