1# mypy: ignore-errors 2 3import copy 4import operator 5import warnings 6from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union 7 8import torch 9from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY 10from torch.ao.quantization.backend_config import ( 11 BackendConfig, 12 get_native_backend_config, 13) 14from torch.ao.quantization.backend_config.utils import ( 15 get_fused_module_classes, 16 get_pattern_to_dtype_configs, 17 get_qat_module_classes, 18 get_root_module_to_quantized_reference_module, 19) 20from torch.ao.quantization.observer import _is_activation_post_process 21from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny 22from torch.ao.quantization.qconfig_mapping import QConfigMapping 23from torch.ao.quantization.quant_type import QuantType 24from torch.ao.quantization.quantize import _remove_qconfig 25from torch.ao.quantization.stubs import DeQuantStub 26from torch.ao.quantization.utils import ( 27 _parent_name, 28 activation_is_statically_quantized, 29 get_qparam_dict, 30 get_swapped_custom_module_class, 31 is_per_channel, 32 to_underlying_dtype, 33 weight_is_quantized, 34) 35from torch.fx import GraphModule 36from torch.fx.graph import Argument, Graph, Node 37from torch.nn.utils.parametrize import type_before_parametrizations 38 39# importing the lib so that the quantized_decomposed ops are registered 40from ._decomposed import quantized_decomposed_lib # noqa: F401 41from ._equalize import convert_eq_obs, update_obs_for_equalization 42from .custom_config import ConvertCustomConfig, PrepareCustomConfig 43from .graph_module import _is_observed_module, _is_observed_standalone_module 44from .lower_to_fbgemm import lower_to_fbgemm 45from .qconfig_mapping_utils import ( 46 _compare_prepare_convert_qconfig_mappings, 47 _generate_node_name_to_qconfig, 48 _is_qconfig_supported_by_dtype_configs, 49 _update_qconfig_for_fusion, 50 _update_qconfig_for_qat, 51) 52from .utils import ( 53 _get_module, 54 _is_custom_module_lstm, 55 _is_custom_module_mha, 56 assert_and_get_unique_device, 57 collect_producer_nodes, 58 create_getattr_from_value, 59 get_custom_module_class_keys, 60 graph_module_from_producer_nodes, 61 node_arg_is_weight, 62) 63 64 65__all__ = [ 66 "convert", 67 "convert_custom_module", 68 "convert_standalone_module", 69 "convert_weighted_module", 70] 71 72SUPPORTED_QDTYPES = [ 73 torch.quint8, 74 torch.qint8, 75 torch.qint32, 76 torch.uint8, 77 torch.int8, 78 torch.int16, 79 torch.int32, 80 torch.float8_e5m2, 81 torch.float8_e4m3fn, 82] 83 84_QSCHEME_TO_CHOOSE_QPARAMS_OP = { 85 torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, 86 torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, 87} 88 89 90def _replace_observer_with_quantize_dequantize_node_decomposed( 91 model: torch.fx.GraphModule, 92 node: Node, 93 modules: Dict[str, torch.nn.Module], 94 node_name_to_scope: Dict[str, Tuple[str, type]], 95 node_name_to_qconfig: Dict[str, QConfigAny], 96) -> None: 97 """Replace activation_post_process module call node with quantize and 98 dequantize node working with decomposed Tensor 99 100 Before: 101 ... -> observer_0(x) -> ... 102 After: 103 ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> 104 torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... 105 106 or quantize_per_channel and dequantize_per_channel 107 """ 108 graph = model.graph 109 assert modules is not None 110 assert isinstance(node.target, str) 111 module_path, prefix = _get_module_path_and_prefix( 112 node, node_name_to_scope, node_name_to_qconfig 113 ) 114 activation_post_process = modules[node.target] 115 if hasattr(activation_post_process, "convert"): 116 activation_post_process.convert(model, node) 117 return 118 # skip replacing observers to quant/dequant nodes if the qconfigs of all 119 # consumers and producers of this observer are None 120 skip_replacement = all( 121 _has_none_qconfig(n, node_name_to_qconfig) 122 for n in list(node.args) + list(node.users.keys()) 123 ) 124 if skip_replacement or not _is_conversion_supported(activation_post_process): 125 # didn't find corresponding quantize op and info for the activation_post_process 126 # so we just remove the observer 127 with graph.inserting_before(node): 128 node.replace_all_uses_with(node.args[0]) 129 graph.erase_node(node) 130 return 131 132 # otherwise, we can convert the activation_post_process module call to quantize/dequantize node 133 134 # 1. extract the information from activation_post_process module for generating 135 # the quantize and dequantize operator 136 dtype = activation_post_process.dtype # type: ignore[attr-defined] 137 138 is_dynamic = False 139 if hasattr(activation_post_process, "is_dynamic"): 140 is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] 141 142 def add_dequantize_op_kwargs(dequantize_op, input_node): 143 dequantize_op_kwargs = {} 144 if "val" in input_node.meta: 145 dq_out_dtype = input_node.meta["val"].dtype 146 if dq_out_dtype != torch.float32: 147 dequantize_op_kwargs = {"out_dtype": dq_out_dtype} 148 return dequantize_op_kwargs 149 150 if dtype in SUPPORTED_QDTYPES and (not is_dynamic): 151 # TODO: probably should cleanup this condition check, it's hard 152 # to reason about this if and the following elif 153 154 # uint8/int8/int32 static quantization branch 155 156 # 1. extract information for inserting q/dq node from activation_post_process 157 node_type = "call_function" 158 quantize_op: Optional[Callable] = None 159 scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] 160 if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] 161 ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] 162 quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default 163 dequantize_op = ( 164 torch.ops.quantized_decomposed.dequantize_per_channel.default 165 ) 166 quant_min = activation_post_process.quant_min 167 quant_max = activation_post_process.quant_max 168 dtype_ = to_underlying_dtype(dtype) 169 qparams = { 170 "_scale_": scale, 171 "_zero_point_": zero_point, 172 "_axis_": ch_axis, 173 "_quant_min_": quant_min, 174 "_quant_max_": quant_max, 175 "_dtype_": dtype_, 176 } 177 else: 178 quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default 179 dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default 180 scale = float(scale) 181 zero_point = int(zero_point) 182 quant_min = activation_post_process.quant_min # type: ignore[attr-defined] 183 quant_max = activation_post_process.quant_max # type: ignore[attr-defined] 184 dtype_ = to_underlying_dtype(dtype) 185 qparams = { 186 "_scale_": scale, 187 "_zero_point_": zero_point, 188 "_quant_min_": quant_min, 189 "_quant_max_": quant_max, 190 "_dtype_": dtype_, 191 } 192 193 # 2. replace activation_post_process node with quantize and dequantize 194 with graph.inserting_before(node): 195 input_node = node.args[0] 196 quantize_op_inputs = [input_node] 197 for key, value_or_node in qparams.items(): 198 # TODO: we can add the information of whether a value needs to 199 # be registered as an attribute in qparams dict itself 200 if key in ["_scale_", "_zero_point_"] and ( 201 not isinstance(value_or_node, (float, int)) 202 ): 203 # For scale and zero_point values we register them as buffers in the root module. 204 # However, note that when the values are not tensors, as in the case of 205 # per_tensor quantization, they will be treated as literals. 206 # However, registering them as a node seems to cause issue with dynamo 207 # tracing where it may consider tensor overload as opposed to default. 208 # With extra check of scale and zero_point being scalar, it makes 209 # sure that the default overload can be used. 210 # TODO: maybe need more complex attr name here 211 qparam_node = create_getattr_from_value( 212 model, graph, module_path + prefix + key, value_or_node 213 ) 214 quantize_op_inputs.append(qparam_node) 215 else: 216 # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. 217 quantize_op_inputs.append(value_or_node) 218 219 quantized_node = graph.create_node( 220 node_type, quantize_op, tuple(quantize_op_inputs), {} 221 ) 222 # use the same qparams from quantize op 223 dq_inputs = [quantized_node] + quantize_op_inputs[1:] 224 dequantized_node = graph.call_function( 225 dequantize_op, 226 tuple(dq_inputs), 227 add_dequantize_op_kwargs(dequantize_op, input_node), 228 ) 229 230 node.replace_all_uses_with(dequantized_node) 231 # propagate numeric debug handle from observer/fake_quant node to dequantize node 232 if ( 233 CUSTOM_KEY in node.meta 234 and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] 235 ): 236 if CUSTOM_KEY not in dequantized_node.meta: 237 dequantized_node.meta[CUSTOM_KEY] = {} 238 dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ 239 CUSTOM_KEY 240 ][NUMERIC_DEBUG_HANDLE_KEY] 241 graph.erase_node(node) 242 elif is_dynamic: 243 # uint8/int8/fp16 dynamic quantization 244 245 # 1. extract information for inserting q/dq node from activation_post_process 246 node_type = "call_function" 247 quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor 248 # we only use choose_qparams for is_decomposed now, 249 # but we should probably align the non-decomposed path with this as well, 250 # and that can be done after we remove reduce_range flag 251 # 1. extract qparams from activation_post_process module 252 dtype_ = to_underlying_dtype(dtype) 253 assert dtype_ in [torch.uint8, torch.int8], ( 254 "only uint8 and int8 are supported in reference flow for " 255 "dynamic quantization right now" 256 ) 257 quant_min = activation_post_process.quant_min # type: ignore[attr-defined] 258 quant_max = activation_post_process.quant_max # type: ignore[attr-defined] 259 qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] 260 eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] 261 # note: scale and zero_point are missing for quantize_per_tensor op 262 # we'll need to get this from choose_qparams op, which we'll add after 263 # this step 264 qparams = { 265 "_quant_min_": quant_min, 266 "_quant_max_": quant_max, 267 "_eps_": eps, 268 "_dtype_": dtype_, 269 } 270 271 choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] 272 # 2. insert choose_qparams op and update the qparams list 273 with graph.inserting_before(node): 274 input_node = node.args[0] 275 choose_qparams_op_inputs = [node.args[0]] 276 for key, value in qparams.items(): 277 # we have quant_min, quant_max and dtype, all should be stored 278 # as literals 279 choose_qparams_op_inputs.append(value) 280 choose_qparams_node = graph.create_node( 281 "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} 282 ) 283 # choose_qparms returns (scale, zero_point) 284 scale_node = graph.create_node( 285 "call_function", operator.getitem, (choose_qparams_node, 0), {} 286 ) 287 zero_point_node = graph.create_node( 288 "call_function", operator.getitem, (choose_qparams_node, 1), {} 289 ) 290 quant_min = qparams["_quant_min_"] 291 quant_max = qparams["_quant_max_"] 292 dtype = qparams["_dtype_"] 293 qparams = { 294 "_scale_": scale_node, 295 "_zero_point_": zero_point_node, 296 "_quant_min_": quant_min, 297 "_quant_max_": quant_max, 298 "_dtype_": dtype, 299 } 300 301 # 3. replace activation_post_process node to quantize and dequantize node 302 with graph.inserting_before(node): 303 input_node = node.args[0] 304 quantize_op_inputs = [input_node] 305 for key, value_or_node in qparams.items(): 306 # TODO: we can add the information of whether a value needs to 307 # be registered as an attribute in qparams dict itself 308 if key in ["_scale_", "_zero_point_"]: 309 # in this case we have a node in the graph since it's dynamically 310 # computed from the input, with choose_qparams op 311 qparam_node = value_or_node 312 quantize_op_inputs.append(qparam_node) 313 else: 314 # for qparams that are not scale/zero_point (like axis, dtype) we 315 # store them as literals in the graph. 316 quantize_op_inputs.append(value_or_node) 317 318 quantized_node = graph.create_node( 319 node_type, quantize_op, tuple(quantize_op_inputs), {} 320 ) 321 # use the same qparams from quantize op 322 dq_inputs = [quantized_node] + quantize_op_inputs[1:] 323 # need to use the tensor variant of this op, since scale and zero_point 324 # from choose_qparam are Tensors, instead of float/int, this is to 325 # prevent these nodes being traced away by downstream systems 326 dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 327 dequantized_node = graph.call_function( 328 dequantize_op, 329 tuple(dq_inputs), 330 add_dequantize_op_kwargs(dequantize_op, input_node), 331 ) 332 333 def remap_fn(x): 334 return dequantized_node if x is node else x 335 336 node.replace_all_uses_with(dequantized_node) 337 # propagate numeric debug handle from observer/fake_quant node to dequantize node 338 if NUMERIC_DEBUG_HANDLE_KEY in node.meta: 339 dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ 340 NUMERIC_DEBUG_HANDLE_KEY 341 ] 342 graph.erase_node(node) 343 elif dtype == torch.float16: 344 raise NotImplementedError("decomposed to float16 op not implemented yet") 345 346 # should not reach since we have checks in the beginning to make sure the 347 # activation_post_process is supported 348 349 350def _replace_observer_with_quantize_dequantize_node( 351 model: torch.fx.GraphModule, 352 node: Node, 353 modules: Dict[str, torch.nn.Module], 354 node_name_to_scope: Dict[str, Tuple[str, type]], 355 node_name_to_qconfig: Dict[str, QConfigAny], 356) -> None: 357 """Replace activation_post_process module call node with quantize and 358 dequantize node 359 360 Before: 361 ... -> observer_0(x) -> ... 362 After: 363 ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... 364 """ 365 assert modules is not None 366 assert isinstance(node.target, str) 367 graph = model.graph 368 module_path, prefix = _get_module_path_and_prefix( 369 node, node_name_to_scope, node_name_to_qconfig 370 ) 371 activation_post_process = modules[node.target] 372 # skip replacing observers to quant/dequant nodes if the qconfigs of all 373 # consumers and producers of this observer are None 374 skip_replacement = all( 375 _has_none_qconfig(n, node_name_to_qconfig) 376 for n in list(node.args) + list(node.users.keys()) 377 ) 378 if skip_replacement or not _is_conversion_supported(activation_post_process): 379 # didn't find corresponding quantize op and info for the activation_post_process 380 # so we just remove the observer 381 with graph.inserting_before(node): 382 node.replace_all_uses_with(node.args[0]) 383 graph.erase_node(node) 384 return 385 386 # otherwise, we can convert the activation_post_process module call to quantize/dequantize node 387 dtype = activation_post_process.dtype # type: ignore[attr-defined] 388 389 is_dynamic = False 390 if hasattr(activation_post_process, "is_dynamic"): 391 is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] 392 393 if dtype in [ 394 torch.quint8, 395 torch.qint8, 396 torch.qint32, 397 torch.float8_e5m2, 398 torch.float8_e4m3fn, 399 ] and (not is_dynamic): 400 # TODO: probably should cleanup this condition check, it's hard 401 # to reason about this if and the following elif 402 403 # uint8/int8/int32 static quantization branch 404 405 # 1. extract the information from activation_post_process module for generating 406 # the quantize and dequantize operator 407 node_type = "call_function" 408 quantize_op: Optional[Callable] = None 409 scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] 410 if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] 411 ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] 412 qparams = { 413 "_scale_": scale, 414 "_zero_point_": zero_point, 415 "_axis_": ch_axis, 416 "_dtype_": dtype, 417 } 418 quantize_op = torch.quantize_per_channel 419 else: 420 scale = float(scale) 421 zero_point = int(zero_point) 422 qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} 423 quantize_op = torch.quantize_per_tensor 424 425 # 2. replace activation_post_process node with quantize and dequantize 426 with graph.inserting_before(node): 427 input_node = node.args[0] 428 quantize_op_inputs = [input_node] 429 for key, value_or_node in qparams.items(): 430 # TODO: we can add the information of whether a value needs to 431 # be registered as an attribute in qparams dict itself 432 if key in ["_scale_", "_zero_point_"]: 433 # For scale and zero_point values we register them as buffers in the root module. 434 # TODO: maybe need more complex attr name here 435 qparam_node = create_getattr_from_value( 436 model, graph, module_path + prefix + key, value_or_node 437 ) 438 quantize_op_inputs.append(qparam_node) 439 else: 440 # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. 441 quantize_op_inputs.append(value_or_node) 442 443 quantized_node = graph.create_node( 444 node_type, quantize_op, tuple(quantize_op_inputs), {} 445 ) 446 dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) 447 node.replace_all_uses_with(dequantized_node) 448 graph.erase_node(node) 449 elif is_dynamic: 450 # uint8/int8/fp16 dynamic quantization branch 451 452 node_type = "call_function" 453 quantize_op = torch.quantize_per_tensor_dynamic 454 # TODO: get reduce range from observer 455 # reduce_range = activation_post_process.reduce_range 456 reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") 457 qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} 458 459 with graph.inserting_before(node): 460 input_node = node.args[0] 461 quantize_op_inputs = [input_node] 462 for key, value in qparams.items(): 463 quantize_op_inputs.append(value) 464 465 quantized_node = graph.create_node( 466 node_type, quantize_op, tuple(quantize_op_inputs), {} 467 ) 468 dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) 469 node.replace_all_uses_with(dequantized_node) 470 graph.erase_node(node) 471 elif dtype == torch.float16: 472 node_type = "call_method" 473 quantize_op = "to" # type: ignore[assignment] 474 qparams = {"_dtype_": dtype} 475 with graph.inserting_before(node): 476 input_node = node.args[0] 477 quantize_op_inputs = [input_node] 478 for key, value in qparams.items(): 479 # TODO: we can add the information of whether a value needs to 480 # be registered as an attribute in qparams dict itself 481 quantize_op_inputs.append(value) 482 483 quantized_node = graph.create_node( 484 node_type, quantize_op, tuple(quantize_op_inputs), {} 485 ) 486 dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) 487 node.replace_all_uses_with(dequantized_node) 488 graph.erase_node(node) 489 490 # should not reach since we have checks in the beginning to make sure the 491 # activation_post_process is supported 492 493 494# this is a temporary hack for custom module, we may want to implement 495# this properly after the custom module class design is finalized 496# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted 497# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs 498# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. 499def _replace_observer_or_dequant_stub_with_dequantize_node( 500 node: Node, graph: Graph 501) -> None: 502 call_custom_module_node = node.args[0] 503 assert isinstance( 504 call_custom_module_node, Node 505 ), f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" 506 node.replace_all_uses_with(call_custom_module_node) 507 graph.erase_node(node) 508 _insert_dequantize_node(call_custom_module_node, graph) 509 510 511def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: 512 dtype = activation_post_process.dtype # type: ignore[attr-defined] 513 514 is_dynamic = False 515 if hasattr(activation_post_process, "is_dynamic"): 516 is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] 517 518 return ( 519 (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) 520 or is_dynamic # type: ignore[return-value] 521 or dtype == torch.float16 522 ) 523 524 525def _has_none_qconfig( 526 node: Argument, node_name_to_qconfig: Dict[str, QConfigAny] 527) -> bool: 528 """Check if a node has a qconfig of None, i.e. user requested to not quantize 529 the node 530 """ 531 return ( 532 isinstance(node, Node) 533 and node.name in node_name_to_qconfig 534 and node_name_to_qconfig[node.name] is None 535 ) 536 537 538def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: 539 """Extract the subgraph that produces the weight for dynamic quant 540 or weight only quant node and run the subgraph to observe the weight. 541 Note that the observers of dynamic quant or weight only quant ops are 542 run during the convert step. 543 """ 544 for node in observed.graph.nodes: 545 if node.op != "call_function": 546 continue 547 for node_arg in node.args: 548 # node_arg is weight 549 if node_arg and node_arg_is_weight(node, node_arg): 550 weight_observer_nodes = collect_producer_nodes(node_arg) 551 if weight_observer_nodes is None: 552 continue 553 weight_observer_module = graph_module_from_producer_nodes( 554 observed, weight_observer_nodes 555 ) 556 # run the weight observer 557 weight_observer_module() 558 559 560def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: 561 """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, 562 we'll recursively remove the dequantize Node 563 """ 564 if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": 565 quantize_node = arg.args[0] 566 # we only replace the specific use since dequantize could be used by other nodes 567 # as well 568 node.replace_input_with(arg, quantize_node) 569 elif isinstance(arg, (list, tuple)): 570 for arg_element in arg: 571 _maybe_recursive_remove_dequantize(arg_element, node, graph) 572 elif isinstance(arg, dict): 573 for arg_element in arg.values(): 574 _maybe_recursive_remove_dequantize(arg_element, node, graph) 575 else: 576 warnings.warn( 577 f"Unsupported node type in recursive remove dequantize: {type(arg)}" 578 ) 579 580 581def _get_module_path_and_prefix( 582 obs_node: Node, 583 node_name_to_scope: Dict[str, Tuple[str, type]], 584 node_name_to_qconfig: Dict[str, QConfigAny], 585) -> Tuple[str, str]: 586 """Given and observer node, get the `Scope` or the fully qualified name for 587 the submodule containing the observed node, also return a prefix of "_input" 588 when the observed node is an input of a F.linear op, and not the output of another 589 quantized op. 590 TODO: this logic is hacky, we should think about how to remove it or make it more 591 general 592 """ 593 observed_node = obs_node.args[0] 594 # an observer can be inserted for both input of the next operator or output of the previous 595 # operator (they can be the same) 596 # this flag identifies if the observer is inserted only because the observed node is 597 # the input of the next operator 598 assert isinstance( 599 observed_node, Node 600 ), f"Expecting observed node to be a Node, but got {observed_node}" 601 is_input_observer_only = ( 602 node_name_to_qconfig[observed_node.name] is None 603 if observed_node.name in node_name_to_qconfig 604 else None 605 ) 606 if is_input_observer_only: 607 # if the quantize function is at the input of op, then we find the first user of the observer_node 608 # to get the path. If a linear call_function is in the user list, we return the first instance 609 # of linear node to get the FQN. 610 users = list(obs_node.users) 611 first_linear_use_or_first_use = users[0] if users else None 612 linear_node = None 613 for n in users: 614 if n.op == "call_function" and n.target == torch.nn.functional.linear: 615 linear_node = n 616 break 617 if linear_node: 618 first_linear_use_or_first_use = linear_node 619 prefix = "_input" 620 else: 621 # if the quantize function is at the output of the op, we use the observer input node to get the path 622 first_linear_use_or_first_use = observed_node 623 prefix = "" 624 625 if ( 626 first_linear_use_or_first_use 627 and first_linear_use_or_first_use.name in node_name_to_scope 628 ): 629 module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] 630 else: 631 # TODO: it's not used, so actually we can skip quantization 632 # but this requires changing return type of quantize_node 633 # we can fix it later if needed 634 module_path = "" 635 return module_path, prefix 636 637 638def _insert_dequantize_node(node: Node, graph: Graph) -> None: 639 """Inserts dequantize node for `node` in `graph`""" 640 with graph.inserting_after(node): 641 dequantize_node = graph.call_method("dequantize", (node,)) 642 for user_node in dict(node.users): 643 if user_node is not dequantize_node: 644 user_node.replace_input_with(node, dequantize_node) 645 646 647def _maybe_get_observer_for_node( 648 node: Node, modules: Dict[str, torch.nn.Module] 649) -> Optional[torch.nn.Module]: 650 """ 651 If the node is observed, return the observer 652 instance. Otherwise, return None. 653 """ 654 for maybe_obs_node in node.users.keys(): 655 if maybe_obs_node.op == "call_module": 656 maybe_obs = modules[str(maybe_obs_node.target)] 657 if _is_activation_post_process(maybe_obs): 658 return maybe_obs 659 return None 660 661 662def convert_standalone_module( 663 node: Node, 664 modules: Dict[str, torch.nn.Module], 665 model: torch.fx.GraphModule, 666 is_reference: bool, 667 backend_config: Optional[BackendConfig], 668) -> None: 669 """Converts a observed standalone module to a quantized standalone module by calling 670 the fx convert api, currently using the same `is_reference` flag as parent, but we may 671 changing this behavior in the future (e.g. separating quantization and lowering for 672 standalone module as well) 673 674 Args: 675 - node: The call_module node of the observed standalone module 676 - modules: named_module of original model 677 - model: original model 678 - is_reference: a flag from parent provided by user to decide if we want to 679 produce a reference model or a fbgemm/qnnpack model 680 - backend_config: backend configuration of the target backend of quantization 681 """ 682 # TODO: remove is_reference flag 683 if is_reference: 684 convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx 685 else: 686 convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] 687 # We know that observed standalone module is a GraphModule since 688 # it's produced by us 689 observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment] 690 sm_input_quantized_idxs = observed_standalone_module.meta[ 691 "_observed_graph_module_attrs" 692 ].standalone_module_input_quantized_idxs 693 # remove the dequantize nodes for inputs 694 args = list(node.args) 695 for idx in range(len(args)): 696 if idx in sm_input_quantized_idxs: 697 arg = args[idx] 698 if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] 699 quantize_node = arg.args[0] # type: ignore[union-attr] 700 node.replace_input_with(arg, quantize_node) 701 if len(arg.users) == 0: # type: ignore[union-attr] 702 model.graph.erase_node(arg) 703 # add dequantize node for output 704 sm_output_quantized_idxs = observed_standalone_module.meta[ 705 "_observed_graph_module_attrs" 706 ].standalone_module_output_quantized_idxs 707 if len(sm_output_quantized_idxs) > 0: 708 assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" 709 "output idxs = [0] is supported" 710 711 # if it's non-empty, then it means the output is kept in quantized form 712 # we'll just add a dequantize node after this node 713 _insert_dequantize_node(node, model.graph) 714 715 # TODO: allow convert_custom_config to override backend_config 716 # for standalone module 717 quantized_standalone_module = convert_fn( 718 observed_standalone_module, backend_config=backend_config 719 ) 720 parent_name, name = _parent_name(node.target) 721 # update the modules dict 722 setattr(modules[parent_name], name, quantized_standalone_module) 723 modules[str(node.target)] = quantized_standalone_module 724 725 726def convert_weighted_module( 727 node: Node, 728 modules: Dict[str, torch.nn.Module], 729 observed_node_names: Set[str], 730 node_name_to_qconfig: Dict[str, QConfigAny], 731 backend_config: BackendConfig, 732 is_decomposed: bool = False, 733 is_reference: bool = False, 734) -> None: 735 """Convert a weighted module to reference quantized module in the model 736 If the QConfig of a QAT module is not set, the module will still be converted to 737 a float module. 738 739 Args: 740 - node: The call_module node of the observed standalone module 741 - modules: named_module of original model 742 - observed_node_names: names for the set of observed fx node, we can skip 743 this conversion if the node is not observed 744 """ 745 original_module = modules[str(node.target)] 746 qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] 747 weight_post_process = None 748 qat_module_classes = get_qat_module_classes(backend_config) 749 750 if isinstance(original_module, qat_module_classes): 751 # Converting qat module to a float module, we need to attach 752 # weight fake_quant to the module, weight fake_quant is assumed to be run during 753 # QAT so we don't need to run it again here 754 weight_post_process = original_module.weight_fake_quant 755 original_module = original_module.to_float() # type: ignore[operator] 756 # change qat module to float module 757 parent_name, name = _parent_name(node.target) 758 setattr(modules[parent_name], name, original_module) 759 760 is_observed = node.name in observed_node_names 761 # If a qconfig is not defined for this node, then skip converting to a reference module 762 if ( 763 qconfig is None 764 or _has_none_qconfig(node, node_name_to_qconfig) 765 or not is_observed 766 ): 767 return 768 769 # skip converting to reference quantized module if the qconfig is not supported 770 pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) 771 dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) 772 if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): 773 return 774 775 # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized 776 is_weight_quantized = weight_is_quantized(qconfig) 777 778 # the condition for swapping the module to reference quantized module is: 779 # weights need to be quantized 780 if not is_weight_quantized: 781 return 782 783 fused_module = None 784 float_module = original_module 785 # extract the individual float_module and fused module 786 if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): 787 fused_module = float_module 788 float_module = fused_module[0] # type: ignore[index] 789 790 # TODO: move this to the reference quantized module 791 # weight_qparams or weight_qparams dict 792 wq_or_wq_dict = {"is_decomposed": is_decomposed} 793 if isinstance(float_module, torch.nn.RNNCellBase): 794 weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] 795 weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] 796 weight_post_process_ih(float_module.weight_ih) 797 weight_post_process_hh(float_module.weight_hh) 798 weight_qparams_ih = get_qparam_dict(weight_post_process_ih) 799 weight_qparams_hh = get_qparam_dict(weight_post_process_hh) 800 wq_or_wq_dict.update( 801 { 802 "weight_ih": weight_qparams_ih, 803 "weight_hh": weight_qparams_hh, 804 } 805 ) 806 elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): 807 # format for wq_or_wq_dict (flattened attributes): 808 # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} 809 for wn in float_module._flat_weights_names: 810 if hasattr(float_module, wn) and wn.startswith("weight"): 811 weight = getattr(float_module, wn) 812 weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] 813 if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] 814 weight_post_process(weight) # type: ignore[operator, misc] 815 wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) 816 else: 817 # weight_post_process is None means the original module is not a QAT module 818 # we need to get weight_post_process from qconfig in this case 819 is_ptq = weight_post_process is None 820 if is_ptq: 821 weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] 822 device = assert_and_get_unique_device(float_module) 823 if device: 824 weight_post_process.to(device) 825 826 # Call weight observer/fake_quant at least once to ensure the scales and zero points 827 # have the right shapes. Note: there are two cases where we don't have to do this: 828 # 829 # (1) QAT: The model's forward method already calls the weight observer/fake_quant, 830 # and this typically happens during training, so we don't need to do it here. 831 # 832 # (2) Non-reference (lowered) case: The quantized module's from_float method already 833 # calls the weight observer/fake_quant, so we don't have to do it here. 834 # 835 # Currently we ignore both cases and call the weight observer/fake_quant here 836 # regardless, which is technically incorrect. For (1), this is mainly to preserve BC 837 # in test code, which may not always train before convert. In the future, we should 838 # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. 839 # 840 # For PT2, however, we don't need to preserve BC here, so we can skip this hack 841 # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). 842 # Note that we still need it for PTQ in the PT2 flow since the model's forward 843 # method doesn't call the weight observer. 844 is_qat = not is_ptq 845 if not (is_decomposed and is_reference and is_qat): 846 weight_post_process(float_module.weight) # type: ignore[operator] 847 848 wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) 849 850 # We use the same reference module for all modes of quantization: static, dynamic, weight_only 851 # root_module_to_quantized_reference_module: module mapping from root (floating point) module class 852 # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d 853 root_module_to_quantized_reference_module = ( 854 get_root_module_to_quantized_reference_module(backend_config) 855 ) 856 ref_qmodule_cls = root_module_to_quantized_reference_module.get( 857 type_before_parametrizations(float_module), None 858 ) 859 assert ( 860 ref_qmodule_cls is not None 861 ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" 862 ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] 863 if fused_module is not None: 864 fused_module[0] = ref_qmodule # type: ignore[operator] 865 else: 866 parent_name, name = _parent_name(node.target) 867 setattr(modules[parent_name], name, ref_qmodule) 868 869 870def _remove_previous_dequantize_in_custom_module( 871 node: Node, prev_node: Node, graph: Graph 872) -> None: 873 """ 874 Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: 875 876 Before: quantize - dequantize - custom_module 877 After: quantize - custom_module 878 \\ - dequantize 879 """ 880 # expecting the input node for a custom module node to be a Node 881 assert isinstance( 882 prev_node, Node 883 ), f"Expecting the argument for custom module node to be a Node, but got {prev_node}" 884 if prev_node.op == "call_method" and prev_node.target == "dequantize": 885 node.replace_input_with(prev_node, prev_node.args[0]) 886 # Remove the dequantize node if it doesn't have other users 887 if len(prev_node.users) == 0: 888 graph.erase_node(prev_node) 889 890 891def convert_custom_module( 892 node: Node, 893 graph: Graph, 894 modules: Dict[str, torch.nn.Module], 895 custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]], 896 statically_quantized_custom_module_nodes: Set[Node], 897) -> None: 898 """Converts an observed custom module to a quantized custom module based on 899 `custom_module_class_mapping` 900 For static quantization, we'll also remove the previous `dequantize` node and 901 attach the observer node for output to the module, the observer for the node 902 will be converted to a dequantize node instead of quantize-dequantize pairs 903 later in the graph. In the end we would have a quantized custom module that 904 has the same interface as a default quantized module in nn.quantized namespace, 905 i.e. quantized input and quantized output. 906 907 Args: 908 - node: The call_module node of the observed standalone module 909 - graph: The graph containing the node 910 - modules: named_module of original model 911 - custom_module_class_mapping: mapping from observed custom module class to 912 quantized custom module class, used to swap custom modules 913 - statically_quantized_custom_module_nodes: we'll add the custom module node 914 if we find it is statically quantized, this will be used later when converting 915 observers to quant/dequant node pairs, if the observed node is a statically 916 quantized custom module nodes, we'll convert the observer to a dequantize node, 917 this is to keep the interface the same as the default quantized module. 918 TODO: maybe we want to redesign this part to align with reference model design 919 as well, but there has been some discussions around the interface, so we can do 920 it later. 921 """ 922 observed_custom_module = modules[str(node.target)] 923 maybe_obs = _maybe_get_observer_for_node(node, modules) 924 qconfig = observed_custom_module.qconfig 925 if activation_is_statically_quantized(qconfig): 926 statically_quantized_custom_module_nodes.add(node) 927 if _is_custom_module_lstm(node, modules): 928 # The inputs are tuples in the form (input, (hidden0, hidden1)) 929 # Ensure all three input nodes are quantized 930 assert ( 931 len(node.args) == 2 932 and isinstance(node.args[1], tuple) 933 and len(node.args[1]) == 2 934 ) 935 (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] 936 assert isinstance(inputs, Node) 937 assert isinstance(hidden0, Node) 938 assert isinstance(hidden1, Node) 939 _remove_previous_dequantize_in_custom_module(node, inputs, graph) 940 _remove_previous_dequantize_in_custom_module(node, hidden0, graph) 941 _remove_previous_dequantize_in_custom_module(node, hidden1, graph) 942 elif _is_custom_module_mha(node, modules): 943 # Inputs are in the form (query, key, value) 944 # TODO: This is the first step in enabling the full fx custom module 945 # quantization path for MultiheadAttention, and only covers the inputs 946 # to the module. 947 # Additional handling is yet to be implemented for the outputs, similar 948 # to LSTM custom module 949 assert len(node.args) == 3 950 query, key, value = node.args 951 assert isinstance(query, Node) 952 assert isinstance(key, Node) 953 assert isinstance(value, Node) 954 _remove_previous_dequantize_in_custom_module(node, query, graph) 955 _remove_previous_dequantize_in_custom_module(node, key, graph) 956 _remove_previous_dequantize_in_custom_module(node, value, graph) 957 else: 958 # remove the previous dequant node to ensure the inputs are quantized 959 arg = node.args[0] 960 assert isinstance(arg, Node) 961 _remove_previous_dequantize_in_custom_module(node, arg, graph) 962 # absorb the following observer into the module conversion 963 activation_post_process = _maybe_get_observer_for_node(node, modules) 964 assert activation_post_process is not None 965 observed_custom_module.activation_post_process = activation_post_process 966 967 # swap the observed custom module to quantized custom module 968 quantized_custom_module_class = get_swapped_custom_module_class( 969 observed_custom_module, custom_module_class_mapping, qconfig 970 ) 971 quantized_custom_module = quantized_custom_module_class.from_observed( 972 observed_custom_module 973 ) 974 parent_name, name = _parent_name(node.target) 975 setattr(modules[parent_name], name, quantized_custom_module) 976 977 978def convert( 979 model: GraphModule, 980 is_reference: bool = False, 981 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 982 is_standalone_module: bool = False, 983 _remove_qconfig_flag: bool = True, 984 qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, 985 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 986 is_decomposed: bool = False, 987) -> GraphModule: 988 """ 989 We will convert an observed model (a module with observer calls) to a reference 990 quantized model, the rule is simple: 991 1. for each observer module call in the graph, we'll convert it to calls to 992 quantize and dequantize functions based on the observer instance 993 2. for weighted operations like linear/conv, we need to convert them to reference 994 quantized module, this requires us to know whether the dtype configured for the 995 weight is supported in the backend, this is done in prepare step and the result 996 is stored in observed_node_names, we can decide whether we need to swap the 997 module based on this set 998 999 Args: 1000 * `is_standalone_module`: when this flag is True, it means we are quantizing 1001 a submodule that is not inlined in parent module, and will be quantized 1002 separately as one unit. 1003 1004 * `is_decomposed`: a boolean flag to indicate whether we want to use the 1005 quantize operator for decomposed quantized tensor 1006 (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone 1007 quantized tensor (torch.quantize_per_tensor) 1008 1009 Returns: 1010 a quantized standalone module, whether input/output is quantized is 1011 specified by prepare_custom_config, with 1012 input_quantized_idxs, output_quantized_idxs, please 1013 see docs for :func:`~torch.ao.quantization.prepare_fx` for details 1014 """ 1015 if convert_custom_config is None: 1016 convert_custom_config = ConvertCustomConfig() 1017 1018 if isinstance(convert_custom_config, dict): 1019 warnings.warn( 1020 "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " 1021 "in a future version. Please pass in a ConvertCustomConfig instead.", 1022 FutureWarning, 1023 stacklevel=2, 1024 ) 1025 convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) 1026 1027 if isinstance(qconfig_mapping, dict): 1028 warnings.warn( 1029 "Passing a QConfig dictionary to convert is deprecated and will not be supported " 1030 "in a future version. Please pass in a QConfigMapping instead.", 1031 FutureWarning, 1032 stacklevel=2, 1033 ) 1034 qconfig_mapping = ( 1035 QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None 1036 ) 1037 qconfig_mapping = copy.deepcopy(qconfig_mapping) 1038 assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) 1039 1040 if isinstance(backend_config, dict): 1041 warnings.warn( 1042 "Passing a backend_config_dict to prepare is deprecated and will not be supported " 1043 "in a future version. Please pass in a BackendConfig instead.", 1044 FutureWarning, 1045 stacklevel=2, 1046 ) 1047 backend_config = BackendConfig.from_dict(backend_config) 1048 1049 if backend_config is None: 1050 backend_config = get_native_backend_config() 1051 1052 assert _is_observed_module(model), "incoming model must be produced by prepare_fx" 1053 observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] 1054 node_name_to_scope: Dict[ 1055 str, Tuple[str, type] 1056 ] = observed_graph_module_attrs.node_name_to_scope 1057 prepare_custom_config: PrepareCustomConfig = ( 1058 observed_graph_module_attrs.prepare_custom_config 1059 ) 1060 observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names 1061 node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment] 1062 1063 # mapping from fully qualified module name to module instance 1064 # for example, 1065 # { 1066 # '': Model(...), 1067 # 'linear': Linear(...), 1068 # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), 1069 # } 1070 # We use remove_duplicate=False here because torch.cat uses 1071 # the same activation_post_process module instance but different names 1072 modules = dict(model.named_modules(remove_duplicate=False)) 1073 1074 # TODO refactor this code once we update the prepare logic to have additional information on 1075 # which graph nodes have been observed and share that with convert to decide which observers to ignore. 1076 if qconfig_mapping: 1077 prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment] 1078 modules_copy = copy.deepcopy(modules) 1079 1080 if observed_graph_module_attrs.is_qat: 1081 _update_qconfig_for_qat(qconfig_mapping, backend_config) 1082 _update_qconfig_for_fusion(model, qconfig_mapping) 1083 1084 _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type] 1085 convert_node_name_to_qconfig = _generate_node_name_to_qconfig( 1086 model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope 1087 ) 1088 # check the convert_node_name_to_qconfig generated and ensure that 1089 # all the values either match what was set in prepare node_name_to_qconfig 1090 # or are set to None in the convert_node_name_to_qconfig. 1091 for k, v in node_name_to_qconfig.items(): 1092 assert ( 1093 k in convert_node_name_to_qconfig 1094 ), f"Expected key {k} in convert node_name_to_qconfig" 1095 if convert_node_name_to_qconfig[k] is not None: 1096 assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( 1097 f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " 1098 f"but {v} was updated to {convert_node_name_to_qconfig[k]}" 1099 ) 1100 node_name_to_qconfig = convert_node_name_to_qconfig 1101 1102 custom_module_classes = get_custom_module_class_keys( 1103 convert_custom_config.observed_to_quantized_mapping 1104 ) 1105 custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping 1106 1107 if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: 1108 # If we want to do equalization then do the following: 1109 # Calculate the equalization scale, update the observers with the scaled 1110 # inputs, and scale the weight 1111 weight_eq_obs_dict = update_obs_for_equalization(model, modules) 1112 convert_eq_obs(model, modules, weight_eq_obs_dict) 1113 1114 # always run weight observers in the top level forward method 1115 # for dynamic quant ops or weight only quant ops 1116 _run_weight_observers(model, backend_config) 1117 1118 graph_inputs: List[str] = [] 1119 for node in model.graph.nodes: 1120 if node.op == "placeholder": 1121 graph_inputs.append(node.name) 1122 1123 # additional state to override inputs to be quantized, if specified 1124 # by the user 1125 placeholder_node_seen_cnt = 0 1126 input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes 1127 output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes 1128 1129 root_module_to_quantized_reference_module = ( 1130 get_root_module_to_quantized_reference_module(backend_config) 1131 ) 1132 # convert tuples so that it can work with isinstance(module, tuple_of_classes) 1133 root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) 1134 qat_module_classes = get_qat_module_classes(backend_config) 1135 fused_module_classes = get_fused_module_classes(backend_config) 1136 statically_quantized_custom_module_nodes: Set[Node] = set() 1137 1138 for node in list(model.graph.nodes): 1139 if node.op == "placeholder": 1140 cur_placeholder_node_idx = placeholder_node_seen_cnt 1141 placeholder_node_seen_cnt += 1 1142 if cur_placeholder_node_idx in input_quantized_idxs: 1143 # Inputs are assumed to be quantized if the user specified the 1144 # input_quantized_idxs override. 1145 # we need to dequantize the inputs since all operators took 1146 # floating point inputs in reference quantized models 1147 _insert_dequantize_node(node, model.graph) 1148 elif node.op == "output": 1149 # If the argument is empty we don't need to do anything 1150 if len(output_quantized_idxs) == 0: 1151 continue 1152 # Result are kept quantized if the user specified the 1153 # output_quantized_idxs override. 1154 # Remove the dequantize operator for the node in the end if any 1155 return_node = node 1156 output = node.args[0] 1157 # outputs can be Node, list, tuple, dict, other cases are not supported yet 1158 if isinstance(output, (list, tuple)): 1159 for idx in output_quantized_idxs: 1160 _maybe_recursive_remove_dequantize( 1161 output[idx], return_node, model.graph 1162 ) 1163 elif isinstance(output, (Node, dict)): 1164 # we treat dict as a single argument currently, but it can be extended 1165 # to support {"key": dtype} after we change output_quantized_idxs to 1166 # dict 1167 if 0 in output_quantized_idxs: 1168 _maybe_recursive_remove_dequantize(output, return_node, model.graph) 1169 else: 1170 warnings.warn( 1171 f"Unsupported node type for output_quantized_idxs: {type(output)}" 1172 ) 1173 elif node.op == "call_module": 1174 mod = _get_module(node, modules) 1175 assert mod is not None 1176 if _is_activation_post_process(mod): 1177 observed_node = node.args[0] 1178 if observed_node in statically_quantized_custom_module_nodes: 1179 _replace_observer_or_dequant_stub_with_dequantize_node( 1180 node, model.graph 1181 ) 1182 else: 1183 if is_decomposed: 1184 _replace_observer_with_quantize_dequantize_node_decomposed( 1185 model, 1186 node, 1187 modules, 1188 node_name_to_scope, 1189 node_name_to_qconfig, 1190 ) 1191 else: 1192 _replace_observer_with_quantize_dequantize_node( 1193 model, 1194 node, 1195 modules, 1196 node_name_to_scope, 1197 node_name_to_qconfig, 1198 ) 1199 elif isinstance(mod, DeQuantStub): 1200 _replace_observer_or_dequant_stub_with_dequantize_node( 1201 node, model.graph 1202 ) 1203 elif _is_observed_standalone_module(mod): 1204 convert_standalone_module( 1205 node, modules, model, is_reference, backend_config 1206 ) 1207 # below this point `type_before_parametrizations` is used 1208 # instead of `type` to handle situations with fx quant + sparsity 1209 elif type_before_parametrizations(mod) in set(root_module_classes).union( 1210 qat_module_classes 1211 ).union(fused_module_classes): 1212 # extra check for fused module classes to make sure they are fused module classes 1213 # of target modules 1214 if ( 1215 type_before_parametrizations(mod) in fused_module_classes 1216 and type_before_parametrizations(mod[0]) not in root_module_classes 1217 ): # type: ignore[index] 1218 continue 1219 convert_weighted_module( 1220 node, 1221 modules, 1222 observed_node_names, 1223 node_name_to_qconfig, 1224 backend_config, 1225 is_decomposed, 1226 is_reference, 1227 ) 1228 elif type_before_parametrizations(mod) in custom_module_classes: 1229 convert_custom_module( 1230 node, 1231 model.graph, 1232 modules, 1233 custom_module_class_mapping, 1234 statically_quantized_custom_module_nodes, 1235 ) 1236 1237 # remove deadcode after converting observers to quant/dequant ops 1238 model.graph.eliminate_dead_code() 1239 model = GraphModule(model, model.graph) 1240 1241 # TODO: maybe move this to quantize_fx.py 1242 if not is_reference: 1243 model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope) 1244 1245 # TODO: this looks hacky, we want to check why we need this and see if we can 1246 # remove this 1247 # removes qconfig and activation_post_process modules 1248 if _remove_qconfig_flag: 1249 _remove_qconfig(model) 1250 model.delete_all_unused_submodules() 1251 model.meta.pop("_observed_graph_module_attrs", None) 1252 return model 1253