1from typing import Callable, Dict, List, Optional 2 3import torch 4import torch.ao.nn.intrinsic as nni 5import torch.ao.nn.intrinsic.qat as nniqat 6import torch.ao.nn.intrinsic.quantized as nniq 7import torch.ao.nn.qat as nnqat 8import torch.ao.nn.quantized as nnq 9import torch.ao.nn.quantized.dynamic as nnqd 10import torch.nn as nn 11import torch.nn.functional as F 12from torch.fx import GraphModule 13from torch.fx.graph import Node 14 15from .ns_types import NSSingleResultType, NSSingleResultValuesType 16from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node 17 18 19toq = torch.ops.quantized 20 21 22def mod_weight_detach(mod: nn.Module) -> torch.Tensor: 23 return mod.weight.detach() # type: ignore[operator] 24 25 26def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor: 27 return mod[0].weight.detach() # type: ignore[index] 28 29 30def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor: 31 return mod._weight_bias()[0] # type: ignore[operator] 32 33 34def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]: 35 res = [] 36 for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type] 37 if "weight_ih_l" in param_name or "weight_hh_l" in param_name: 38 param_value = mod._flat_weights[idx].detach() # type: ignore[index] 39 res.append(param_value) 40 return res 41 42 43def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]: 44 res = [] 45 for weight_value in mod._all_weight_values: # type: ignore[union-attr] 46 res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) 47 res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) 48 return res 49 50 51def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: 52 if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 53 return mod.weight.detach() 54 elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)): 55 return mod[0].weight.detach() 56 else: 57 return mod._weight_bias()[0] # type: ignore[operator] 58 59 60def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: 61 if isinstance(mod, nn.Linear): 62 return mod.weight.detach() 63 elif isinstance(mod, nni.LinearReLU): 64 return mod[0].weight.detach() 65 else: 66 return mod._weight_bias()[0] # type: ignore[operator] 67 68 69def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]: 70 # TODO(future PR): make more generic, handle everything 71 if isinstance(mod, nn.LSTM): 72 res = [] 73 for idx, param_name in enumerate(mod._flat_weights_names): 74 if "weight_ih_l" in param_name or "weight_hh_l" in param_name: 75 param_value = mod._flat_weights[idx].detach() 76 res.append(param_value) 77 return res 78 else: 79 assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet" 80 res = [] 81 for weight_value in mod._all_weight_values: 82 res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) 83 res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) 84 return res 85 86 87def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: 88 # traverse backwards from the weight arg, accounting for any observers 89 weight_arg_node = node.args[1] 90 assert isinstance(weight_arg_node, Node) 91 weight_node = return_first_non_observer_node(weight_arg_node, gm) 92 assert isinstance(weight_node, Node) 93 assert weight_node.op == "get_attr" 94 weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] 95 return weight.detach() 96 97 98def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: 99 # qconv state is arg 1 100 qconv_state_node = node.args[1] 101 assert isinstance(qconv_state_node, Node) 102 assert qconv_state_node.op == "get_attr" 103 qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type] 104 return qconv_state_obj.weight() 105 106 107def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: 108 # traverse backwards from the weight arg, accounting for any observers 109 # supported patterns: 110 # weight -> obs -> linear 111 # weight -> to(torch.float16) -> dequantize -> linear 112 linear_second_arg = node.args[1] 113 assert isinstance(linear_second_arg, Node) 114 115 if linear_second_arg.op == "call_module": 116 # weight -> obs -> linear 117 weight_arg_node = node.args[1] 118 assert isinstance(weight_arg_node, Node) 119 weight_node = weight_arg_node.args[0] 120 assert isinstance(weight_node, Node) 121 assert weight_node.op == "get_attr" 122 weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] 123 return weight.detach() 124 elif linear_second_arg.op == "call_method": 125 # weight -> to(torch.float16) -> dequantize -> linear 126 assert linear_second_arg.op == "call_method" 127 dequant_node = node.args[1] 128 assert isinstance(dequant_node, Node) 129 to_fp16_node = dequant_node.args[0] 130 assert isinstance(to_fp16_node, Node) 131 # extract the dtype, so we can cast to it before returning 132 target_dtype = to_fp16_node.args[1] 133 weight_node = to_fp16_node.args[0] 134 assert isinstance(weight_node, Node) 135 assert weight_node.op == "get_attr" 136 weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] 137 # return the weight with fp16 cast 138 return weight.detach().to(target_dtype) 139 else: 140 assert linear_second_arg.op == "get_attr" 141 weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type] 142 return weight.detach() 143 144 145def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: 146 # packed weight is arg 1 147 packed_weight_node = node.args[1] 148 assert isinstance(packed_weight_node, Node) 149 assert packed_weight_node.op == "get_attr" 150 packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type] 151 # TODO(future PR): why does packed_weight.unpack() not work? 152 (weight, _bias), _name = packed_weight.__getstate__() 153 return weight 154 155 156def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]: 157 op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = { 158 "call_module": { 159 # Conv1d 160 nn.Conv1d: mod_weight_detach, 161 nni.ConvReLU1d: mod_0_weight_detach, 162 nnq.Conv1d: mod_weight_bias_0, 163 nnqat.Conv1d: mod_weight_detach, 164 nniqat.ConvBn1d: mod_weight_detach, 165 nniqat.ConvBnReLU1d: mod_weight_detach, 166 nniqat.ConvReLU1d: mod_weight_detach, 167 nniq.ConvReLU1d: mod_weight_bias_0, 168 # Conv2d 169 nn.Conv2d: mod_weight_detach, 170 nni.ConvReLU2d: mod_0_weight_detach, 171 nnq.Conv2d: mod_weight_bias_0, 172 nnqat.Conv2d: mod_weight_detach, 173 nniqat.ConvBn2d: mod_weight_detach, 174 nniqat.ConvBnReLU2d: mod_weight_detach, 175 nniqat.ConvReLU2d: mod_weight_detach, 176 nniq.ConvReLU2d: mod_weight_bias_0, 177 # Conv3d 178 nn.Conv3d: mod_weight_detach, 179 nni.ConvReLU3d: mod_0_weight_detach, 180 nnq.Conv3d: mod_weight_bias_0, 181 nnqat.Conv3d: mod_weight_detach, 182 nniqat.ConvBn3d: mod_weight_detach, 183 nniqat.ConvBnReLU3d: mod_weight_detach, 184 nniqat.ConvReLU3d: mod_weight_detach, 185 nniq.ConvReLU3d: mod_weight_bias_0, 186 # Linear 187 nn.Linear: mod_weight_detach, 188 nnq.Linear: mod_weight_bias_0, 189 nni.LinearReLU: mod_0_weight_detach, 190 nniq.LinearReLU: mod_weight_bias_0, 191 nnqat.Linear: mod_weight_detach, 192 nnqd.Linear: mod_weight_bias_0, 193 nniqat.LinearReLU: mod_weight_detach, 194 nniqat.LinearBn1d: mod_weight_detach, 195 nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach, 196 # LSTM 197 nn.LSTM: get_lstm_weight, 198 nnqd.LSTM: get_qlstm_weight, 199 }, 200 "call_function": { 201 # Conv 202 F.conv1d: get_conv_fun_weight, 203 F.conv2d: get_conv_fun_weight, 204 F.conv3d: get_conv_fun_weight, 205 toq.conv1d: get_qconv_fun_weight, 206 toq.conv2d: get_qconv_fun_weight, 207 toq.conv3d: get_qconv_fun_weight, 208 toq.conv1d_relu: get_qconv_fun_weight, 209 toq.conv2d_relu: get_qconv_fun_weight, 210 toq.conv3d_relu: get_qconv_fun_weight, 211 # Linear 212 F.linear: get_linear_fun_weight, 213 toq.linear: get_qlinear_fun_weight, 214 toq.linear_relu: get_qlinear_fun_weight, 215 }, 216 } 217 218 return op_to_type_to_weight_extraction_fn 219 220 221def extract_weight_from_node( 222 node: Node, 223 gm: GraphModule, 224 op_to_type_to_weight_extraction_fn: Optional[ 225 Dict[str, Dict[Callable, Callable]] 226 ] = None, 227) -> Optional[NSSingleResultType]: 228 res_type = NSSingleResultValuesType.WEIGHT.value 229 230 # Not all graphmodules have _node_name_to_scope, so only fill it 231 # out if it exists. 232 fqn = None 233 if hasattr(gm, "_node_name_to_scope"): 234 fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index] 235 236 if op_to_type_to_weight_extraction_fn is None: 237 op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() 238 239 ref_node_type = get_target_type_str(node, gm) 240 # for extracting weights, these are always the same 241 prev_node_type = ref_node_type 242 243 if node.op == "call_function": 244 function_mapping = op_to_type_to_weight_extraction_fn["call_function"] 245 for target_fn_type, weight_extraction_fn in function_mapping.items(): 246 if node.target == target_fn_type: 247 weight = weight_extraction_fn(node, gm) 248 return { 249 "type": res_type, 250 "values": [weight], 251 "prev_node_name": node.name, 252 "prev_node_target_type": prev_node_type, 253 "ref_node_name": node.name, 254 "ref_node_target_type": ref_node_type, 255 "index_within_arg": 0, 256 "index_of_arg": 0, 257 "fqn": fqn, 258 } 259 260 elif node.op == "call_module": 261 # for call_module, we need to look up the modules to do the type check 262 assert isinstance(node.target, str) 263 mod = getattr_from_fqn(gm, node.target) 264 module_mapping = op_to_type_to_weight_extraction_fn["call_module"] 265 for target_mod_type, weight_extraction_fn in module_mapping.items(): 266 if type(mod) == target_mod_type: 267 weight = weight_extraction_fn(mod) 268 return { 269 "type": res_type, 270 "values": [weight], 271 "prev_node_name": node.name, 272 "prev_node_target_type": prev_node_type, 273 "ref_node_name": node.name, 274 "ref_node_target_type": ref_node_type, 275 "index_within_arg": 0, 276 "index_of_arg": 0, 277 "fqn": fqn, 278 } 279 280 return None 281