xref: /aosp_15_r20/external/executorch/backends/cadence/aot/reorder_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-unsafe
4
5
6# This file contains all the functions that reorder ops in the graph module.
7
8import copy
9from collections import defaultdict
10from math import prod
11from typing import cast, DefaultDict, List, Set, Tuple
12
13import torch
14import torch.fx
15from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape
16from executorch.backends.cadence.aot.pass_utils import (
17    CadencePassAttribute,
18    get_overload_packet,
19    register_cadence_pass,
20)
21from executorch.backends.cadence.aot.utils import get_edge_overload_packet
22from executorch.exir.dialects._ops import ops as exir_ops
23from executorch.exir.dialects.edge._ops import EdgeOpOverload
24from executorch.exir.pass_base import ExportPass, PassResult
25from executorch.exir.tensor import num_bytes_from_shape_and_dtype
26
27# A list of ops that can be trivially quantized
28trivially_quantizable_ops_overloadpkt = {
29    torch.ops.aten.slice_copy,
30    torch.ops.aten.slice,
31    torch.ops.aten.view_copy,
32    torch.ops.aten.view,
33    torch.ops.aten.clone,
34    torch.ops.aten.transpose_copy,
35    torch.ops.aten.transpose,
36    torch.ops.aten.permute_copy,
37    torch.ops.aten.permute,
38    torch.ops.aten.squeeze_copy,
39    torch.ops.aten.squeeze,
40    torch.ops.aten.unsqueeze_copy,
41    torch.ops.aten.unsqueeze,
42    torch.ops.aten.chunk,
43    torch.ops.aten.contiguous,
44    torch.ops.aten.select_copy,
45    exir_ops.edge.aten.slice_copy,
46    exir_ops.edge.aten.view_copy,
47    exir_ops.edge.aten.clone,
48    exir_ops.edge.aten.transpose_copy,
49    exir_ops.edge.aten.permute_copy,
50    exir_ops.edge.aten.squeeze_copy,
51    exir_ops.edge.aten.unsqueeze_copy,
52    exir_ops.edge.aten.unfold_copy,
53    exir_ops.edge.aten.chunk,
54    exir_ops.edge.aten.contiguous,
55    exir_ops.edge.aten.select_copy,
56}
57
58# slice-equivalent ops
59slice_or_select_overloadpkt = {
60    torch.ops.aten.slice_copy,
61    torch.ops.aten.select_copy,
62    exir_ops.edge.aten.slice_copy,
63    exir_ops.edge.aten.select_copy,
64}
65
66
67@register_cadence_pass(CadencePassAttribute(opt_level=2))
68class AdvanceQuantizeOpAboveDefInBranchPass(ExportPass):
69    """
70    If the graph is branched with the following pattern:
71    I = ...
72    S1 = slice(I)
73    Q1 = quantize(S1)
74    S2 = slice(I)
75    Q2 = quantize(S2)
76    S3 = slice(I)
77    Q3 = quantize(S3)
78    ...
79    such that the elements in the slices S1 + S2 + S3 is greater than I,
80    we can advance the quantize above their defs (i.e., all the slice nodes),
81    and reorder the pattern to the following:
82    I = ...
83    Q1 = quantize(I)
84    S1 = slice(Q1)
85    Q1 = requantize(S1)
86    S2 = slice(Q1)
87    Q2 = requantize(S2)
88    S3 = slice(Q1)
89    Q3 = requantize(S3)
90    ...
91    Note that the other passes won't do this transformation because they expect
92    a linear chain of def-use, which is not true here; the uses of I are
93    branched.
94    """
95
96    def __init__(self):
97        super().__init__()
98        self.graph_module = None
99
100    # Starting at node, iterate through its successors, bypassing any trivially
101    # quantizable op. If all the descendents are quantize ops, return them.
102    def get_descendent_quant_ops(self, node: torch.fx.Node) -> List[torch.fx.Node]:
103        # The list of quant ops that are descendents of node, such that the only
104        # nodes in the path from node --> quant are trivially quantizable ops.
105        descendent_quant_ops = []
106        # The list of trivially quantizable ops in the path from node --> quant op.
107        trivial_quantized_ops = []
108
109        users = list(node.users.keys())
110        while users:
111            user = users.pop(0)
112            user_target = get_overload_packet(user.target)
113            # Record a quant op successor
114            if user_target in {
115                torch.ops.quantized_decomposed.quantize_per_tensor,
116                exir_ops.edge.quantized_decomposed.quantize_per_tensor,
117            }:
118                descendent_quant_ops.append(user)
119            # If the successor is a trivially quantizable op, consider its users
120            # instead.
121            elif user_target in trivially_quantizable_ops_overloadpkt:
122                trivial_quantized_ops.append(user)
123                users.extend(list(user.users.keys()))
124            # Otherwise all successors of node are not quant op, so break the loop.
125            else:
126                descendent_quant_ops.clear()
127                break
128
129        # If all the nodes in trivial_quantize_ops of the node were slice ops,
130        # ensure that the advance is still profitable.
131        if descendent_quant_ops and all(
132            get_overload_packet(x.target) in slice_or_select_overloadpkt
133            for x in trivial_quantized_ops
134        ):
135            # Profitability metric: the sum of all the output slices must be at
136            # least half the input node slice.
137            slice_sizes = [
138                prod(list(y))
139                for x in trivial_quantized_ops
140                if (y := get_shape(self.graph_module, x)) is not None
141            ]
142            node_shape = get_shape(self.graph_module, node)
143            node_size = prod(list(node_shape)) if node_shape is not None else 0
144            if node_size > 2 * sum(slice_sizes):
145                descendent_quant_ops.clear()
146
147        return descendent_quant_ops
148
149    def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
150        graph = graph_module.graph
151        for node in graph.nodes:
152            # We are only interested in call functions and placeholders
153            if node.op not in {"placeholder", "call_function"}:
154                continue
155            # If the node is trivially quantizable, skip it
156            if (
157                get_overload_packet(node.target)
158                in trivially_quantizable_ops_overloadpkt
159            ):
160                continue
161            # Get the descendent quant ops that are connected to the current
162            # node via trivially quantizable ops.
163            descendent_quant_ops = self.get_descendent_quant_ops(node)
164            if not descendent_quant_ops:
165                continue
166
167            # Get the insertion point below which we need to insert anything.
168            # if node is a placeholder, we will only insert a new node after
169            # all the placeholders in the graph.
170            insertion_pt = (
171                get_placeholders(graph)[-1] if node.op == "placeholder" else node
172            )
173
174            # If the node only has a single quant op as descendent, we can
175            # simply hoist the quant op below the current node as its single
176            # child.
177            if len(descendent_quant_ops) == 1:
178                quant_node = descendent_quant_ops.pop()
179                # Replace the uses of quant node with its predecessor
180                quant_node.replace_all_uses_with(quant_node.args[0])  # pyre-fixme[6]
181                # Hoist the quant node after the current node. Make sure that
182                # the insertion is after placeholders
183                with graph.inserting_after(insertion_pt):
184                    dom_quant_args = (node,) + quant_node.args[1:]
185                    dom_quant_node = graph.call_function(
186                        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
187                    )
188                    dom_quant_node.meta = node.meta
189                    node.replace_all_uses_with(dom_quant_node)
190                    dom_quant_node.args = dom_quant_args
191                graph.erase_node(quant_node)
192                continue
193
194            # Otherwise we have the quant descendents. Cluster them into sets
195            # that have the same scale, zero_point, and dtype. We use quant_dict
196            # for the clustering
197            quant_dict: DefaultDict[Tuple, int] = defaultdict(int)
198            for quant_node in descendent_quant_ops:
199                quant_dict[quant_node.args[1:]] += 1
200            rep_args = sorted(quant_dict.keys(), key=lambda x: x[1]).pop()
201
202            # Create a new quant node that dominates all the nodes in
203            # descendent_quant_ops. Make sure that the insertion is after
204            # all the placeholders.
205            with graph.inserting_after(insertion_pt):
206                dom_quant_args = (node,) + rep_args
207                dom_quant_node = graph.call_function(
208                    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
209                )
210                dom_quant_node.meta = node.meta
211                node.replace_all_uses_with(dom_quant_node)
212                dom_quant_node.args = dom_quant_args
213
214            # Finally, convert each of the quant node to a dequant/quant pair that
215            # requantizes the data flowing through dom_quant_node.
216            # TODO: Once requantize is implemented for PT2, replace the
217            # dequant/quant pair here with a single requantize node
218            for quant_node in descendent_quant_ops:
219                with graph.inserting_before(quant_node):
220                    dequant_node = graph.call_function(
221                        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
222                    )
223                    dequant_node.args = (quant_node.args[0],) + rep_args
224                    quant_node.args = (dequant_node,) + quant_node.args[1:]
225
226        graph_module.recompile()
227        graph_module.graph.eliminate_dead_code()
228
229    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
230        self.graph_module = graph_module
231        self.advance_quantize_op(graph_module)
232        result = super().call(graph_module)
233        return result
234
235
236@register_cadence_pass(CadencePassAttribute(opt_level=1))
237class AdvanceQuantizeOpAboveDefChainPass(ExportPass):
238    """
239    If the input to quantize op is linear chain of view, transpose, permute, or
240    slice ops that are trivially quantized, we can convert the pattern
241    view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to
242    quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8).
243    The benefit of such reordering is that the view/transpose/permute/slice
244    will move far less data.
245    """
246
247    def __init__(self):
248        super().__init__()
249        self.graph_module = None
250
251    # Return true if advancing the quantize node is feasible
252    def advancing_feasible(self, quant_node: torch.fx.Node):
253        assert quant_node.op == "call_function" and len(quant_node.args) >= 1
254        # Get the input of the quant node. Only proceed if it's a torch node.
255        inp = quant_node.args[0]
256        if not isinstance(inp, torch.fx.Node):
257            return False
258
259        # Return false if the input to the quantize node is (1) not trivially
260        # quantizable, or (2) has more than one user.
261        inp_users = list(inp.users.keys())
262        inp_overloadpkt = None
263        if isinstance(inp.target, EdgeOpOverload):
264            inp_overloadpkt = get_edge_overload_packet(inp.target)
265        else:
266            inp_overloadpkt = get_overload_packet(inp.target)
267
268        if (
269            inp_overloadpkt not in trivially_quantizable_ops_overloadpkt
270            or len(inp_users) != 1
271        ):
272            return False
273
274        # Advancing quantize op above slice nodes is tricky. If we advance the
275        # quantize node above slice, then we will quantize the input to the slice
276        # op, which can be expensive. We only bypass nop slice at present.
277        if inp_overloadpkt in slice_or_select_overloadpkt:
278            sliced_tensor = inp.args[0]
279            assert isinstance(sliced_tensor, torch.fx.Node)
280            slice_input_shape = get_shape(self.graph_module, sliced_tensor)
281            slice_output_shape = get_shape(self.graph_module, inp)
282            # If we could not glean the shapes, or the slice op is a nop, bail
283            if (
284                slice_output_shape is None
285                or slice_input_shape is None
286                or prod(list(slice_output_shape)) < prod(list(slice_input_shape))
287            ):
288                return False
289
290        # All the conditions satisfied, we advance.
291        return True
292
293    def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
294        graph = graph_module.graph
295        for node in reversed(graph.nodes):
296            if get_overload_packet(node.target) not in (
297                exir_ops.edge.quantized_decomposed.quantize_per_tensor,
298                torch.ops.quantized_decomposed.quantize_per_tensor,
299            ):
300                continue
301
302            if not self.advancing_feasible(node):
303                continue
304
305            trivially_quantizable_op = node.args[0]
306            # The input to the quant node must now be the input to the trivially
307            # quantizable op.
308            quant_args = list(node.args)
309            quant_args[0] = trivially_quantizable_op.args[0]
310
311            # Insert the new quant node with updated args before the current
312            # quant node.
313            with graph.inserting_before(node):
314                quant_node = graph.call_function(node.target, args=tuple(quant_args))
315                quant_node.meta = node.meta
316            # Move the trivially quantizable node after the quant node
317            with graph.inserting_after(node):
318                tq_args = list(trivially_quantizable_op.args)
319                tq_args[0] = quant_node
320                tq_node = graph.call_function(
321                    trivially_quantizable_op.target,
322                    args=tuple(tq_args),
323                    kwargs=trivially_quantizable_op.kwargs,
324                )
325                tq_node.meta = trivially_quantizable_op.meta
326            # Replace all uses of node with newly created tq_node
327            node.replace_all_uses_with(tq_node)
328            # We can safely remove the quant node and trivially quantizable op
329            graph.erase_node(node)
330            graph.erase_node(trivially_quantizable_op)
331
332        graph_module.recompile()
333        graph_module.graph.eliminate_dead_code()
334
335    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
336        self.graph_module = graph_module
337        self.advance_quantize_op(graph_module)
338        result = super().call(graph_module)
339        return result
340
341
342@register_cadence_pass(CadencePassAttribute(opt_level=1))
343class PostponeDequantizeOpBelowUseChainPass(ExportPass):
344    """
345    If the consumer of dequantize is a linear chain of view, transpose, permute,
346    or slice ops that are trivially quantized, we can convert the pattern
347    dequantize(int8/uint8) -> view/transpose/permute/slice(fp32) to
348    view/transpose/permute/slice(int8/uint8) -> dequantize(int8/uint8)
349    The benefit of such reordering is that the view/transpose/permute/slice
350    will move far less data.
351    """
352
353    def __init__(self):
354        super().__init__()
355        self.graph_module = None
356
357    # Return true if postponing the dequantize node is feasible
358    def postponing_feasible(self, dequant_node: torch.fx.Node):
359        users = list(dequant_node.users.keys())
360        # Check if the dequantize op has a single user, and that user is
361        # trivially quantizable.
362        trivially_quantizable_users = all(
363            get_overload_packet(user.target) in trivially_quantizable_ops_overloadpkt
364            for user in users
365        )
366        if len(users) == 1:
367            return trivially_quantizable_users
368
369        # Otherwise check if all the users are slice op
370        if not all(
371            get_overload_packet(user.target) in slice_or_select_overloadpkt
372            for user in users
373        ):
374            return False
375
376        dequant_shape = get_shape(self.graph_module, dequant_node)
377        slice_shapes = [
378            shape
379            for user in users
380            if (shape := get_shape(self.graph_module, user))
381            and (
382                # skip slices that are the size of the sliced tensor itself.
383                # They should technically get removed in the later passes as nop.
384                shape is None
385                or dequant_shape is None
386                or prod(list(shape)) != prod(list(dequant_shape))
387            )
388        ]
389
390        if dequant_shape is not None and all(
391            shape is not None for shape in slice_shapes
392        ):
393            dequant_bytes = num_bytes_from_shape_and_dtype(dequant_shape, torch.float32)
394            slice_bytes = sum(
395                [
396                    num_bytes_from_shape_and_dtype(shape, torch.float32)
397                    for shape in slice_shapes
398                ]
399            )
400            if slice_bytes <= dequant_bytes:
401                return True
402
403        # If the users of each slice op is quantize op, then we can postpone
404        # dequantize, and convert slice -> dequantize -> quantize to
405        # slice -> requantize.
406        users = [x for y in users for x in y.users if x.op != "output"]
407        return all(
408            get_overload_packet(x.target)
409            in {
410                exir_ops.edge.quantized_decomposed.quantize_per_tensor,
411                exir_ops.edge.quantized_decomposed.quantize_per_channel,
412            }
413            for x in users
414        )
415
416    def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
417        # Different supported dequant ops have their own default variants
418        packet_to_overload_map = {
419            exir_ops.edge.quantized_decomposed.dequantize_per_tensor: "default",
420            exir_ops.edge.quantized_decomposed.dequantize_per_channel: "default",
421        }
422        graph = graph_module.graph
423        modified = False
424        for node in graph.nodes:
425            overload_packet = get_overload_packet(node.target)
426            if (
427                overload_packet not in packet_to_overload_map.keys()
428                or not self.postponing_feasible(node)
429            ):
430                continue
431
432            for user in node.users:
433                with graph.inserting_after(user):
434                    dequant_node = graph.call_function(
435                        getattr(
436                            overload_packet, packet_to_overload_map[overload_packet]
437                        ),
438                        args=(user, *node.args[1:]),
439                    )
440                    dequant_node.meta = user.meta.copy()
441                    # Remove meta["debug_handle"] on new node. Reassign it at the
442                    # caller level by calling generate_missing_debug_handles
443                    dequant_node.meta.pop("debug_handle")
444                    user.replace_all_uses_with(dequant_node)
445                    dequant_node.args = (user, *node.args[1:])
446
447            pred = node.args[0]
448            node.replace_all_uses_with(pred)
449            graph.erase_node(node)
450            modified = True
451
452        graph_module.recompile()
453        return modified
454
455    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
456        # The logic in postpone_dequantize_op that handles branching checks the shape
457        # of the dequant node, which isn't available if that node was already postponed
458        # in the same pass invokation. The shape information is recreated by tracing in
459        # super().call(), meaning that every branch in the graph that we wish to postpone
460        # dequant past requires retracing. We iterate the pass until it no longer modifies
461        # the graph (up to 3 times max, to avoid potential infinite loops)
462        self.graph_module = graph_module
463        iter_count = 0
464        modified = True
465
466        while modified and iter_count < 3:
467            modified = self.postpone_dequantize_op(self.graph_module)
468            self.graph_module = super().call(self.graph_module).graph_module
469            iter_count += 1
470
471        return super().call(self.graph_module)
472
473
474@register_cadence_pass(CadencePassAttribute(opt_level=1))
475class SinkOpsCloserToUsePass(ExportPass):
476    """
477    Assume that the dequantize op D = dequantize(I) has only a single user.
478    If the current graph looks like
479    I = ...;
480    D = dequantize(I);
481    ...
482    Y = use(D);
483    then we can postpone the dequantize op closer to its use, and convert the
484    graph to:
485    I = ...;
486    ...
487    D = dequantize(I);
488    Y = use(D);
489
490    The transformation is valid since D had a single user. The benfit comes from
491    the fact that now we have I in the live range instead of D, which has a
492    much smaller size.
493    """
494
495    sinkable_ops: Set[EdgeOpOverload] = {
496        exir_ops.edge.aten.dequantize,
497        exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
498        exir_ops.edge.quantized_decomposed.dequantize_per_channel,
499    }
500
501    def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule):
502        graph = graph_module.graph
503        # We are only interested in sinkable nodes
504        sinkable_nodes = [
505            node
506            for node in graph.nodes
507            if isinstance(node.target, EdgeOpOverload)
508            and get_edge_overload_packet(node.target) in self.sinkable_ops
509        ]
510        for node in sinkable_nodes:
511            # The sinkable node must have a single user
512            users = list(node.users.keys())
513            if len(users) != 1:
514                continue
515
516            # Insert the dequant node just before its user
517            with graph.inserting_before(users[0]):
518                new_node = graph.call_function(
519                    node.target, args=node.args, kwargs=node.kwargs
520                )
521                new_node.meta = node.meta
522            node.replace_all_uses_with(new_node)
523            graph.erase_node(node)
524
525        graph_module.recompile()
526
527    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
528        self.sink_ops_closer_to_use(graph_module)
529        result = super().call(graph_module)
530        return result
531
532
533@register_cadence_pass(CadencePassAttribute(opt_level=1))
534class HoistOpsCloserToDefPass(ExportPass):
535    """
536    Assume that the input I to a quantize op Q = quantize(I) has only a single
537    use, the quantize node itself.
538    If the current graph looks like
539    I = ...;
540    ...
541    Q = quantize(I);
542    X = use(Q);
543    then we can hoist the quantize op closer to its def, and convert the
544    graph to:
545    I = ...;
546    Q = quantize(I);
547    ...
548    X = use(Q);
549
550    The transformation is valid since I had a single user. The benefit comes from
551    the fact that now we have Q in the live range instead of I, which has a
552    much smaller size. The same transformation also applies to slice/select op.
553    """
554
555    hoistable_ops: Set[EdgeOpOverload] = {
556        exir_ops.edge.quantized_decomposed.quantize_per_tensor,
557        exir_ops.edge.aten.slice_copy,
558        exir_ops.edge.aten.select_copy,
559    }
560
561    def hoist_ops_closer_to_def(self, graph_module: torch.fx.GraphModule):
562        graph = graph_module.graph
563        # We are only interested in hoistable nodes
564        hoistable_nodes = [
565            node
566            for node in graph.nodes
567            if isinstance(node.target, EdgeOpOverload)
568            and get_edge_overload_packet(node.target) in self.hoistable_ops
569        ]
570        for node in hoistable_nodes:
571            def_node = node.args[0]
572            if not isinstance(def_node, torch.fx.Node):
573                continue
574            # The def node must have a single user
575            users = list(def_node.users.keys())
576            if len(users) != 1:
577                continue
578
579            # Get the node args as list
580            args = list(node.args)
581
582            # If the graph has placeholders, we do not want to hoist above the
583            # last placeholder. Otherwise we will shrink the live range of the
584            # def_node considerably, which could lead to reuse of input memory.
585            def_node = (
586                get_placeholders(graph)[-1]
587                if def_node.op == "placeholder"
588                else def_node
589            )
590
591            # If the node is quantize_per_channel, we need to hoist the scale
592            # and zero_point tensors as well.
593            if (
594                node.target
595                == exir_ops.edge.quantized_decomposed.quantize_per_channel.default
596            ):
597                scale, zero_point = args[1], args[2]
598                with graph.inserting_after(def_node):
599                    zero_point_copy = graph.node_copy(zero_point)
600                    scale_copy = graph.node_copy(scale)
601                    args[1], args[2] = scale_copy, zero_point_copy
602                    def_node = zero_point_copy
603
604            # Insert the quant node just after def_node
605            with graph.inserting_after(def_node):
606                new_node = graph.call_function(
607                    node.target, args=tuple(args), kwargs=node.kwargs
608                )
609                new_node.meta = node.meta
610            node.replace_all_uses_with(new_node)
611            graph.erase_node(node)
612
613        # Eliminate dead code
614        graph_module.recompile()
615        graph_module.graph.eliminate_dead_code()
616
617    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
618        self.hoist_ops_closer_to_def(graph_module)
619        result = super().call(graph_module)
620        return result
621
622
623@register_cadence_pass(CadencePassAttribute(opt_level=1))
624class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(ExportPass):
625    """
626    A common pattern seen in transformer models.  If the consumer of permute
627    is a view op, swap their order so permute is below view.
628    Change "permute -> view" to "view -> permute"
629    This is to optimize a chain of view->permute->view->permute...
630    so that the chain will be become view->v...->view->permute->p...->permute.
631    The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and
632    FuseCascadedViewOps().
633    Notice the class name has ViewSqueeze to indicate the View is
634    functionally the same as a squeeze or unsqueeze. It does not necessarily
635    mean the view_copy is normalized from squeeze or unsqueeze.
636    """
637
638    def __init__(self):
639        super().__init__()
640        self.graph_module = None
641
642    # If list1 and list2 are same (same values and in same order) except
643    # list1 has one more element with value of 1. Return index of the extra 1.
644    # Otherwise return -1.
645    def check_if_shapes_differ_in_single_dim_of_size_1(self, list1, list2) -> int:
646        if len(list1) != len(list2) + 1:
647            return -1
648        for i in range(len(list2)):
649            if list1[i] != list2[i]:
650                # Return index of the extra 1 if the remaining parts are the same
651                if list1[i] == 1 and list2[i:] == list1[i + 1 :]:
652                    return i
653                else:
654                    return -1
655        # If no difference was found, the extra element is at the end
656        if list1[-1] == 1:
657            return len(list2)
658        else:
659            return -1
660
661    def insert_nodes(
662        self,
663        graph: torch.fx.Graph,
664        pred: torch.fx.Node,
665        permute_node: torch.fx.Node,
666        view_node: torch.fx.Node,
667        new_view_shape: List,
668        new_permute_dims: List,
669    ):
670        with graph.inserting_after(view_node):
671            new_view_node = graph.call_function(
672                view_node.target,  # pyre-fixme[6]
673                args=(pred, new_view_shape),
674            )
675
676        with graph.inserting_after(new_view_node):
677            new_permute_node = graph.call_function(
678                permute_node.target,  # pyre-fixme[6]
679                args=(new_view_node, new_permute_dims),
680            )
681            new_permute_node.meta = view_node.meta
682            view_node.replace_all_uses_with(new_permute_node)
683
684        # view_node is user of permute_node, so must erase view_node first
685        graph.erase_node(view_node)
686        graph.erase_node(permute_node)
687
688    # flake8: noqa 'PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView.postpone_permute_op' is too complex (13)
689    def postpone_permute_op(self, graph_module: torch.fx.GraphModule):
690        packet_to_overload_map = {
691            exir_ops.edge.aten.permute_copy: "default",
692        }
693        graph = graph_module.graph
694        changed = True
695        modified = False
696        # Loop iteratively until no more changes are made
697        while changed:
698            changed = False
699            for permute_node in graph.nodes:
700                permute_overload_packet = get_overload_packet(permute_node.target)
701                if permute_overload_packet not in packet_to_overload_map.keys():
702                    continue
703
704                users = list(permute_node.users.keys())
705                # Transform only for pattern permute_copy->view_copy, and
706                # view_copy op is the only user of permute_copy.
707                if len(users) == 1 and users[0].target in (
708                    exir_ops.edge.aten.view_copy.default,
709                    exir_ops.edge.aten.view.default,
710                ):
711                    # If the permute_node/view_node was newly added to the
712                    # graph, it may not have the meta["val"] FakeTensor.
713                    # Skip in this case.
714                    if permute_node.meta.get("val") is None:
715                        continue
716                    permute_node_shape = [
717                        *cast(list, get_shape(graph_module, permute_node))
718                    ]
719                    permute_dims = permute_node.args[1]
720                    view_node = users[0]
721                    if view_node.meta.get("val") is None:
722                        continue
723                    view_node_shape = [*cast(list, get_shape(graph_module, view_node))]
724                    pred = permute_node.args[0]
725                    if pred.meta.get("val") is None:
726                        continue
727                    pred_shape = [*cast(list, get_shape(graph_module, pred))]
728                    # Handle two cases
729                    # 1. view_node_shape is almost same as permute_node_shape
730                    #    except the view_node has one more dim somewhere
731                    #    and the extra dim has value of 1.
732                    # 2. view_node_shape is almost same as permute_node_shape
733                    #    except permute_node_shape has one more dim somewhere
734                    #    and the extra dim has value of 1.
735                    # 3. view_node_shape is the same as permute_node_shape.
736                    if len(permute_node_shape) + 1 == len(view_node_shape):
737                        index = self.check_if_shapes_differ_in_single_dim_of_size_1(
738                            view_node_shape, permute_node_shape
739                        )
740                        if index != -1:
741                            # view_node_shape is almost same as permute_node_shape
742                            # except it has one more dim somewhere
743                            # and the extra dim has value of 1.
744                            new_view_shape = copy.deepcopy(pred_shape)
745                            new_view_shape.insert(index, 1)
746                            new_permute_dims = [
747                                x + 1 if x >= index else x for x in permute_dims
748                            ]
749                            new_permute_dims.insert(index, index)
750                            self.insert_nodes(
751                                graph,
752                                pred,
753                                permute_node,
754                                view_node,
755                                new_view_shape,
756                                new_permute_dims,
757                            )
758                            changed = True
759                            modified = True
760                    elif len(view_node_shape) + 1 == len(permute_node_shape):
761                        index = self.check_if_shapes_differ_in_single_dim_of_size_1(
762                            permute_node_shape, view_node_shape
763                        )
764                        if index != -1:
765                            # view_node_shape is almost same as permute_node_shape
766                            # except permute_node_shape has one more dim somewhere
767                            # and the extra dim has value of 1.
768                            index_to_remove = permute_dims[index]
769                            new_view_shape = copy.deepcopy(pred_shape)
770                            del new_view_shape[index_to_remove]
771                            new_permute_dims = [
772                                x - 1 if x > index_to_remove else x
773                                for x in permute_dims
774                            ]
775                            del new_permute_dims[index]
776                            self.insert_nodes(
777                                graph,
778                                pred,
779                                permute_node,
780                                view_node,
781                                new_view_shape,
782                                new_permute_dims,
783                            )
784                            changed = True
785                            modified = True
786                    elif permute_node_shape == view_node_shape:
787                        # view_node_shape is the same as permute_node_shape
788                        # Replace the uses of view_node with permute_node
789                        view_node.replace_all_uses_with(permute_node)
790                        changed = True
791                        modified = True
792
793        graph_module.recompile()
794        return modified
795
796    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
797        self.graph_module = graph_module
798        iter_count = 0
799        modified = True
800
801        while modified and iter_count <= 3:
802            modified = self.postpone_permute_op(self.graph_module)
803            self.graph_module = super().call(self.graph_module).graph_module
804            iter_count += 1
805
806        return super().call(self.graph_module)
807
808
809# The following class consolidates functions to reoder ops (i.e., either hoist
810# or sink some ops in the graph).
811class CadenceReorderOpsInGraph:
812    passes = [
813        # Hoist/sink nodes closer to their SSA def/use
814        HoistOpsCloserToDefPass,
815        SinkOpsCloserToUsePass,
816        # For quantize/dequantize ops, move them above/below their def chain.
817        # This is a more aggressive optimization than just hoisting/sinking
818        # nodes closer to their def/use.
819        AdvanceQuantizeOpAboveDefChainPass,
820        PostponeDequantizeOpBelowUseChainPass,
821        # These passes work on branches instead of linear chains to advance
822        # quantize op beyond their def.
823        AdvanceQuantizeOpAboveDefInBranchPass,
824    ]
825