1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from executorch.exir.graph_module import get_control_flow_submodules 8from executorch.exir.pass_base import ExportPass 9from torch.export import ExportedProgram 10from torch.fx import GraphModule 11from torch.fx.passes.infra.pass_base import PassResult 12 13 14class DebugHandleGeneratorPass(ExportPass): 15 def call(self, graph_module: GraphModule) -> PassResult: 16 """Lower a quantized reference model (with reference quantized operator patterns) 17 to executorch backend, that has a canonical set of quantized operators 18 """ 19 20 queue = [graph_module] 21 index = 1 22 # bfs to traverse all modules including control flow submodules to attached debug handle id 23 while queue: 24 current_graph_module = queue.pop(0) 25 for node in current_graph_module.graph.nodes: 26 node.meta["debug_handle"] = index 27 index += 1 28 control_flow_submodules = [ 29 submodule 30 for _, submodule, _ in get_control_flow_submodules(current_graph_module) 31 ] 32 queue.extend(control_flow_submodules) 33 return PassResult(graph_module, True) 34 35 36def generate_missing_debug_handles(ep: ExportedProgram): 37 """ 38 This pass is used to generate missing debug handles for the graph module and its submodules. 39 """ 40 41 def get_control_flow_submodules_list(graph_module): 42 return [ 43 submodule for _, submodule, _ in get_control_flow_submodules(graph_module) 44 ] 45 46 max_handle = 0 47 queue = [ep.graph_module] 48 49 while queue: 50 current_graph_module = queue.pop(0) 51 for node in current_graph_module.graph.nodes: 52 if "debug_handle" in node.meta: 53 max_handle = max(max_handle, node.meta["debug_handle"]) 54 control_flow_submodules = get_control_flow_submodules_list(current_graph_module) 55 queue.extend(control_flow_submodules) 56 57 queue = [ep.graph_module] 58 while queue: 59 current_graph_module = queue.pop(0) 60 for node in current_graph_module.graph.nodes: 61 if node.meta.get("debug_handle", 0) in (0, None): 62 node.meta["debug_handle"] = max_handle + 1 63 max_handle += 1 64 control_flow_submodules = get_control_flow_submodules_list(current_graph_module) 65 queue.extend(control_flow_submodules) 66