1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-strict 4 5 6# This file contains functions to remove operators from the graph. The removed 7# ops should belong to either of the following categories: 8# 1. The op should be redundant for inference (e.g., dropout). Such ops are grouped 9# together in 'RemoveRedundantOps'. Anyone running inference can add this class 10# in their pass list, and it should semantic-preserving transformation. 11# 2. The op should be redundant for Jarvis (e.g., contiguous). Such ops are grouped 12# together in 'CadenceRemoveNops'. The ops removed in this class might not be nop 13# in a context outside of Jarvis', so exercise caution while invoking this in a 14# pass list outside of Jarvis. 15 16import itertools 17import logging 18from dataclasses import dataclass, field 19from typing import Callable, cast, Dict, List, Optional, Sequence 20 21import torch 22import torch.fx 23from executorch.backends.cadence.aot.pass_utils import ( 24 CadencePassAttribute, 25 register_cadence_pass, 26) 27 28from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass 29from executorch.backends.cadence.aot.utils import get_edge_overload_packet 30from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform 31from executorch.exir.dialects._ops import ops as exir_ops 32from executorch.exir.dialects.edge._ops import EdgeOpOverload 33from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue 34from executorch.exir.pass_manager import PassManager, PassType 35from executorch.exir.passes import dead_code_elimination_pass 36from executorch.exir.passes.spec_prop_pass import SpecPropPass 37from torch.fx.node import Argument 38 39 40@register_cadence_pass(CadencePassAttribute(opt_level=0)) 41class RemoveCloneOpsTransformImported(ExportPass): 42 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 43 finalize_passes: List[PassType] = [ 44 RemoveCloneOpsTransform(), 45 ] 46 result = PassManager(passes=finalize_passes)(graph_module) 47 dead_code_elimination_pass(result.graph_module) 48 return result 49 50 51@register_cadence_pass(CadencePassAttribute(opt_level=0)) 52class RemoveDetachCopyPass(ExportPass): 53 def call_operator( 54 self, 55 op, # pyre-ignore 56 args: tuple[Argument, ...], 57 kwargs: dict[str, Argument], 58 meta: NodeMetadata, 59 ) -> ProxyValue: 60 if op != exir_ops.edge.aten.detach_copy.default: 61 return super().call_operator(op, args, kwargs, meta) 62 63 assert len(args) == 1 64 return cast(ProxyValue, args[0]) 65 66 67# The following class consolidates passes to remove ops that are redundant: 68# either by the virtue of the operation they perform, or redundant in the 69# context of inference. 70class RemoveRedundantOps: 71 passes = [ 72 RemoveDetachCopyPass, 73 ] 74 75 76@register_cadence_pass(CadencePassAttribute(opt_level=0)) 77class RemoveZeroSizedCatArgsPass(ExportPass): 78 def call_operator( 79 self, 80 op, # pyre-ignore 81 args: tuple[Argument, ...], 82 kwargs: dict[str, Argument], 83 meta: NodeMetadata, 84 ) -> ProxyValue: 85 if op != exir_ops.edge.aten.cat.default: 86 return super().call_operator(op, args, kwargs, meta) 87 88 # Remove any zero-sized tensor arg to form a new args list. 89 cat_inputs: list[ProxyValue] = [] 90 for arg in cast(Sequence[ProxyValue], args[0]): 91 if arg.to_tensor().numel() > 0: 92 cat_inputs.append(arg) 93 94 # If all the tensors were empty, we just return an empty tensor with 95 # the right shape. 96 if not cat_inputs: 97 empty_shape = meta["val"].shape 98 dtype = meta["val"].dtype 99 return super().call_operator( 100 exir_ops.edge.aten.full.default, 101 (tuple(empty_shape), 0), 102 {"dtype": dtype}, 103 meta, 104 ) 105 106 # If there was only one tensor in the cat_inputs list, 107 # we can safely erase this cat op. 108 if len(cat_inputs) == 1: 109 return cat_inputs[0] 110 111 # Otherwise, we replace args[0] with cat_inputs. 112 new_args = list(args) 113 new_args[0] = cat_inputs 114 return super().call_operator(op, tuple(new_args), kwargs, meta) 115 116 117@register_cadence_pass(CadencePassAttribute(opt_level=0)) 118class RemoveNopExpandOpPass(ExportPass): 119 """ 120 For an expand op, if the operator shape matches the expand shape, then the 121 expand is a nop. 122 """ 123 124 def call_operator( 125 self, 126 op, # pyre-ignore 127 args: tuple[Argument, ...], 128 kwargs: dict[str, Argument], 129 meta: NodeMetadata, 130 ) -> ProxyValue: 131 if get_edge_overload_packet(op) not in { 132 exir_ops.edge.aten.expand_copy, 133 exir_ops.edge.aten.expand, 134 }: 135 return super().call_operator(op, args, kwargs, meta) 136 137 # Parse the args, and check for nop condition 138 arg0 = cast(ProxyValue, args[0]) 139 arg1 = cast(Sequence[int], args[1]) 140 in_tensor = arg0.to_tensor() 141 if list(in_tensor.shape) == list(arg1): 142 return arg0 143 144 return super().call_operator(op, args, kwargs, meta) 145 146 147@register_cadence_pass(CadencePassAttribute(opt_level=0)) 148class RemoveToOpsPass(ExportPass): 149 # aten.to.* as of now are all nops for Jarvis 150 def call_operator( 151 self, 152 op, # pyre-ignore 153 args: tuple[Argument, ...], 154 kwargs: dict[str, Argument], 155 meta: NodeMetadata, 156 ) -> ProxyValue: 157 if op not in ( 158 exir_ops.edge.aten.to.dtype, 159 exir_ops.edge.aten.to.dtype_layout, 160 ): 161 return super().call_operator(op, args, kwargs, meta) 162 163 logging.debug(f"Erasing to.dtype node (target = {op})") 164 return cast(ProxyValue, args[0]) 165 166 167@register_cadence_pass(CadencePassAttribute(opt_level=1)) 168class RemoveZeroSizedConstantPadNd(ExportPass): 169 def call_operator( 170 self, 171 op, # pyre-ignore 172 args: tuple[ProxyValue, tuple[int, ...], Argument], 173 kwargs: dict[str, Argument], 174 meta: NodeMetadata, 175 ) -> ProxyValue: 176 if op != exir_ops.edge.aten.constant_pad_nd.default: 177 return super().call_operator(op, args, kwargs, meta) 178 179 input_tensor = args[0] 180 padding = args[1] 181 182 if any(x != 0 for x in padding): 183 return super().call_operator(op, args, kwargs, meta) 184 185 logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") 186 return input_tensor 187 188 189@register_cadence_pass(CadencePassAttribute(opt_level=1)) 190class RemoveNopSliceOrViewOpPass(ExportPass): 191 """ 192 Remove slice ops that are more like views, and view ops that do not change the shape 193 """ 194 195 def call_operator( 196 self, 197 op, # pyre-ignore 198 args: tuple[Argument, ...], 199 kwargs: dict[str, Argument], 200 meta: NodeMetadata, 201 ) -> ProxyValue: 202 if op not in { 203 exir_ops.edge.aten.slice_copy.Tensor, 204 exir_ops.edge.aten.view_copy.default, 205 }: 206 return super().call_operator(op, args, kwargs, meta) 207 208 arg0 = cast(ProxyValue, args[0]) 209 out_shape = meta["val"].shape 210 211 # If both arg_shape and out_shape are the same, this slice is a nop 212 return ( 213 arg0 214 if arg0.to_tensor().shape == out_shape 215 else super().call_operator(op, args, kwargs, meta) 216 ) 217 218 219@register_cadence_pass(CadencePassAttribute(opt_level=1)) 220class RemoveNopLinalgVectorNormOpPass(ExportPass): 221 """ 222 If the norm is applied over a dimension that is size 1, it can be eliminated. 223 """ 224 225 def call_operator( 226 self, 227 op, # pyre-ignore 228 args: tuple[Argument, ...], 229 kwargs: dict[str, Argument], 230 meta: NodeMetadata, 231 ) -> ProxyValue: 232 if op not in { 233 exir_ops.edge.aten.linalg_vector_norm.default, 234 exir_ops.edge.cadence.linalg_vector_norm.default, 235 }: 236 return super().call_operator(op, args, kwargs, meta) 237 238 # If the op has three args or less, it can't be a nop 239 if len(args) <= 3: 240 return super().call_operator(op, args, kwargs, meta) 241 # If dim is None, or keepdim is False, it is not a nop 242 dim = cast(Optional[tuple[int, ...]], args[2]) 243 keepdim = cast(bool, args[3]) 244 if dim is None or not keepdim: 245 return super().call_operator(op, args, kwargs, meta) 246 247 # If the norm has 4 args and keepdim is True, check if dim is not None 248 # and if the dimensions in dim are size 1. If not, the norm is not a nop. 249 t = cast(ProxyValue, args[0]) 250 shape = t.to_tensor().shape 251 if len(args) < 4: 252 for d in dim: 253 if shape[d] != 1: 254 return super().call_operator(op, args, kwargs, meta) 255 256 return t 257 258 259@register_cadence_pass(CadencePassAttribute(opt_level=1)) 260class RemoveNopSelectOpPass(ExportPass): 261 """ 262 A select op that selects from a dimension that is size 1 can be eliminated 263 in a few cases. For example, 264 ``` 265 x = view (x, [1, 3, 16]) 266 y = select(x, 0, 0) 267 z = add(m, y) 268 ``` 269 The special thing about this pattern is the add op, which allows 270 broadcasting. So adding an operand with shape [3, 16] is the same as 271 adding an operand with shape [1, 3, 16]. Therefore, if m has the same 272 shape as x, then this select op is a nop, and can be eliminated: 273 ``` 274 x = view (x, [1, 3, 16]) 275 z = add(x, m) 276 ``` 277 """ 278 279 # A set of binary operators that could require broadcasting, and are 280 # critical to this transformation if their operand is select op. 281 binary_broadcast_ops: set[EdgeOpOverload] = { 282 exir_ops.edge.aten.add.Tensor, 283 exir_ops.edge.aten.mul.Tensor, 284 exir_ops.edge.aten.div.Tensor, 285 } 286 287 def __init__(self) -> None: 288 super().__init__() 289 self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {} 290 291 # For select, view, or any op in binary_broadcast_ops, record the shapes of 292 # input and output tensors. 293 def call_operator( 294 self, 295 op, # pyre-ignore 296 args: tuple[Argument, ...], 297 kwargs: dict[str, Argument], 298 meta: NodeMetadata, 299 ) -> ProxyValue: 300 res = super().call_operator(op, args, kwargs, meta) 301 # Unary ops: input and output 302 if op in { 303 exir_ops.edge.aten.select_copy.int, 304 exir_ops.edge.aten.view_copy.default, 305 }: 306 arg0 = cast(ProxyValue, args[0]) 307 self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape) 308 # Binary ops: two inputs, output shape can be inferred 309 elif op in self.binary_broadcast_ops: 310 arg0 = cast(ProxyValue, args[0]) 311 arg1 = cast(ProxyValue, args[1]) 312 self.op_sizes[res.node.name] = ( 313 arg0.to_tensor().shape, 314 arg1.to_tensor().shape, 315 ) 316 return res 317 318 # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops, 319 # and check if their arg is a select op. 320 def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None: 321 for sel_node in graph_module.graph.nodes: 322 # We are only interested in select ops 323 if sel_node.target != exir_ops.edge.aten.select_copy.int: 324 continue 325 # The shape of the input/output operands for this select op should 326 # have been precomputed. 327 assert sel_node.name in self.op_sizes 328 (sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name] 329 # Get the select dimension 330 sel_dim = ( 331 sel_node.args[1] 332 if sel_node.args[1] >= 0 333 else sel_node.args[1] + len(sel_in_shape) 334 ) 335 # If the input size along select dimension is not 1, bail. 336 if sel_in_shape[sel_dim] != 1: 337 continue 338 339 # Get all the users of the select op that are either view, or 340 # binary_broadcast_ops. 341 users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes] 342 sel_in = sel_node.args[0] 343 344 # Iterate over the users of select op, and remove the use of the 345 # select op in the user if feasible. 346 for node in users: 347 args = list(node.args) 348 for idx, sel_arg in enumerate(args): 349 # Check if the arg is the select op 350 if sel_arg != sel_node: 351 continue 352 # If the input of select has the same shape as the other arg 353 # of the binary op, the select op can be bypassed. 354 if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]: 355 args[idx] = sel_in 356 # update the node's args 357 node.args = tuple(args) 358 359 graph_module.recompile() 360 graph_module.graph.eliminate_dead_code() 361 362 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 363 result = SpecPropPass()(graph_module) 364 assert result is not None 365 result = super().call(result.graph_module) 366 self.eliminate_nop_select_op(result.graph_module) 367 return result 368 369 370@register_cadence_pass(CadencePassAttribute(opt_level=1)) 371class RemoveCloneOpPass(ExportPass): 372 # If the op is a clone op, return the input and eliminate the op 373 def call_operator( 374 self, 375 op, # pyre-ignore 376 args: tuple[ProxyValue], 377 kwargs: dict[str, Argument], 378 meta: NodeMetadata, 379 ) -> ProxyValue: 380 if op != exir_ops.edge.aten.clone.default: 381 return super().call_operator(op, args, kwargs, meta) 382 383 return args[0] 384 385 386@register_cadence_pass(CadencePassAttribute(opt_level=1)) 387class RemoveContiguousOpPass(ExportPass): 388 """ 389 This is based on the assumption that all tensors are contiguous in ExecuTorch 390 and after cadence passes, and we should revisit this if that assumption is no longer true. 391 This causes the model to not be runnable with the arguments given to the 392 original graph module. 393 """ 394 395 def call_operator( 396 self, 397 op, # pyre-ignore 398 args: tuple[Argument, ...], 399 kwargs: dict[str, Argument], 400 meta: NodeMetadata, 401 ) -> ProxyValue: 402 if op != exir_ops.edge.aten.contiguous.default: 403 return super().call_operator(op, args, kwargs, meta) 404 405 assert len(args) == 1 406 return cast(ProxyValue, args[0]) 407 408 409@register_cadence_pass(CadencePassAttribute(opt_level=0)) 410class RemoveAliasCopyOpPass(ExportPass): 411 """ 412 413 alias_copy is a no-op for Jarvis and can be removed. 414 """ 415 416 def call_operator( 417 self, 418 op, # pyre-ignore 419 args: tuple[Argument, ...], 420 kwargs: dict[str, Argument], 421 meta: NodeMetadata, 422 ) -> ProxyValue: 423 if op != exir_ops.edge.aten.alias_copy.default: 424 return super().call_operator(op, args, kwargs, meta) 425 426 assert len(args) == 1 427 return cast(ProxyValue, args[0]) 428 429 430@register_cadence_pass(CadencePassAttribute(opt_level=1)) 431class RemoveNopRequantizeOpPass(ExportPass): 432 """ 433 For a requantize op, if the following three conditions are satisfied: 434 1. the in_scale matches the out_scale 435 2. the in_zero_point matches the out_zero_point 436 3. the dtypes of the input and output tensors are the same 437 then the requantize op is redundant, and can be eliminated 438 """ 439 440 def call_operator( 441 self, 442 op, # pyre-ignore 443 args: tuple[Argument, ...], 444 kwargs: dict[str, Argument], 445 meta: NodeMetadata, 446 ) -> ProxyValue: 447 if op != exir_ops.edge.cadence.requantize.default: 448 return super().call_operator(op, args, kwargs, meta) 449 450 # Parse the args 451 (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( 452 tuple[ProxyValue, int, float, int, float, torch.dtype], args 453 ) 454 in_dtype = X.to_tensor().dtype 455 # Check the three conditions 456 if ( 457 in_scale == out_scale 458 and in_zero_point == out_zero_point 459 and in_dtype == out_dtype 460 ): 461 return cast(ProxyValue, args[0]) 462 463 return super().call_operator(op, args, kwargs, meta) 464 465 466@register_cadence_pass(CadencePassAttribute(opt_level=1)) 467class RemoveNopMulOpPass(ExportPass): 468 """ 469 If a mul op is multiplying two tensors with the same shape and one 470 of those tensors is all zeros, return the zero tensor instead. 471 """ 472 473 def call_operator( 474 self, 475 op, # pyre-ignore 476 args: tuple[Argument, ...], 477 kwargs: dict[str, Argument], 478 meta: NodeMetadata, 479 ) -> ProxyValue: 480 if op != exir_ops.edge.aten.mul.Tensor: 481 return super().call_operator(op, args, kwargs, meta) 482 483 # Parse the args 484 (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) 485 486 # Check if both inputs have the same shape 487 if input1.to_tensor().shape != input2.to_tensor().shape: 488 return super().call_operator(op, args, kwargs, meta) 489 490 # Check if one of the inputs is a zero tensor 491 if input1.node.target == exir_ops.edge.aten.full.default: 492 if input1.node.args[1] == 0: 493 return input1 494 elif input2.node.target == exir_ops.edge.aten.full.default: 495 if input2.node.args[1] == 0: 496 return input2 497 498 return super().call_operator(op, args, kwargs, meta) 499 500 501@register_cadence_pass(CadencePassAttribute(opt_level=1)) 502class RemoveNopAddOpPass(ExportPass): 503 """ 504 If an add op is adding two tensors with the same shape and one 505 of those tensors is all zeros, return the other tensor instead. 506 """ 507 508 def call_operator( 509 self, 510 op, # pyre-ignore 511 args: tuple[Argument, ...], 512 kwargs: dict[str, Argument], 513 meta: NodeMetadata, 514 ) -> ProxyValue: 515 if op != exir_ops.edge.aten.add.Tensor: 516 return super().call_operator(op, args, kwargs, meta) 517 518 # Parse the args 519 (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) 520 521 # Check if both inputs have the same shape 522 if input1.to_tensor().shape != input2.to_tensor().shape: 523 return super().call_operator(op, args, kwargs, meta) 524 525 # Check if one of the inputs is a zero tensor 526 if input1.node.target == exir_ops.edge.aten.full.default: 527 if input1.node.args[1] == 0: 528 return input2 529 elif input2.node.target == exir_ops.edge.aten.full.default: 530 if input2.node.args[1] == 0: 531 return input1 532 533 return super().call_operator(op, args, kwargs, meta) 534 535 536@register_cadence_pass(CadencePassAttribute(opt_level=1)) 537class RemovePermutesAroundElementwiseOps(ExportPass): 538 """ 539 Looks for subgraphs of elementwise ops sandwiched between permutes and removes those 540 permutes if possible. This pass is targeted at models where delegated subgraphs 541 must be in NHWC format, so there's usually a to_NHWC permute before each delegate and 542 a to_NCHW permute after it. If all the ops between two delegates are elementwise ops 543 then these permutes can be safely removed. 544 Allows special handling for certain non-elementwise ops that can be easily updated based on 545 the permute's parameter, such as mean and cat 546 """ 547 548 @dataclass() 549 class Subgraph: 550 """ 551 Keeps track of nodes grouped as a subgraph between two sets of permutes 552 """ 553 554 start_permutes: set[torch.fx.Node] = field(default_factory=set) 555 end_permutes: set[torch.fx.Node] = field(default_factory=set) 556 intermediate_nodes: set[torch.fx.Node] = field(default_factory=set) 557 is_valid: bool = True 558 559 elementwise_ops: set[EdgeOpOverload] = { 560 exir_ops.edge.aten.add.Tensor, 561 exir_ops.edge.aten.mul.Tensor, 562 exir_ops.edge.aten.mean.dim, 563 exir_ops.edge.aten.cat.default, 564 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 565 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 566 } 567 568 # must be initialized in the constructor 569 special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {} 570 571 to_NCHW = [0, 3, 1, 2] 572 to_NHWC = [0, 2, 3, 1] 573 574 def __init__(self) -> None: 575 super().__init__() 576 self.visited: set[object] = set() 577 self.special_handling = { 578 exir_ops.edge.aten.mean.dim: self.handle_mean_dim, 579 exir_ops.edge.aten.cat.default: self.handle_cat, 580 } 581 582 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 583 self.visited = set() 584 for node in graph_module.graph.nodes: 585 sg = self.Subgraph() 586 self.start_search(node, sg) 587 if self.is_valid_subgraph(sg): 588 logging.debug(f"Found valid subgraph: {sg}") 589 self.handle_subgraph(graph_module, sg) 590 591 result = super().call(graph_module) 592 return result 593 594 def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None: 595 assert mean_dim.target == exir_ops.edge.aten.mean.dim 596 args = list(mean_dim.args) 597 args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])] 598 mean_dim.args = tuple(args) 599 600 def handle_cat(self, cat: torch.fx.Node) -> None: 601 assert cat.target == exir_ops.edge.aten.cat.default 602 args = list(cat.args) 603 args[1] = self.to_NCHW[cast(int, args[1])] 604 cat.args = tuple(args) 605 606 def is_valid_subgraph(self, sg: Subgraph) -> bool: 607 return ( 608 sg.is_valid 609 and len(sg.start_permutes) > 0 610 and len(sg.end_permutes) > 0 611 and len(sg.intermediate_nodes) > 0 612 ) 613 614 def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None: 615 for permute in itertools.chain(sg.start_permutes, sg.end_permutes): 616 permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6] 617 618 for node in sg.intermediate_nodes: 619 if node.target in self.special_handling: 620 self.special_handling[node.target](node) 621 622 graph_module.recompile() 623 graph_module.graph.eliminate_dead_code() 624 625 def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None: 626 if node in self.visited: 627 return 628 629 if self.is_starting_permute(node): 630 sg.start_permutes.add(node) 631 self.visited.add(node) 632 for user in node.users: 633 self.search_down(user, sg) 634 635 def search_up(self, node: object, sg: Subgraph) -> None: 636 # non-nodes can be ignored. These would be arguments like integers or lists 637 # of integers, which don't affect the subgraph validity or inclusion set. 638 if not isinstance(node, torch.fx.Node): 639 return 640 641 if node.op == "placeholder": 642 # If we reach a placeholder or other terminal node without encountering 643 # a start permute, then the subgraph is invalid. 644 # This could be because in the add(x, y) case where x is permuted and 645 # y is a graph input, we can't remove the permute on x because it might 646 # become two different shapes that don't broadcast together. 647 # TODO: Adding a permute on y could be the more optimal solution, 648 # but perhaps not in all cases, say if x is small and y is very large. 649 # This transform prefers to be safe over optimal for now. 650 sg.is_valid = False 651 return 652 653 if node in self.visited: 654 return 655 656 self.visited.add(node) 657 658 if self.is_starting_permute(node): 659 sg.start_permutes.add(node) 660 for user in node.users: 661 self.search_down(user, sg) 662 else: 663 self.traverse_intermediate_node(node, sg) 664 665 def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None: 666 if node in self.visited or self.is_starting_permute(node): 667 return 668 669 self.visited.add(node) 670 671 if self.is_ending_permute(node): 672 sg.end_permutes.add(node) 673 for arg in node.args: 674 if isinstance(arg, list): 675 for elem in arg: 676 self.search_up(elem, sg) 677 else: 678 self.search_up(arg, sg) 679 else: 680 self.traverse_intermediate_node(node, sg) 681 682 def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: 683 if node.target in self.elementwise_ops: 684 sg.intermediate_nodes.add(node) 685 for arg in node.args: 686 if isinstance(arg, list): 687 for elem in arg: 688 self.search_up(elem, sg) 689 else: 690 self.search_up(arg, sg) 691 692 for user in node.users: 693 self.search_down(user, sg) 694 695 else: 696 sg.is_valid = False 697 698 def is_starting_permute(self, node: torch.fx.Node) -> bool: 699 return ( 700 node.target == exir_ops.edge.aten.permute_copy.default 701 and cast(list[int], node.args[1]) == self.to_NCHW 702 ) 703 704 def is_ending_permute(self, node: torch.fx.Node) -> bool: 705 return ( 706 node.target == exir_ops.edge.aten.permute_copy.default 707 and cast(list[int], node.args[1]) == self.to_NHWC 708 ) 709 710 711# The following class consolidates functions to remove ops that are redundant 712# in Jarvis. Currently, each function in this class iterates over each node of 713# the graph module once. In future, we could consolidate them into a monolithic 714# function. 715class CadenceRemoveNops: 716 passes = [ 717 SimplifySliceOpPass, 718 RemoveCloneOpsTransformImported, 719 RemoveToOpsPass, 720 RemoveNopRequantizeOpPass, 721 RemoveZeroSizedCatArgsPass, 722 RemoveNopSliceOrViewOpPass, 723 RemoveNopExpandOpPass, 724 RemoveZeroSizedConstantPadNd, 725 RemoveCloneOpPass, 726 RemoveContiguousOpPass, 727 RemoveAliasCopyOpPass, 728 RemoveNopMulOpPass, 729 RemoveNopAddOpPass, 730 RemoveNopLinalgVectorNormOpPass, 731 ] 732