xref: /aosp_15_r20/external/executorch/exir/tests/test_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import copy
9import os
10import tempfile
11import unittest
12from typing import List, Optional, Tuple
13
14import executorch.exir as exir
15
16# Import passes
17import executorch.exir.memory_planning  # noqa
18import torch
19from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
20from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
21from executorch.exir.dialects.edge._ops import EdgeOpOverload
22from executorch.exir.emit import emit_program
23from executorch.exir.graph_module import get_control_flow_submodules
24from executorch.exir.pass_base import ExportPass, PassResult
25from executorch.exir.passes import (
26    dead_code_elimination_pass,
27    DebugPass,
28    HintBasedSymShapeEvalPass,
29    MemoryPlanningPass,
30    propagate_dynamic_shape,
31    RemoveNoopPass,
32    ReplaceSymSizeOpPass,
33    ToOutVarPass,
34)
35from executorch.exir.passes.constant_prop_pass import constant_prop_pass
36from executorch.exir.passes.debug_handle_generator_pass import (
37    DebugHandleGeneratorPass,
38    generate_missing_debug_handles,
39)
40from executorch.exir.passes.insert_write_back_for_buffers_pass import (
41    insert_write_back_for_buffers_pass,
42)
43
44from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
45from executorch.exir.passes.normalize_view_copy_base_pass import (
46    NormalizeViewCopyBasePass,
47)
48from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
49from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
50from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
51from executorch.exir.passes.replace_view_copy_with_view_pass import (
52    ReplaceViewCopyWithViewPass,
53)
54from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
55from executorch.exir.passes.spec_prop_pass import SpecPropPass
56from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
57from executorch.exir.program._program import lift_constant_tensor_pass
58from executorch.exir.schema import TensorShapeDynamism
59from executorch.exir.tensor import TensorSpec
60from executorch.exir.tests.common import register_additional_test_aten_ops
61from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
62from executorch.exir.tests.models import MLP, Mul
63from functorch.experimental import control_flow
64
65from torch import nn
66
67from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
68from torch.ao.quantization.quantizer.xnnpack_quantizer import (
69    get_symmetric_quantization_config,
70    XNNPACKQuantizer,
71)
72from torch.export import export
73from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
74from torch.fx import GraphModule, subgraph_rewriter
75from torch.fx.experimental.proxy_tensor import make_fx
76from torch.library import impl, Library
77from torch.testing import FileCheck
78from torch.utils import _pytree as pytree
79
80
81# pyre-ignore
82def collect_ops(gm: torch.fx.GraphModule):
83    """
84    Collect all targets for call_function nodes from the graph module recursively.
85    """
86    ops = set()
87    for subgm in gm.modules():
88        if not isinstance(subgm, torch.fx.GraphModule):
89            continue
90        for node in subgm.graph.nodes:
91            if node.op == "call_function":
92                ops.add(node.target)
93    return ops
94
95
96lib = Library("DO_NOT_USE_TEST_ONLY", "DEF")
97
98lib.define("foo(Tensor self) -> (Tensor, Tensor)")
99lib.define("add_relu(Tensor self, Tensor other) -> Tensor")
100
101
102@impl(lib, "foo", "CompositeExplicitAutograd")
103def foo(a: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
104    return a + 1, None
105
106
107lib.define(
108    "foo.out(Tensor self, *, Tensor(a!) out1, Tensor(b!) out2) -> (Tensor(a!), Tensor(b!))"
109)
110
111
112@impl(lib, "foo.out", "CompositeExplicitAutograd")
113def foo_out(
114    a: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor
115) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
116    return a + 1, None
117
118
119class TestPasses(unittest.TestCase):
120    @classmethod
121    def setUpClass(cls) -> None:
122        register_additional_test_aten_ops()
123
124    def test_remove_mixed_type_operators(self) -> None:
125        class Add(torch.nn.Module):
126            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
127                return (x + y) + x
128
129        add = Add()
130
131        int_tensor = torch.tensor([[1, 2, 3]])
132        float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
133        edge_prog = to_edge(
134            export(
135                add,
136                (int_tensor, float_tensor),
137            )
138        )
139
140        new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
141        new_graph_module = new_prog.exported_program().graph_module
142        self.assertIsNotNone(new_graph_module)
143
144        add_count = 0
145
146        for node in new_graph_module.graph.nodes:
147            if (
148                node.op == "call_function"
149                and node.target == exir_ops.edge.aten.add.Tensor
150            ):
151                add_count += 1
152                node_args = node.args
153                for arg in node_args:
154                    self.assertEqual(arg.meta["val"].dtype, torch.float)
155
156        self.assertEqual(add_count, 2)
157
158        double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
159        double_tensor = double_tensor.to(torch.double)
160
161        double_prog = to_edge(export(add, (int_tensor, double_tensor)))
162
163        double_prog.transform([RemoveMixedTypeOperators()])
164        new_graph_module_double = double_prog.exported_program().graph_module
165        self.assertIsNotNone(new_graph_module_double)
166
167        add_count_double = 0
168
169        for node in new_graph_module_double.graph.nodes:
170            if (
171                node.op == "call_function"
172                and node.target == exir_ops.edge.aten.add.Tensor
173            ):
174                add_count_double += 1
175                node_args = node.args
176                for arg in node_args:
177                    self.assertEqual(arg.meta["val"].dtype, torch.double)
178
179        self.assertEqual(add_count_double, 2)
180
181        class Mult(torch.nn.Module):
182            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
183                return x * y
184
185        mult = Mult()
186
187        float_tensor_vert = float_tensor.T
188        mult_prog = to_edge(
189            export(
190                mult,
191                (int_tensor, float_tensor_vert),
192            )
193        )
194
195        # graph_module_mult.graph.print_tabular()
196
197        mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
198        new_graph_module_mult = mult_prog.exported_program().graph_module
199        self.assertIsNotNone(new_graph_module_mult)
200
201        mult_count = 0
202
203        for node in new_graph_module_mult.graph.nodes:
204            if (
205                node.op == "call_function"
206                and node.target == exir_ops.edge.aten.mul.Tensor
207            ):
208                mult_count += 1
209                node_args = node.args
210                for arg in node_args:
211                    self.assertEqual(arg.meta["val"].dtype, torch.float)
212
213        self.assertEqual(mult_count, 1)
214
215    def test_remove_noop_pass(self) -> None:
216        class Foo(torch.nn.Module):
217            def forward(self, x: torch.Tensor) -> torch.Tensor:
218                return x.to(dtype=torch.float32)
219
220        foo = Foo()
221
222        # Turn off functionalization so that we can get the actual to.dtype op
223        edge_prog = to_edge(
224            export(
225                foo,
226                (torch.ones(1, dtype=torch.float32),),
227            )
228        )
229        edge_prog = edge_prog.transform([RemoveNoopPass()])
230        self.assertIsNotNone(edge_prog.exported_program().graph_module)
231        new_graph_module = edge_prog.exported_program().graph_module
232        for node in new_graph_module.graph.nodes:
233            if node.op == "call_function":
234                self.assertNotEqual(node.target, torch.ops.aten.to.dtype)
235
236    def test_redundant_slice_copy_removal(self) -> None:
237        class FooWithNoSlice(torch.nn.Module):
238            def forward(self, x: torch.Tensor) -> torch.Tensor:
239                return x[:, :, :]
240
241        foo_with_no_slice = FooWithNoSlice()
242
243        class FooWithOneSlice(torch.nn.Module):
244            def forward(self, x: torch.Tensor) -> torch.Tensor:
245                return x[:1, :, :]
246
247        foo_with_one_slice = FooWithOneSlice()
248
249        class FooWithAllSlices(torch.nn.Module):
250            def forward(self, x: torch.Tensor) -> torch.Tensor:
251                return x[:1, :2, 2:4]
252
253        foo_with_all_slices = FooWithAllSlices()
254
255        # Turn off functionalization so that we can get the actual to.dtype op
256        x = torch.ones((3, 8, 8))
257        prog = to_edge(
258            export(
259                foo_with_no_slice,
260                (x,),
261            )
262        )
263        prog = prog.transform([RemoveNoopPass()])
264        new_graph_module = prog.exported_program().graph_module
265        FileCheck().check_count(
266            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 0, exactly=True
267        ).run(new_graph_module.code)
268
269        prog = to_edge(
270            export(
271                foo_with_one_slice,
272                (x,),
273            )
274        )
275        prog = prog.transform([RemoveNoopPass()])
276        new_graph_module = prog.exported_program().graph_module
277        FileCheck().check_count(
278            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 1, exactly=True
279        ).run(new_graph_module.code)
280
281        prog = to_edge(
282            export(
283                foo_with_all_slices,
284                (x,),
285            )
286        )
287        prog = prog.transform([RemoveNoopPass()])
288        new_graph_module = prog.exported_program().graph_module
289        FileCheck().check_count(
290            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 3, exactly=True
291        ).run(new_graph_module.code)
292
293    def test_compile_to_edge(self) -> None:
294        class Foo(torch.nn.Module):
295            def forward(self, x: torch.Tensor) -> torch.Tensor:
296                return x * 2
297
298        f = Foo()
299
300        x = (torch.randn(2, 3),)
301
302        to_edge(
303            export(
304                f,
305                x,
306            )
307        ).exported_program().graph_module
308        # TODO(angelayi): Add a utility function that verifies a model is in
309        # the edge dialect
310
311    def test_to_out_variant_none_output(self) -> None:
312        class CompositeModel(torch.nn.Module):
313            def __init__(self, _weight):
314                super().__init__()
315                self.weight = _weight
316                self.lstm = torch.nn.LSTM(
317                    input_size=32,
318                    hidden_size=32,
319                    num_layers=1,
320                )
321
322            def forward(self, x_raw, h, c):
323                output, (hn, cn) = self.lstm(x_raw, (h, c))
324                return output
325
326        # Prepare input and trace it
327        input_x = torch.ones([1, 32])
328        input_h = torch.ones([1, 32])
329        input_c = torch.ones([1, 32])
330        inputs = (input_x, input_h, input_c)
331
332        composite_m = CompositeModel(3)
333
334        edge_prog = to_edge(
335            export(
336                composite_m,
337                inputs,
338            )
339            # torch._ops.aten.t.default
340            ,
341            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
342        )
343
344        new_prog = edge_prog.transform([SpecPropPass()])
345
346        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
347        self.assertIsNotNone(new_gm_res)
348        new_gm = new_gm_res.graph_module
349        for node in new_gm.graph.nodes:
350            if node.op == "call_function" and node.target in [
351                torch.ops.DO_NOT_USE_TEST_ONLY.foo.out,
352                torch.ops.my_awesome_3rdparty_ns.awesome_op.out,
353            ]:
354                self.assertEqual(len(node.kwargs), 2)
355                out1_node = node.kwargs["out1"]
356                self.assertEqual(out1_node.op, "call_function")
357                self.assertIs(out1_node.target, memory.alloc)
358                self.assertIs(node.kwargs["out2"], None)
359
360        new_gm_res = MemoryPlanningPass()(new_gm)
361        self.assertIsNotNone(new_gm_res)
362        new_gm = new_gm_res.graph_module
363        new_prog.exported_program().graph_module.graph = new_gm.graph
364        emit_program(new_prog.exported_program())
365
366    def test_to_out_variant_singleon_tensor_list(self) -> None:
367        class MyModel(nn.Module):
368            def __init__(self):
369                super().__init__()
370
371            def forward(self, x):
372                return torch.split(x, 10)
373
374            def get_random_inputs(self):
375                return (torch.randn(10),)
376
377        model = MyModel()
378        inputs = model.get_random_inputs()
379        prog = to_edge(
380            export(
381                model,
382                inputs,
383            ),
384            compile_config=EdgeCompileConfig(_check_ir_validity=False),
385        )  # TODO(larryliu): fix split_copy
386        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
387        self.assertIsNotNone(new_gm_res)
388        new_gm = new_gm_res.graph_module
389
390        for nd in new_gm.graph.nodes:
391            if nd.target is exir_ops.edge.aten.split_copy.Tensor_out:
392                break
393
394        val = nd.meta["val"]
395
396        # We must return a spec which is a list of a signle TensorSpec item.
397        # Returning the TensorSpec item directly cause future getitem op fails.
398        self.assertTrue(isinstance(val, (tuple, list)))
399        self.assertEqual(1, len(val))
400
401    def test_to_out_variant_multiple_out(self) -> None:
402        class MyModel(nn.Module):
403            def __init__(self):
404                super().__init__()
405
406            def forward(self, x):
407                return torch.topk(x, 5)
408
409            def get_random_inputs(self):
410                return (torch.randn(10),)
411
412        model = MyModel()
413        inputs = model.get_random_inputs()
414        prog = to_edge(
415            export(
416                model,
417                inputs,
418            ),
419            compile_config=EdgeCompileConfig(_check_ir_validity=False),
420        )  # TODO(larryliu): fix topk
421        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
422        self.assertIsNotNone(new_gm_res)
423        new_gm = new_gm_res.graph_module
424
425        for nd in new_gm.graph.nodes:
426            if nd.target is torch.ops.aten.topk.values:
427                break
428
429        val = nd.meta["val"]
430
431        # We must return a spec which is a list of a signle TensorSpec item.
432        # Returning the TensorSpec item directly cause future getitem op fails.
433        self.assertTrue(isinstance(val, (tuple, list)))
434        self.assertEqual(2, len(val))
435
436    def test_to_out_variant_to_copy(self) -> None:
437        class Module(torch.nn.Module):
438            def __init__(self):
439                super().__init__()
440
441            def forward(self, x):
442                return x.to(torch.int32)
443
444        model = Module()
445
446        inputs = torch.tensor(1.0, dtype=torch.float)
447        model_res = model(inputs)
448
449        edge_dialect = to_edge(
450            export(
451                model,
452                (inputs,),
453            )
454        )
455        edge_res = edge_dialect.exported_program().module()(inputs)
456        self.assertTrue(torch.allclose(model_res, edge_res))
457
458    def test_export_pass(self) -> None:
459        class Foo(torch.nn.Module):
460            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
461                y = torch.cat([x, x])
462                return torch.ops.aten.tensor_split.sections(y, 2)
463
464        f = Foo()
465
466        class NullPass(ExportPass):
467            pass
468
469        prog = to_edge(
470            export(
471                f,
472                (torch.ones(3, 2),),
473            ),
474            compile_config=EdgeCompileConfig(_check_ir_validity=False),
475        )  # TODO(larryliu): fix cat
476        new_prog = prog.transform([NullPass()])
477        new_nodes = new_prog.exported_program().graph_module.graph.nodes
478        for node in new_nodes:
479            if node.op != "call_function":
480                continue
481            self.assertTrue(hasattr(node, "stack_trace"))
482            self.assertIsNotNone(node.stack_trace)
483
484        old_nodes = prog.exported_program().graph_module.graph.nodes
485        self.assertEqual(len(new_nodes), len(old_nodes))
486        for new_node, old_node in zip(new_nodes, old_nodes):
487            self.assertEqual(new_node.op, old_node.op)
488            self.assertEqual(new_node.target, old_node.target)
489
490    def test_export_pass_pt2(self) -> None:
491        class Foo(torch.nn.Module):
492            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
493                y = torch.cat([x, x])
494                return torch.ops.aten.tensor_split.sections(y, 2)
495
496        f = Foo()
497
498        class NullPass(ExportPass):
499            pass
500
501        prog = to_edge(
502            export(
503                f,
504                (torch.ones(3, 2),),
505            ),
506            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
507        )
508        new_prog = prog.transform([NullPass()])
509        new_nodes = new_prog.exported_program().graph_module.graph.nodes
510        for node in new_nodes:
511            if node.op != "call_function":
512                continue
513            self.assertTrue(hasattr(node, "stack_trace"))
514            self.assertIsNotNone(node.stack_trace)
515
516        old_nodes = prog.exported_program().graph_module.graph.nodes
517        self.assertEqual(len(new_nodes), len(old_nodes))
518        for new_node, old_node in zip(new_nodes, old_nodes):
519            self.assertEqual(new_node.op, old_node.op)
520            self.assertEqual(new_node.target, old_node.target)
521
522    def test_export_scalar_to_tensor_pass(self) -> None:
523        class Mul(torch.nn.Module):
524            def forward(self, x: torch.Tensor) -> torch.Tensor:
525                return x * 3.14
526
527        mul = Mul()
528
529        expo_prog = to_edge(export(mul, (torch.ones(1),)))
530        new_prog = expo_prog.transform([ScalarToTensorPass()])
531        self.assertIsNotNone(new_prog.exported_program().graph_module)
532        new_graph_module = new_prog.exported_program().graph_module
533
534        inp = torch.zeros(1)
535        self.assertTrue(
536            torch.allclose(
537                expo_prog.exported_program().module()(inp),
538                new_prog.exported_program().module()(inp),
539            )
540        )
541        for node in new_graph_module.graph.nodes:
542            if node.op == "call_function":
543                for arg in node.args + tuple(node.kwargs.values()):
544                    self.assertFalse(isinstance(arg, float))
545
546    def test_remove_mixed_types_symfloats(self) -> None:
547        class Foo(torch.nn.Module):
548            def forward(self, x: torch.Tensor) -> torch.Tensor:
549                return torch.nn.functional.interpolate(
550                    x,
551                    size=(x.shape[2] * 2, x.shape[3] * 3),
552                    mode="bilinear",
553                    align_corners=False,
554                    antialias=False,
555                )
556
557        f = Foo()
558
559        example_inputs = (torch.randn(2, 3, 4, 5),)
560
561        gm = to_edge(
562            export(
563                f,
564                example_inputs,
565            )
566        )
567        new_gm = gm.transform(
568            [ReplaceSymSizeOpPass(), ScalarToTensorPass(), RemoveMixedTypeOperators()]
569        )
570        self.assertIsNotNone(new_gm.exported_program().graph_module)
571
572        self.assertTrue(
573            torch.allclose(
574                gm.exported_program().module()(*example_inputs),
575                new_gm.exported_program().module()(*example_inputs),
576            )
577        )
578
579    def test_spec_prop_pass(self) -> None:
580        class Foo(torch.nn.Module):
581            def forward(self, x: torch.Tensor) -> torch.Tensor:
582                return x + x
583
584        f = Foo()
585
586        gm = (
587            to_edge(
588                export(
589                    f,
590                    (torch.ones(3, 2),),
591                )
592            )
593            .exported_program()
594            .graph_module
595        )
596        new_gm = SpecPropPass()(gm)
597        self.assertIsNotNone(new_gm)
598        new_nodes = new_gm.graph_module.graph.nodes
599        counter = 0
600        for node in new_nodes:
601            if node.op != "output":
602                continue
603            counter += 1
604            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])
605
606        self.assertEqual(counter, 1)
607
608    def test_spec_prop_pass_tuple_output(self) -> None:
609        class Foo(torch.nn.Module):
610            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
611                return (x + x,)
612
613        f = Foo()
614
615        gm = (
616            to_edge(
617                export(
618                    f,
619                    (torch.ones(3, 2),),
620                )
621            )
622            .exported_program()
623            .graph_module
624        )
625        new_gm = SpecPropPass()(gm)
626        self.assertIsNotNone(new_gm)
627        new_nodes = new_gm.graph_module.graph.nodes
628        counter = 0
629        for node in new_nodes:
630            if node.op != "output":
631                continue
632            counter += 1
633            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])
634
635        self.assertEqual(counter, 1)
636
637    def test_compile_fix_broken_ops(self) -> None:
638        # When pass an input of more than 4 dimensions to Linear
639        # aten._unsafe_view is used under the hood
640        x = torch.randn([2, 3, 4, 5])
641        model: torch.nn.Linear = torch.nn.Linear(5, 5)
642
643        class Foo(torch.nn.Module):
644            def __init__(self):
645                super().__init__()
646                self.model = model
647
648            def forward(self, inp: torch.Tensor) -> torch.Tensor:
649                return self.model(inp)
650
651        f = Foo()
652
653        # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
654        prog = to_edge(
655            export(
656                f,
657                (x,),
658            ),
659            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
660        )
661        gm = prog.exported_program().graph_module
662        count_after = 0
663        for node in gm.graph.nodes:
664            if node.target == torch.ops.aten._unsafe_view.default:
665                count_after += 1
666        self.assertEqual(count_after, 0)
667        self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))
668
669    def test_convert_symb_ops(self) -> None:
670        class Foo(torch.nn.Module):
671            def forward(self, x: torch.Tensor) -> torch.Tensor:
672                return torch.add(x, x.shape[0] - 1)
673
674        f = Foo()
675
676        # Mark the 0th dimension of X as dynamic with a max value of 3.
677        dim_x = torch.export.Dim("dim_x", max=3)
678
679        prog = to_edge(
680            export(
681                f,
682                (torch.ones(3, 2),),
683                dynamic_shapes={"x": {0: dim_x}},
684            ),
685            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
686        )
687        new_prog = prog.transform([EdgeToBackendOpsPass()])
688        self.assertIsNotNone(new_prog.exported_program().graph_module)
689        converted_gm = new_prog.exported_program().graph_module
690
691        FileCheck().check("torch.ops.aten.sym_size.int").check(
692            "executorch_exir_dialects_backend__ops_executorch_prim_sub_Scalar"
693        ).check_not("operator.sub").run(converted_gm.code)
694
695    def test_alloc_node_spec(self) -> None:
696        """
697        Make sure every memory.alloc node including those in sub graph modules
698        have a TensorSpec.
699        """
700        eager_model = FTMapBasic()
701        inputs = eager_model.get_random_inputs()
702        prog = to_edge(
703            export(
704                eager_model,
705                inputs,
706            ),
707            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
708        )
709        passes = [
710            SpecPropPass(),
711            HintBasedSymShapeEvalPass(),
712        ]
713        new_prog = prog.transform(passes)
714
715        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
716        self.assertIsNotNone(new_gm_res)
717        new_gm = new_gm_res.graph_module
718
719        new_gm_res = MemoryPlanningPass()(new_gm)
720        self.assertIsNotNone(new_gm_res)
721        new_gm = new_gm_res.graph_module
722
723        alloc_nodes = []
724        for subgm in new_gm.modules():
725            if isinstance(subgm, torch.fx.GraphModule):
726                for node in subgm.graph.nodes:
727                    if node.target == memory.alloc:
728                        alloc_nodes.append(node)
729        self.assertTrue(len(alloc_nodes) > 0)
730        for node in alloc_nodes:
731            self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec))
732
733    def test_debug_pass_file_log(self) -> None:
734        eager_model = Mul()
735        inputs = eager_model.get_random_inputs()
736
737        # the debug pass works with a graph generated with make_fx directly
738        gm = make_fx(eager_model)(*inputs)
739
740        try:
741            fd, path = tempfile.mkstemp()
742
743            print(f"Write DebugPass output to {path}")
744            DebugPass(log_filename=path)(gm)
745            with open(path) as f:
746                file_cont = f.read()
747            self.assertTrue("torch.ops.aten.mul" in file_cont)
748        finally:
749            os.close(fd)
750            os.unlink(path)
751
752    def test_dce_recursive(self) -> None:
753        eager_model = FTCondDeadCode()
754        inputs = eager_model.get_random_inputs()
755        gm = export(
756            eager_model,
757            inputs,
758        ).graph_module
759
760        self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm))
761        dead_code_elimination_pass(gm)
762        gm.print_readable()
763        self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm))
764
765    def test_propagate_dynamic_shape(self) -> None:
766        class Foo(torch.nn.Module):
767            def forward(self, x: torch.Tensor) -> torch.Tensor:
768                y = x
769                for _ in range(2):
770                    y = y + x
771                return y
772
773        f = Foo()
774
775        prog = to_edge(
776            export(
777                f,
778                (torch.rand(5),),
779            ),
780            # missing dispatch key
781            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
782        ).transform(propagate_dynamic_shape())
783        gm = prog.exported_program().graph_module
784        nspec = 0
785        for n in gm.graph.nodes:
786            for spec in pytree.tree_flatten(n.meta["spec"])[0]:
787                self.assertTrue(all(isinstance(x, int) for x in spec.shape))
788                nspec += 1
789
790        self.assertTrue(nspec > 0)
791
792    def test_losing_symbolic_info(self) -> None:
793        """
794        Guard against an issue that after calling ConvertSymbolicOpsPass(),
795        future ExportPass will encounter symbolic information loss.
796        """
797
798        class Foo(torch.nn.Module):
799            def forward(self, x: torch.Tensor) -> torch.Tensor:
800                return torch.add(x, x.shape[0] - 1)
801
802        f = Foo()
803
804        dim_x = torch.export.Dim("dim_x", max=3)
805        prog = to_edge(
806            export(
807                f,
808                (torch.ones(3, 2),),
809                dynamic_shapes={"x": {0: dim_x}},
810            ),
811            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
812        )
813
814        new_prog = prog.transform([EdgeToBackendOpsPass()])
815        gm = new_prog.exported_program().graph_module
816        gm.print_readable()
817        *_, ones, out = gm.graph.nodes
818        print(f"Before ExportPass: {ones.format_node()}")
819        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
820        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)
821
822        new_prog = new_prog.transform([ExportPass()])
823        gm = new_prog.exported_program().graph_module
824        gm.print_readable()
825        *_, ones, out = gm.graph.nodes
826        print(f"After ExportPass: {ones.format_node()}")
827        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
828        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)
829
830    def test_to_edge_with_edge_ops(self) -> None:
831        x = torch.randn([2, 3, 4, 5])
832
833        class Foo(torch.nn.Module):
834            def forward(self, x: torch.Tensor) -> torch.Tensor:
835                return x + x
836
837        f = Foo()
838
839        gm = (
840            to_edge(
841                export(
842                    f,
843                    (x,),
844                )
845            )
846            .exported_program()
847            .graph_module
848        )
849        for node in gm.graph.nodes:
850            if node.op == "call_function":
851                self.assertEqual(type(node.target), EdgeOpOverload)
852
853    # TODO(T143084047)
854    @unittest.expectedFailure
855    def test_backend_fused_op_retraceable(self) -> None:
856        """This test makes sure the backend op is still retraceable, with the pattern being registered as kernel."""
857
858        class Foo(torch.nn.Module):
859            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
860                z = x + y
861                return torch.ops.aten.relu.default(z)
862
863        f = Foo()
864
865        gm = export(
866            f,
867            (
868                torch.randn(2, 2),
869                torch.randn(2, 2),
870            ),
871        )
872        # should look like:
873        # graph():
874        #     %ph_0 : [#users=1] = placeholder[target=ph_0]
875        #     %ph_1 : [#users=1] = placeholder[target=ph_1]
876        #     %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, %ph_1), kwargs = {})
877        #     %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {})
878        #     return [relu_default]
879        FileCheck().check("torch.ops.aten.add.Tensor").check(
880            "torch.ops.aten.relu.default"
881        ).run(gm.graph_module.code)
882
883        class AddReluFusionPass(ExportPass):
884            def call(self, graph_module: GraphModule) -> PassResult:
885                # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
886                @bind_pattern_to_op(lib, "add_relu")
887                def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
888                    z = torch.ops.aten.add.Tensor(x, y)
889                    out = torch.ops.aten.relu.default(z)
890                    return out
891
892                def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
893                    return ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default(x, y)
894
895                subgraph_rewriter.replace_pattern(graph_module, pattern, replacement)
896                return PassResult(graph_module, True)
897
898        # TODO: larryliu this pass needs to be in to_executorch()
899        class OpReplacePass(ExportPass):
900            def call_operator(self, op, args, kwargs, meta):
901                if op == torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default:
902                    return super().call_operator(
903                        ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default,
904                        args,
905                        kwargs,
906                        meta,
907                    )
908                return super().call_operator(op, args, kwargs, meta)
909
910        gm_lowered = to_edge(
911            gm,
912            compile_config=EdgeCompileConfig(
913                _check_ir_validity=False,
914            ),
915        ).transform([AddReluFusionPass(), OpReplacePass()])
916
917        FileCheck().check(
918            "executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default"
919        ).run(gm_lowered.exported_program().graph_module.code)
920        # lowered module:
921        # def forward(self, ph_0, ph_1):
922        #     do_not_use_test_only_add_relu_default = executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default(ph_0, ph_1);  ph_0 = ph_1 = None
923        #     return [do_not_use_test_only_add_relu_default]
924
925        # Retrace:
926        # If not backend op retrace will error out because no CPU/CompositeExplicitAutograd kernel registered.
927        gm_retraced = to_edge(
928            export(
929                gm_lowered.exported_program().module(),
930                (
931                    torch.randn(2, 2),
932                    torch.randn(2, 2),
933                ),
934            )
935        )
936        # Retrace-able, the graph "promote" back to ATen dialect, showing up add and relu, which is expected.
937        FileCheck().check("torch.ops.aten.add.Tensor").check(
938            "torch.ops.aten.relu.default"
939        ).run(gm_retraced.exported_program().graph_module.code)
940
941    def test_debug_handle_generator_pass(self) -> None:
942        eager_model = MLP(2, output_size=4)
943        inputs = eager_model.get_random_inputs()
944
945        graph_module = (
946            to_edge(
947                export(
948                    eager_model,
949                    inputs,
950                )
951            )
952            .exported_program()
953            .graph_module
954        )
955        for node in graph_module.graph.nodes:
956            self.assertIn("debug_handle", node.meta)
957        ScalarToTensorPass()(graph_module)
958        for node in graph_module.graph.nodes:
959            self.assertIn("debug_handle", node.meta)
960
961    def test_generate_missing_debug_handles(self) -> None:
962        eager_model = MLP(2, output_size=4)
963        inputs = eager_model.get_random_inputs()
964
965        ep = to_edge(
966            export(
967                eager_model,
968                inputs,
969            )
970        ).exported_program()
971
972        list(ep.graph.nodes)[0].meta.pop("debug_handle")
973        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
974        generate_missing_debug_handles(ep)
975        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
976
977    def test_debug_handle_generator_pass_with_control_flow(self) -> None:
978        def true_nested(y: torch.Tensor) -> torch.Tensor:
979            y = y + y
980            y = torch.mm(y, y)
981            return y
982
983        def false_nested(y: torch.Tensor) -> torch.Tensor:
984            return torch.mm(y, y)
985
986        def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
987            z = control_flow.cond(pred2, true_nested, false_nested, [x])
988            return x + z
989
990        def false_fn(x: torch.Tensor, _) -> torch.Tensor:
991            return x.cos()
992
993        def map_fn(
994            x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
995        ) -> torch.Tensor:
996            x = x.cos()
997            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
998            x = x + y
999            return x.sin()
1000
1001        class Foo(torch.nn.Module):
1002            def forward(
1003                self,
1004                xs: torch.Tensor,
1005                pred1: torch.Tensor,
1006                pred2: torch.Tensor,
1007                y: torch.Tensor,
1008            ) -> torch.Tensor:
1009                y = torch.mm(y, y)
1010                return control_flow.map(map_fn, xs, pred1, pred2, y)
1011
1012        f = Foo()
1013
1014        inputs = (
1015            torch.ones(2, 2),
1016            torch.tensor([False]),
1017            torch.tensor([False]),
1018            torch.ones(2, 2),
1019        )
1020
1021        ep = to_edge(
1022            export(
1023                f,
1024                inputs,
1025            )
1026        ).exported_program()
1027        graph_module = ep.graph_module
1028
1029        def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
1030            queue = [graph_module]
1031            while queue:
1032                current_graph_module = queue.pop(0)
1033                for node in current_graph_module.graph.nodes:
1034                    self.assertIn("debug_handle", node.meta)
1035                control_flow_submodules = [
1036                    submodule
1037                    for _, submodule, _ in get_control_flow_submodules(
1038                        current_graph_module
1039                    )
1040                ]
1041                queue.extend(control_flow_submodules)
1042
1043        DebugHandleGeneratorPass()(graph_module)
1044        check_debug_handle_metadata(graph_module)
1045        generate_missing_debug_handles(ep)
1046
1047        # Check debug handle still preserved after ScalarToTensorPass
1048        ScalarToTensorPass()(graph_module)
1049        check_debug_handle_metadata(graph_module)
1050
1051    def test_symint_conversion(self) -> None:
1052        class Foo(torch.nn.Module):
1053            def forward(self, x: torch.Tensor) -> torch.Tensor:
1054                return torch.add(x, x.shape[0] - 1)
1055
1056        f = Foo()
1057
1058        dim_x = torch.export.Dim("dim_x", max=3)
1059        prog = to_edge(
1060            export(
1061                f,
1062                (torch.ones(3, 2),),
1063                dynamic_shapes={"x": {0: dim_x}},
1064            ),
1065            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1066        )
1067        prog = prog.transform([SymToTensorPass()])
1068
1069        FileCheck().check("torch.ops.aten.scalar_tensor.default").run(
1070            prog.exported_program().graph_module.code
1071        )
1072        self.assertTrue(
1073            torch.allclose(
1074                f(torch.ones(3, 2)), prog.exported_program().module()(torch.ones(3, 2))
1075            )
1076        )
1077        self.assertTrue(
1078            torch.allclose(
1079                f(torch.zeros(3, 2)),
1080                prog.exported_program().module()(torch.zeros(3, 2)),
1081            )
1082        )
1083
1084    def test_remove_assert_pass(self) -> None:
1085        class Foo(torch.nn.Module):
1086            def forward(self, x: torch.Tensor) -> torch.Tensor:
1087                assert x.shape[0] == 5
1088                return x * x
1089
1090        f = Foo()
1091
1092        gm = to_edge(
1093            export(
1094                f,
1095                (torch.randn(5),),
1096            ),
1097            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1098        )
1099        new_gm = gm.transform([RemoveGraphAssertsPass()])
1100        num_asserts = [
1101            node
1102            for node in new_gm.exported_program().graph.nodes
1103            if node.op == "call_function"
1104            and node.target == torch.ops.aten._assert_async.msg
1105        ]
1106        self.assertEqual(len(num_asserts), 0)
1107
1108    def test_arange(self) -> None:
1109        class M(torch.nn.Module):
1110            def __init__(self):
1111                super().__init__()
1112                self.a = torch.ones(2)
1113
1114            def forward(self, x):
1115                return torch.arange(start=0, end=2) + x
1116
1117        _ = to_edge(
1118            export(
1119                M(),
1120                (torch.randn(2),),
1121            )
1122        ).to_executorch()
1123
1124    def test_replace_slice(self) -> None:
1125        class M(torch.nn.Module):
1126            def __init__(self):
1127                super().__init__()
1128                self.a = torch.ones(10)
1129
1130            def forward(self, x):
1131                return self.a[:2] + x
1132
1133        gm = (
1134            to_edge(
1135                export(
1136                    M(),
1137                    (torch.randn(2),),
1138                )
1139            )
1140            .exported_program()
1141            .graph_module
1142        )
1143        FileCheck().check(
1144            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
1145        ).run(gm.code)
1146
1147    def test_constant_prop_pass_for_add(self) -> None:
1148        class Add(torch.nn.Module):
1149            def forward(self, x: torch.Tensor) -> torch.Tensor:
1150                return x + 3
1151
1152        add = Add()
1153
1154        edge = to_edge(
1155            export(add, (torch.ones(1),)),
1156            compile_config=EdgeCompileConfig(_skip_dim_order=False),
1157        )
1158        edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()])
1159        exported_program = lift_constant_tensor_pass(edge.exported_program())
1160
1161        # Check there is a lifted tensor followed by a to_copy node
1162        FileCheck().check("_lifted_tensor_constant0").check(
1163            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1164        ).run(exported_program.graph_module.code)
1165
1166        new_ep = constant_prop_pass(exported_program)
1167
1168        # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
1169        FileCheck().check_not("_lifted_tensor_constant").check(
1170            "_prop_tensor_constant0"
1171        ).check_not(
1172            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1173        ).run(
1174            new_ep.graph_module.code
1175        )
1176
1177    def test_constant_prop_pass_for_parameter(self) -> None:
1178        def count_additions(gm: torch.fx.GraphModule) -> int:
1179            return sum(
1180                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
1181            )
1182
1183        class M(torch.nn.Module):
1184            def __init__(self):
1185                super().__init__()
1186                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))
1187
1188            def forward(self, x):
1189                b = self.a + self.a
1190                c = torch.cat([self.a, b])
1191                return (c + c) + x
1192
1193        aten = export(
1194            M(),
1195            (torch.zeros(2, 2, 3),),
1196        )
1197        self.assertEqual(count_additions(aten.graph_module), 3)
1198        new_ep = constant_prop_pass(aten)
1199        self.assertEqual(count_additions(new_ep.graph_module), 1)
1200
1201    def test_constant_prop_pass_graph_signature(self) -> None:
1202        def count_additions(gm: torch.fx.GraphModule) -> int:
1203            return sum(
1204                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
1205            )
1206
1207        class M(torch.nn.Module):
1208            def __init__(self):
1209                super().__init__()
1210                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))
1211
1212            def forward(self, x):
1213                b = self.a + self.a
1214                c = torch.cat([self.a, b])
1215                return (c + c) + x
1216
1217        aten = export(
1218            M(),
1219            (torch.zeros(2, 2, 3),),
1220        )
1221        # Input signature will have two entries:
1222        # (1) parameter `a` and (2) user input `x`.
1223        self.assertEqual(len(aten.graph_signature.input_specs), 2)
1224        new_ep = constant_prop_pass(aten)
1225        # Check that there are exactly two propagated tensors - (1) propagated
1226        # constant and (2) user input.
1227        self.assertEqual(
1228            new_ep.graph_signature.input_specs,
1229            [
1230                InputSpec(
1231                    kind=InputKind.CONSTANT_TENSOR,
1232                    arg=TensorArgument(name="_prop_tensor_constant0"),
1233                    target="_prop_tensor_constant0",
1234                    persistent=True,
1235                ),
1236                # User input graph signature.
1237                aten.graph_signature.input_specs[-1],
1238            ],
1239        )
1240
1241    def test_constant_prop_pass_for_parameter_slice(self) -> None:
1242        def count_slice(gm: torch.fx.GraphModule) -> int:
1243            return sum(
1244                (node.target == torch.ops.aten.slice_copy.Tensor)
1245                for node in gm.graph.nodes
1246            )
1247
1248        class M(torch.nn.Module):
1249            def __init__(self):
1250                super().__init__()
1251                self.a = torch.nn.Parameter(torch.ones(3, 2, 2))
1252
1253            def forward(self, x):
1254                # Create slice of shape (1, 2, 2)
1255                slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1)
1256                return torch.cat([x, slice_tensor])
1257
1258        aten = export(
1259            M(),
1260            (torch.zeros(2, 2, 2),),
1261        )
1262        self.assertIn("a", aten.state_dict)
1263        self.assertEqual(count_slice(aten.graph_module), 1)
1264
1265        new_ep = constant_prop_pass(aten)
1266        # Check there is a propagated tensor.
1267        FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code)
1268        self.assertIn("_prop_tensor_constant0", new_ep.constants)
1269        self.assertNotIn("a", new_ep.state_dict)
1270        # No more slice copy.
1271        self.assertEqual(count_slice(new_ep.graph_module), 0)
1272
1273    def test_constant_prop_pass_no_propagate(self) -> None:
1274        def count_placeholder(gm: torch.fx.GraphModule) -> int:
1275            return sum((node.op == "placeholder") for node in gm.graph.nodes)
1276
1277        class M(torch.nn.Module):
1278            def __init__(self):
1279                super().__init__()
1280                self.a = torch.nn.Parameter(torch.ones(3, 2, 4))
1281
1282            def forward(self, x, y):
1283                # y is unused.
1284                return x + self.a
1285
1286        aten = export(
1287            M(),
1288            (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)),
1289        )
1290        self.assertIn("a", aten.state_dict)
1291        self.assertEqual(count_placeholder(aten.graph_module), 3)
1292
1293        new_ep = constant_prop_pass(aten)
1294        # Check there is no propagated tensor.
1295        FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code)
1296        self.assertNotIn("_prop_tensor_constant0", new_ep.constants)
1297        self.assertIn("a", new_ep.state_dict)
1298        self.assertEqual(count_placeholder(new_ep.graph_module), 3)
1299
1300    def test_constant_prop_pass_for_control_flow(self) -> None:
1301        class Module(torch.nn.Module):
1302            def __init__(self):
1303                super().__init__()
1304                self.linear = torch.nn.Linear(3, 3)
1305
1306            def t(self, val):
1307                return val + 1
1308
1309            def f(self, val):
1310                return val - 1
1311
1312            def true_fn(self, val):
1313                return self.linear(val) + self.t(val)
1314
1315            def false_fn(self, val):
1316                return self.linear(val) - self.f(val)
1317
1318            def forward(self, pred, x):
1319                return torch.ops.higher_order.cond(
1320                    pred, self.true_fn, self.false_fn, [x]
1321                )
1322
1323        mod = Module()
1324        x = torch.randn([3, 3])
1325        pred = torch.tensor(x[0][0].item() < 0)
1326        edge = to_edge(
1327            export(mod, (pred, x)),
1328            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1329        )
1330        error_msg = r"constant_prop_pass for control flow is not supported yet."
1331
1332        # TODO(chenlai): enable constant prop pass for control flow
1333        with self.assertRaisesRegex(
1334            RuntimeError,
1335            error_msg,
1336        ):
1337            _ = constant_prop_pass(edge.exported_program())
1338
1339    def test_mutable_buffers(self) -> None:
1340        def count_copies(gm: torch.fx.GraphModule) -> int:
1341            return sum(
1342                (node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes
1343            )
1344
1345        class MutableStateModule(torch.nn.Module):
1346            def __init__(self):
1347                super().__init__()
1348                self.register_buffer("state", torch.zeros(1))
1349
1350            def forward(self, x):
1351                y = x + self.state
1352                self.state.add_(1)
1353                return y
1354
1355        model = to_edge(
1356            export(
1357                MutableStateModule(),
1358                (torch.zeros(1),),
1359            )
1360        )
1361        self.assertEqual(count_copies(model.exported_program().graph_module), 0)
1362        # Before
1363        # graph():
1364        #     %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1365        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1366        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1367        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1368        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1369        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1370        #     return (aten_add_tensor_1, aten_add_tensor)
1371        gm, _ = insert_write_back_for_buffers_pass(model.exported_program())
1372
1373        # After
1374        # graph():
1375        #     %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1376        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1377        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1378        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1379        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1380        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1381        #     %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1382        #     return (copy__default, aten_add_tensor)
1383        self.assertEqual(count_copies(gm), 1)
1384
1385    def test_remove_quantized_op_noop_pass(self) -> None:
1386        class TestAddSliceNoop(torch.nn.Module):
1387            def __init__(self):
1388                super().__init__()
1389
1390            def forward(self, x):
1391                x = x + x
1392                x = x + x[:]
1393                return x
1394
1395        class TestAddSliceNotNoop(torch.nn.Module):
1396            def __init__(self):
1397                super().__init__()
1398
1399            def forward(self, x):
1400                x = x + x
1401                x = x + x[:1]
1402                return x
1403
1404        def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1405            return sum(
1406                (
1407                    node.target
1408                    in (
1409                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1410                        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1411                    )
1412                )
1413                for node in gm.graph.nodes
1414            )
1415
1416        def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1417            return sum(
1418                (
1419                    node.target
1420                    in (
1421                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
1422                        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1423                    )
1424                )
1425                for node in gm.graph.nodes
1426            )
1427
1428        def quantize_model(
1429            m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
1430        ) -> Tuple[EdgeProgramManager, int, int]:
1431            # program capture
1432            m = torch.export.export_for_training(
1433                m_eager,
1434                example_inputs,
1435            ).module()
1436
1437            quantizer = XNNPACKQuantizer()
1438            quantization_config = get_symmetric_quantization_config()
1439            quantizer.set_global(quantization_config)
1440            m = prepare_pt2e(m, quantizer)  # pyre-fixme[6]
1441            m = convert_pt2e(m, fold_quantize=True)
1442            ep = torch.export.export(m, example_inputs)
1443            dq_nodes_pre = count_dq_nodes(ep.graph_module)
1444            q_nodes_pre = count_q_nodes(ep.graph_module)
1445            edge = to_edge(
1446                ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
1447            )
1448            return edge, dq_nodes_pre, q_nodes_pre
1449
1450        example_inputs = (torch.randn(9, 8),)
1451        model = TestAddSliceNoop()
1452        m_eager = model.eval()
1453        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1454
1455        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1456        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1457        # One dq and one q node around the slice copy should have been removed.
1458        self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1459        self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1460
1461        # Check that the slice_copy is removed by the RemoveNoopPass.
1462        for node in edge.exported_program().graph_module.graph.nodes:
1463            self.assertFalse("slice" in str(node.target))
1464
1465        model = TestAddSliceNotNoop()
1466        m_eager = model.eval()
1467        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1468
1469        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1470        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1471        # One dq and one q node around the slice copy should have been removed.
1472        self.assertEqual(dq_nodes_pre, dq_nodes_post)
1473        self.assertEqual(q_nodes_pre, q_nodes_post)
1474
1475        # Check that the slice_copy is not removed by the RemoveNoopPass.
1476        self.assertTrue(
1477            any(
1478                "slice" in str(node.target)
1479                for node in edge.exported_program().graph_module.graph.nodes
1480            )
1481        )
1482
1483    def test_dq_q_no_op_pass(self) -> None:
1484        class TestDqQ(torch.nn.Module):
1485            def __init__(self):
1486                super().__init__()
1487
1488            def forward(self, x):
1489                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1490                    x, 1.0, 0, -128, 127, torch.int8
1491                )
1492                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1493                    dq, 1.0, 0, -128, 127, torch.int8
1494                )
1495                return q
1496
1497        model = TestDqQ()
1498        m_eager = model.eval()
1499        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1500        edge = to_edge(ep)
1501        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
1502        self.assertTrue(
1503            any(
1504                "dequantize" in str(node.target)
1505                for node in edge.exported_program().graph_module.graph.nodes
1506            )
1507        )
1508        self.assertTrue(
1509            any(
1510                "quantize" in str(node.target)
1511                for node in edge.exported_program().graph_module.graph.nodes
1512            )
1513        )
1514
1515    def test_dq_q_different_qparams(self) -> None:
1516        class TestDqQDifferentQParam(torch.nn.Module):
1517            def __init__(self):
1518                super().__init__()
1519
1520            def forward(self, x):
1521                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1522                    x, 1.0, 0, -128, 127, torch.int8
1523                )
1524                slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0)
1525                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1526                    slice_copy_output, 1.0, 0, -127, 127, torch.int8
1527                )
1528                return q
1529
1530        model = TestDqQDifferentQParam()
1531        m_eager = model.eval()
1532        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1533        edge = to_edge(ep)
1534        print(edge.exported_program().graph_module.graph)
1535        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
1536        self.assertTrue(
1537            any(
1538                "dequantize" in str(node.target)
1539                for node in edge.exported_program().graph_module.graph.nodes
1540            )
1541        )
1542        self.assertTrue(
1543            any(
1544                "quantize" in str(node.target)
1545                for node in edge.exported_program().graph_module.graph.nodes
1546            )
1547        )
1548        self.assertFalse(
1549            any(
1550                "slice" in str(node.target)
1551                for node in edge.exported_program().graph_module.graph.nodes
1552            )
1553        )
1554
1555    def test_normalize_view_copy_base_pass(self) -> None:
1556
1557        class ViewChain(torch.nn.Module):
1558            def forward(self, x):
1559                x = torch.ops.aten.view_copy.default(x, [30, 1])
1560                x = torch.ops.aten.view_copy.default(x, [5, 6])
1561                x = torch.ops.aten.view_copy.default(x, [2, 15])
1562                x = torch.ops.aten.view_copy.default(x, [3, -1])
1563                return x
1564
1565        def is_view_copy(node: torch.fx.Node) -> bool:
1566            return (
1567                node.op == "call_function"
1568                and node.target == torch.ops.aten.view_copy.default
1569            )
1570
1571        gm = export(ViewChain(), (torch.ones(30),)).graph_module
1572
1573        # Check before transformation
1574        n_view_copy_before = 0
1575        n_view_copy_bases_before = 0
1576        for node in gm.graph.nodes:
1577            if is_view_copy(node):
1578                n_view_copy_before += 1
1579                base = node.args[0]
1580                if is_view_copy(base):
1581                    n_view_copy_bases_before += 1
1582
1583        self.assertEqual(n_view_copy_before, 4)
1584        self.assertEqual(n_view_copy_bases_before, 3)
1585
1586        # Do transformation
1587        p = NormalizeViewCopyBasePass()
1588        gm_res = p(gm)
1589        assert gm_res is not None
1590        gm = gm_res.graph_module
1591
1592        # Check after transformation
1593        n_view_copy_after = 0
1594        n_view_copy_bases_after = 0
1595        for node in gm.graph.nodes:
1596            if is_view_copy(node):
1597                n_view_copy_after += 1
1598                base = node.args[0]
1599                if is_view_copy(base):
1600                    n_view_copy_bases_after += 1
1601
1602        self.assertEqual(n_view_copy_after, 4)
1603        self.assertEqual(n_view_copy_bases_after, 0)
1604
1605    def test_replace_view_copy_with_view_pass(self) -> None:  # noqa: C901
1606
1607        # Helper functions
1608        def is_view_copy(node: torch.fx.Node) -> bool:
1609            return (
1610                node.op == "call_function"
1611                and node.target == torch.ops.aten.view_copy.default
1612            )
1613
1614        def is_memory_view(node: torch.fx.Node) -> bool:
1615            return node.op == "call_function" and node.target == memory.view
1616
1617        # Test example set up
1618        class TestViewCopies(torch.nn.Module):
1619            def __init__(self):
1620                super().__init__()
1621                self.parameter = torch.nn.Parameter(torch.ones(1))
1622
1623            def forward(self, x):
1624                o1 = torch.ops.aten.view_copy.default(x, [1])
1625                o2 = torch.ops.aten.view_copy.default(self.parameter, [1])
1626                # view_copys at the end of a function are not replaced, so add
1627                # a computation before the end of the graph.
1628                return torch.ops.aten.add.Tensor(o1, o2)
1629
1630        ep = torch.export.export(
1631            TestViewCopies(),
1632            args=(torch.ones(1),),
1633        )
1634        for node in ep.graph.nodes:
1635            if node.op == "placeholder":
1636                node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1637                node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC
1638
1639        # Run tests
1640        gm = ep.graph_module
1641
1642        # Check before transformation
1643        FileCheck().check_count(
1644            "torch.ops.aten.view_copy.default", 2, exactly=True
1645        ).run(gm.code)
1646        FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run(
1647            gm.code
1648        )
1649
1650        # Do transformation
1651        p = ReplaceViewCopyWithViewPass()
1652        gm_res = p(gm)
1653        assert gm_res is not None
1654        gm = gm_res.graph_module
1655
1656        # Check after transformation
1657        FileCheck().check_count(
1658            "torch.ops.aten.view_copy.default", 0, exactly=True
1659        ).run(gm.code)
1660        FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
1661            gm.code
1662        )
1663
1664    def test_constant_prop_pass_for_no_grad(self) -> None:
1665        class LSTM(torch.nn.Module):
1666            def __init__(self, input_size, hidden_size, num_layers):
1667                super(LSTM, self).__init__()
1668                self.hidden_size = hidden_size
1669                self.num_layers = num_layers
1670                self.lstm = torch.nn.LSTM(
1671                    input_size, hidden_size, num_layers, batch_first=True
1672                )
1673
1674            def forward(self, text_tokens):
1675                # input: (seq_len, batch, input_size)
1676                lstm_out, (new_hidden_state, new_cell_state) = self.lstm(
1677                    input=text_tokens, hx=None
1678                )
1679                return lstm_out
1680
1681        lstm = LSTM(input_size=200, hidden_size=203, num_layers=2)
1682        example_input = (torch.rand(2, 10, 200),)
1683
1684        aten = torch.export.export(lstm, example_input, strict=False)
1685        _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
1686            _check_ir_validity=True,
1687            _skip_dim_order=True,  # TODO(T189114319): Reuse dim order op after solving the ios oss issue
1688        )
1689
1690        edge_manager: EdgeProgramManager = to_edge(
1691            aten,
1692            compile_config=_EDGE_COMPILE_CONFIG,
1693        )
1694        new_ep = constant_prop_pass(edge_manager._edge_programs["forward"])
1695        _ = copy.deepcopy(new_ep.module_call_graph)
1696
1697    def test_dim_order_revert_pass(self) -> None:
1698        aten_op_str = "torch.ops.aten._to_copy.default"
1699        edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default"
1700        edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1701
1702        class Module(torch.nn.Module):
1703            """
1704            A simple module that has a single to op that converts to channels last and then back to contiguous.
1705            Assuming contiguous input.
1706            """
1707
1708            def __init__(self):
1709                super().__init__()
1710
1711            def forward(self, x: torch.Tensor) -> torch.Tensor:
1712                return x.to(memory_format=torch.channels_last).to(
1713                    memory_format=torch.contiguous_format
1714                ) + x.to(memory_format=torch.channels_last).to(
1715                    memory_format=torch.contiguous_format
1716                )
1717
1718            @staticmethod
1719            def to_copy_count():
1720                return 4
1721
1722        def _do_checks(
1723            test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str]
1724        ) -> None:
1725            for not_allowed in not_allowed_list:
1726                FileCheck().check_count(allowed, allowed_count, exactly=True).check_not(
1727                    not_allowed
1728                ).run(test_str)
1729
1730        m = Module()
1731        n = m.to_copy_count()
1732        input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format)
1733
1734        # 1. vanilla export, no edge ops
1735        ep = export(
1736            m,
1737            (input,),
1738        ).run_decompositions({})
1739        _do_checks(
1740            ep.graph_module.code,
1741            aten_op_str,
1742            n,
1743            [edge_aten_op_str, edge_dim_order_op_str],
1744        )
1745
1746        # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops
1747        edge_prog = to_edge(
1748            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True)
1749        )._edge_programs["forward"]
1750        _do_checks(
1751            edge_prog.graph_module.code,
1752            edge_aten_op_str,
1753            n,
1754            [aten_op_str, edge_dim_order_op_str],
1755        )
1756
1757        # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops
1758        new_res = DimOrderOpsRevertPass()(edge_prog.graph_module)
1759        self.assertIsNotNone(new_res)
1760        _do_checks(
1761            new_res.graph_module.code,
1762            edge_aten_op_str,
1763            n,
1764            [aten_op_str, edge_dim_order_op_str],
1765        )
1766
1767        # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops
1768        edge_prog_dim_order = to_edge(
1769            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False)
1770        )._edge_programs["forward"]
1771        _do_checks(
1772            edge_prog_dim_order.graph_module.code,
1773            edge_dim_order_op_str,
1774            n,
1775            [aten_op_str, edge_aten_op_str],
1776        )
1777
1778        # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops
1779        new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module)
1780        self.assertIsNotNone(new_res_dim_order)
1781        _do_checks(
1782            new_res_dim_order.graph_module.code,
1783            edge_aten_op_str,
1784            n,
1785            [aten_op_str, edge_dim_order_op_str],
1786        )
1787
1788        output_no_dim_order = new_res.graph_module(input)
1789        output_no_dim_order_revert = new_res_dim_order.graph_module(input)
1790        self.assertTrue(
1791            torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
1792        )
1793