xref: /aosp_15_r20/external/executorch/backends/cadence/aot/remove_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-strict
4
5
6# This file contains functions to remove operators from the graph. The removed
7# ops should belong to either of the following categories:
8# 1. The op should be redundant for inference (e.g., dropout). Such ops are grouped
9# together in 'RemoveRedundantOps'. Anyone running inference can add this class
10# in their pass list, and it should semantic-preserving transformation.
11# 2. The op should be redundant for Jarvis (e.g., contiguous). Such ops are grouped
12# together in 'CadenceRemoveNops'. The ops removed in this class might not be nop
13# in a context outside of Jarvis', so exercise caution while invoking this in a
14# pass list outside of Jarvis.
15
16import itertools
17import logging
18from dataclasses import dataclass, field
19from typing import Callable, cast, Dict, List, Optional, Sequence
20
21import torch
22import torch.fx
23from executorch.backends.cadence.aot.pass_utils import (
24    CadencePassAttribute,
25    register_cadence_pass,
26)
27
28from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
29from executorch.backends.cadence.aot.utils import get_edge_overload_packet
30from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
31from executorch.exir.dialects._ops import ops as exir_ops
32from executorch.exir.dialects.edge._ops import EdgeOpOverload
33from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
34from executorch.exir.pass_manager import PassManager, PassType
35from executorch.exir.passes import dead_code_elimination_pass
36from executorch.exir.passes.spec_prop_pass import SpecPropPass
37from torch.fx.node import Argument
38
39
40@register_cadence_pass(CadencePassAttribute(opt_level=0))
41class RemoveCloneOpsTransformImported(ExportPass):
42    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
43        finalize_passes: List[PassType] = [
44            RemoveCloneOpsTransform(),
45        ]
46        result = PassManager(passes=finalize_passes)(graph_module)
47        dead_code_elimination_pass(result.graph_module)
48        return result
49
50
51@register_cadence_pass(CadencePassAttribute(opt_level=0))
52class RemoveDetachCopyPass(ExportPass):
53    def call_operator(
54        self,
55        op,  # pyre-ignore
56        args: tuple[Argument, ...],
57        kwargs: dict[str, Argument],
58        meta: NodeMetadata,
59    ) -> ProxyValue:
60        if op != exir_ops.edge.aten.detach_copy.default:
61            return super().call_operator(op, args, kwargs, meta)
62
63        assert len(args) == 1
64        return cast(ProxyValue, args[0])
65
66
67# The following class consolidates passes to remove ops that are redundant:
68# either by the virtue of the operation they perform, or redundant in the
69# context of inference.
70class RemoveRedundantOps:
71    passes = [
72        RemoveDetachCopyPass,
73    ]
74
75
76@register_cadence_pass(CadencePassAttribute(opt_level=0))
77class RemoveZeroSizedCatArgsPass(ExportPass):
78    def call_operator(
79        self,
80        op,  # pyre-ignore
81        args: tuple[Argument, ...],
82        kwargs: dict[str, Argument],
83        meta: NodeMetadata,
84    ) -> ProxyValue:
85        if op != exir_ops.edge.aten.cat.default:
86            return super().call_operator(op, args, kwargs, meta)
87
88        # Remove any zero-sized tensor arg to form a new args list.
89        cat_inputs: list[ProxyValue] = []
90        for arg in cast(Sequence[ProxyValue], args[0]):
91            if arg.to_tensor().numel() > 0:
92                cat_inputs.append(arg)
93
94        # If all the tensors were empty, we just return an empty tensor with
95        # the right shape.
96        if not cat_inputs:
97            empty_shape = meta["val"].shape
98            dtype = meta["val"].dtype
99            return super().call_operator(
100                exir_ops.edge.aten.full.default,
101                (tuple(empty_shape), 0),
102                {"dtype": dtype},
103                meta,
104            )
105
106        # If there was only one tensor in the cat_inputs list,
107        # we can safely erase this cat op.
108        if len(cat_inputs) == 1:
109            return cat_inputs[0]
110
111        # Otherwise, we replace args[0] with cat_inputs.
112        new_args = list(args)
113        new_args[0] = cat_inputs
114        return super().call_operator(op, tuple(new_args), kwargs, meta)
115
116
117@register_cadence_pass(CadencePassAttribute(opt_level=0))
118class RemoveNopExpandOpPass(ExportPass):
119    """
120    For an expand op, if the operator shape matches the expand shape, then the
121    expand is a nop.
122    """
123
124    def call_operator(
125        self,
126        op,  # pyre-ignore
127        args: tuple[Argument, ...],
128        kwargs: dict[str, Argument],
129        meta: NodeMetadata,
130    ) -> ProxyValue:
131        if get_edge_overload_packet(op) not in {
132            exir_ops.edge.aten.expand_copy,
133            exir_ops.edge.aten.expand,
134        }:
135            return super().call_operator(op, args, kwargs, meta)
136
137        # Parse the args, and check for nop condition
138        arg0 = cast(ProxyValue, args[0])
139        arg1 = cast(Sequence[int], args[1])
140        in_tensor = arg0.to_tensor()
141        if list(in_tensor.shape) == list(arg1):
142            return arg0
143
144        return super().call_operator(op, args, kwargs, meta)
145
146
147@register_cadence_pass(CadencePassAttribute(opt_level=0))
148class RemoveToOpsPass(ExportPass):
149    # aten.to.* as of now are all nops for Jarvis
150    def call_operator(
151        self,
152        op,  # pyre-ignore
153        args: tuple[Argument, ...],
154        kwargs: dict[str, Argument],
155        meta: NodeMetadata,
156    ) -> ProxyValue:
157        if op not in (
158            exir_ops.edge.aten.to.dtype,
159            exir_ops.edge.aten.to.dtype_layout,
160        ):
161            return super().call_operator(op, args, kwargs, meta)
162
163        logging.debug(f"Erasing to.dtype node (target = {op})")
164        return cast(ProxyValue, args[0])
165
166
167@register_cadence_pass(CadencePassAttribute(opt_level=1))
168class RemoveZeroSizedConstantPadNd(ExportPass):
169    def call_operator(
170        self,
171        op,  # pyre-ignore
172        args: tuple[ProxyValue, tuple[int, ...], Argument],
173        kwargs: dict[str, Argument],
174        meta: NodeMetadata,
175    ) -> ProxyValue:
176        if op != exir_ops.edge.aten.constant_pad_nd.default:
177            return super().call_operator(op, args, kwargs, meta)
178
179        input_tensor = args[0]
180        padding = args[1]
181
182        if any(x != 0 for x in padding):
183            return super().call_operator(op, args, kwargs, meta)
184
185        logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}")
186        return input_tensor
187
188
189@register_cadence_pass(CadencePassAttribute(opt_level=1))
190class RemoveNopSliceOrViewOpPass(ExportPass):
191    """
192    Remove slice ops that are more like views, and view ops that do not change the shape
193    """
194
195    def call_operator(
196        self,
197        op,  # pyre-ignore
198        args: tuple[Argument, ...],
199        kwargs: dict[str, Argument],
200        meta: NodeMetadata,
201    ) -> ProxyValue:
202        if op not in {
203            exir_ops.edge.aten.slice_copy.Tensor,
204            exir_ops.edge.aten.view_copy.default,
205        }:
206            return super().call_operator(op, args, kwargs, meta)
207
208        arg0 = cast(ProxyValue, args[0])
209        out_shape = meta["val"].shape
210
211        # If both arg_shape and out_shape are the same, this slice is a nop
212        return (
213            arg0
214            if arg0.to_tensor().shape == out_shape
215            else super().call_operator(op, args, kwargs, meta)
216        )
217
218
219@register_cadence_pass(CadencePassAttribute(opt_level=1))
220class RemoveNopLinalgVectorNormOpPass(ExportPass):
221    """
222    If the norm is applied over a dimension that is size 1, it can be eliminated.
223    """
224
225    def call_operator(
226        self,
227        op,  # pyre-ignore
228        args: tuple[Argument, ...],
229        kwargs: dict[str, Argument],
230        meta: NodeMetadata,
231    ) -> ProxyValue:
232        if op not in {
233            exir_ops.edge.aten.linalg_vector_norm.default,
234            exir_ops.edge.cadence.linalg_vector_norm.default,
235        }:
236            return super().call_operator(op, args, kwargs, meta)
237
238        # If the op has three args or less, it can't be a nop
239        if len(args) <= 3:
240            return super().call_operator(op, args, kwargs, meta)
241        # If dim is None, or keepdim is False, it is not a nop
242        dim = cast(Optional[tuple[int, ...]], args[2])
243        keepdim = cast(bool, args[3])
244        if dim is None or not keepdim:
245            return super().call_operator(op, args, kwargs, meta)
246
247        # If the norm has 4 args and keepdim is True, check if dim is not None
248        # and if the dimensions in dim are size 1. If not, the norm is not a nop.
249        t = cast(ProxyValue, args[0])
250        shape = t.to_tensor().shape
251        if len(args) < 4:
252            for d in dim:
253                if shape[d] != 1:
254                    return super().call_operator(op, args, kwargs, meta)
255
256        return t
257
258
259@register_cadence_pass(CadencePassAttribute(opt_level=1))
260class RemoveNopSelectOpPass(ExportPass):
261    """
262    A select op that selects from a dimension that is size 1 can be eliminated
263    in a few cases. For example,
264    ```
265    x = view (x, [1, 3, 16])
266    y = select(x, 0, 0)
267    z = add(m, y)
268    ```
269    The special thing about this pattern is the add op, which allows
270    broadcasting. So adding an operand with shape [3, 16] is the same as
271    adding an operand with shape [1, 3, 16]. Therefore, if m has the same
272    shape as x, then this select op is a nop, and can be eliminated:
273    ```
274    x = view (x, [1, 3, 16])
275    z = add(x, m)
276    ```
277    """
278
279    # A set of binary operators that could require broadcasting, and are
280    # critical to this transformation if their operand is select op.
281    binary_broadcast_ops: set[EdgeOpOverload] = {
282        exir_ops.edge.aten.add.Tensor,
283        exir_ops.edge.aten.mul.Tensor,
284        exir_ops.edge.aten.div.Tensor,
285    }
286
287    def __init__(self) -> None:
288        super().__init__()
289        self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {}
290
291    # For select, view, or any op in binary_broadcast_ops, record the shapes of
292    # input and output tensors.
293    def call_operator(
294        self,
295        op,  # pyre-ignore
296        args: tuple[Argument, ...],
297        kwargs: dict[str, Argument],
298        meta: NodeMetadata,
299    ) -> ProxyValue:
300        res = super().call_operator(op, args, kwargs, meta)
301        # Unary ops: input and output
302        if op in {
303            exir_ops.edge.aten.select_copy.int,
304            exir_ops.edge.aten.view_copy.default,
305        }:
306            arg0 = cast(ProxyValue, args[0])
307            self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape)
308        # Binary ops: two inputs, output shape can be inferred
309        elif op in self.binary_broadcast_ops:
310            arg0 = cast(ProxyValue, args[0])
311            arg1 = cast(ProxyValue, args[1])
312            self.op_sizes[res.node.name] = (
313                arg0.to_tensor().shape,
314                arg1.to_tensor().shape,
315            )
316        return res
317
318    # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops,
319    # and check if their arg is a select op.
320    def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None:
321        for sel_node in graph_module.graph.nodes:
322            # We are only interested in select ops
323            if sel_node.target != exir_ops.edge.aten.select_copy.int:
324                continue
325            # The shape of the input/output operands for this select op should
326            # have been precomputed.
327            assert sel_node.name in self.op_sizes
328            (sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name]
329            # Get the select dimension
330            sel_dim = (
331                sel_node.args[1]
332                if sel_node.args[1] >= 0
333                else sel_node.args[1] + len(sel_in_shape)
334            )
335            # If the input size along select dimension is not 1, bail.
336            if sel_in_shape[sel_dim] != 1:
337                continue
338
339            # Get all the users of the select op that are either view, or
340            # binary_broadcast_ops.
341            users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes]
342            sel_in = sel_node.args[0]
343
344            # Iterate over the users of select op, and remove the use of the
345            # select op in the user if feasible.
346            for node in users:
347                args = list(node.args)
348                for idx, sel_arg in enumerate(args):
349                    # Check if the arg is the select op
350                    if sel_arg != sel_node:
351                        continue
352                    # If the input of select has the same shape as the other arg
353                    # of the binary op, the select op can be bypassed.
354                    if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]:
355                        args[idx] = sel_in
356                # update the node's args
357                node.args = tuple(args)
358
359        graph_module.recompile()
360        graph_module.graph.eliminate_dead_code()
361
362    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
363        result = SpecPropPass()(graph_module)
364        assert result is not None
365        result = super().call(result.graph_module)
366        self.eliminate_nop_select_op(result.graph_module)
367        return result
368
369
370@register_cadence_pass(CadencePassAttribute(opt_level=1))
371class RemoveCloneOpPass(ExportPass):
372    # If the op is a clone op, return the input and eliminate the op
373    def call_operator(
374        self,
375        op,  # pyre-ignore
376        args: tuple[ProxyValue],
377        kwargs: dict[str, Argument],
378        meta: NodeMetadata,
379    ) -> ProxyValue:
380        if op != exir_ops.edge.aten.clone.default:
381            return super().call_operator(op, args, kwargs, meta)
382
383        return args[0]
384
385
386@register_cadence_pass(CadencePassAttribute(opt_level=1))
387class RemoveContiguousOpPass(ExportPass):
388    """
389    This is based on the assumption that all tensors are contiguous in ExecuTorch
390    and after cadence passes, and we should revisit this if that assumption is no longer true.
391    This causes the model to not be runnable with the arguments given to the
392    original graph module.
393    """
394
395    def call_operator(
396        self,
397        op,  # pyre-ignore
398        args: tuple[Argument, ...],
399        kwargs: dict[str, Argument],
400        meta: NodeMetadata,
401    ) -> ProxyValue:
402        if op != exir_ops.edge.aten.contiguous.default:
403            return super().call_operator(op, args, kwargs, meta)
404
405        assert len(args) == 1
406        return cast(ProxyValue, args[0])
407
408
409@register_cadence_pass(CadencePassAttribute(opt_level=0))
410class RemoveAliasCopyOpPass(ExportPass):
411    """
412
413    alias_copy is a no-op for Jarvis and can be removed.
414    """
415
416    def call_operator(
417        self,
418        op,  # pyre-ignore
419        args: tuple[Argument, ...],
420        kwargs: dict[str, Argument],
421        meta: NodeMetadata,
422    ) -> ProxyValue:
423        if op != exir_ops.edge.aten.alias_copy.default:
424            return super().call_operator(op, args, kwargs, meta)
425
426        assert len(args) == 1
427        return cast(ProxyValue, args[0])
428
429
430@register_cadence_pass(CadencePassAttribute(opt_level=1))
431class RemoveNopRequantizeOpPass(ExportPass):
432    """
433    For a requantize op, if the following three conditions are satisfied:
434    1. the in_scale matches the out_scale
435    2. the in_zero_point matches the out_zero_point
436    3. the dtypes of the input and output tensors are the same
437    then the requantize op is redundant, and can be eliminated
438    """
439
440    def call_operator(
441        self,
442        op,  # pyre-ignore
443        args: tuple[Argument, ...],
444        kwargs: dict[str, Argument],
445        meta: NodeMetadata,
446    ) -> ProxyValue:
447        if op != exir_ops.edge.cadence.requantize.default:
448            return super().call_operator(op, args, kwargs, meta)
449
450        # Parse the args
451        (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast(
452            tuple[ProxyValue, int, float, int, float, torch.dtype], args
453        )
454        in_dtype = X.to_tensor().dtype
455        # Check the three conditions
456        if (
457            in_scale == out_scale
458            and in_zero_point == out_zero_point
459            and in_dtype == out_dtype
460        ):
461            return cast(ProxyValue, args[0])
462
463        return super().call_operator(op, args, kwargs, meta)
464
465
466@register_cadence_pass(CadencePassAttribute(opt_level=1))
467class RemoveNopMulOpPass(ExportPass):
468    """
469    If a mul op is multiplying two tensors with the same shape and one
470    of those tensors is all zeros, return the zero tensor instead.
471    """
472
473    def call_operator(
474        self,
475        op,  # pyre-ignore
476        args: tuple[Argument, ...],
477        kwargs: dict[str, Argument],
478        meta: NodeMetadata,
479    ) -> ProxyValue:
480        if op != exir_ops.edge.aten.mul.Tensor:
481            return super().call_operator(op, args, kwargs, meta)
482
483        # Parse the args
484        (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args)
485
486        # Check if both inputs have the same shape
487        if input1.to_tensor().shape != input2.to_tensor().shape:
488            return super().call_operator(op, args, kwargs, meta)
489
490        # Check if one of the inputs is a zero tensor
491        if input1.node.target == exir_ops.edge.aten.full.default:
492            if input1.node.args[1] == 0:
493                return input1
494        elif input2.node.target == exir_ops.edge.aten.full.default:
495            if input2.node.args[1] == 0:
496                return input2
497
498        return super().call_operator(op, args, kwargs, meta)
499
500
501@register_cadence_pass(CadencePassAttribute(opt_level=1))
502class RemoveNopAddOpPass(ExportPass):
503    """
504    If an add op is adding two tensors with the same shape and one
505    of those tensors is all zeros, return the other tensor instead.
506    """
507
508    def call_operator(
509        self,
510        op,  # pyre-ignore
511        args: tuple[Argument, ...],
512        kwargs: dict[str, Argument],
513        meta: NodeMetadata,
514    ) -> ProxyValue:
515        if op != exir_ops.edge.aten.add.Tensor:
516            return super().call_operator(op, args, kwargs, meta)
517
518        # Parse the args
519        (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args)
520
521        # Check if both inputs have the same shape
522        if input1.to_tensor().shape != input2.to_tensor().shape:
523            return super().call_operator(op, args, kwargs, meta)
524
525        # Check if one of the inputs is a zero tensor
526        if input1.node.target == exir_ops.edge.aten.full.default:
527            if input1.node.args[1] == 0:
528                return input2
529        elif input2.node.target == exir_ops.edge.aten.full.default:
530            if input2.node.args[1] == 0:
531                return input1
532
533        return super().call_operator(op, args, kwargs, meta)
534
535
536@register_cadence_pass(CadencePassAttribute(opt_level=1))
537class RemovePermutesAroundElementwiseOps(ExportPass):
538    """
539    Looks for subgraphs of elementwise ops sandwiched between permutes and removes those
540    permutes if possible. This pass is targeted at models where delegated subgraphs
541    must be in NHWC format, so there's usually a to_NHWC permute before each delegate and
542    a to_NCHW permute after it. If all the ops between two delegates are elementwise ops
543    then these permutes can be safely removed.
544    Allows special handling for certain non-elementwise ops that can be easily updated based on
545    the permute's parameter, such as mean and cat
546    """
547
548    @dataclass()
549    class Subgraph:
550        """
551        Keeps track of nodes grouped as a subgraph between two sets of permutes
552        """
553
554        start_permutes: set[torch.fx.Node] = field(default_factory=set)
555        end_permutes: set[torch.fx.Node] = field(default_factory=set)
556        intermediate_nodes: set[torch.fx.Node] = field(default_factory=set)
557        is_valid: bool = True
558
559    elementwise_ops: set[EdgeOpOverload] = {
560        exir_ops.edge.aten.add.Tensor,
561        exir_ops.edge.aten.mul.Tensor,
562        exir_ops.edge.aten.mean.dim,
563        exir_ops.edge.aten.cat.default,
564        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
565        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
566    }
567
568    # must be initialized in the constructor
569    special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {}
570
571    to_NCHW = [0, 3, 1, 2]
572    to_NHWC = [0, 2, 3, 1]
573
574    def __init__(self) -> None:
575        super().__init__()
576        self.visited: set[object] = set()
577        self.special_handling = {
578            exir_ops.edge.aten.mean.dim: self.handle_mean_dim,
579            exir_ops.edge.aten.cat.default: self.handle_cat,
580        }
581
582    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
583        self.visited = set()
584        for node in graph_module.graph.nodes:
585            sg = self.Subgraph()
586            self.start_search(node, sg)
587            if self.is_valid_subgraph(sg):
588                logging.debug(f"Found valid subgraph: {sg}")
589                self.handle_subgraph(graph_module, sg)
590
591        result = super().call(graph_module)
592        return result
593
594    def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None:
595        assert mean_dim.target == exir_ops.edge.aten.mean.dim
596        args = list(mean_dim.args)
597        args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])]
598        mean_dim.args = tuple(args)
599
600    def handle_cat(self, cat: torch.fx.Node) -> None:
601        assert cat.target == exir_ops.edge.aten.cat.default
602        args = list(cat.args)
603        args[1] = self.to_NCHW[cast(int, args[1])]
604        cat.args = tuple(args)
605
606    def is_valid_subgraph(self, sg: Subgraph) -> bool:
607        return (
608            sg.is_valid
609            and len(sg.start_permutes) > 0
610            and len(sg.end_permutes) > 0
611            and len(sg.intermediate_nodes) > 0
612        )
613
614    def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None:
615        for permute in itertools.chain(sg.start_permutes, sg.end_permutes):
616            permute.replace_all_uses_with(permute.args[0])  # pyre-fixme[6]
617
618        for node in sg.intermediate_nodes:
619            if node.target in self.special_handling:
620                self.special_handling[node.target](node)
621
622        graph_module.recompile()
623        graph_module.graph.eliminate_dead_code()
624
625    def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None:
626        if node in self.visited:
627            return
628
629        if self.is_starting_permute(node):
630            sg.start_permutes.add(node)
631            self.visited.add(node)
632            for user in node.users:
633                self.search_down(user, sg)
634
635    def search_up(self, node: object, sg: Subgraph) -> None:
636        # non-nodes can be ignored. These would be arguments like integers or lists
637        # of integers, which don't affect the subgraph validity or inclusion set.
638        if not isinstance(node, torch.fx.Node):
639            return
640
641        if node.op == "placeholder":
642            # If we reach a placeholder or other terminal node without encountering
643            # a start permute, then the subgraph is invalid.
644            # This could be because in the add(x, y) case where x is permuted and
645            # y is a graph input, we can't remove the permute on x because it might
646            # become two different shapes that don't broadcast together.
647            # TODO: Adding a permute on y could be the more optimal solution,
648            # but perhaps not in all cases, say if x is small and y is very large.
649            # This transform prefers to be safe over optimal for now.
650            sg.is_valid = False
651            return
652
653        if node in self.visited:
654            return
655
656        self.visited.add(node)
657
658        if self.is_starting_permute(node):
659            sg.start_permutes.add(node)
660            for user in node.users:
661                self.search_down(user, sg)
662        else:
663            self.traverse_intermediate_node(node, sg)
664
665    def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None:
666        if node in self.visited or self.is_starting_permute(node):
667            return
668
669        self.visited.add(node)
670
671        if self.is_ending_permute(node):
672            sg.end_permutes.add(node)
673            for arg in node.args:
674                if isinstance(arg, list):
675                    for elem in arg:
676                        self.search_up(elem, sg)
677                else:
678                    self.search_up(arg, sg)
679        else:
680            self.traverse_intermediate_node(node, sg)
681
682    def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None:
683        if node.target in self.elementwise_ops:
684            sg.intermediate_nodes.add(node)
685            for arg in node.args:
686                if isinstance(arg, list):
687                    for elem in arg:
688                        self.search_up(elem, sg)
689                else:
690                    self.search_up(arg, sg)
691
692            for user in node.users:
693                self.search_down(user, sg)
694
695        else:
696            sg.is_valid = False
697
698    def is_starting_permute(self, node: torch.fx.Node) -> bool:
699        return (
700            node.target == exir_ops.edge.aten.permute_copy.default
701            and cast(list[int], node.args[1]) == self.to_NCHW
702        )
703
704    def is_ending_permute(self, node: torch.fx.Node) -> bool:
705        return (
706            node.target == exir_ops.edge.aten.permute_copy.default
707            and cast(list[int], node.args[1]) == self.to_NHWC
708        )
709
710
711# The following class consolidates functions to remove ops that are redundant
712# in Jarvis. Currently, each function in this class iterates over each node of
713# the graph module once. In future, we could consolidate them into a monolithic
714# function.
715class CadenceRemoveNops:
716    passes = [
717        SimplifySliceOpPass,
718        RemoveCloneOpsTransformImported,
719        RemoveToOpsPass,
720        RemoveNopRequantizeOpPass,
721        RemoveZeroSizedCatArgsPass,
722        RemoveNopSliceOrViewOpPass,
723        RemoveNopExpandOpPass,
724        RemoveZeroSizedConstantPadNd,
725        RemoveCloneOpPass,
726        RemoveContiguousOpPass,
727        RemoveAliasCopyOpPass,
728        RemoveNopMulOpPass,
729        RemoveNopAddOpPass,
730        RemoveNopLinalgVectorNormOpPass,
731    ]
732