1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-unsafe 4 5 6# This file contains all the functions that reorder ops in the graph module. 7 8import copy 9from collections import defaultdict 10from math import prod 11from typing import cast, DefaultDict, List, Set, Tuple 12 13import torch 14import torch.fx 15from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape 16from executorch.backends.cadence.aot.pass_utils import ( 17 CadencePassAttribute, 18 get_overload_packet, 19 register_cadence_pass, 20) 21from executorch.backends.cadence.aot.utils import get_edge_overload_packet 22from executorch.exir.dialects._ops import ops as exir_ops 23from executorch.exir.dialects.edge._ops import EdgeOpOverload 24from executorch.exir.pass_base import ExportPass, PassResult 25from executorch.exir.tensor import num_bytes_from_shape_and_dtype 26 27# A list of ops that can be trivially quantized 28trivially_quantizable_ops_overloadpkt = { 29 torch.ops.aten.slice_copy, 30 torch.ops.aten.slice, 31 torch.ops.aten.view_copy, 32 torch.ops.aten.view, 33 torch.ops.aten.clone, 34 torch.ops.aten.transpose_copy, 35 torch.ops.aten.transpose, 36 torch.ops.aten.permute_copy, 37 torch.ops.aten.permute, 38 torch.ops.aten.squeeze_copy, 39 torch.ops.aten.squeeze, 40 torch.ops.aten.unsqueeze_copy, 41 torch.ops.aten.unsqueeze, 42 torch.ops.aten.chunk, 43 torch.ops.aten.contiguous, 44 torch.ops.aten.select_copy, 45 exir_ops.edge.aten.slice_copy, 46 exir_ops.edge.aten.view_copy, 47 exir_ops.edge.aten.clone, 48 exir_ops.edge.aten.transpose_copy, 49 exir_ops.edge.aten.permute_copy, 50 exir_ops.edge.aten.squeeze_copy, 51 exir_ops.edge.aten.unsqueeze_copy, 52 exir_ops.edge.aten.unfold_copy, 53 exir_ops.edge.aten.chunk, 54 exir_ops.edge.aten.contiguous, 55 exir_ops.edge.aten.select_copy, 56} 57 58# slice-equivalent ops 59slice_or_select_overloadpkt = { 60 torch.ops.aten.slice_copy, 61 torch.ops.aten.select_copy, 62 exir_ops.edge.aten.slice_copy, 63 exir_ops.edge.aten.select_copy, 64} 65 66 67@register_cadence_pass(CadencePassAttribute(opt_level=2)) 68class AdvanceQuantizeOpAboveDefInBranchPass(ExportPass): 69 """ 70 If the graph is branched with the following pattern: 71 I = ... 72 S1 = slice(I) 73 Q1 = quantize(S1) 74 S2 = slice(I) 75 Q2 = quantize(S2) 76 S3 = slice(I) 77 Q3 = quantize(S3) 78 ... 79 such that the elements in the slices S1 + S2 + S3 is greater than I, 80 we can advance the quantize above their defs (i.e., all the slice nodes), 81 and reorder the pattern to the following: 82 I = ... 83 Q1 = quantize(I) 84 S1 = slice(Q1) 85 Q1 = requantize(S1) 86 S2 = slice(Q1) 87 Q2 = requantize(S2) 88 S3 = slice(Q1) 89 Q3 = requantize(S3) 90 ... 91 Note that the other passes won't do this transformation because they expect 92 a linear chain of def-use, which is not true here; the uses of I are 93 branched. 94 """ 95 96 def __init__(self): 97 super().__init__() 98 self.graph_module = None 99 100 # Starting at node, iterate through its successors, bypassing any trivially 101 # quantizable op. If all the descendents are quantize ops, return them. 102 def get_descendent_quant_ops(self, node: torch.fx.Node) -> List[torch.fx.Node]: 103 # The list of quant ops that are descendents of node, such that the only 104 # nodes in the path from node --> quant are trivially quantizable ops. 105 descendent_quant_ops = [] 106 # The list of trivially quantizable ops in the path from node --> quant op. 107 trivial_quantized_ops = [] 108 109 users = list(node.users.keys()) 110 while users: 111 user = users.pop(0) 112 user_target = get_overload_packet(user.target) 113 # Record a quant op successor 114 if user_target in { 115 torch.ops.quantized_decomposed.quantize_per_tensor, 116 exir_ops.edge.quantized_decomposed.quantize_per_tensor, 117 }: 118 descendent_quant_ops.append(user) 119 # If the successor is a trivially quantizable op, consider its users 120 # instead. 121 elif user_target in trivially_quantizable_ops_overloadpkt: 122 trivial_quantized_ops.append(user) 123 users.extend(list(user.users.keys())) 124 # Otherwise all successors of node are not quant op, so break the loop. 125 else: 126 descendent_quant_ops.clear() 127 break 128 129 # If all the nodes in trivial_quantize_ops of the node were slice ops, 130 # ensure that the advance is still profitable. 131 if descendent_quant_ops and all( 132 get_overload_packet(x.target) in slice_or_select_overloadpkt 133 for x in trivial_quantized_ops 134 ): 135 # Profitability metric: the sum of all the output slices must be at 136 # least half the input node slice. 137 slice_sizes = [ 138 prod(list(y)) 139 for x in trivial_quantized_ops 140 if (y := get_shape(self.graph_module, x)) is not None 141 ] 142 node_shape = get_shape(self.graph_module, node) 143 node_size = prod(list(node_shape)) if node_shape is not None else 0 144 if node_size > 2 * sum(slice_sizes): 145 descendent_quant_ops.clear() 146 147 return descendent_quant_ops 148 149 def advance_quantize_op(self, graph_module: torch.fx.GraphModule): 150 graph = graph_module.graph 151 for node in graph.nodes: 152 # We are only interested in call functions and placeholders 153 if node.op not in {"placeholder", "call_function"}: 154 continue 155 # If the node is trivially quantizable, skip it 156 if ( 157 get_overload_packet(node.target) 158 in trivially_quantizable_ops_overloadpkt 159 ): 160 continue 161 # Get the descendent quant ops that are connected to the current 162 # node via trivially quantizable ops. 163 descendent_quant_ops = self.get_descendent_quant_ops(node) 164 if not descendent_quant_ops: 165 continue 166 167 # Get the insertion point below which we need to insert anything. 168 # if node is a placeholder, we will only insert a new node after 169 # all the placeholders in the graph. 170 insertion_pt = ( 171 get_placeholders(graph)[-1] if node.op == "placeholder" else node 172 ) 173 174 # If the node only has a single quant op as descendent, we can 175 # simply hoist the quant op below the current node as its single 176 # child. 177 if len(descendent_quant_ops) == 1: 178 quant_node = descendent_quant_ops.pop() 179 # Replace the uses of quant node with its predecessor 180 quant_node.replace_all_uses_with(quant_node.args[0]) # pyre-fixme[6] 181 # Hoist the quant node after the current node. Make sure that 182 # the insertion is after placeholders 183 with graph.inserting_after(insertion_pt): 184 dom_quant_args = (node,) + quant_node.args[1:] 185 dom_quant_node = graph.call_function( 186 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 187 ) 188 dom_quant_node.meta = node.meta 189 node.replace_all_uses_with(dom_quant_node) 190 dom_quant_node.args = dom_quant_args 191 graph.erase_node(quant_node) 192 continue 193 194 # Otherwise we have the quant descendents. Cluster them into sets 195 # that have the same scale, zero_point, and dtype. We use quant_dict 196 # for the clustering 197 quant_dict: DefaultDict[Tuple, int] = defaultdict(int) 198 for quant_node in descendent_quant_ops: 199 quant_dict[quant_node.args[1:]] += 1 200 rep_args = sorted(quant_dict.keys(), key=lambda x: x[1]).pop() 201 202 # Create a new quant node that dominates all the nodes in 203 # descendent_quant_ops. Make sure that the insertion is after 204 # all the placeholders. 205 with graph.inserting_after(insertion_pt): 206 dom_quant_args = (node,) + rep_args 207 dom_quant_node = graph.call_function( 208 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 209 ) 210 dom_quant_node.meta = node.meta 211 node.replace_all_uses_with(dom_quant_node) 212 dom_quant_node.args = dom_quant_args 213 214 # Finally, convert each of the quant node to a dequant/quant pair that 215 # requantizes the data flowing through dom_quant_node. 216 # TODO: Once requantize is implemented for PT2, replace the 217 # dequant/quant pair here with a single requantize node 218 for quant_node in descendent_quant_ops: 219 with graph.inserting_before(quant_node): 220 dequant_node = graph.call_function( 221 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 222 ) 223 dequant_node.args = (quant_node.args[0],) + rep_args 224 quant_node.args = (dequant_node,) + quant_node.args[1:] 225 226 graph_module.recompile() 227 graph_module.graph.eliminate_dead_code() 228 229 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 230 self.graph_module = graph_module 231 self.advance_quantize_op(graph_module) 232 result = super().call(graph_module) 233 return result 234 235 236@register_cadence_pass(CadencePassAttribute(opt_level=1)) 237class AdvanceQuantizeOpAboveDefChainPass(ExportPass): 238 """ 239 If the input to quantize op is linear chain of view, transpose, permute, or 240 slice ops that are trivially quantized, we can convert the pattern 241 view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to 242 quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8). 243 The benefit of such reordering is that the view/transpose/permute/slice 244 will move far less data. 245 """ 246 247 def __init__(self): 248 super().__init__() 249 self.graph_module = None 250 251 # Return true if advancing the quantize node is feasible 252 def advancing_feasible(self, quant_node: torch.fx.Node): 253 assert quant_node.op == "call_function" and len(quant_node.args) >= 1 254 # Get the input of the quant node. Only proceed if it's a torch node. 255 inp = quant_node.args[0] 256 if not isinstance(inp, torch.fx.Node): 257 return False 258 259 # Return false if the input to the quantize node is (1) not trivially 260 # quantizable, or (2) has more than one user. 261 inp_users = list(inp.users.keys()) 262 inp_overloadpkt = None 263 if isinstance(inp.target, EdgeOpOverload): 264 inp_overloadpkt = get_edge_overload_packet(inp.target) 265 else: 266 inp_overloadpkt = get_overload_packet(inp.target) 267 268 if ( 269 inp_overloadpkt not in trivially_quantizable_ops_overloadpkt 270 or len(inp_users) != 1 271 ): 272 return False 273 274 # Advancing quantize op above slice nodes is tricky. If we advance the 275 # quantize node above slice, then we will quantize the input to the slice 276 # op, which can be expensive. We only bypass nop slice at present. 277 if inp_overloadpkt in slice_or_select_overloadpkt: 278 sliced_tensor = inp.args[0] 279 assert isinstance(sliced_tensor, torch.fx.Node) 280 slice_input_shape = get_shape(self.graph_module, sliced_tensor) 281 slice_output_shape = get_shape(self.graph_module, inp) 282 # If we could not glean the shapes, or the slice op is a nop, bail 283 if ( 284 slice_output_shape is None 285 or slice_input_shape is None 286 or prod(list(slice_output_shape)) < prod(list(slice_input_shape)) 287 ): 288 return False 289 290 # All the conditions satisfied, we advance. 291 return True 292 293 def advance_quantize_op(self, graph_module: torch.fx.GraphModule): 294 graph = graph_module.graph 295 for node in reversed(graph.nodes): 296 if get_overload_packet(node.target) not in ( 297 exir_ops.edge.quantized_decomposed.quantize_per_tensor, 298 torch.ops.quantized_decomposed.quantize_per_tensor, 299 ): 300 continue 301 302 if not self.advancing_feasible(node): 303 continue 304 305 trivially_quantizable_op = node.args[0] 306 # The input to the quant node must now be the input to the trivially 307 # quantizable op. 308 quant_args = list(node.args) 309 quant_args[0] = trivially_quantizable_op.args[0] 310 311 # Insert the new quant node with updated args before the current 312 # quant node. 313 with graph.inserting_before(node): 314 quant_node = graph.call_function(node.target, args=tuple(quant_args)) 315 quant_node.meta = node.meta 316 # Move the trivially quantizable node after the quant node 317 with graph.inserting_after(node): 318 tq_args = list(trivially_quantizable_op.args) 319 tq_args[0] = quant_node 320 tq_node = graph.call_function( 321 trivially_quantizable_op.target, 322 args=tuple(tq_args), 323 kwargs=trivially_quantizable_op.kwargs, 324 ) 325 tq_node.meta = trivially_quantizable_op.meta 326 # Replace all uses of node with newly created tq_node 327 node.replace_all_uses_with(tq_node) 328 # We can safely remove the quant node and trivially quantizable op 329 graph.erase_node(node) 330 graph.erase_node(trivially_quantizable_op) 331 332 graph_module.recompile() 333 graph_module.graph.eliminate_dead_code() 334 335 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 336 self.graph_module = graph_module 337 self.advance_quantize_op(graph_module) 338 result = super().call(graph_module) 339 return result 340 341 342@register_cadence_pass(CadencePassAttribute(opt_level=1)) 343class PostponeDequantizeOpBelowUseChainPass(ExportPass): 344 """ 345 If the consumer of dequantize is a linear chain of view, transpose, permute, 346 or slice ops that are trivially quantized, we can convert the pattern 347 dequantize(int8/uint8) -> view/transpose/permute/slice(fp32) to 348 view/transpose/permute/slice(int8/uint8) -> dequantize(int8/uint8) 349 The benefit of such reordering is that the view/transpose/permute/slice 350 will move far less data. 351 """ 352 353 def __init__(self): 354 super().__init__() 355 self.graph_module = None 356 357 # Return true if postponing the dequantize node is feasible 358 def postponing_feasible(self, dequant_node: torch.fx.Node): 359 users = list(dequant_node.users.keys()) 360 # Check if the dequantize op has a single user, and that user is 361 # trivially quantizable. 362 trivially_quantizable_users = all( 363 get_overload_packet(user.target) in trivially_quantizable_ops_overloadpkt 364 for user in users 365 ) 366 if len(users) == 1: 367 return trivially_quantizable_users 368 369 # Otherwise check if all the users are slice op 370 if not all( 371 get_overload_packet(user.target) in slice_or_select_overloadpkt 372 for user in users 373 ): 374 return False 375 376 dequant_shape = get_shape(self.graph_module, dequant_node) 377 slice_shapes = [ 378 shape 379 for user in users 380 if (shape := get_shape(self.graph_module, user)) 381 and ( 382 # skip slices that are the size of the sliced tensor itself. 383 # They should technically get removed in the later passes as nop. 384 shape is None 385 or dequant_shape is None 386 or prod(list(shape)) != prod(list(dequant_shape)) 387 ) 388 ] 389 390 if dequant_shape is not None and all( 391 shape is not None for shape in slice_shapes 392 ): 393 dequant_bytes = num_bytes_from_shape_and_dtype(dequant_shape, torch.float32) 394 slice_bytes = sum( 395 [ 396 num_bytes_from_shape_and_dtype(shape, torch.float32) 397 for shape in slice_shapes 398 ] 399 ) 400 if slice_bytes <= dequant_bytes: 401 return True 402 403 # If the users of each slice op is quantize op, then we can postpone 404 # dequantize, and convert slice -> dequantize -> quantize to 405 # slice -> requantize. 406 users = [x for y in users for x in y.users if x.op != "output"] 407 return all( 408 get_overload_packet(x.target) 409 in { 410 exir_ops.edge.quantized_decomposed.quantize_per_tensor, 411 exir_ops.edge.quantized_decomposed.quantize_per_channel, 412 } 413 for x in users 414 ) 415 416 def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool: 417 # Different supported dequant ops have their own default variants 418 packet_to_overload_map = { 419 exir_ops.edge.quantized_decomposed.dequantize_per_tensor: "default", 420 exir_ops.edge.quantized_decomposed.dequantize_per_channel: "default", 421 } 422 graph = graph_module.graph 423 modified = False 424 for node in graph.nodes: 425 overload_packet = get_overload_packet(node.target) 426 if ( 427 overload_packet not in packet_to_overload_map.keys() 428 or not self.postponing_feasible(node) 429 ): 430 continue 431 432 for user in node.users: 433 with graph.inserting_after(user): 434 dequant_node = graph.call_function( 435 getattr( 436 overload_packet, packet_to_overload_map[overload_packet] 437 ), 438 args=(user, *node.args[1:]), 439 ) 440 dequant_node.meta = user.meta.copy() 441 # Remove meta["debug_handle"] on new node. Reassign it at the 442 # caller level by calling generate_missing_debug_handles 443 dequant_node.meta.pop("debug_handle") 444 user.replace_all_uses_with(dequant_node) 445 dequant_node.args = (user, *node.args[1:]) 446 447 pred = node.args[0] 448 node.replace_all_uses_with(pred) 449 graph.erase_node(node) 450 modified = True 451 452 graph_module.recompile() 453 return modified 454 455 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 456 # The logic in postpone_dequantize_op that handles branching checks the shape 457 # of the dequant node, which isn't available if that node was already postponed 458 # in the same pass invokation. The shape information is recreated by tracing in 459 # super().call(), meaning that every branch in the graph that we wish to postpone 460 # dequant past requires retracing. We iterate the pass until it no longer modifies 461 # the graph (up to 3 times max, to avoid potential infinite loops) 462 self.graph_module = graph_module 463 iter_count = 0 464 modified = True 465 466 while modified and iter_count < 3: 467 modified = self.postpone_dequantize_op(self.graph_module) 468 self.graph_module = super().call(self.graph_module).graph_module 469 iter_count += 1 470 471 return super().call(self.graph_module) 472 473 474@register_cadence_pass(CadencePassAttribute(opt_level=1)) 475class SinkOpsCloserToUsePass(ExportPass): 476 """ 477 Assume that the dequantize op D = dequantize(I) has only a single user. 478 If the current graph looks like 479 I = ...; 480 D = dequantize(I); 481 ... 482 Y = use(D); 483 then we can postpone the dequantize op closer to its use, and convert the 484 graph to: 485 I = ...; 486 ... 487 D = dequantize(I); 488 Y = use(D); 489 490 The transformation is valid since D had a single user. The benfit comes from 491 the fact that now we have I in the live range instead of D, which has a 492 much smaller size. 493 """ 494 495 sinkable_ops: Set[EdgeOpOverload] = { 496 exir_ops.edge.aten.dequantize, 497 exir_ops.edge.quantized_decomposed.dequantize_per_tensor, 498 exir_ops.edge.quantized_decomposed.dequantize_per_channel, 499 } 500 501 def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule): 502 graph = graph_module.graph 503 # We are only interested in sinkable nodes 504 sinkable_nodes = [ 505 node 506 for node in graph.nodes 507 if isinstance(node.target, EdgeOpOverload) 508 and get_edge_overload_packet(node.target) in self.sinkable_ops 509 ] 510 for node in sinkable_nodes: 511 # The sinkable node must have a single user 512 users = list(node.users.keys()) 513 if len(users) != 1: 514 continue 515 516 # Insert the dequant node just before its user 517 with graph.inserting_before(users[0]): 518 new_node = graph.call_function( 519 node.target, args=node.args, kwargs=node.kwargs 520 ) 521 new_node.meta = node.meta 522 node.replace_all_uses_with(new_node) 523 graph.erase_node(node) 524 525 graph_module.recompile() 526 527 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 528 self.sink_ops_closer_to_use(graph_module) 529 result = super().call(graph_module) 530 return result 531 532 533@register_cadence_pass(CadencePassAttribute(opt_level=1)) 534class HoistOpsCloserToDefPass(ExportPass): 535 """ 536 Assume that the input I to a quantize op Q = quantize(I) has only a single 537 use, the quantize node itself. 538 If the current graph looks like 539 I = ...; 540 ... 541 Q = quantize(I); 542 X = use(Q); 543 then we can hoist the quantize op closer to its def, and convert the 544 graph to: 545 I = ...; 546 Q = quantize(I); 547 ... 548 X = use(Q); 549 550 The transformation is valid since I had a single user. The benefit comes from 551 the fact that now we have Q in the live range instead of I, which has a 552 much smaller size. The same transformation also applies to slice/select op. 553 """ 554 555 hoistable_ops: Set[EdgeOpOverload] = { 556 exir_ops.edge.quantized_decomposed.quantize_per_tensor, 557 exir_ops.edge.aten.slice_copy, 558 exir_ops.edge.aten.select_copy, 559 } 560 561 def hoist_ops_closer_to_def(self, graph_module: torch.fx.GraphModule): 562 graph = graph_module.graph 563 # We are only interested in hoistable nodes 564 hoistable_nodes = [ 565 node 566 for node in graph.nodes 567 if isinstance(node.target, EdgeOpOverload) 568 and get_edge_overload_packet(node.target) in self.hoistable_ops 569 ] 570 for node in hoistable_nodes: 571 def_node = node.args[0] 572 if not isinstance(def_node, torch.fx.Node): 573 continue 574 # The def node must have a single user 575 users = list(def_node.users.keys()) 576 if len(users) != 1: 577 continue 578 579 # Get the node args as list 580 args = list(node.args) 581 582 # If the graph has placeholders, we do not want to hoist above the 583 # last placeholder. Otherwise we will shrink the live range of the 584 # def_node considerably, which could lead to reuse of input memory. 585 def_node = ( 586 get_placeholders(graph)[-1] 587 if def_node.op == "placeholder" 588 else def_node 589 ) 590 591 # If the node is quantize_per_channel, we need to hoist the scale 592 # and zero_point tensors as well. 593 if ( 594 node.target 595 == exir_ops.edge.quantized_decomposed.quantize_per_channel.default 596 ): 597 scale, zero_point = args[1], args[2] 598 with graph.inserting_after(def_node): 599 zero_point_copy = graph.node_copy(zero_point) 600 scale_copy = graph.node_copy(scale) 601 args[1], args[2] = scale_copy, zero_point_copy 602 def_node = zero_point_copy 603 604 # Insert the quant node just after def_node 605 with graph.inserting_after(def_node): 606 new_node = graph.call_function( 607 node.target, args=tuple(args), kwargs=node.kwargs 608 ) 609 new_node.meta = node.meta 610 node.replace_all_uses_with(new_node) 611 graph.erase_node(node) 612 613 # Eliminate dead code 614 graph_module.recompile() 615 graph_module.graph.eliminate_dead_code() 616 617 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 618 self.hoist_ops_closer_to_def(graph_module) 619 result = super().call(graph_module) 620 return result 621 622 623@register_cadence_pass(CadencePassAttribute(opt_level=1)) 624class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(ExportPass): 625 """ 626 A common pattern seen in transformer models. If the consumer of permute 627 is a view op, swap their order so permute is below view. 628 Change "permute -> view" to "view -> permute" 629 This is to optimize a chain of view->permute->view->permute... 630 so that the chain will be become view->v...->view->permute->p...->permute. 631 The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and 632 FuseCascadedViewOps(). 633 Notice the class name has ViewSqueeze to indicate the View is 634 functionally the same as a squeeze or unsqueeze. It does not necessarily 635 mean the view_copy is normalized from squeeze or unsqueeze. 636 """ 637 638 def __init__(self): 639 super().__init__() 640 self.graph_module = None 641 642 # If list1 and list2 are same (same values and in same order) except 643 # list1 has one more element with value of 1. Return index of the extra 1. 644 # Otherwise return -1. 645 def check_if_shapes_differ_in_single_dim_of_size_1(self, list1, list2) -> int: 646 if len(list1) != len(list2) + 1: 647 return -1 648 for i in range(len(list2)): 649 if list1[i] != list2[i]: 650 # Return index of the extra 1 if the remaining parts are the same 651 if list1[i] == 1 and list2[i:] == list1[i + 1 :]: 652 return i 653 else: 654 return -1 655 # If no difference was found, the extra element is at the end 656 if list1[-1] == 1: 657 return len(list2) 658 else: 659 return -1 660 661 def insert_nodes( 662 self, 663 graph: torch.fx.Graph, 664 pred: torch.fx.Node, 665 permute_node: torch.fx.Node, 666 view_node: torch.fx.Node, 667 new_view_shape: List, 668 new_permute_dims: List, 669 ): 670 with graph.inserting_after(view_node): 671 new_view_node = graph.call_function( 672 view_node.target, # pyre-fixme[6] 673 args=(pred, new_view_shape), 674 ) 675 676 with graph.inserting_after(new_view_node): 677 new_permute_node = graph.call_function( 678 permute_node.target, # pyre-fixme[6] 679 args=(new_view_node, new_permute_dims), 680 ) 681 new_permute_node.meta = view_node.meta 682 view_node.replace_all_uses_with(new_permute_node) 683 684 # view_node is user of permute_node, so must erase view_node first 685 graph.erase_node(view_node) 686 graph.erase_node(permute_node) 687 688 # flake8: noqa 'PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView.postpone_permute_op' is too complex (13) 689 def postpone_permute_op(self, graph_module: torch.fx.GraphModule): 690 packet_to_overload_map = { 691 exir_ops.edge.aten.permute_copy: "default", 692 } 693 graph = graph_module.graph 694 changed = True 695 modified = False 696 # Loop iteratively until no more changes are made 697 while changed: 698 changed = False 699 for permute_node in graph.nodes: 700 permute_overload_packet = get_overload_packet(permute_node.target) 701 if permute_overload_packet not in packet_to_overload_map.keys(): 702 continue 703 704 users = list(permute_node.users.keys()) 705 # Transform only for pattern permute_copy->view_copy, and 706 # view_copy op is the only user of permute_copy. 707 if len(users) == 1 and users[0].target in ( 708 exir_ops.edge.aten.view_copy.default, 709 exir_ops.edge.aten.view.default, 710 ): 711 # If the permute_node/view_node was newly added to the 712 # graph, it may not have the meta["val"] FakeTensor. 713 # Skip in this case. 714 if permute_node.meta.get("val") is None: 715 continue 716 permute_node_shape = [ 717 *cast(list, get_shape(graph_module, permute_node)) 718 ] 719 permute_dims = permute_node.args[1] 720 view_node = users[0] 721 if view_node.meta.get("val") is None: 722 continue 723 view_node_shape = [*cast(list, get_shape(graph_module, view_node))] 724 pred = permute_node.args[0] 725 if pred.meta.get("val") is None: 726 continue 727 pred_shape = [*cast(list, get_shape(graph_module, pred))] 728 # Handle two cases 729 # 1. view_node_shape is almost same as permute_node_shape 730 # except the view_node has one more dim somewhere 731 # and the extra dim has value of 1. 732 # 2. view_node_shape is almost same as permute_node_shape 733 # except permute_node_shape has one more dim somewhere 734 # and the extra dim has value of 1. 735 # 3. view_node_shape is the same as permute_node_shape. 736 if len(permute_node_shape) + 1 == len(view_node_shape): 737 index = self.check_if_shapes_differ_in_single_dim_of_size_1( 738 view_node_shape, permute_node_shape 739 ) 740 if index != -1: 741 # view_node_shape is almost same as permute_node_shape 742 # except it has one more dim somewhere 743 # and the extra dim has value of 1. 744 new_view_shape = copy.deepcopy(pred_shape) 745 new_view_shape.insert(index, 1) 746 new_permute_dims = [ 747 x + 1 if x >= index else x for x in permute_dims 748 ] 749 new_permute_dims.insert(index, index) 750 self.insert_nodes( 751 graph, 752 pred, 753 permute_node, 754 view_node, 755 new_view_shape, 756 new_permute_dims, 757 ) 758 changed = True 759 modified = True 760 elif len(view_node_shape) + 1 == len(permute_node_shape): 761 index = self.check_if_shapes_differ_in_single_dim_of_size_1( 762 permute_node_shape, view_node_shape 763 ) 764 if index != -1: 765 # view_node_shape is almost same as permute_node_shape 766 # except permute_node_shape has one more dim somewhere 767 # and the extra dim has value of 1. 768 index_to_remove = permute_dims[index] 769 new_view_shape = copy.deepcopy(pred_shape) 770 del new_view_shape[index_to_remove] 771 new_permute_dims = [ 772 x - 1 if x > index_to_remove else x 773 for x in permute_dims 774 ] 775 del new_permute_dims[index] 776 self.insert_nodes( 777 graph, 778 pred, 779 permute_node, 780 view_node, 781 new_view_shape, 782 new_permute_dims, 783 ) 784 changed = True 785 modified = True 786 elif permute_node_shape == view_node_shape: 787 # view_node_shape is the same as permute_node_shape 788 # Replace the uses of view_node with permute_node 789 view_node.replace_all_uses_with(permute_node) 790 changed = True 791 modified = True 792 793 graph_module.recompile() 794 return modified 795 796 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 797 self.graph_module = graph_module 798 iter_count = 0 799 modified = True 800 801 while modified and iter_count <= 3: 802 modified = self.postpone_permute_op(self.graph_module) 803 self.graph_module = super().call(self.graph_module).graph_module 804 iter_count += 1 805 806 return super().call(self.graph_module) 807 808 809# The following class consolidates functions to reoder ops (i.e., either hoist 810# or sink some ops in the graph). 811class CadenceReorderOpsInGraph: 812 passes = [ 813 # Hoist/sink nodes closer to their SSA def/use 814 HoistOpsCloserToDefPass, 815 SinkOpsCloserToUsePass, 816 # For quantize/dequantize ops, move them above/below their def chain. 817 # This is a more aggressive optimization than just hoisting/sinking 818 # nodes closer to their def/use. 819 AdvanceQuantizeOpAboveDefChainPass, 820 PostponeDequantizeOpBelowUseChainPass, 821 # These passes work on branches instead of linear chains to advance 822 # quantize op beyond their def. 823 AdvanceQuantizeOpAboveDefInBranchPass, 824 ] 825