1# mypy: allow-untyped-defs 2import operator 3import types 4from typing import Any, Callable, Dict, List, Optional, Tuple, Union 5 6import torch 7import torch.nn.functional as F 8from torch._export import capture_pre_autograd_graph 9 10# Makes sure that quantized_decomposed ops are registered 11from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 12from torch.ao.quantization.quantizer import QuantizationAnnotation 13from torch.export.unflatten import _assign_attr, _AttrKind 14from torch.fx import GraphModule, Node 15from torch.nn.utils.fusion import fuse_conv_bn_weights 16from torch.utils._pytree import LeafSpec 17 18 19__all__ = [ 20 "fold_bn_weights_into_conv_node", 21 "remove_tensor_overload_for_qdq_ops", 22] 23 24_QUANTIZE_OPS = [ 25 torch.ops.quantized_decomposed.quantize_per_tensor.default, 26 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 27 torch.ops.quantized_decomposed.quantize_per_channel.default, 28] 29 30 31_DEQUANTIZE_OPS = [ 32 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 33 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 34 torch.ops.quantized_decomposed.dequantize_per_channel.default, 35] 36 37# Example inputs for conv-bn1d patterns 38_conv1d_bn_example_inputs = ( 39 torch.randn(1, 1, 3), # x 40 torch.randn(1, 1, 1), # conv_weight 41 torch.randn(1), # conv_bias 42 torch.randn(1), # bn_weight 43 torch.randn(1), # bn_bias 44 torch.randn(1), # bn_running_mean 45 torch.randn(1), # bn_running_var 46) 47 48# Example inputs for conv-bn2d patterns 49_conv2d_bn_example_inputs = ( 50 torch.randn(1, 1, 3, 3), # x 51 torch.randn(1, 1, 1, 1), # conv_weight 52 torch.randn(1), # conv_bias 53 torch.randn(1), # bn_weight 54 torch.randn(1), # bn_bias 55 torch.randn(1), # bn_running_mean 56 torch.randn(1), # bn_running_var 57) 58 59 60def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: 61 """ 62 Assuming dest is one of the ops inserted by quant workflow, this function 63 finds if source and dest are connected. Assumption is that only quant workflow 64 inserted ops exist between source and dest 65 """ 66 quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS 67 quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) 68 while dest.target in quant_workflow_ops: 69 if not isinstance(dest.args[0], torch.fx.Node): 70 raise ValueError( 71 f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}" 72 ) 73 dest = dest.args[0] 74 return dest == source 75 76 77def _find_q_dq_node_for_user( 78 produer: torch.fx.Node, user: torch.fx.Node 79) -> Tuple[Any, Any]: 80 """ 81 Find q, dq pair corresponding to [producer -> q -> dq -> user] 82 Utils works by finding dq arg of user and ensuring it is connected to 83 producer 84 """ 85 dq_node = None 86 for n in user.args: 87 if ( 88 isinstance(n, torch.fx.Node) 89 and n.op == "call_function" 90 and n.target in _DEQUANTIZE_OPS 91 ): 92 if _is_connected(produer, n): 93 dq_node = n 94 break 95 if dq_node is None: 96 for n in user.kwargs: 97 if ( 98 isinstance(n, torch.fx.Node) 99 and n.op == "call_function" 100 and n.target in _DEQUANTIZE_OPS 101 ): 102 if _is_connected(produer, n): 103 dq_node = n 104 break 105 if dq_node is None: 106 return (None, None) 107 108 q_node = None 109 if ( 110 dq_node.args[0].op == "call_function" # type: ignore[union-attr] 111 and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr] 112 ): 113 q_node = dq_node.args[0] 114 return (q_node, dq_node) 115 116 117def _is_sym_size_node(node: Node): 118 return ( 119 node.op == "call_function" 120 and node.target == torch.ops.aten.sym_size.default 121 or node.target == torch.ops.aten.sym_numel.default 122 or node.target == torch.ops.aten.sym_numel 123 or node.target == torch.ops.aten.sym_size 124 ) 125 126 127def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]: 128 node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) 129 return node_users 130 131 132def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: 133 if annotation is None: 134 return False 135 input_qspec_map = annotation.input_qspec_map 136 output_qspec = annotation.output_qspec 137 if len(input_qspec_map) == 0 and output_qspec is None: 138 return False 139 return True 140 141 142def _get_tensor_constant_from_node(node, m): 143 if node is None: 144 return None 145 assert node.op == "get_attr" 146 target_atoms = node.target.split(".") 147 attr_itr = m 148 for i, atom in enumerate(target_atoms): 149 if not hasattr(attr_itr, atom): 150 raise RuntimeError( 151 f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" 152 ) 153 attr_itr = getattr(attr_itr, atom) 154 return attr_itr 155 156 157def _get_all_arguments(orig_args, orig_kwargs, args_schema): 158 all_args = [] 159 for i, schema in enumerate(args_schema): 160 if schema.name in orig_kwargs: 161 all_args.append(orig_kwargs[schema.name]) 162 elif not schema.kwarg_only and i < len(orig_args): 163 all_args.append(orig_args[i]) 164 else: 165 all_args.append(schema.default_value) 166 return all_args 167 168 169def _is_supported_batch_norm_for_training(node: Node): 170 """ 171 Return True if the given node refers to an aten batch norm op QAT supports. 172 """ 173 supported_ops = [ 174 torch.ops.aten.batch_norm.default, 175 torch.ops.aten._native_batch_norm_legit.default, 176 # Note: we won't need this op anymore after batch norm consolidation 177 # For now, we need to continue to support it because it gives better 178 # training numerics than `_native_batch_norm_legit` 179 torch.ops.aten.cudnn_batch_norm.default, 180 torch.ops.aten.miopen_batch_norm.default, 181 ] 182 return node.target in supported_ops 183 184 185# TODO: move this to torch/ao/quantization/utils.py 186def _is_conv_node(n: Node): 187 """ 188 Return whether the node refers to an aten conv op. 189 """ 190 return n.op == "call_function" and n.target in [ 191 torch.ops.aten.conv1d.default, 192 torch.ops.aten.conv2d.default, 193 ] 194 195 196def _is_conv_transpose_node(n: Node): 197 """ 198 Return whether the node refers to an aten conv_transpose op. 199 """ 200 return n.op == "call_function" and n.target in [ 201 torch.ops.aten.conv_transpose1d, 202 torch.ops.aten.conv_transpose1d.default, 203 torch.ops.aten.conv_transpose2d, 204 torch.ops.aten.conv_transpose2d.input, 205 ] 206 207 208def _is_conv_or_conv_transpose_node(n: Node): 209 """ 210 Return whether the node refers to an aten conv or conv transpose op. 211 """ 212 return _is_conv_node(n) or _is_conv_transpose_node(n) 213 214 215def _is_conv_transpose_fn(conv_fn: Callable): 216 return conv_fn in [F.conv_transpose1d, F.conv_transpose2d] 217 218 219def _is_bn_node(n: Node): 220 return ( 221 _is_supported_batch_norm_for_training(n) 222 or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default 223 ) 224 225 226def fold_bn_weights_into_conv_node( 227 conv_node: Node, 228 conv_weight_node: Node, 229 conv_bias_node: Optional[Node], 230 bn_node: Node, 231 m: GraphModule, 232) -> None: 233 # conv args: input, weight, bias, stride, padding, dilation, ... 234 conv_w = _get_tensor_constant_from_node(conv_weight_node, m) 235 conv_b = _get_tensor_constant_from_node(conv_bias_node, m) 236 transpose = _is_conv_transpose_node(conv_node) 237 238 # eval bn args: input, weight, bias, running mean, running var, momentum, eps 239 # train bn args: input, weight, bias, running mean, running var, training, momentum, eps 240 bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] 241 bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) 242 bn_w = _get_tensor_constant_from_node(bn_args[1], m) 243 bn_b = _get_tensor_constant_from_node(bn_args[2], m) 244 bn_rm = _get_tensor_constant_from_node(bn_args[3], m) 245 bn_rv = _get_tensor_constant_from_node(bn_args[4], m) 246 if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: 247 eps_arg_index = 6 248 elif _is_supported_batch_norm_for_training(bn_node): 249 eps_arg_index = 7 250 else: 251 raise ValueError("BN node target is unexpected ", bn_node.target) 252 bn_eps = bn_args[eps_arg_index] 253 254 fused_weight, fused_bias = fuse_conv_bn_weights( 255 conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose 256 ) 257 258 # update the weight and bias for conv 259 conv_args = list(conv_node.args) 260 # filling in the default bias argument 261 if len(conv_args) == 2: 262 conv_args.append(None) 263 264 # calling data since the fused_weight and fused_bias are nn.Parameter 265 weight_attr_name = conv_weight_node.target 266 assert isinstance(weight_attr_name, str) 267 _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) 268 if conv_bias_node is not None: 269 bias_attr_name = conv_bias_node.target 270 _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) 271 else: 272 bias_attr_name = weight_attr_name + "_bias" 273 _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) 274 with m.graph.inserting_before(conv_node): 275 get_bias_node = m.graph.get_attr(bias_attr_name) 276 # NOTE: here we assume the bias of conv is not quantized! 277 conv_args[2] = get_bias_node 278 conv_node.args = tuple(conv_args) 279 280 # native_batch_norm has 3 outputs, we expect getitem calls on the output 281 # and we want to replace the uses of getitem 0 with the output of conv 282 # 283 if bn_node.target == torch.ops.aten.batch_norm.default: 284 # With the new training ir, instead of batch_norm + getitem, 285 # we only have the batch_norm node. 286 # 287 # Before: 288 # conv -> bn -> users 289 # After: 290 # conv -> users 291 # bn has no users now 292 bn_node.replace_all_uses_with(conv_node) 293 else: 294 # Before: 295 # conv -> bn - (first output) -> users1 296 # \ - (second output) -> users2 297 # \ - (third output) -> users3 298 # After: 299 # conv -> (first output) -> users1 300 # bn - 301 # \ - (second output) -> users2 302 # \ - (third output) -> users3 303 # if users2 and users3 are empty then bn will be removed through dead code elimination 304 for user in bn_node.users: 305 if ( 306 user.op != "call_function" 307 or user.target != operator.getitem 308 or user.args[1] != 0 309 ): 310 continue 311 user.replace_all_uses_with(conv_node) 312 313 # If the BN node does not have users, erase it from the graph 314 # Note: we need to do this manually because the model can still be in train 315 # mode at this point, in which case DCE won't erase the BN node automatically 316 # since the node refers to a mutating op. Here we still need to call DCE first 317 # to get rid of the unused getitem nodes that consume the BN node. 318 m.graph.eliminate_dead_code() 319 if len(bn_node.users) == 0: 320 m.graph.erase_node(bn_node) 321 322 323# fuse conv bn weights, inplace modification of the graph_module and graph 324def _fuse_conv_bn_(m: GraphModule) -> None: 325 has_bn = any(_is_bn_node(n) for n in m.graph.nodes) 326 if not has_bn: 327 return 328 for n in m.graph.nodes: 329 if n.op != "call_function" or n.target not in ( 330 torch.ops.aten._native_batch_norm_legit_no_training.default, 331 torch.ops.aten.batch_norm.default, 332 ): 333 continue 334 bn_node = n 335 n = bn_node.args[0] 336 if not _is_conv_or_conv_transpose_node(n): 337 continue 338 conv_node = n 339 conv_weight_node = conv_node.args[1] 340 conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None 341 fold_bn_weights_into_conv_node( 342 conv_node, conv_weight_node, conv_bias_node, bn_node, m 343 ) 344 345 m.graph.eliminate_dead_code() 346 m.recompile() 347 348 349def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: 350 # TODO: move this information to fx node itself 351 node_name_to_scope: Dict[str, Tuple[str, type]] = {} 352 for n in model.graph.nodes: 353 nn_module_stack = n.meta.get("nn_module_stack", None) 354 current_scope = ("", type(None)) 355 if nn_module_stack: 356 bt = list(nn_module_stack.values())[-1] 357 current_scope = (bt[0].split(".")[-1], bt[1]) 358 node_name_to_scope[n.name] = current_scope 359 return node_name_to_scope 360 361 362def _get_aten_graph_module_for_pattern( 363 pattern: Callable, 364 example_inputs: Tuple[Any, ...], 365 is_cuda: bool = False, 366 **kwargs, 367) -> GraphModule: 368 """ 369 Convert the pattern to an FX graph with decomposed aten ops. 370 """ 371 if is_cuda: 372 example_inputs = tuple( 373 [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] 374 ) 375 aten_pattern = capture_pre_autograd_graph( 376 pattern, # type: ignore[arg-type] 377 example_inputs, 378 kwargs, 379 ) 380 aten_pattern.graph.eliminate_dead_code() 381 aten_pattern.recompile() 382 383 # ep.module() adds copy_ nodes for the mutated inputs. 384 # For patterns, it doesn't matter 385 for node in aten_pattern.graph.nodes: 386 if ( 387 node.op == "call_function" 388 and node.target == torch.ops.aten.copy_.default 389 and len(node.users) == 0 390 ): 391 aten_pattern.graph.erase_node(node) 392 393 aten_pattern.graph.eliminate_dead_code() 394 aten_pattern.recompile() 395 396 return aten_pattern # type: ignore[return-value] 397 398 399def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: 400 """Remove .tensor overload for quantize/dequantize ops so that we can 401 use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e 402 """ 403 _MAP = { 404 torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, 405 torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, 406 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, 407 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, 408 torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, 409 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, 410 torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, 411 torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, 412 torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, 413 } 414 for n in match_pattern.graph.nodes: 415 if n.op != "call_function": 416 continue 417 if n.target in _MAP: 418 n.target = _MAP[n.target] 419 420 421def _is_literal(arg): 422 if isinstance(arg, (int, float)): 423 return True 424 if isinstance(arg, (tuple, list)): 425 return all(map(_is_literal, arg)) 426 return False 427 428 429def _replace_literals_with_new_placeholders( 430 gm: torch.fx.GraphModule, 431 merge_dup: bool = False, 432 exclude_literals: Optional[List[Any]] = None, 433): 434 """Replace the literals in the graph with placeholder nodes that's created on the fly while we 435 traverse the graph, so that the literal arguments in the graph can be matched and replaced 436 437 To use this, the pattern and replacement graph should have the exact same number of literal args 438 and they should be used in the exact same order in the pattern and replacement graph. 439 440 If the literal arguments are not used in the same order in pattern and replacement graph, please 441 use `_replace_literals_with_existing_placeholders` instead 442 443 Args: 444 `gm`: input GraphModule that we'll transform 445 `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in 446 the graph, whether they should correspond to the same placeholder or not 447 `exclude_literals`: a list of literals that will not be replaced with placeholders 448 449 Example: 450 451 # 1. Original Graph 452 def pattern(self, x): 453 return x + 3 454 455 def replacement(self, x): 456 return x - 3 457 458 example_inputs = (torch.randn(1, 3, 3, 3),) 459 pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) 460 replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) 461 462 # 2. Before calling replace literals we'll see the following graph: 463 def pattern(self, x): 464 return x + 3 465 466 def replacement(self, x): 467 return x - 3 468 469 pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) 470 replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) 471 472 # 3. After replacing literals with new placeholder nodes 473 474 def pattern(self, x, new_ph): 475 return x + new_ph 476 477 def pattern(self, x, new_ph): 478 return x - new_ph 479 480 """ 481 last_ph = None 482 cnt = 0 483 literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {} 484 if exclude_literals is None: 485 exclude_literals = [] 486 487 in_spec = gm._in_spec 488 args_spec = in_spec.children_specs[0] 489 for node in gm.graph.nodes: 490 if node.op == "placeholder": 491 last_ph = node 492 cnt += 1 493 continue 494 with gm.graph.inserting_after(last_ph): 495 new_args = [] 496 for arg in node.args: 497 if _is_literal(arg) and arg not in exclude_literals: 498 if merge_dup and arg in literal_to_ph: 499 new_args.append(literal_to_ph[arg]) 500 else: 501 ph_node = gm.graph.placeholder("arg" + str(cnt)) 502 new_args.append(ph_node) 503 args_spec.children_specs.append(LeafSpec()) 504 cnt += 1 505 if merge_dup: 506 literal_to_ph[arg] = ph_node 507 else: 508 new_args.append(arg) 509 new_args = tuple(new_args) 510 511 node.args = new_args 512 513 # Update `num_nodes`, `num_leaves`, `num_children`. 514 args_spec.__post_init__() 515 in_spec.__post_init__() 516 return gm 517 518 519def _replace_literals_with_existing_placeholders( 520 gm: torch.fx.GraphModule, 521 exclude_literals: Optional[List[Any]] = None, 522 literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None, 523): 524 """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments 525 in the graph can be matched and replaced 526 527 To use this, all literal args in the graph should be unique and each of them should correspond 528 to exactly one placeholder node 529 530 # 1. Original Graph 531 def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): 532 return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) 533 534 def replacement(x_i8, scale, zero_point, quant_min, quant_max): 535 x_i8 = torch.clamp(x_i8, quant_min, quant_max) 536 return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) 537 538 example_inputs = ( 539 torch.randn(1, 3, 3, 3), 540 1.0, 541 0, 542 -128, 543 127, 544 ) 545 pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) 546 replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) 547 548 # 2. Before calling replace literals we'll see the following graph: 549 def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): 550 # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values 551 return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) 552 553 def replacement(x_i8, scale, zero_point, quant_min, quant_max): 554 # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values 555 x_i8 = torch.clamp(x_i8, -128, 127) 556 return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) 557 558 # Note that literal args appear in different order in pattern and replacement graph, so 559 # we can't use _replace_literals_with_new_placeholders 560 561 literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} 562 pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) 563 replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) 564 565 # 3. After replacing literals with existing placeholder nodes 566 567 def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): 568 # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values 569 return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) 570 571 def replacement(x_i8, scale, zero_point, quant_min, quant_max): 572 # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values 573 x_i8 = torch.clamp(x_i8, quant_min, quant_max) 574 return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) 575 """ 576 if exclude_literals is None: 577 exclude_literals = [] 578 579 if literal_to_ph_idx is None: 580 literal_to_ph_idx = {} 581 582 phs = [node for node in gm.graph.nodes if node.op == "placeholder"] 583 584 for node in gm.graph.nodes: 585 if node.op != "call_function": 586 continue 587 new_args = [] 588 for arg in node.args: 589 if ( 590 _is_literal(arg) 591 and arg not in exclude_literals 592 and arg in literal_to_ph_idx 593 ): 594 ph_idx = literal_to_ph_idx[arg] 595 ph_node = phs[ph_idx] 596 new_args.append(ph_node) 597 else: 598 new_args.append(arg) 599 new_args = tuple(new_args) 600 node.args = new_args 601 return gm 602 603 604# TODO: Handle this in export itself and don't wrap the model in another GraphModule 605# in prepare and convert 606def _disallow_eval_train(model: GraphModule): 607 """ 608 Disallow calling `model.train()` or `model.eval()` on the given GraphModule. 609 This is useful for exported models, where these methods don't actually behave as expected. 610 """ 611 error_message = """ 612 Calling train() or eval() is not supported for exported models. 613 Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead. 614 615 If you cannot replace the calls to `model.train()` and `model.eval()`, you may override 616 the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`, 617 which does the above automatically for you. Note that this has limited effect on switching 618 behavior between train and eval modes, and should be used only for special ops such as dropout 619 and batchnorm. 620 """ 621 622 def _train(self, mode: bool = True): 623 raise NotImplementedError(error_message) 624 625 def _eval(self, mode: bool = True): 626 raise NotImplementedError(error_message) 627 628 model.train = types.MethodType(_train, model) # type: ignore[method-assign] 629 model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] 630 return model 631