1# mypy: allow-untyped-defs 2import logging 3import operator 4 5import torch 6from torch.ao.quantization.pt2e.utils import ( 7 _filter_sym_size_users, 8 _is_valid_annotation, 9) 10from torch.fx.node import map_arg 11from torch.fx.passes.infra.pass_base import PassBase, PassResult 12 13 14logger = logging.getLogger(__name__) 15logger.setLevel(logging.WARNING) 16 17__all__ = ["DuplicateDQPass"] 18 19_QUANTIZE_OPS = [ 20 torch.ops.quantized_decomposed.quantize_per_tensor.default, 21 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 22 torch.ops.quantized_decomposed.quantize_per_channel.default, 23] 24 25_DEQUANTIZE_OPS = [ 26 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 27 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 28 torch.ops.quantized_decomposed.dequantize_per_channel.default, 29] 30 31 32def _maybe_duplicate_dq( 33 gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node 34): 35 annotation = user.meta.get("quantization_annotation", None) 36 if not _is_valid_annotation(annotation): 37 return 38 with gm.graph.inserting_after(dq_node): 39 new_node = gm.graph.node_copy(dq_node) 40 41 def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: 42 if n == dq_node: 43 return new_node 44 else: 45 return n 46 47 new_args = map_arg(user.args, maybe_replace_node) 48 new_kwargs = map_arg(user.kwargs, maybe_replace_node) 49 user.args = new_args # type: ignore[assignment] 50 user.kwargs = new_kwargs # type: ignore[assignment] 51 52 53class DuplicateDQPass(PassBase): 54 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 55 for node in graph_module.graph.nodes: 56 if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: 57 dq_users = _filter_sym_size_users(node) 58 if len(dq_users) <= 1: 59 continue 60 # Do not duplicate dq for dynamic quantization 61 # Pattern: choose_qparam - getitem - q - dq 62 q_node = node.args[0] 63 if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: 64 getitem_node = q_node.args[1] 65 if ( 66 isinstance(getitem_node, torch.fx.node.Node) 67 and getitem_node.op == "call_function" 68 and getitem_node.target == operator.getitem 69 ): 70 choose_qparam_node = getitem_node.args[0] 71 if ( 72 isinstance(choose_qparam_node, torch.fx.node.Node) 73 and choose_qparam_node.op == "call_function" 74 and choose_qparam_node.target 75 == torch.ops.quantized_decomposed.choose_qparams.tensor 76 ): 77 continue 78 for user in dq_users: 79 _maybe_duplicate_dq(graph_module, node, user) 80 graph_module.graph.eliminate_dead_code() 81 graph_module.recompile() 82 return PassResult(graph_module, True) 83