1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport enum 10*523fa7a6SAndroid Build Coastguard Workerimport logging 11*523fa7a6SAndroid Build Coastguard Workerimport operator 12*523fa7a6SAndroid Build Coastguard Workerimport os 13*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 14*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport torch 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExecutorchProgramManager, memory 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 21*523fa7a6SAndroid Build Coastguard Workerfrom tabulate import tabulate 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker# Check if the model is quantized, by looking at the graph and finding quant/dequant ops 27*523fa7a6SAndroid Build Coastguard Workerdef model_is_quantized(model: torch.nn.Module) -> bool: 28*523fa7a6SAndroid Build Coastguard Worker # Quantized models have to be GraphModules already, from prepare/convert calls. 29*523fa7a6SAndroid Build Coastguard Worker # Return false if the model is not a GraphModule. 30*523fa7a6SAndroid Build Coastguard Worker if not isinstance(model, torch.fx.GraphModule): 31*523fa7a6SAndroid Build Coastguard Worker return False 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker # Walk through the graph and look for quant/dequant ops 34*523fa7a6SAndroid Build Coastguard Worker for op in quant_ops: 35*523fa7a6SAndroid Build Coastguard Worker if model.graph.find_nodes(op="call_function", target=op): 36*523fa7a6SAndroid Build Coastguard Worker return True 37*523fa7a6SAndroid Build Coastguard Worker return False 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker# Get the output size of a 1D convolution given the input size and parameters 41*523fa7a6SAndroid Build Coastguard Workerdef get_conv1d_output_size( 42*523fa7a6SAndroid Build Coastguard Worker in_size: torch.Size, 43*523fa7a6SAndroid Build Coastguard Worker out_channels: int, 44*523fa7a6SAndroid Build Coastguard Worker stride: int, 45*523fa7a6SAndroid Build Coastguard Worker padding: int, 46*523fa7a6SAndroid Build Coastguard Worker dilation: int, 47*523fa7a6SAndroid Build Coastguard Worker kernel_size: int, 48*523fa7a6SAndroid Build Coastguard Worker channel_last: bool, 49*523fa7a6SAndroid Build Coastguard Worker) -> torch.Size: 50*523fa7a6SAndroid Build Coastguard Worker assert len(in_size) == 3 51*523fa7a6SAndroid Build Coastguard Worker if channel_last: 52*523fa7a6SAndroid Build Coastguard Worker N, L, C = in_size 53*523fa7a6SAndroid Build Coastguard Worker else: 54*523fa7a6SAndroid Build Coastguard Worker N, C, L = in_size 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 57*523fa7a6SAndroid Build Coastguard Worker lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker if channel_last: 60*523fa7a6SAndroid Build Coastguard Worker return torch.Size((N, lout, out_channels)) 61*523fa7a6SAndroid Build Coastguard Worker return torch.Size((N, out_channels, lout)) 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker# Get the output size of a 2D convolution given the input size and parameters 65*523fa7a6SAndroid Build Coastguard Workerdef get_conv2d_output_size( 66*523fa7a6SAndroid Build Coastguard Worker in_size: torch.Size, 67*523fa7a6SAndroid Build Coastguard Worker out_channels: int, 68*523fa7a6SAndroid Build Coastguard Worker stride: Tuple[int], 69*523fa7a6SAndroid Build Coastguard Worker padding: Tuple[int], 70*523fa7a6SAndroid Build Coastguard Worker dilation: Tuple[int], 71*523fa7a6SAndroid Build Coastguard Worker kernel_size: List[int], 72*523fa7a6SAndroid Build Coastguard Worker channel_last: bool, 73*523fa7a6SAndroid Build Coastguard Worker) -> torch.Size: 74*523fa7a6SAndroid Build Coastguard Worker assert len(in_size) == 4 75*523fa7a6SAndroid Build Coastguard Worker if channel_last: 76*523fa7a6SAndroid Build Coastguard Worker N, H, W, C = in_size 77*523fa7a6SAndroid Build Coastguard Worker else: 78*523fa7a6SAndroid Build Coastguard Worker N, C, H, W = in_size 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 81*523fa7a6SAndroid Build Coastguard Worker hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[ 82*523fa7a6SAndroid Build Coastguard Worker 0 83*523fa7a6SAndroid Build Coastguard Worker ] + 1 84*523fa7a6SAndroid Build Coastguard Worker wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ 85*523fa7a6SAndroid Build Coastguard Worker 1 86*523fa7a6SAndroid Build Coastguard Worker ] + 1 87*523fa7a6SAndroid Build Coastguard Worker if channel_last: 88*523fa7a6SAndroid Build Coastguard Worker return torch.Size((N, hout, wout, out_channels)) 89*523fa7a6SAndroid Build Coastguard Worker return torch.Size((in_size[0], out_channels, hout, wout)) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker# Return the overload packet for the edge op 93*523fa7a6SAndroid Build Coastguard Workerdef get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: 94*523fa7a6SAndroid Build Coastguard Worker edge_op_namespace, edge_op_name = ( 95*523fa7a6SAndroid Build Coastguard Worker edge_op.namespace, 96*523fa7a6SAndroid Build Coastguard Worker edge_op._schema.name.split("::")[1], 97*523fa7a6SAndroid Build Coastguard Worker ) 98*523fa7a6SAndroid Build Coastguard Worker edge_op_overload_packet = getattr( 99*523fa7a6SAndroid Build Coastguard Worker getattr(exir_ops.edge, edge_op_namespace), edge_op_name 100*523fa7a6SAndroid Build Coastguard Worker ) 101*523fa7a6SAndroid Build Coastguard Worker return edge_op_overload_packet 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker# Get the frequency list of ops in a graph module 105*523fa7a6SAndroid Build Coastguard Workerdef get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]: 106*523fa7a6SAndroid Build Coastguard Worker freq = {} 107*523fa7a6SAndroid Build Coastguard Worker # Loop over nodes to count the number of times each op occurs 108*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 109*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 110*523fa7a6SAndroid Build Coastguard Worker # Ignore getitem, alloc and view cases, we only want actual operations 111*523fa7a6SAndroid Build Coastguard Worker if ( 112*523fa7a6SAndroid Build Coastguard Worker node.target == operator.getitem 113*523fa7a6SAndroid Build Coastguard Worker or node.target.__name__ == "alloc" 114*523fa7a6SAndroid Build Coastguard Worker or node.target == memory.view 115*523fa7a6SAndroid Build Coastguard Worker ): 116*523fa7a6SAndroid Build Coastguard Worker continue 117*523fa7a6SAndroid Build Coastguard Worker # If the op is already present, increment the count 118*523fa7a6SAndroid Build Coastguard Worker if node.target._name in freq: 119*523fa7a6SAndroid Build Coastguard Worker freq[node.target._name] += 1 120*523fa7a6SAndroid Build Coastguard Worker # else, add a new entry 121*523fa7a6SAndroid Build Coastguard Worker else: 122*523fa7a6SAndroid Build Coastguard Worker freq[node.target._name] = 1 123*523fa7a6SAndroid Build Coastguard Worker return freq 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Worker# Print the ops and how many times they occur multiple graph modules: 127*523fa7a6SAndroid Build Coastguard Worker# from export, from to_edge, and from final. Print the available 128*523fa7a6SAndroid Build Coastguard Worker# implementations for each op, and error out if the op is not supported. 129*523fa7a6SAndroid Build Coastguard Workerdef print_ops_info( 130*523fa7a6SAndroid Build Coastguard Worker to_edge_gm: torch.fx.GraphModule, 131*523fa7a6SAndroid Build Coastguard Worker final_gm: torch.fx.GraphModule, 132*523fa7a6SAndroid Build Coastguard Worker) -> None: 133*523fa7a6SAndroid Build Coastguard Worker to_edge_ops_count = get_ops_count(to_edge_gm) 134*523fa7a6SAndroid Build Coastguard Worker final_ops_count = get_ops_count(final_gm) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker removed_ops = [] 137*523fa7a6SAndroid Build Coastguard Worker # Get the counts of the ops that are removed from the final graph 138*523fa7a6SAndroid Build Coastguard Worker for k in to_edge_ops_count: 139*523fa7a6SAndroid Build Coastguard Worker if k not in final_ops_count: 140*523fa7a6SAndroid Build Coastguard Worker removed_ops.append(k) 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker # Create a dict of ops and their counts to pass to tabulate 143*523fa7a6SAndroid Build Coastguard Worker ops_count = [ 144*523fa7a6SAndroid Build Coastguard Worker [ 145*523fa7a6SAndroid Build Coastguard Worker op, 146*523fa7a6SAndroid Build Coastguard Worker final_ops_count[op], 147*523fa7a6SAndroid Build Coastguard Worker to_edge_ops_count[op] if op in to_edge_ops_count else 0, 148*523fa7a6SAndroid Build Coastguard Worker ] 149*523fa7a6SAndroid Build Coastguard Worker for op in final_ops_count 150*523fa7a6SAndroid Build Coastguard Worker ] 151*523fa7a6SAndroid Build Coastguard Worker sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True) 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker # Create a dict of deleted ops and their counts to pass to tabulate 154*523fa7a6SAndroid Build Coastguard Worker removed_ops_count = [ 155*523fa7a6SAndroid Build Coastguard Worker [ 156*523fa7a6SAndroid Build Coastguard Worker op, 157*523fa7a6SAndroid Build Coastguard Worker 0, 158*523fa7a6SAndroid Build Coastguard Worker to_edge_ops_count[op] if op in to_edge_ops_count else 0, 159*523fa7a6SAndroid Build Coastguard Worker ] 160*523fa7a6SAndroid Build Coastguard Worker for op in removed_ops 161*523fa7a6SAndroid Build Coastguard Worker ] 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Worker # Print the final ops and their counts in a tabular format 164*523fa7a6SAndroid Build Coastguard Worker logging.info( 165*523fa7a6SAndroid Build Coastguard Worker tabulate( 166*523fa7a6SAndroid Build Coastguard Worker sorted_ops_count, 167*523fa7a6SAndroid Build Coastguard Worker headers=[ 168*523fa7a6SAndroid Build Coastguard Worker "Final Operators ", # one character longer than the longest op name 169*523fa7a6SAndroid Build Coastguard Worker "Final Graph", 170*523fa7a6SAndroid Build Coastguard Worker "To_edge Graph", 171*523fa7a6SAndroid Build Coastguard Worker "Export Graph", 172*523fa7a6SAndroid Build Coastguard Worker ], 173*523fa7a6SAndroid Build Coastguard Worker tablefmt="outline", 174*523fa7a6SAndroid Build Coastguard Worker ) 175*523fa7a6SAndroid Build Coastguard Worker ) 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker # Print the removed ops and their counts in a tabular format (if any) 178*523fa7a6SAndroid Build Coastguard Worker if removed_ops != []: 179*523fa7a6SAndroid Build Coastguard Worker logging.info( 180*523fa7a6SAndroid Build Coastguard Worker tabulate( 181*523fa7a6SAndroid Build Coastguard Worker removed_ops_count, 182*523fa7a6SAndroid Build Coastguard Worker headers=[ 183*523fa7a6SAndroid Build Coastguard Worker "Deleted Operators ", # one character longer than the longest op name 184*523fa7a6SAndroid Build Coastguard Worker "Final Graph", 185*523fa7a6SAndroid Build Coastguard Worker "To_edge Graph", 186*523fa7a6SAndroid Build Coastguard Worker "Export Graph", 187*523fa7a6SAndroid Build Coastguard Worker ], 188*523fa7a6SAndroid Build Coastguard Worker tablefmt="outline", 189*523fa7a6SAndroid Build Coastguard Worker ) 190*523fa7a6SAndroid Build Coastguard Worker ) 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker 193*523fa7a6SAndroid Build Coastguard Workerdef model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool: 194*523fa7a6SAndroid Build Coastguard Worker for node in model_gm.graph.nodes: 195*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 196*523fa7a6SAndroid Build Coastguard Worker if node.target == torch.ops.aten.scaled_dot_product_attention.default: 197*523fa7a6SAndroid Build Coastguard Worker return True 198*523fa7a6SAndroid Build Coastguard Worker return False 199*523fa7a6SAndroid Build Coastguard Worker 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Workerdef save_pte_program( 202*523fa7a6SAndroid Build Coastguard Worker prog: ExecutorchProgramManager, model_name: str, output_dir: str = "" 203*523fa7a6SAndroid Build Coastguard Worker) -> None: 204*523fa7a6SAndroid Build Coastguard Worker if model_name.endswith(".pte"): 205*523fa7a6SAndroid Build Coastguard Worker filename = model_name 206*523fa7a6SAndroid Build Coastguard Worker else: 207*523fa7a6SAndroid Build Coastguard Worker filename = os.path.join(output_dir, f"{model_name}.pte") 208*523fa7a6SAndroid Build Coastguard Worker 209*523fa7a6SAndroid Build Coastguard Worker try: 210*523fa7a6SAndroid Build Coastguard Worker with open(filename, "wb") as file: 211*523fa7a6SAndroid Build Coastguard Worker prog.write_to_file(file) 212*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Saved exported program to {filename}") 213*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 214*523fa7a6SAndroid Build Coastguard Worker logging.error(f"Error while saving to {filename}: {e}") 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Workerdef save_bpte_program( 218*523fa7a6SAndroid Build Coastguard Worker buffer: bytes, 219*523fa7a6SAndroid Build Coastguard Worker model_name: str, 220*523fa7a6SAndroid Build Coastguard Worker output_dir: str = "", 221*523fa7a6SAndroid Build Coastguard Worker) -> None: 222*523fa7a6SAndroid Build Coastguard Worker if model_name.endswith(".bpte"): 223*523fa7a6SAndroid Build Coastguard Worker filename = model_name 224*523fa7a6SAndroid Build Coastguard Worker else: 225*523fa7a6SAndroid Build Coastguard Worker filename = os.path.join(output_dir, f"{model_name}.bpte") 226*523fa7a6SAndroid Build Coastguard Worker try: 227*523fa7a6SAndroid Build Coastguard Worker with open(filename, "wb") as f: 228*523fa7a6SAndroid Build Coastguard Worker f.write(buffer) 229*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Saved exported program to {filename}") 230*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 231*523fa7a6SAndroid Build Coastguard Worker logging.error(f"Error while saving to {output_dir}: {e}") 232*523fa7a6SAndroid Build Coastguard Worker 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker@dataclass 235*523fa7a6SAndroid Build Coastguard Workerclass MemoryConfig: 236*523fa7a6SAndroid Build Coastguard Worker memory_sizes: List[int] 237*523fa7a6SAndroid Build Coastguard Worker 238*523fa7a6SAndroid Build Coastguard Worker # Optional fields for logs 239*523fa7a6SAndroid Build Coastguard Worker memory_names: Optional[List[str]] = None 240*523fa7a6SAndroid Build Coastguard Worker base_addrs: Optional[List[int]] = None 241*523fa7a6SAndroid Build Coastguard Worker memory_xml_path: Optional[str] = None 242*523fa7a6SAndroid Build Coastguard Worker MemorySpace: Optional[enum.Enum] = None 243*523fa7a6SAndroid Build Coastguard Worker 244*523fa7a6SAndroid Build Coastguard Worker # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id 245*523fa7a6SAndroid Build Coastguard Worker def get_num_memories(self) -> int: 246*523fa7a6SAndroid Build Coastguard Worker return len(self.memory_sizes) + 1 247*523fa7a6SAndroid Build Coastguard Worker 248*523fa7a6SAndroid Build Coastguard Worker # memory_space module provides num_memories indexed 0..num_memories-1. 249*523fa7a6SAndroid Build Coastguard Worker def get_size(self, exir_id: int) -> int: 250*523fa7a6SAndroid Build Coastguard Worker return self.memory_sizes[exir_id - 1] 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker# Return default memory config for the backend 254*523fa7a6SAndroid Build Coastguard Workerdef get_default_memory_config() -> MemoryConfig: 255*523fa7a6SAndroid Build Coastguard Worker return MemoryConfig(memory_sizes=[0x1000000000]) 256