1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import operator 8from itertools import accumulate 9from typing import cast 10 11import torch 12from executorch.exir.backend.canonical_partitioners.config_partitioner import ( 13 format_target_name, 14) 15 16_Q_OPS = { 17 "quantize_per_tensor.tensor", 18 "quantize_per_tensor.default", 19 "quantize_per_channel.default", 20 "quantize_per_channel_group.default", 21 "quantize_per_token.default", 22 "quantize_affine.default", 23} 24 25_DQ_OPS = { 26 "dequantize_per_tensor.tensor", 27 "dequantize_per_tensor.default", 28 "dequantize_per_channel.default", 29 "dequantize_per_channel_group.default", 30 "dequantize_per_token.default", 31 "dequantize_affine.default", 32} 33 34 35_QPARAM_OPS = { 36 "choose_qparams.tensor", 37 "choose_qparams_per_token_asymmetric.default", 38 "choose_qparams_affine.default", 39} 40 41_DYNAMIC_OPS = { 42 "quantize_per_tensor.tensor", 43 "quantize_per_token.default", 44 "dequantize_per_tensor.tensor", 45 "dequantize_per_token.default", 46} 47 48 49def is_dynamic_qdq(node: torch.fx.Node) -> bool: 50 if node.op != "call_function": 51 return False 52 node_name = format_target_name(node.target.__name__) # pyre-ignore 53 is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node) 54 55 return node_name in _DYNAMIC_OPS or is_dynamic_affine 56 57 58def is_qparam(node: torch.fx.Node) -> bool: 59 if node.op != "call_function": 60 return False 61 node_name = format_target_name(node.target.__name__) # pyre-ignore 62 63 return node_name in _QPARAM_OPS 64 65 66def is_quant(node: torch.fx.Node) -> bool: 67 if node.op != "call_function": 68 return False 69 node_name = format_target_name(node.target.__name__) # pyre-ignore 70 71 return node_name in _Q_OPS 72 73 74def is_dequant(node: torch.fx.Node) -> bool: 75 if node.op != "call_function": 76 return False 77 node_name = format_target_name(node.target.__name__) # pyre-ignore 78 79 return node_name in _DQ_OPS 80 81 82def is_per_channel(node: torch.fx.Node) -> bool: 83 if not (is_quant(node) or is_dequant(node)): 84 return False 85 86 is_affine_per_channel_group = is_per_channel_group(node) 87 is_per_channel = "per_channel" in node.target.__name__ # pyre-ignore 88 89 return is_per_channel or is_affine_per_channel_group 90 91 92def is_affine_qdq(node: torch.fx.Node) -> bool: 93 if not (is_quant(node) or is_dequant(node)): 94 return False 95 96 return "quantize_affine" in node.target.__name__ # pyre-ignore 97 98 99def _get_block_size_input_scale(node: torch.fx.Node): 100 assert is_affine_qdq(node) 101 block_size = node.args[1] 102 input_val = node.all_input_nodes[0].meta["val"] 103 scale_val = node.all_input_nodes[1].meta["val"] 104 return block_size, input_val, scale_val 105 106 107def is_per_token(node: torch.fx.Node): 108 if not (is_quant(node) or is_dequant(node)): 109 return False 110 111 if "per_token" in node.target.__name__: # pyre-ignore 112 return True 113 elif is_affine_qdq(node): 114 block_size, input_val, scale_val = _get_block_size_input_scale(node) 115 flag = True 116 scale_numel_expected = 1 117 for i in range(len(block_size) - 1): 118 flag &= block_size[i] == 1 119 scale_numel_expected *= input_val.shape[i] 120 121 flag &= block_size[-1] == input_val.shape[-1] 122 flag &= scale_val.numel() == scale_numel_expected 123 return flag 124 125 return False 126 127 128def is_per_channel_group(node: torch.fx.Node): 129 if not (is_quant(node) or is_dequant(node)): 130 return False 131 132 if "per_channel_group" in node.target.__name__: # pyre-ignore 133 return True 134 elif is_affine_qdq(node): 135 block_size, input_val, scale_val = _get_block_size_input_scale(node) 136 flag = True 137 flag &= len(block_size) == 2 138 flag &= block_size[0] == 1 139 group_size = block_size[1] 140 scale_numel = list(accumulate(scale_val.shape, operator.mul))[-1] 141 input_numel = list(accumulate(input_val.shape, operator.mul))[-1] 142 flag &= input_numel == group_size * scale_numel 143 return flag 144 145 return False 146 147 148def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node): 149 if not is_affine_qdq(node): 150 return None, None 151 # make sure input_dtype and zero_point_domain have expected values 152 input_node = node.args[0] 153 scale_node = node.args[2] 154 zero_point_node = node.args[3] 155 args = [input_node, scale_node, zero_point_node] 156 assert ( 157 len(node.args) > 4 158 ), f"expecting at least 6 args, got node: {node.format_node()}" 159 160 if node.args[4] != torch.int8: 161 return None, None 162 target_dtype = cast(torch.dtype, node.args[4]) 163 164 if len(node.args) > 6: 165 # quant_min 166 args.append(node.args[5]) 167 # quant_max 168 args.append(node.args[6]) 169 else: 170 dtype_info = torch.iinfo(target_dtype) 171 quant_min = dtype_info.min 172 quant_max = dtype_info.max 173 args.append(quant_min) 174 args.append(quant_max) 175 176 # add target_dtype_node after quant_min/quant_max 177 args.append(target_dtype) 178 # zero_point_domain 179 if len(node.args) > 7 and node.args[7] != "INT": 180 return None, None 181 182 if is_per_channel_group(node): 183 block_sizes = cast(list[int], node.args[1]) 184 args.append(block_sizes[-1]) 185 186 args.append(node.args[-1]) 187 188 return args 189