1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import enum 10import logging 11import operator 12import os 13from dataclasses import dataclass 14from typing import Dict, List, Optional, Tuple 15 16import torch 17 18from executorch.exir import ExecutorchProgramManager, memory 19from executorch.exir.dialects._ops import ops as exir_ops 20from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 21from tabulate import tabulate 22 23from torch.ao.quantization.quantize_pt2e import _QUANT_OPS as quant_ops 24 25 26# Check if the model is quantized, by looking at the graph and finding quant/dequant ops 27def model_is_quantized(model: torch.nn.Module) -> bool: 28 # Quantized models have to be GraphModules already, from prepare/convert calls. 29 # Return false if the model is not a GraphModule. 30 if not isinstance(model, torch.fx.GraphModule): 31 return False 32 33 # Walk through the graph and look for quant/dequant ops 34 for op in quant_ops: 35 if model.graph.find_nodes(op="call_function", target=op): 36 return True 37 return False 38 39 40# Get the output size of a 1D convolution given the input size and parameters 41def get_conv1d_output_size( 42 in_size: torch.Size, 43 out_channels: int, 44 stride: int, 45 padding: int, 46 dilation: int, 47 kernel_size: int, 48 channel_last: bool, 49) -> torch.Size: 50 assert len(in_size) == 3 51 if channel_last: 52 N, L, C = in_size 53 else: 54 N, C, L = in_size 55 56 # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 57 lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 58 59 if channel_last: 60 return torch.Size((N, lout, out_channels)) 61 return torch.Size((N, out_channels, lout)) 62 63 64# Get the output size of a 2D convolution given the input size and parameters 65def get_conv2d_output_size( 66 in_size: torch.Size, 67 out_channels: int, 68 stride: Tuple[int], 69 padding: Tuple[int], 70 dilation: Tuple[int], 71 kernel_size: List[int], 72 channel_last: bool, 73) -> torch.Size: 74 assert len(in_size) == 4 75 if channel_last: 76 N, H, W, C = in_size 77 else: 78 N, C, H, W = in_size 79 80 # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 81 hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[ 82 0 83 ] + 1 84 wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ 85 1 86 ] + 1 87 if channel_last: 88 return torch.Size((N, hout, wout, out_channels)) 89 return torch.Size((in_size[0], out_channels, hout, wout)) 90 91 92# Return the overload packet for the edge op 93def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: 94 edge_op_namespace, edge_op_name = ( 95 edge_op.namespace, 96 edge_op._schema.name.split("::")[1], 97 ) 98 edge_op_overload_packet = getattr( 99 getattr(exir_ops.edge, edge_op_namespace), edge_op_name 100 ) 101 return edge_op_overload_packet 102 103 104# Get the frequency list of ops in a graph module 105def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]: 106 freq = {} 107 # Loop over nodes to count the number of times each op occurs 108 for node in graph_module.graph.nodes: 109 if node.op == "call_function": 110 # Ignore getitem, alloc and view cases, we only want actual operations 111 if ( 112 node.target == operator.getitem 113 or node.target.__name__ == "alloc" 114 or node.target == memory.view 115 ): 116 continue 117 # If the op is already present, increment the count 118 if node.target._name in freq: 119 freq[node.target._name] += 1 120 # else, add a new entry 121 else: 122 freq[node.target._name] = 1 123 return freq 124 125 126# Print the ops and how many times they occur multiple graph modules: 127# from export, from to_edge, and from final. Print the available 128# implementations for each op, and error out if the op is not supported. 129def print_ops_info( 130 to_edge_gm: torch.fx.GraphModule, 131 final_gm: torch.fx.GraphModule, 132) -> None: 133 to_edge_ops_count = get_ops_count(to_edge_gm) 134 final_ops_count = get_ops_count(final_gm) 135 136 removed_ops = [] 137 # Get the counts of the ops that are removed from the final graph 138 for k in to_edge_ops_count: 139 if k not in final_ops_count: 140 removed_ops.append(k) 141 142 # Create a dict of ops and their counts to pass to tabulate 143 ops_count = [ 144 [ 145 op, 146 final_ops_count[op], 147 to_edge_ops_count[op] if op in to_edge_ops_count else 0, 148 ] 149 for op in final_ops_count 150 ] 151 sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True) 152 153 # Create a dict of deleted ops and their counts to pass to tabulate 154 removed_ops_count = [ 155 [ 156 op, 157 0, 158 to_edge_ops_count[op] if op in to_edge_ops_count else 0, 159 ] 160 for op in removed_ops 161 ] 162 163 # Print the final ops and their counts in a tabular format 164 logging.info( 165 tabulate( 166 sorted_ops_count, 167 headers=[ 168 "Final Operators ", # one character longer than the longest op name 169 "Final Graph", 170 "To_edge Graph", 171 "Export Graph", 172 ], 173 tablefmt="outline", 174 ) 175 ) 176 177 # Print the removed ops and their counts in a tabular format (if any) 178 if removed_ops != []: 179 logging.info( 180 tabulate( 181 removed_ops_count, 182 headers=[ 183 "Deleted Operators ", # one character longer than the longest op name 184 "Final Graph", 185 "To_edge Graph", 186 "Export Graph", 187 ], 188 tablefmt="outline", 189 ) 190 ) 191 192 193def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool: 194 for node in model_gm.graph.nodes: 195 if node.op == "call_function": 196 if node.target == torch.ops.aten.scaled_dot_product_attention.default: 197 return True 198 return False 199 200 201def save_pte_program( 202 prog: ExecutorchProgramManager, model_name: str, output_dir: str = "" 203) -> None: 204 if model_name.endswith(".pte"): 205 filename = model_name 206 else: 207 filename = os.path.join(output_dir, f"{model_name}.pte") 208 209 try: 210 with open(filename, "wb") as file: 211 prog.write_to_file(file) 212 logging.info(f"Saved exported program to {filename}") 213 except Exception as e: 214 logging.error(f"Error while saving to {filename}: {e}") 215 216 217def save_bpte_program( 218 buffer: bytes, 219 model_name: str, 220 output_dir: str = "", 221) -> None: 222 if model_name.endswith(".bpte"): 223 filename = model_name 224 else: 225 filename = os.path.join(output_dir, f"{model_name}.bpte") 226 try: 227 with open(filename, "wb") as f: 228 f.write(buffer) 229 logging.info(f"Saved exported program to {filename}") 230 except Exception as e: 231 logging.error(f"Error while saving to {output_dir}: {e}") 232 233 234@dataclass 235class MemoryConfig: 236 memory_sizes: List[int] 237 238 # Optional fields for logs 239 memory_names: Optional[List[str]] = None 240 base_addrs: Optional[List[int]] = None 241 memory_xml_path: Optional[str] = None 242 MemorySpace: Optional[enum.Enum] = None 243 244 # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id 245 def get_num_memories(self) -> int: 246 return len(self.memory_sizes) + 1 247 248 # memory_space module provides num_memories indexed 0..num_memories-1. 249 def get_size(self, exir_id: int) -> int: 250 return self.memory_sizes[exir_id - 1] 251 252 253# Return default memory config for the backend 254def get_default_memory_config() -> MemoryConfig: 255 return MemoryConfig(memory_sizes=[0x1000000000]) 256