1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9from executorch.exir.passes import dead_code_elimination_pass 10 11from .utils import dq_ops, q_ops 12 13 14class FoldQDQ(ExportPass): 15 """ 16 Erase QDQ pattern. 17 """ 18 19 def __init__(self): 20 super(FoldQDQ, self).__init__() 21 22 def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 23 # remove dq 24 for n in graph_module.graph.nodes: 25 user_list = list(n.users.keys()) 26 if n.target not in dq_ops: 27 continue 28 for user_n in user_list: 29 user_n.replace_input_with(n, n.args[0]) 30 graph_module.graph.erase_node(n) 31 32 # remove q 33 for n in graph_module.graph.nodes: 34 if n.target not in q_ops: 35 continue 36 to_be_removed = [n] 37 source_n = n.args[0] 38 39 # TODO: remove this hack as source_fn_stack is internal implementation detail of torch.export. 40 # To make constant value/tensor be tagged as delegatable during partition 41 if source_n.op == "get_attr": 42 source_n.meta["source_fn_stack"] = list(n.users.keys())[0].meta.get( 43 "source_fn_stack" 44 ) 45 46 # collecting quant nodes to be removed 47 for i in range(1, len(n.args)): 48 if isinstance(n.args[i], torch.fx.node.Node): 49 to_be_removed.append(n.args[i]) 50 # could be a commonly shared attribute between q & dq 51 if n.args[i].target == exir_ops.edge.aten._to_copy.default: 52 to_be_removed.append(n.args[i].args[0]) 53 # connect source node to quant users and remove quant node 54 for user_n in list(n.users.keys()): 55 user_n.replace_input_with(n, n.args[0]) 56 for n in to_be_removed: 57 graph_module.graph.erase_node(n) 58 59 def call(self, graph_module: torch.fx.GraphModule): 60 self._fold(graph_module) 61 graph_module.recompile() 62 dead_code_elimination_pass(graph_module) 63 return PassResult(graph_module, True) 64