1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import operator 9 10import torch 11 12from torch.ao.quantization.pt2e.utils import ( 13 _filter_sym_size_users, 14 _is_valid_annotation, 15) 16 17from torch.fx.node import map_arg 18from torch.fx.passes.infra.pass_base import PassBase, PassResult 19 20 21logger = logging.getLogger(__name__) 22logger.setLevel(logging.WARNING) 23 24__all__ = ["DuplicateDynamicQuantChainPass"] 25 26_QUANTIZE_OPS = [ 27 torch.ops.quantized_decomposed.quantize_per_tensor.default, 28 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 29 torch.ops.quantized_decomposed.quantize_per_channel.default, 30] 31 32_DEQUANTIZE_OPS = [ 33 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 34 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 35 torch.ops.quantized_decomposed.dequantize_per_channel.default, 36] 37 38 39def _replace_input_node_with_new_node(node, input_node, new_node): 40 def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: 41 if n == input_node: 42 return new_node 43 else: 44 return n 45 46 new_args = map_arg(node.args, maybe_replace_node) 47 new_kwargs = map_arg(node.kwargs, maybe_replace_node) 48 node.args = new_args 49 node.kwargs = new_kwargs 50 51 52def _replicate_chose_qparam_nodes_for_q_dq( 53 gm: torch.fx.GraphModule, chose_qparams_node, get_item_node_1, get_item_node_2 54): 55 if ( 56 ( 57 chose_qparams_node.target 58 != torch.ops.quantized_decomposed.choose_qparams.tensor 59 ) 60 or (get_item_node_1.target != operator.getitem) 61 or (get_item_node_2.target != operator.getitem) 62 ): 63 raise RuntimeError( 64 f"Expecting choose_qparams.tensor and getitem nodes but got {chose_qparams_node}, {get_item_node_1}, {get_item_node_2}" 65 ) 66 67 users = list(get_item_node_1.users.copy()) 68 q_dq_pair = [] 69 for user in users: 70 if user.target in _QUANTIZE_OPS: 71 if len(user.users) != 1: 72 raise RuntimeError(f"Node {user} has more than one user") 73 dq_node = list(user.users)[0] 74 if dq_node.target not in _DEQUANTIZE_OPS: 75 raise RuntimeError( 76 f"Node {user}'s use must be a dequantize op but got {dq_node}:{dq_node.target}" 77 ) 78 q_dq_pair.append((user, dq_node)) 79 80 for q_node, dq_node in q_dq_pair: 81 with gm.graph.inserting_after(get_item_node_1): 82 new_get_item_node_1 = gm.graph.node_copy(get_item_node_1) 83 new_get_item_node_2 = gm.graph.node_copy(get_item_node_2) 84 new_chose_qparams_node = gm.graph.node_copy(chose_qparams_node) 85 _replace_input_node_with_new_node( 86 new_get_item_node_1, chose_qparams_node, new_chose_qparams_node 87 ) 88 _replace_input_node_with_new_node( 89 new_get_item_node_2, chose_qparams_node, new_chose_qparams_node 90 ) 91 92 _replace_input_node_with_new_node( 93 q_node, get_item_node_1, new_get_item_node_1 94 ) 95 _replace_input_node_with_new_node( 96 dq_node, get_item_node_1, new_get_item_node_1 97 ) 98 _replace_input_node_with_new_node( 99 q_node, get_item_node_2, new_get_item_node_2 100 ) 101 _replace_input_node_with_new_node( 102 dq_node, get_item_node_2, new_get_item_node_2 103 ) 104 105 gm.graph.eliminate_dead_code() 106 gm.recompile() 107 108 109def _replicate_node_for_each_user(gm: torch.fx.GraphModule, node: torch.fx.Node): 110 users = list(node.users.copy()) 111 for user in users: 112 with gm.graph.inserting_after(node): 113 new_node = gm.graph.node_copy(node) 114 _replace_input_node_with_new_node(user, node, new_node) 115 116 gm.graph.eliminate_dead_code() 117 gm.recompile() 118 119 120def _maybe_duplicate_dynamic_quantize_chain( 121 gm: torch.fx.GraphModule, 122 chose_qparams_node, 123 get_item_node_1, 124 get_item_node_2, 125 q_node, 126 dq_node: torch.fx.Node, 127): 128 num_dq_users = len(dq_node.users) 129 dq_node_users = list(dq_node.users.copy()) 130 for user in dq_node_users: 131 annotation = user.meta.get("quantization_annotation", None) 132 if not _is_valid_annotation(annotation): 133 return 134 with gm.graph.inserting_after(dq_node): 135 new_node = gm.graph.node_copy(dq_node) 136 _replace_input_node_with_new_node(user, dq_node, new_node) 137 138 gm.graph.eliminate_dead_code() 139 gm.recompile() 140 if len(q_node.users) != num_dq_users: 141 raise RuntimeError( 142 f"Expected {num_dq_users} users of {q_node}, but got {len(q_node.users)}" 143 ) 144 _replicate_node_for_each_user(gm, q_node) 145 146 # *2 because scale/zp are used both with q and dq nodes 147 if len(get_item_node_1.users) != num_dq_users * 2: 148 raise RuntimeError( 149 f"Expected {num_dq_users} users of {get_item_node_1}, but got {len(get_item_node_1.users)}" 150 ) 151 if len(get_item_node_2.users) != num_dq_users * 2: 152 raise RuntimeError( 153 f"Expected {num_dq_users} users of {get_item_node_2}, but got {len(get_item_node_2.users)}" 154 ) 155 _replicate_chose_qparam_nodes_for_q_dq( 156 gm, chose_qparams_node, get_item_node_1, get_item_node_2 157 ) 158 159 160class DuplicateDynamicQuantChainPass(PassBase): 161 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 162 for node in graph_module.graph.nodes: 163 if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: 164 dq_users = _filter_sym_size_users(node) 165 if len(dq_users) <= 1: 166 continue 167 # Do not duplicate dq for dynamic quantization 168 # Pattern: choose_qparam - getitem - q - dq 169 q_node = node.args[0] 170 if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: 171 getitem_1_node = q_node.args[1] 172 getitem_2_node = q_node.args[2] 173 if ( 174 isinstance(getitem_1_node, torch.fx.node.Node) 175 and getitem_1_node.op == "call_function" 176 and getitem_1_node.target == operator.getitem 177 ): 178 choose_qparam_node = getitem_1_node.args[0] 179 if ( 180 isinstance(choose_qparam_node, torch.fx.node.Node) 181 and choose_qparam_node.op == "call_function" 182 and choose_qparam_node.target 183 == torch.ops.quantized_decomposed.choose_qparams.tensor 184 ): 185 _maybe_duplicate_dynamic_quantize_chain( 186 graph_module, 187 choose_qparam_node, 188 getitem_1_node, 189 getitem_2_node, 190 q_node, 191 node, 192 ) 193 continue 194 graph_module.graph.eliminate_dead_code() 195 graph_module.recompile() 196 return PassResult(graph_module, True) 197