1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8 9from executorch.backends.qualcomm.utils.constants import ( 10 QCOM_QUANT_ATTRS, 11 QCOM_QUANTIZED_IO, 12 QCOM_REQUANTIZE, 13) 14 15from executorch.exir.dialects._ops import ops as exir_ops 16from executorch.exir.pass_base import ExportPass, PassResult 17 18 19class InsertRequantize(ExportPass): 20 """ 21 This pass inserts convert op for operators which have 22 different quantization specs in input and activation. 23 Convert OP is a specific op which helps to requantize in Qnn backend 24 """ 25 26 # Storing ops that has multi output but run _single_output_annotation logic 27 # instead of _multi_output_annotation. Ops might be added into this set because 28 # we don't use the 2nd output, 2nd output is an integer, etc. 29 multi_output_op_ignore_set = { 30 exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 31 exir_ops.edge.aten.topk.default, 32 } 33 34 def __init__( 35 self, 36 edge_program: torch.export.ExportedProgram, 37 ): 38 super(InsertRequantize, self).__init__() 39 self.edge_program = edge_program 40 41 # TODO: Implement this function when we have an op with 42 # multiple outputs that requires quant attributes. 43 def _multi_output_annotation(self) -> None: 44 raise NotImplementedError("requant is not implemented for multi output yet") 45 46 def _single_output_annotation( 47 self, gm: torch.fx.GraphModule, n: torch.fx.node 48 ) -> None: 49 with gm.graph.inserting_after(n): 50 users = list(n.users.keys()) 51 inserted_n = gm.graph.create_node( 52 "call_function", 53 exir_ops.edge.aten._to_copy.default, 54 (n,), 55 ) 56 57 inserted_n.meta["val"] = n.meta["val"] 58 inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE) 59 if n.meta.get(QCOM_QUANTIZED_IO): 60 inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO] 61 62 for user in users: 63 user.replace_input_with(n, inserted_n) 64 65 def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 66 for n in graph_module.graph.nodes: 67 if QCOM_REQUANTIZE in n.meta: 68 ( 69 self._single_output_annotation(graph_module, n) 70 if isinstance( 71 n.meta["val"], torch._subclasses.fake_tensor.FakeTensor 72 ) 73 or n.target in self.multi_output_op_ignore_set 74 else self._multi_output_annotation() 75 ) 76 77 def call(self, graph_module: torch.fx.GraphModule): 78 self._insert(graph_module) 79 graph_module.graph.eliminate_dead_code() 80 graph_module.recompile() 81 return PassResult(graph_module, True) 82