1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS 8from executorch.exir.pass_base import ExportPass, PassResult 9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 10 11from .utils import dq_ops, get_quant_attrs, q_ops 12 13 14class AnnotateDecomposed(ExportPass): 15 """ 16 Add "quant_attrs" to graph nodes' meta from the QDQ information 17 generated after quantization process. 18 """ 19 20 def __init__(self, edge_program: torch.export.ExportedProgram): 21 super(AnnotateDecomposed, self).__init__() 22 self.edge_program = edge_program 23 24 def _annotate_unbind(self, graph_module: torch.fx.GraphModule): 25 partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) 26 for _, src_partitions in partitions.items(): 27 for src_partition in src_partitions: 28 if src_partition.input_nodes[0].target in dq_ops: 29 q_node = src_partition.input_nodes[0].args[0] 30 quant_attrs = get_quant_attrs(self.edge_program, q_node) 31 for n in src_partition.nodes: 32 n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() 33 34 def _annotate_stack(self, graph_module: torch.fx.GraphModule): 35 partitions = get_source_partitions(graph_module.graph, [torch.stack]) 36 for _, src_partitions in partitions.items(): 37 for src_partition in src_partitions: 38 output = src_partition.output_nodes[0] 39 if (list(output.users)[0].target) in q_ops: 40 quant_attrs = get_quant_attrs( 41 self.edge_program, list(output.users)[0] 42 ) 43 for n in src_partition.nodes: 44 n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() 45 46 def call(self, graph_module: torch.fx.GraphModule): 47 self._annotate_unbind(graph_module) 48 self._annotate_stack(graph_module) 49 graph_module.recompile() 50 return PassResult(graph_module, True) 51