xref: /aosp_15_r20/external/executorch/exir/tests/test_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Workerimport copy
9*523fa7a6SAndroid Build Coastguard Workerimport os
10*523fa7a6SAndroid Build Coastguard Workerimport tempfile
11*523fa7a6SAndroid Build Coastguard Workerimport unittest
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Optional, Tuple
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker# Import passes
17*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory_planning  # noqa
18*523fa7a6SAndroid Build Coastguard Workerimport torch
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import ExportPass, PassResult
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import (
26*523fa7a6SAndroid Build Coastguard Worker    dead_code_elimination_pass,
27*523fa7a6SAndroid Build Coastguard Worker    DebugPass,
28*523fa7a6SAndroid Build Coastguard Worker    HintBasedSymShapeEvalPass,
29*523fa7a6SAndroid Build Coastguard Worker    MemoryPlanningPass,
30*523fa7a6SAndroid Build Coastguard Worker    propagate_dynamic_shape,
31*523fa7a6SAndroid Build Coastguard Worker    RemoveNoopPass,
32*523fa7a6SAndroid Build Coastguard Worker    ReplaceSymSizeOpPass,
33*523fa7a6SAndroid Build Coastguard Worker    ToOutVarPass,
34*523fa7a6SAndroid Build Coastguard Worker)
35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.constant_prop_pass import constant_prop_pass
36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.debug_handle_generator_pass import (
37*523fa7a6SAndroid Build Coastguard Worker    DebugHandleGeneratorPass,
38*523fa7a6SAndroid Build Coastguard Worker    generate_missing_debug_handles,
39*523fa7a6SAndroid Build Coastguard Worker)
40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.insert_write_back_for_buffers_pass import (
41*523fa7a6SAndroid Build Coastguard Worker    insert_write_back_for_buffers_pass,
42*523fa7a6SAndroid Build Coastguard Worker)
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.normalize_view_copy_base_pass import (
46*523fa7a6SAndroid Build Coastguard Worker    NormalizeViewCopyBasePass,
47*523fa7a6SAndroid Build Coastguard Worker)
48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
51*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_view_copy_with_view_pass import (
52*523fa7a6SAndroid Build Coastguard Worker    ReplaceViewCopyWithViewPass,
53*523fa7a6SAndroid Build Coastguard Worker)
54*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
55*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import SpecPropPass
56*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import lift_constant_tensor_pass
58*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import TensorShapeDynamism
59*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
60*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops
61*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
62*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.models import MLP, Mul
63*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental import control_flow
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
68*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer import (
69*523fa7a6SAndroid Build Coastguard Worker    get_symmetric_quantization_config,
70*523fa7a6SAndroid Build Coastguard Worker    XNNPACKQuantizer,
71*523fa7a6SAndroid Build Coastguard Worker)
72*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
73*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.graph_signature import InputKind, InputSpec, TensorArgument
74*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule, subgraph_rewriter
75*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx
76*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import impl, Library
77*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
78*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore
82*523fa7a6SAndroid Build Coastguard Workerdef collect_ops(gm: torch.fx.GraphModule):
83*523fa7a6SAndroid Build Coastguard Worker    """
84*523fa7a6SAndroid Build Coastguard Worker    Collect all targets for call_function nodes from the graph module recursively.
85*523fa7a6SAndroid Build Coastguard Worker    """
86*523fa7a6SAndroid Build Coastguard Worker    ops = set()
87*523fa7a6SAndroid Build Coastguard Worker    for subgm in gm.modules():
88*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(subgm, torch.fx.GraphModule):
89*523fa7a6SAndroid Build Coastguard Worker            continue
90*523fa7a6SAndroid Build Coastguard Worker        for node in subgm.graph.nodes:
91*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
92*523fa7a6SAndroid Build Coastguard Worker                ops.add(node.target)
93*523fa7a6SAndroid Build Coastguard Worker    return ops
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Workerlib = Library("DO_NOT_USE_TEST_ONLY", "DEF")
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Workerlib.define("foo(Tensor self) -> (Tensor, Tensor)")
99*523fa7a6SAndroid Build Coastguard Workerlib.define("add_relu(Tensor self, Tensor other) -> Tensor")
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker@impl(lib, "foo", "CompositeExplicitAutograd")
103*523fa7a6SAndroid Build Coastguard Workerdef foo(a: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
104*523fa7a6SAndroid Build Coastguard Worker    return a + 1, None
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Workerlib.define(
108*523fa7a6SAndroid Build Coastguard Worker    "foo.out(Tensor self, *, Tensor(a!) out1, Tensor(b!) out2) -> (Tensor(a!), Tensor(b!))"
109*523fa7a6SAndroid Build Coastguard Worker)
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker@impl(lib, "foo.out", "CompositeExplicitAutograd")
113*523fa7a6SAndroid Build Coastguard Workerdef foo_out(
114*523fa7a6SAndroid Build Coastguard Worker    a: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor
115*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
116*523fa7a6SAndroid Build Coastguard Worker    return a + 1, None
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Workerclass TestPasses(unittest.TestCase):
120*523fa7a6SAndroid Build Coastguard Worker    @classmethod
121*523fa7a6SAndroid Build Coastguard Worker    def setUpClass(cls) -> None:
122*523fa7a6SAndroid Build Coastguard Worker        register_additional_test_aten_ops()
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker    def test_remove_mixed_type_operators(self) -> None:
125*523fa7a6SAndroid Build Coastguard Worker        class Add(torch.nn.Module):
126*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
127*523fa7a6SAndroid Build Coastguard Worker                return (x + y) + x
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker        add = Add()
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker        int_tensor = torch.tensor([[1, 2, 3]])
132*523fa7a6SAndroid Build Coastguard Worker        float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
133*523fa7a6SAndroid Build Coastguard Worker        edge_prog = to_edge(
134*523fa7a6SAndroid Build Coastguard Worker            export(
135*523fa7a6SAndroid Build Coastguard Worker                add,
136*523fa7a6SAndroid Build Coastguard Worker                (int_tensor, float_tensor),
137*523fa7a6SAndroid Build Coastguard Worker            )
138*523fa7a6SAndroid Build Coastguard Worker        )
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
141*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = new_prog.exported_program().graph_module
142*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_graph_module)
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker        add_count = 0
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker        for node in new_graph_module.graph.nodes:
147*523fa7a6SAndroid Build Coastguard Worker            if (
148*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
149*523fa7a6SAndroid Build Coastguard Worker                and node.target == exir_ops.edge.aten.add.Tensor
150*523fa7a6SAndroid Build Coastguard Worker            ):
151*523fa7a6SAndroid Build Coastguard Worker                add_count += 1
152*523fa7a6SAndroid Build Coastguard Worker                node_args = node.args
153*523fa7a6SAndroid Build Coastguard Worker                for arg in node_args:
154*523fa7a6SAndroid Build Coastguard Worker                    self.assertEqual(arg.meta["val"].dtype, torch.float)
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(add_count, 2)
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker        double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
159*523fa7a6SAndroid Build Coastguard Worker        double_tensor = double_tensor.to(torch.double)
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker        double_prog = to_edge(export(add, (int_tensor, double_tensor)))
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker        double_prog.transform([RemoveMixedTypeOperators()])
164*523fa7a6SAndroid Build Coastguard Worker        new_graph_module_double = double_prog.exported_program().graph_module
165*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_graph_module_double)
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker        add_count_double = 0
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker        for node in new_graph_module_double.graph.nodes:
170*523fa7a6SAndroid Build Coastguard Worker            if (
171*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
172*523fa7a6SAndroid Build Coastguard Worker                and node.target == exir_ops.edge.aten.add.Tensor
173*523fa7a6SAndroid Build Coastguard Worker            ):
174*523fa7a6SAndroid Build Coastguard Worker                add_count_double += 1
175*523fa7a6SAndroid Build Coastguard Worker                node_args = node.args
176*523fa7a6SAndroid Build Coastguard Worker                for arg in node_args:
177*523fa7a6SAndroid Build Coastguard Worker                    self.assertEqual(arg.meta["val"].dtype, torch.double)
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(add_count_double, 2)
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Worker        class Mult(torch.nn.Module):
182*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
183*523fa7a6SAndroid Build Coastguard Worker                return x * y
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker        mult = Mult()
186*523fa7a6SAndroid Build Coastguard Worker
187*523fa7a6SAndroid Build Coastguard Worker        float_tensor_vert = float_tensor.T
188*523fa7a6SAndroid Build Coastguard Worker        mult_prog = to_edge(
189*523fa7a6SAndroid Build Coastguard Worker            export(
190*523fa7a6SAndroid Build Coastguard Worker                mult,
191*523fa7a6SAndroid Build Coastguard Worker                (int_tensor, float_tensor_vert),
192*523fa7a6SAndroid Build Coastguard Worker            )
193*523fa7a6SAndroid Build Coastguard Worker        )
194*523fa7a6SAndroid Build Coastguard Worker
195*523fa7a6SAndroid Build Coastguard Worker        # graph_module_mult.graph.print_tabular()
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker        mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
198*523fa7a6SAndroid Build Coastguard Worker        new_graph_module_mult = mult_prog.exported_program().graph_module
199*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_graph_module_mult)
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker        mult_count = 0
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker        for node in new_graph_module_mult.graph.nodes:
204*523fa7a6SAndroid Build Coastguard Worker            if (
205*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
206*523fa7a6SAndroid Build Coastguard Worker                and node.target == exir_ops.edge.aten.mul.Tensor
207*523fa7a6SAndroid Build Coastguard Worker            ):
208*523fa7a6SAndroid Build Coastguard Worker                mult_count += 1
209*523fa7a6SAndroid Build Coastguard Worker                node_args = node.args
210*523fa7a6SAndroid Build Coastguard Worker                for arg in node_args:
211*523fa7a6SAndroid Build Coastguard Worker                    self.assertEqual(arg.meta["val"].dtype, torch.float)
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(mult_count, 1)
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Worker    def test_remove_noop_pass(self) -> None:
216*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
217*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
218*523fa7a6SAndroid Build Coastguard Worker                return x.to(dtype=torch.float32)
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker        foo = Foo()
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Worker        # Turn off functionalization so that we can get the actual to.dtype op
223*523fa7a6SAndroid Build Coastguard Worker        edge_prog = to_edge(
224*523fa7a6SAndroid Build Coastguard Worker            export(
225*523fa7a6SAndroid Build Coastguard Worker                foo,
226*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(1, dtype=torch.float32),),
227*523fa7a6SAndroid Build Coastguard Worker            )
228*523fa7a6SAndroid Build Coastguard Worker        )
229*523fa7a6SAndroid Build Coastguard Worker        edge_prog = edge_prog.transform([RemoveNoopPass()])
230*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(edge_prog.exported_program().graph_module)
231*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = edge_prog.exported_program().graph_module
232*523fa7a6SAndroid Build Coastguard Worker        for node in new_graph_module.graph.nodes:
233*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
234*523fa7a6SAndroid Build Coastguard Worker                self.assertNotEqual(node.target, torch.ops.aten.to.dtype)
235*523fa7a6SAndroid Build Coastguard Worker
236*523fa7a6SAndroid Build Coastguard Worker    def test_redundant_slice_copy_removal(self) -> None:
237*523fa7a6SAndroid Build Coastguard Worker        class FooWithNoSlice(torch.nn.Module):
238*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
239*523fa7a6SAndroid Build Coastguard Worker                return x[:, :, :]
240*523fa7a6SAndroid Build Coastguard Worker
241*523fa7a6SAndroid Build Coastguard Worker        foo_with_no_slice = FooWithNoSlice()
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker        class FooWithOneSlice(torch.nn.Module):
244*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
245*523fa7a6SAndroid Build Coastguard Worker                return x[:1, :, :]
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Worker        foo_with_one_slice = FooWithOneSlice()
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker        class FooWithAllSlices(torch.nn.Module):
250*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
251*523fa7a6SAndroid Build Coastguard Worker                return x[:1, :2, 2:4]
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker        foo_with_all_slices = FooWithAllSlices()
254*523fa7a6SAndroid Build Coastguard Worker
255*523fa7a6SAndroid Build Coastguard Worker        # Turn off functionalization so that we can get the actual to.dtype op
256*523fa7a6SAndroid Build Coastguard Worker        x = torch.ones((3, 8, 8))
257*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
258*523fa7a6SAndroid Build Coastguard Worker            export(
259*523fa7a6SAndroid Build Coastguard Worker                foo_with_no_slice,
260*523fa7a6SAndroid Build Coastguard Worker                (x,),
261*523fa7a6SAndroid Build Coastguard Worker            )
262*523fa7a6SAndroid Build Coastguard Worker        )
263*523fa7a6SAndroid Build Coastguard Worker        prog = prog.transform([RemoveNoopPass()])
264*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = prog.exported_program().graph_module
265*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count(
266*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 0, exactly=True
267*523fa7a6SAndroid Build Coastguard Worker        ).run(new_graph_module.code)
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
270*523fa7a6SAndroid Build Coastguard Worker            export(
271*523fa7a6SAndroid Build Coastguard Worker                foo_with_one_slice,
272*523fa7a6SAndroid Build Coastguard Worker                (x,),
273*523fa7a6SAndroid Build Coastguard Worker            )
274*523fa7a6SAndroid Build Coastguard Worker        )
275*523fa7a6SAndroid Build Coastguard Worker        prog = prog.transform([RemoveNoopPass()])
276*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = prog.exported_program().graph_module
277*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count(
278*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 1, exactly=True
279*523fa7a6SAndroid Build Coastguard Worker        ).run(new_graph_module.code)
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
282*523fa7a6SAndroid Build Coastguard Worker            export(
283*523fa7a6SAndroid Build Coastguard Worker                foo_with_all_slices,
284*523fa7a6SAndroid Build Coastguard Worker                (x,),
285*523fa7a6SAndroid Build Coastguard Worker            )
286*523fa7a6SAndroid Build Coastguard Worker        )
287*523fa7a6SAndroid Build Coastguard Worker        prog = prog.transform([RemoveNoopPass()])
288*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = prog.exported_program().graph_module
289*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count(
290*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 3, exactly=True
291*523fa7a6SAndroid Build Coastguard Worker        ).run(new_graph_module.code)
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker    def test_compile_to_edge(self) -> None:
294*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
295*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
296*523fa7a6SAndroid Build Coastguard Worker                return x * 2
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
299*523fa7a6SAndroid Build Coastguard Worker
300*523fa7a6SAndroid Build Coastguard Worker        x = (torch.randn(2, 3),)
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Worker        to_edge(
303*523fa7a6SAndroid Build Coastguard Worker            export(
304*523fa7a6SAndroid Build Coastguard Worker                f,
305*523fa7a6SAndroid Build Coastguard Worker                x,
306*523fa7a6SAndroid Build Coastguard Worker            )
307*523fa7a6SAndroid Build Coastguard Worker        ).exported_program().graph_module
308*523fa7a6SAndroid Build Coastguard Worker        # TODO(angelayi): Add a utility function that verifies a model is in
309*523fa7a6SAndroid Build Coastguard Worker        # the edge dialect
310*523fa7a6SAndroid Build Coastguard Worker
311*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_none_output(self) -> None:
312*523fa7a6SAndroid Build Coastguard Worker        class CompositeModel(torch.nn.Module):
313*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, _weight):
314*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
315*523fa7a6SAndroid Build Coastguard Worker                self.weight = _weight
316*523fa7a6SAndroid Build Coastguard Worker                self.lstm = torch.nn.LSTM(
317*523fa7a6SAndroid Build Coastguard Worker                    input_size=32,
318*523fa7a6SAndroid Build Coastguard Worker                    hidden_size=32,
319*523fa7a6SAndroid Build Coastguard Worker                    num_layers=1,
320*523fa7a6SAndroid Build Coastguard Worker                )
321*523fa7a6SAndroid Build Coastguard Worker
322*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x_raw, h, c):
323*523fa7a6SAndroid Build Coastguard Worker                output, (hn, cn) = self.lstm(x_raw, (h, c))
324*523fa7a6SAndroid Build Coastguard Worker                return output
325*523fa7a6SAndroid Build Coastguard Worker
326*523fa7a6SAndroid Build Coastguard Worker        # Prepare input and trace it
327*523fa7a6SAndroid Build Coastguard Worker        input_x = torch.ones([1, 32])
328*523fa7a6SAndroid Build Coastguard Worker        input_h = torch.ones([1, 32])
329*523fa7a6SAndroid Build Coastguard Worker        input_c = torch.ones([1, 32])
330*523fa7a6SAndroid Build Coastguard Worker        inputs = (input_x, input_h, input_c)
331*523fa7a6SAndroid Build Coastguard Worker
332*523fa7a6SAndroid Build Coastguard Worker        composite_m = CompositeModel(3)
333*523fa7a6SAndroid Build Coastguard Worker
334*523fa7a6SAndroid Build Coastguard Worker        edge_prog = to_edge(
335*523fa7a6SAndroid Build Coastguard Worker            export(
336*523fa7a6SAndroid Build Coastguard Worker                composite_m,
337*523fa7a6SAndroid Build Coastguard Worker                inputs,
338*523fa7a6SAndroid Build Coastguard Worker            )
339*523fa7a6SAndroid Build Coastguard Worker            # torch._ops.aten.t.default
340*523fa7a6SAndroid Build Coastguard Worker            ,
341*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
342*523fa7a6SAndroid Build Coastguard Worker        )
343*523fa7a6SAndroid Build Coastguard Worker
344*523fa7a6SAndroid Build Coastguard Worker        new_prog = edge_prog.transform([SpecPropPass()])
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
347*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
348*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
349*523fa7a6SAndroid Build Coastguard Worker        for node in new_gm.graph.nodes:
350*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target in [
351*523fa7a6SAndroid Build Coastguard Worker                torch.ops.DO_NOT_USE_TEST_ONLY.foo.out,
352*523fa7a6SAndroid Build Coastguard Worker                torch.ops.my_awesome_3rdparty_ns.awesome_op.out,
353*523fa7a6SAndroid Build Coastguard Worker            ]:
354*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.kwargs), 2)
355*523fa7a6SAndroid Build Coastguard Worker                out1_node = node.kwargs["out1"]
356*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(out1_node.op, "call_function")
357*523fa7a6SAndroid Build Coastguard Worker                self.assertIs(out1_node.target, memory.alloc)
358*523fa7a6SAndroid Build Coastguard Worker                self.assertIs(node.kwargs["out2"], None)
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = MemoryPlanningPass()(new_gm)
361*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
362*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
363*523fa7a6SAndroid Build Coastguard Worker        new_prog.exported_program().graph_module.graph = new_gm.graph
364*523fa7a6SAndroid Build Coastguard Worker        emit_program(new_prog.exported_program())
365*523fa7a6SAndroid Build Coastguard Worker
366*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_singleon_tensor_list(self) -> None:
367*523fa7a6SAndroid Build Coastguard Worker        class MyModel(nn.Module):
368*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
369*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
372*523fa7a6SAndroid Build Coastguard Worker                return torch.split(x, 10)
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self):
375*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(10),)
376*523fa7a6SAndroid Build Coastguard Worker
377*523fa7a6SAndroid Build Coastguard Worker        model = MyModel()
378*523fa7a6SAndroid Build Coastguard Worker        inputs = model.get_random_inputs()
379*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
380*523fa7a6SAndroid Build Coastguard Worker            export(
381*523fa7a6SAndroid Build Coastguard Worker                model,
382*523fa7a6SAndroid Build Coastguard Worker                inputs,
383*523fa7a6SAndroid Build Coastguard Worker            ),
384*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(_check_ir_validity=False),
385*523fa7a6SAndroid Build Coastguard Worker        )  # TODO(larryliu): fix split_copy
386*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
387*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
388*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker        for nd in new_gm.graph.nodes:
391*523fa7a6SAndroid Build Coastguard Worker            if nd.target is exir_ops.edge.aten.split_copy.Tensor_out:
392*523fa7a6SAndroid Build Coastguard Worker                break
393*523fa7a6SAndroid Build Coastguard Worker
394*523fa7a6SAndroid Build Coastguard Worker        val = nd.meta["val"]
395*523fa7a6SAndroid Build Coastguard Worker
396*523fa7a6SAndroid Build Coastguard Worker        # We must return a spec which is a list of a signle TensorSpec item.
397*523fa7a6SAndroid Build Coastguard Worker        # Returning the TensorSpec item directly cause future getitem op fails.
398*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(val, (tuple, list)))
399*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(1, len(val))
400*523fa7a6SAndroid Build Coastguard Worker
401*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_multiple_out(self) -> None:
402*523fa7a6SAndroid Build Coastguard Worker        class MyModel(nn.Module):
403*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
404*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
405*523fa7a6SAndroid Build Coastguard Worker
406*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
407*523fa7a6SAndroid Build Coastguard Worker                return torch.topk(x, 5)
408*523fa7a6SAndroid Build Coastguard Worker
409*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self):
410*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(10),)
411*523fa7a6SAndroid Build Coastguard Worker
412*523fa7a6SAndroid Build Coastguard Worker        model = MyModel()
413*523fa7a6SAndroid Build Coastguard Worker        inputs = model.get_random_inputs()
414*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
415*523fa7a6SAndroid Build Coastguard Worker            export(
416*523fa7a6SAndroid Build Coastguard Worker                model,
417*523fa7a6SAndroid Build Coastguard Worker                inputs,
418*523fa7a6SAndroid Build Coastguard Worker            ),
419*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(_check_ir_validity=False),
420*523fa7a6SAndroid Build Coastguard Worker        )  # TODO(larryliu): fix topk
421*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
422*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
423*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
424*523fa7a6SAndroid Build Coastguard Worker
425*523fa7a6SAndroid Build Coastguard Worker        for nd in new_gm.graph.nodes:
426*523fa7a6SAndroid Build Coastguard Worker            if nd.target is torch.ops.aten.topk.values:
427*523fa7a6SAndroid Build Coastguard Worker                break
428*523fa7a6SAndroid Build Coastguard Worker
429*523fa7a6SAndroid Build Coastguard Worker        val = nd.meta["val"]
430*523fa7a6SAndroid Build Coastguard Worker
431*523fa7a6SAndroid Build Coastguard Worker        # We must return a spec which is a list of a signle TensorSpec item.
432*523fa7a6SAndroid Build Coastguard Worker        # Returning the TensorSpec item directly cause future getitem op fails.
433*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(val, (tuple, list)))
434*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(2, len(val))
435*523fa7a6SAndroid Build Coastguard Worker
436*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_to_copy(self) -> None:
437*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
438*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
439*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
440*523fa7a6SAndroid Build Coastguard Worker
441*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
442*523fa7a6SAndroid Build Coastguard Worker                return x.to(torch.int32)
443*523fa7a6SAndroid Build Coastguard Worker
444*523fa7a6SAndroid Build Coastguard Worker        model = Module()
445*523fa7a6SAndroid Build Coastguard Worker
446*523fa7a6SAndroid Build Coastguard Worker        inputs = torch.tensor(1.0, dtype=torch.float)
447*523fa7a6SAndroid Build Coastguard Worker        model_res = model(inputs)
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker        edge_dialect = to_edge(
450*523fa7a6SAndroid Build Coastguard Worker            export(
451*523fa7a6SAndroid Build Coastguard Worker                model,
452*523fa7a6SAndroid Build Coastguard Worker                (inputs,),
453*523fa7a6SAndroid Build Coastguard Worker            )
454*523fa7a6SAndroid Build Coastguard Worker        )
455*523fa7a6SAndroid Build Coastguard Worker        edge_res = edge_dialect.exported_program().module()(inputs)
456*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(model_res, edge_res))
457*523fa7a6SAndroid Build Coastguard Worker
458*523fa7a6SAndroid Build Coastguard Worker    def test_export_pass(self) -> None:
459*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
460*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
461*523fa7a6SAndroid Build Coastguard Worker                y = torch.cat([x, x])
462*523fa7a6SAndroid Build Coastguard Worker                return torch.ops.aten.tensor_split.sections(y, 2)
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
465*523fa7a6SAndroid Build Coastguard Worker
466*523fa7a6SAndroid Build Coastguard Worker        class NullPass(ExportPass):
467*523fa7a6SAndroid Build Coastguard Worker            pass
468*523fa7a6SAndroid Build Coastguard Worker
469*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
470*523fa7a6SAndroid Build Coastguard Worker            export(
471*523fa7a6SAndroid Build Coastguard Worker                f,
472*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2),),
473*523fa7a6SAndroid Build Coastguard Worker            ),
474*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(_check_ir_validity=False),
475*523fa7a6SAndroid Build Coastguard Worker        )  # TODO(larryliu): fix cat
476*523fa7a6SAndroid Build Coastguard Worker        new_prog = prog.transform([NullPass()])
477*523fa7a6SAndroid Build Coastguard Worker        new_nodes = new_prog.exported_program().graph_module.graph.nodes
478*523fa7a6SAndroid Build Coastguard Worker        for node in new_nodes:
479*523fa7a6SAndroid Build Coastguard Worker            if node.op != "call_function":
480*523fa7a6SAndroid Build Coastguard Worker                continue
481*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(hasattr(node, "stack_trace"))
482*523fa7a6SAndroid Build Coastguard Worker            self.assertIsNotNone(node.stack_trace)
483*523fa7a6SAndroid Build Coastguard Worker
484*523fa7a6SAndroid Build Coastguard Worker        old_nodes = prog.exported_program().graph_module.graph.nodes
485*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(new_nodes), len(old_nodes))
486*523fa7a6SAndroid Build Coastguard Worker        for new_node, old_node in zip(new_nodes, old_nodes):
487*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(new_node.op, old_node.op)
488*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(new_node.target, old_node.target)
489*523fa7a6SAndroid Build Coastguard Worker
490*523fa7a6SAndroid Build Coastguard Worker    def test_export_pass_pt2(self) -> None:
491*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
492*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
493*523fa7a6SAndroid Build Coastguard Worker                y = torch.cat([x, x])
494*523fa7a6SAndroid Build Coastguard Worker                return torch.ops.aten.tensor_split.sections(y, 2)
495*523fa7a6SAndroid Build Coastguard Worker
496*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
497*523fa7a6SAndroid Build Coastguard Worker
498*523fa7a6SAndroid Build Coastguard Worker        class NullPass(ExportPass):
499*523fa7a6SAndroid Build Coastguard Worker            pass
500*523fa7a6SAndroid Build Coastguard Worker
501*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
502*523fa7a6SAndroid Build Coastguard Worker            export(
503*523fa7a6SAndroid Build Coastguard Worker                f,
504*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2),),
505*523fa7a6SAndroid Build Coastguard Worker            ),
506*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
507*523fa7a6SAndroid Build Coastguard Worker        )
508*523fa7a6SAndroid Build Coastguard Worker        new_prog = prog.transform([NullPass()])
509*523fa7a6SAndroid Build Coastguard Worker        new_nodes = new_prog.exported_program().graph_module.graph.nodes
510*523fa7a6SAndroid Build Coastguard Worker        for node in new_nodes:
511*523fa7a6SAndroid Build Coastguard Worker            if node.op != "call_function":
512*523fa7a6SAndroid Build Coastguard Worker                continue
513*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(hasattr(node, "stack_trace"))
514*523fa7a6SAndroid Build Coastguard Worker            self.assertIsNotNone(node.stack_trace)
515*523fa7a6SAndroid Build Coastguard Worker
516*523fa7a6SAndroid Build Coastguard Worker        old_nodes = prog.exported_program().graph_module.graph.nodes
517*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(new_nodes), len(old_nodes))
518*523fa7a6SAndroid Build Coastguard Worker        for new_node, old_node in zip(new_nodes, old_nodes):
519*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(new_node.op, old_node.op)
520*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(new_node.target, old_node.target)
521*523fa7a6SAndroid Build Coastguard Worker
522*523fa7a6SAndroid Build Coastguard Worker    def test_export_scalar_to_tensor_pass(self) -> None:
523*523fa7a6SAndroid Build Coastguard Worker        class Mul(torch.nn.Module):
524*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
525*523fa7a6SAndroid Build Coastguard Worker                return x * 3.14
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker        mul = Mul()
528*523fa7a6SAndroid Build Coastguard Worker
529*523fa7a6SAndroid Build Coastguard Worker        expo_prog = to_edge(export(mul, (torch.ones(1),)))
530*523fa7a6SAndroid Build Coastguard Worker        new_prog = expo_prog.transform([ScalarToTensorPass()])
531*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_prog.exported_program().graph_module)
532*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = new_prog.exported_program().graph_module
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker        inp = torch.zeros(1)
535*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
536*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
537*523fa7a6SAndroid Build Coastguard Worker                expo_prog.exported_program().module()(inp),
538*523fa7a6SAndroid Build Coastguard Worker                new_prog.exported_program().module()(inp),
539*523fa7a6SAndroid Build Coastguard Worker            )
540*523fa7a6SAndroid Build Coastguard Worker        )
541*523fa7a6SAndroid Build Coastguard Worker        for node in new_graph_module.graph.nodes:
542*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
543*523fa7a6SAndroid Build Coastguard Worker                for arg in node.args + tuple(node.kwargs.values()):
544*523fa7a6SAndroid Build Coastguard Worker                    self.assertFalse(isinstance(arg, float))
545*523fa7a6SAndroid Build Coastguard Worker
546*523fa7a6SAndroid Build Coastguard Worker    def test_remove_mixed_types_symfloats(self) -> None:
547*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
548*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
549*523fa7a6SAndroid Build Coastguard Worker                return torch.nn.functional.interpolate(
550*523fa7a6SAndroid Build Coastguard Worker                    x,
551*523fa7a6SAndroid Build Coastguard Worker                    size=(x.shape[2] * 2, x.shape[3] * 3),
552*523fa7a6SAndroid Build Coastguard Worker                    mode="bilinear",
553*523fa7a6SAndroid Build Coastguard Worker                    align_corners=False,
554*523fa7a6SAndroid Build Coastguard Worker                    antialias=False,
555*523fa7a6SAndroid Build Coastguard Worker                )
556*523fa7a6SAndroid Build Coastguard Worker
557*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker        example_inputs = (torch.randn(2, 3, 4, 5),)
560*523fa7a6SAndroid Build Coastguard Worker
561*523fa7a6SAndroid Build Coastguard Worker        gm = to_edge(
562*523fa7a6SAndroid Build Coastguard Worker            export(
563*523fa7a6SAndroid Build Coastguard Worker                f,
564*523fa7a6SAndroid Build Coastguard Worker                example_inputs,
565*523fa7a6SAndroid Build Coastguard Worker            )
566*523fa7a6SAndroid Build Coastguard Worker        )
567*523fa7a6SAndroid Build Coastguard Worker        new_gm = gm.transform(
568*523fa7a6SAndroid Build Coastguard Worker            [ReplaceSymSizeOpPass(), ScalarToTensorPass(), RemoveMixedTypeOperators()]
569*523fa7a6SAndroid Build Coastguard Worker        )
570*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm.exported_program().graph_module)
571*523fa7a6SAndroid Build Coastguard Worker
572*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
573*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
574*523fa7a6SAndroid Build Coastguard Worker                gm.exported_program().module()(*example_inputs),
575*523fa7a6SAndroid Build Coastguard Worker                new_gm.exported_program().module()(*example_inputs),
576*523fa7a6SAndroid Build Coastguard Worker            )
577*523fa7a6SAndroid Build Coastguard Worker        )
578*523fa7a6SAndroid Build Coastguard Worker
579*523fa7a6SAndroid Build Coastguard Worker    def test_spec_prop_pass(self) -> None:
580*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
581*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
582*523fa7a6SAndroid Build Coastguard Worker                return x + x
583*523fa7a6SAndroid Build Coastguard Worker
584*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
585*523fa7a6SAndroid Build Coastguard Worker
586*523fa7a6SAndroid Build Coastguard Worker        gm = (
587*523fa7a6SAndroid Build Coastguard Worker            to_edge(
588*523fa7a6SAndroid Build Coastguard Worker                export(
589*523fa7a6SAndroid Build Coastguard Worker                    f,
590*523fa7a6SAndroid Build Coastguard Worker                    (torch.ones(3, 2),),
591*523fa7a6SAndroid Build Coastguard Worker                )
592*523fa7a6SAndroid Build Coastguard Worker            )
593*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
594*523fa7a6SAndroid Build Coastguard Worker            .graph_module
595*523fa7a6SAndroid Build Coastguard Worker        )
596*523fa7a6SAndroid Build Coastguard Worker        new_gm = SpecPropPass()(gm)
597*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm)
598*523fa7a6SAndroid Build Coastguard Worker        new_nodes = new_gm.graph_module.graph.nodes
599*523fa7a6SAndroid Build Coastguard Worker        counter = 0
600*523fa7a6SAndroid Build Coastguard Worker        for node in new_nodes:
601*523fa7a6SAndroid Build Coastguard Worker            if node.op != "output":
602*523fa7a6SAndroid Build Coastguard Worker                continue
603*523fa7a6SAndroid Build Coastguard Worker            counter += 1
604*523fa7a6SAndroid Build Coastguard Worker            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])
605*523fa7a6SAndroid Build Coastguard Worker
606*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(counter, 1)
607*523fa7a6SAndroid Build Coastguard Worker
608*523fa7a6SAndroid Build Coastguard Worker    def test_spec_prop_pass_tuple_output(self) -> None:
609*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
610*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
611*523fa7a6SAndroid Build Coastguard Worker                return (x + x,)
612*523fa7a6SAndroid Build Coastguard Worker
613*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
614*523fa7a6SAndroid Build Coastguard Worker
615*523fa7a6SAndroid Build Coastguard Worker        gm = (
616*523fa7a6SAndroid Build Coastguard Worker            to_edge(
617*523fa7a6SAndroid Build Coastguard Worker                export(
618*523fa7a6SAndroid Build Coastguard Worker                    f,
619*523fa7a6SAndroid Build Coastguard Worker                    (torch.ones(3, 2),),
620*523fa7a6SAndroid Build Coastguard Worker                )
621*523fa7a6SAndroid Build Coastguard Worker            )
622*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
623*523fa7a6SAndroid Build Coastguard Worker            .graph_module
624*523fa7a6SAndroid Build Coastguard Worker        )
625*523fa7a6SAndroid Build Coastguard Worker        new_gm = SpecPropPass()(gm)
626*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm)
627*523fa7a6SAndroid Build Coastguard Worker        new_nodes = new_gm.graph_module.graph.nodes
628*523fa7a6SAndroid Build Coastguard Worker        counter = 0
629*523fa7a6SAndroid Build Coastguard Worker        for node in new_nodes:
630*523fa7a6SAndroid Build Coastguard Worker            if node.op != "output":
631*523fa7a6SAndroid Build Coastguard Worker                continue
632*523fa7a6SAndroid Build Coastguard Worker            counter += 1
633*523fa7a6SAndroid Build Coastguard Worker            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])
634*523fa7a6SAndroid Build Coastguard Worker
635*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(counter, 1)
636*523fa7a6SAndroid Build Coastguard Worker
637*523fa7a6SAndroid Build Coastguard Worker    def test_compile_fix_broken_ops(self) -> None:
638*523fa7a6SAndroid Build Coastguard Worker        # When pass an input of more than 4 dimensions to Linear
639*523fa7a6SAndroid Build Coastguard Worker        # aten._unsafe_view is used under the hood
640*523fa7a6SAndroid Build Coastguard Worker        x = torch.randn([2, 3, 4, 5])
641*523fa7a6SAndroid Build Coastguard Worker        model: torch.nn.Linear = torch.nn.Linear(5, 5)
642*523fa7a6SAndroid Build Coastguard Worker
643*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
644*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
645*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
646*523fa7a6SAndroid Build Coastguard Worker                self.model = model
647*523fa7a6SAndroid Build Coastguard Worker
648*523fa7a6SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
649*523fa7a6SAndroid Build Coastguard Worker                return self.model(inp)
650*523fa7a6SAndroid Build Coastguard Worker
651*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
652*523fa7a6SAndroid Build Coastguard Worker
653*523fa7a6SAndroid Build Coastguard Worker        # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
654*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
655*523fa7a6SAndroid Build Coastguard Worker            export(
656*523fa7a6SAndroid Build Coastguard Worker                f,
657*523fa7a6SAndroid Build Coastguard Worker                (x,),
658*523fa7a6SAndroid Build Coastguard Worker            ),
659*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
660*523fa7a6SAndroid Build Coastguard Worker        )
661*523fa7a6SAndroid Build Coastguard Worker        gm = prog.exported_program().graph_module
662*523fa7a6SAndroid Build Coastguard Worker        count_after = 0
663*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
664*523fa7a6SAndroid Build Coastguard Worker            if node.target == torch.ops.aten._unsafe_view.default:
665*523fa7a6SAndroid Build Coastguard Worker                count_after += 1
666*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_after, 0)
667*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))
668*523fa7a6SAndroid Build Coastguard Worker
669*523fa7a6SAndroid Build Coastguard Worker    def test_convert_symb_ops(self) -> None:
670*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
671*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
672*523fa7a6SAndroid Build Coastguard Worker                return torch.add(x, x.shape[0] - 1)
673*523fa7a6SAndroid Build Coastguard Worker
674*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
675*523fa7a6SAndroid Build Coastguard Worker
676*523fa7a6SAndroid Build Coastguard Worker        # Mark the 0th dimension of X as dynamic with a max value of 3.
677*523fa7a6SAndroid Build Coastguard Worker        dim_x = torch.export.Dim("dim_x", max=3)
678*523fa7a6SAndroid Build Coastguard Worker
679*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
680*523fa7a6SAndroid Build Coastguard Worker            export(
681*523fa7a6SAndroid Build Coastguard Worker                f,
682*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2),),
683*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes={"x": {0: dim_x}},
684*523fa7a6SAndroid Build Coastguard Worker            ),
685*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
686*523fa7a6SAndroid Build Coastguard Worker        )
687*523fa7a6SAndroid Build Coastguard Worker        new_prog = prog.transform([EdgeToBackendOpsPass()])
688*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_prog.exported_program().graph_module)
689*523fa7a6SAndroid Build Coastguard Worker        converted_gm = new_prog.exported_program().graph_module
690*523fa7a6SAndroid Build Coastguard Worker
691*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("torch.ops.aten.sym_size.int").check(
692*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_backend__ops_executorch_prim_sub_Scalar"
693*523fa7a6SAndroid Build Coastguard Worker        ).check_not("operator.sub").run(converted_gm.code)
694*523fa7a6SAndroid Build Coastguard Worker
695*523fa7a6SAndroid Build Coastguard Worker    def test_alloc_node_spec(self) -> None:
696*523fa7a6SAndroid Build Coastguard Worker        """
697*523fa7a6SAndroid Build Coastguard Worker        Make sure every memory.alloc node including those in sub graph modules
698*523fa7a6SAndroid Build Coastguard Worker        have a TensorSpec.
699*523fa7a6SAndroid Build Coastguard Worker        """
700*523fa7a6SAndroid Build Coastguard Worker        eager_model = FTMapBasic()
701*523fa7a6SAndroid Build Coastguard Worker        inputs = eager_model.get_random_inputs()
702*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
703*523fa7a6SAndroid Build Coastguard Worker            export(
704*523fa7a6SAndroid Build Coastguard Worker                eager_model,
705*523fa7a6SAndroid Build Coastguard Worker                inputs,
706*523fa7a6SAndroid Build Coastguard Worker            ),
707*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
708*523fa7a6SAndroid Build Coastguard Worker        )
709*523fa7a6SAndroid Build Coastguard Worker        passes = [
710*523fa7a6SAndroid Build Coastguard Worker            SpecPropPass(),
711*523fa7a6SAndroid Build Coastguard Worker            HintBasedSymShapeEvalPass(),
712*523fa7a6SAndroid Build Coastguard Worker        ]
713*523fa7a6SAndroid Build Coastguard Worker        new_prog = prog.transform(passes)
714*523fa7a6SAndroid Build Coastguard Worker
715*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
716*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
717*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
718*523fa7a6SAndroid Build Coastguard Worker
719*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = MemoryPlanningPass()(new_gm)
720*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_gm_res)
721*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
722*523fa7a6SAndroid Build Coastguard Worker
723*523fa7a6SAndroid Build Coastguard Worker        alloc_nodes = []
724*523fa7a6SAndroid Build Coastguard Worker        for subgm in new_gm.modules():
725*523fa7a6SAndroid Build Coastguard Worker            if isinstance(subgm, torch.fx.GraphModule):
726*523fa7a6SAndroid Build Coastguard Worker                for node in subgm.graph.nodes:
727*523fa7a6SAndroid Build Coastguard Worker                    if node.target == memory.alloc:
728*523fa7a6SAndroid Build Coastguard Worker                        alloc_nodes.append(node)
729*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(alloc_nodes) > 0)
730*523fa7a6SAndroid Build Coastguard Worker        for node in alloc_nodes:
731*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec))
732*523fa7a6SAndroid Build Coastguard Worker
733*523fa7a6SAndroid Build Coastguard Worker    def test_debug_pass_file_log(self) -> None:
734*523fa7a6SAndroid Build Coastguard Worker        eager_model = Mul()
735*523fa7a6SAndroid Build Coastguard Worker        inputs = eager_model.get_random_inputs()
736*523fa7a6SAndroid Build Coastguard Worker
737*523fa7a6SAndroid Build Coastguard Worker        # the debug pass works with a graph generated with make_fx directly
738*523fa7a6SAndroid Build Coastguard Worker        gm = make_fx(eager_model)(*inputs)
739*523fa7a6SAndroid Build Coastguard Worker
740*523fa7a6SAndroid Build Coastguard Worker        try:
741*523fa7a6SAndroid Build Coastguard Worker            fd, path = tempfile.mkstemp()
742*523fa7a6SAndroid Build Coastguard Worker
743*523fa7a6SAndroid Build Coastguard Worker            print(f"Write DebugPass output to {path}")
744*523fa7a6SAndroid Build Coastguard Worker            DebugPass(log_filename=path)(gm)
745*523fa7a6SAndroid Build Coastguard Worker            with open(path) as f:
746*523fa7a6SAndroid Build Coastguard Worker                file_cont = f.read()
747*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue("torch.ops.aten.mul" in file_cont)
748*523fa7a6SAndroid Build Coastguard Worker        finally:
749*523fa7a6SAndroid Build Coastguard Worker            os.close(fd)
750*523fa7a6SAndroid Build Coastguard Worker            os.unlink(path)
751*523fa7a6SAndroid Build Coastguard Worker
752*523fa7a6SAndroid Build Coastguard Worker    def test_dce_recursive(self) -> None:
753*523fa7a6SAndroid Build Coastguard Worker        eager_model = FTCondDeadCode()
754*523fa7a6SAndroid Build Coastguard Worker        inputs = eager_model.get_random_inputs()
755*523fa7a6SAndroid Build Coastguard Worker        gm = export(
756*523fa7a6SAndroid Build Coastguard Worker            eager_model,
757*523fa7a6SAndroid Build Coastguard Worker            inputs,
758*523fa7a6SAndroid Build Coastguard Worker        ).graph_module
759*523fa7a6SAndroid Build Coastguard Worker
760*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm))
761*523fa7a6SAndroid Build Coastguard Worker        dead_code_elimination_pass(gm)
762*523fa7a6SAndroid Build Coastguard Worker        gm.print_readable()
763*523fa7a6SAndroid Build Coastguard Worker        self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm))
764*523fa7a6SAndroid Build Coastguard Worker
765*523fa7a6SAndroid Build Coastguard Worker    def test_propagate_dynamic_shape(self) -> None:
766*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
767*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
768*523fa7a6SAndroid Build Coastguard Worker                y = x
769*523fa7a6SAndroid Build Coastguard Worker                for _ in range(2):
770*523fa7a6SAndroid Build Coastguard Worker                    y = y + x
771*523fa7a6SAndroid Build Coastguard Worker                return y
772*523fa7a6SAndroid Build Coastguard Worker
773*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
774*523fa7a6SAndroid Build Coastguard Worker
775*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
776*523fa7a6SAndroid Build Coastguard Worker            export(
777*523fa7a6SAndroid Build Coastguard Worker                f,
778*523fa7a6SAndroid Build Coastguard Worker                (torch.rand(5),),
779*523fa7a6SAndroid Build Coastguard Worker            ),
780*523fa7a6SAndroid Build Coastguard Worker            # missing dispatch key
781*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
782*523fa7a6SAndroid Build Coastguard Worker        ).transform(propagate_dynamic_shape())
783*523fa7a6SAndroid Build Coastguard Worker        gm = prog.exported_program().graph_module
784*523fa7a6SAndroid Build Coastguard Worker        nspec = 0
785*523fa7a6SAndroid Build Coastguard Worker        for n in gm.graph.nodes:
786*523fa7a6SAndroid Build Coastguard Worker            for spec in pytree.tree_flatten(n.meta["spec"])[0]:
787*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(all(isinstance(x, int) for x in spec.shape))
788*523fa7a6SAndroid Build Coastguard Worker                nspec += 1
789*523fa7a6SAndroid Build Coastguard Worker
790*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(nspec > 0)
791*523fa7a6SAndroid Build Coastguard Worker
792*523fa7a6SAndroid Build Coastguard Worker    def test_losing_symbolic_info(self) -> None:
793*523fa7a6SAndroid Build Coastguard Worker        """
794*523fa7a6SAndroid Build Coastguard Worker        Guard against an issue that after calling ConvertSymbolicOpsPass(),
795*523fa7a6SAndroid Build Coastguard Worker        future ExportPass will encounter symbolic information loss.
796*523fa7a6SAndroid Build Coastguard Worker        """
797*523fa7a6SAndroid Build Coastguard Worker
798*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
799*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
800*523fa7a6SAndroid Build Coastguard Worker                return torch.add(x, x.shape[0] - 1)
801*523fa7a6SAndroid Build Coastguard Worker
802*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
803*523fa7a6SAndroid Build Coastguard Worker
804*523fa7a6SAndroid Build Coastguard Worker        dim_x = torch.export.Dim("dim_x", max=3)
805*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
806*523fa7a6SAndroid Build Coastguard Worker            export(
807*523fa7a6SAndroid Build Coastguard Worker                f,
808*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2),),
809*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes={"x": {0: dim_x}},
810*523fa7a6SAndroid Build Coastguard Worker            ),
811*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
812*523fa7a6SAndroid Build Coastguard Worker        )
813*523fa7a6SAndroid Build Coastguard Worker
814*523fa7a6SAndroid Build Coastguard Worker        new_prog = prog.transform([EdgeToBackendOpsPass()])
815*523fa7a6SAndroid Build Coastguard Worker        gm = new_prog.exported_program().graph_module
816*523fa7a6SAndroid Build Coastguard Worker        gm.print_readable()
817*523fa7a6SAndroid Build Coastguard Worker        *_, ones, out = gm.graph.nodes
818*523fa7a6SAndroid Build Coastguard Worker        print(f"Before ExportPass: {ones.format_node()}")
819*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
820*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)
821*523fa7a6SAndroid Build Coastguard Worker
822*523fa7a6SAndroid Build Coastguard Worker        new_prog = new_prog.transform([ExportPass()])
823*523fa7a6SAndroid Build Coastguard Worker        gm = new_prog.exported_program().graph_module
824*523fa7a6SAndroid Build Coastguard Worker        gm.print_readable()
825*523fa7a6SAndroid Build Coastguard Worker        *_, ones, out = gm.graph.nodes
826*523fa7a6SAndroid Build Coastguard Worker        print(f"After ExportPass: {ones.format_node()}")
827*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
828*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)
829*523fa7a6SAndroid Build Coastguard Worker
830*523fa7a6SAndroid Build Coastguard Worker    def test_to_edge_with_edge_ops(self) -> None:
831*523fa7a6SAndroid Build Coastguard Worker        x = torch.randn([2, 3, 4, 5])
832*523fa7a6SAndroid Build Coastguard Worker
833*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
834*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
835*523fa7a6SAndroid Build Coastguard Worker                return x + x
836*523fa7a6SAndroid Build Coastguard Worker
837*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
838*523fa7a6SAndroid Build Coastguard Worker
839*523fa7a6SAndroid Build Coastguard Worker        gm = (
840*523fa7a6SAndroid Build Coastguard Worker            to_edge(
841*523fa7a6SAndroid Build Coastguard Worker                export(
842*523fa7a6SAndroid Build Coastguard Worker                    f,
843*523fa7a6SAndroid Build Coastguard Worker                    (x,),
844*523fa7a6SAndroid Build Coastguard Worker                )
845*523fa7a6SAndroid Build Coastguard Worker            )
846*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
847*523fa7a6SAndroid Build Coastguard Worker            .graph_module
848*523fa7a6SAndroid Build Coastguard Worker        )
849*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
850*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
851*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(type(node.target), EdgeOpOverload)
852*523fa7a6SAndroid Build Coastguard Worker
853*523fa7a6SAndroid Build Coastguard Worker    # TODO(T143084047)
854*523fa7a6SAndroid Build Coastguard Worker    @unittest.expectedFailure
855*523fa7a6SAndroid Build Coastguard Worker    def test_backend_fused_op_retraceable(self) -> None:
856*523fa7a6SAndroid Build Coastguard Worker        """This test makes sure the backend op is still retraceable, with the pattern being registered as kernel."""
857*523fa7a6SAndroid Build Coastguard Worker
858*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
859*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
860*523fa7a6SAndroid Build Coastguard Worker                z = x + y
861*523fa7a6SAndroid Build Coastguard Worker                return torch.ops.aten.relu.default(z)
862*523fa7a6SAndroid Build Coastguard Worker
863*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
864*523fa7a6SAndroid Build Coastguard Worker
865*523fa7a6SAndroid Build Coastguard Worker        gm = export(
866*523fa7a6SAndroid Build Coastguard Worker            f,
867*523fa7a6SAndroid Build Coastguard Worker            (
868*523fa7a6SAndroid Build Coastguard Worker                torch.randn(2, 2),
869*523fa7a6SAndroid Build Coastguard Worker                torch.randn(2, 2),
870*523fa7a6SAndroid Build Coastguard Worker            ),
871*523fa7a6SAndroid Build Coastguard Worker        )
872*523fa7a6SAndroid Build Coastguard Worker        # should look like:
873*523fa7a6SAndroid Build Coastguard Worker        # graph():
874*523fa7a6SAndroid Build Coastguard Worker        #     %ph_0 : [#users=1] = placeholder[target=ph_0]
875*523fa7a6SAndroid Build Coastguard Worker        #     %ph_1 : [#users=1] = placeholder[target=ph_1]
876*523fa7a6SAndroid Build Coastguard Worker        #     %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, %ph_1), kwargs = {})
877*523fa7a6SAndroid Build Coastguard Worker        #     %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {})
878*523fa7a6SAndroid Build Coastguard Worker        #     return [relu_default]
879*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("torch.ops.aten.add.Tensor").check(
880*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.relu.default"
881*523fa7a6SAndroid Build Coastguard Worker        ).run(gm.graph_module.code)
882*523fa7a6SAndroid Build Coastguard Worker
883*523fa7a6SAndroid Build Coastguard Worker        class AddReluFusionPass(ExportPass):
884*523fa7a6SAndroid Build Coastguard Worker            def call(self, graph_module: GraphModule) -> PassResult:
885*523fa7a6SAndroid Build Coastguard Worker                # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
886*523fa7a6SAndroid Build Coastguard Worker                @bind_pattern_to_op(lib, "add_relu")
887*523fa7a6SAndroid Build Coastguard Worker                def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
888*523fa7a6SAndroid Build Coastguard Worker                    z = torch.ops.aten.add.Tensor(x, y)
889*523fa7a6SAndroid Build Coastguard Worker                    out = torch.ops.aten.relu.default(z)
890*523fa7a6SAndroid Build Coastguard Worker                    return out
891*523fa7a6SAndroid Build Coastguard Worker
892*523fa7a6SAndroid Build Coastguard Worker                def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
893*523fa7a6SAndroid Build Coastguard Worker                    return ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default(x, y)
894*523fa7a6SAndroid Build Coastguard Worker
895*523fa7a6SAndroid Build Coastguard Worker                subgraph_rewriter.replace_pattern(graph_module, pattern, replacement)
896*523fa7a6SAndroid Build Coastguard Worker                return PassResult(graph_module, True)
897*523fa7a6SAndroid Build Coastguard Worker
898*523fa7a6SAndroid Build Coastguard Worker        # TODO: larryliu this pass needs to be in to_executorch()
899*523fa7a6SAndroid Build Coastguard Worker        class OpReplacePass(ExportPass):
900*523fa7a6SAndroid Build Coastguard Worker            def call_operator(self, op, args, kwargs, meta):
901*523fa7a6SAndroid Build Coastguard Worker                if op == torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default:
902*523fa7a6SAndroid Build Coastguard Worker                    return super().call_operator(
903*523fa7a6SAndroid Build Coastguard Worker                        ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default,
904*523fa7a6SAndroid Build Coastguard Worker                        args,
905*523fa7a6SAndroid Build Coastguard Worker                        kwargs,
906*523fa7a6SAndroid Build Coastguard Worker                        meta,
907*523fa7a6SAndroid Build Coastguard Worker                    )
908*523fa7a6SAndroid Build Coastguard Worker                return super().call_operator(op, args, kwargs, meta)
909*523fa7a6SAndroid Build Coastguard Worker
910*523fa7a6SAndroid Build Coastguard Worker        gm_lowered = to_edge(
911*523fa7a6SAndroid Build Coastguard Worker            gm,
912*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(
913*523fa7a6SAndroid Build Coastguard Worker                _check_ir_validity=False,
914*523fa7a6SAndroid Build Coastguard Worker            ),
915*523fa7a6SAndroid Build Coastguard Worker        ).transform([AddReluFusionPass(), OpReplacePass()])
916*523fa7a6SAndroid Build Coastguard Worker
917*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check(
918*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default"
919*523fa7a6SAndroid Build Coastguard Worker        ).run(gm_lowered.exported_program().graph_module.code)
920*523fa7a6SAndroid Build Coastguard Worker        # lowered module:
921*523fa7a6SAndroid Build Coastguard Worker        # def forward(self, ph_0, ph_1):
922*523fa7a6SAndroid Build Coastguard Worker        #     do_not_use_test_only_add_relu_default = executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default(ph_0, ph_1);  ph_0 = ph_1 = None
923*523fa7a6SAndroid Build Coastguard Worker        #     return [do_not_use_test_only_add_relu_default]
924*523fa7a6SAndroid Build Coastguard Worker
925*523fa7a6SAndroid Build Coastguard Worker        # Retrace:
926*523fa7a6SAndroid Build Coastguard Worker        # If not backend op retrace will error out because no CPU/CompositeExplicitAutograd kernel registered.
927*523fa7a6SAndroid Build Coastguard Worker        gm_retraced = to_edge(
928*523fa7a6SAndroid Build Coastguard Worker            export(
929*523fa7a6SAndroid Build Coastguard Worker                gm_lowered.exported_program().module(),
930*523fa7a6SAndroid Build Coastguard Worker                (
931*523fa7a6SAndroid Build Coastguard Worker                    torch.randn(2, 2),
932*523fa7a6SAndroid Build Coastguard Worker                    torch.randn(2, 2),
933*523fa7a6SAndroid Build Coastguard Worker                ),
934*523fa7a6SAndroid Build Coastguard Worker            )
935*523fa7a6SAndroid Build Coastguard Worker        )
936*523fa7a6SAndroid Build Coastguard Worker        # Retrace-able, the graph "promote" back to ATen dialect, showing up add and relu, which is expected.
937*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("torch.ops.aten.add.Tensor").check(
938*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.relu.default"
939*523fa7a6SAndroid Build Coastguard Worker        ).run(gm_retraced.exported_program().graph_module.code)
940*523fa7a6SAndroid Build Coastguard Worker
941*523fa7a6SAndroid Build Coastguard Worker    def test_debug_handle_generator_pass(self) -> None:
942*523fa7a6SAndroid Build Coastguard Worker        eager_model = MLP(2, output_size=4)
943*523fa7a6SAndroid Build Coastguard Worker        inputs = eager_model.get_random_inputs()
944*523fa7a6SAndroid Build Coastguard Worker
945*523fa7a6SAndroid Build Coastguard Worker        graph_module = (
946*523fa7a6SAndroid Build Coastguard Worker            to_edge(
947*523fa7a6SAndroid Build Coastguard Worker                export(
948*523fa7a6SAndroid Build Coastguard Worker                    eager_model,
949*523fa7a6SAndroid Build Coastguard Worker                    inputs,
950*523fa7a6SAndroid Build Coastguard Worker                )
951*523fa7a6SAndroid Build Coastguard Worker            )
952*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
953*523fa7a6SAndroid Build Coastguard Worker            .graph_module
954*523fa7a6SAndroid Build Coastguard Worker        )
955*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes:
956*523fa7a6SAndroid Build Coastguard Worker            self.assertIn("debug_handle", node.meta)
957*523fa7a6SAndroid Build Coastguard Worker        ScalarToTensorPass()(graph_module)
958*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes:
959*523fa7a6SAndroid Build Coastguard Worker            self.assertIn("debug_handle", node.meta)
960*523fa7a6SAndroid Build Coastguard Worker
961*523fa7a6SAndroid Build Coastguard Worker    def test_generate_missing_debug_handles(self) -> None:
962*523fa7a6SAndroid Build Coastguard Worker        eager_model = MLP(2, output_size=4)
963*523fa7a6SAndroid Build Coastguard Worker        inputs = eager_model.get_random_inputs()
964*523fa7a6SAndroid Build Coastguard Worker
965*523fa7a6SAndroid Build Coastguard Worker        ep = to_edge(
966*523fa7a6SAndroid Build Coastguard Worker            export(
967*523fa7a6SAndroid Build Coastguard Worker                eager_model,
968*523fa7a6SAndroid Build Coastguard Worker                inputs,
969*523fa7a6SAndroid Build Coastguard Worker            )
970*523fa7a6SAndroid Build Coastguard Worker        ).exported_program()
971*523fa7a6SAndroid Build Coastguard Worker
972*523fa7a6SAndroid Build Coastguard Worker        list(ep.graph.nodes)[0].meta.pop("debug_handle")
973*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
974*523fa7a6SAndroid Build Coastguard Worker        generate_missing_debug_handles(ep)
975*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
976*523fa7a6SAndroid Build Coastguard Worker
977*523fa7a6SAndroid Build Coastguard Worker    def test_debug_handle_generator_pass_with_control_flow(self) -> None:
978*523fa7a6SAndroid Build Coastguard Worker        def true_nested(y: torch.Tensor) -> torch.Tensor:
979*523fa7a6SAndroid Build Coastguard Worker            y = y + y
980*523fa7a6SAndroid Build Coastguard Worker            y = torch.mm(y, y)
981*523fa7a6SAndroid Build Coastguard Worker            return y
982*523fa7a6SAndroid Build Coastguard Worker
983*523fa7a6SAndroid Build Coastguard Worker        def false_nested(y: torch.Tensor) -> torch.Tensor:
984*523fa7a6SAndroid Build Coastguard Worker            return torch.mm(y, y)
985*523fa7a6SAndroid Build Coastguard Worker
986*523fa7a6SAndroid Build Coastguard Worker        def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
987*523fa7a6SAndroid Build Coastguard Worker            z = control_flow.cond(pred2, true_nested, false_nested, [x])
988*523fa7a6SAndroid Build Coastguard Worker            return x + z
989*523fa7a6SAndroid Build Coastguard Worker
990*523fa7a6SAndroid Build Coastguard Worker        def false_fn(x: torch.Tensor, _) -> torch.Tensor:
991*523fa7a6SAndroid Build Coastguard Worker            return x.cos()
992*523fa7a6SAndroid Build Coastguard Worker
993*523fa7a6SAndroid Build Coastguard Worker        def map_fn(
994*523fa7a6SAndroid Build Coastguard Worker            x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
995*523fa7a6SAndroid Build Coastguard Worker        ) -> torch.Tensor:
996*523fa7a6SAndroid Build Coastguard Worker            x = x.cos()
997*523fa7a6SAndroid Build Coastguard Worker            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
998*523fa7a6SAndroid Build Coastguard Worker            x = x + y
999*523fa7a6SAndroid Build Coastguard Worker            return x.sin()
1000*523fa7a6SAndroid Build Coastguard Worker
1001*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1002*523fa7a6SAndroid Build Coastguard Worker            def forward(
1003*523fa7a6SAndroid Build Coastguard Worker                self,
1004*523fa7a6SAndroid Build Coastguard Worker                xs: torch.Tensor,
1005*523fa7a6SAndroid Build Coastguard Worker                pred1: torch.Tensor,
1006*523fa7a6SAndroid Build Coastguard Worker                pred2: torch.Tensor,
1007*523fa7a6SAndroid Build Coastguard Worker                y: torch.Tensor,
1008*523fa7a6SAndroid Build Coastguard Worker            ) -> torch.Tensor:
1009*523fa7a6SAndroid Build Coastguard Worker                y = torch.mm(y, y)
1010*523fa7a6SAndroid Build Coastguard Worker                return control_flow.map(map_fn, xs, pred1, pred2, y)
1011*523fa7a6SAndroid Build Coastguard Worker
1012*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
1013*523fa7a6SAndroid Build Coastguard Worker
1014*523fa7a6SAndroid Build Coastguard Worker        inputs = (
1015*523fa7a6SAndroid Build Coastguard Worker            torch.ones(2, 2),
1016*523fa7a6SAndroid Build Coastguard Worker            torch.tensor([False]),
1017*523fa7a6SAndroid Build Coastguard Worker            torch.tensor([False]),
1018*523fa7a6SAndroid Build Coastguard Worker            torch.ones(2, 2),
1019*523fa7a6SAndroid Build Coastguard Worker        )
1020*523fa7a6SAndroid Build Coastguard Worker
1021*523fa7a6SAndroid Build Coastguard Worker        ep = to_edge(
1022*523fa7a6SAndroid Build Coastguard Worker            export(
1023*523fa7a6SAndroid Build Coastguard Worker                f,
1024*523fa7a6SAndroid Build Coastguard Worker                inputs,
1025*523fa7a6SAndroid Build Coastguard Worker            )
1026*523fa7a6SAndroid Build Coastguard Worker        ).exported_program()
1027*523fa7a6SAndroid Build Coastguard Worker        graph_module = ep.graph_module
1028*523fa7a6SAndroid Build Coastguard Worker
1029*523fa7a6SAndroid Build Coastguard Worker        def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
1030*523fa7a6SAndroid Build Coastguard Worker            queue = [graph_module]
1031*523fa7a6SAndroid Build Coastguard Worker            while queue:
1032*523fa7a6SAndroid Build Coastguard Worker                current_graph_module = queue.pop(0)
1033*523fa7a6SAndroid Build Coastguard Worker                for node in current_graph_module.graph.nodes:
1034*523fa7a6SAndroid Build Coastguard Worker                    self.assertIn("debug_handle", node.meta)
1035*523fa7a6SAndroid Build Coastguard Worker                control_flow_submodules = [
1036*523fa7a6SAndroid Build Coastguard Worker                    submodule
1037*523fa7a6SAndroid Build Coastguard Worker                    for _, submodule, _ in get_control_flow_submodules(
1038*523fa7a6SAndroid Build Coastguard Worker                        current_graph_module
1039*523fa7a6SAndroid Build Coastguard Worker                    )
1040*523fa7a6SAndroid Build Coastguard Worker                ]
1041*523fa7a6SAndroid Build Coastguard Worker                queue.extend(control_flow_submodules)
1042*523fa7a6SAndroid Build Coastguard Worker
1043*523fa7a6SAndroid Build Coastguard Worker        DebugHandleGeneratorPass()(graph_module)
1044*523fa7a6SAndroid Build Coastguard Worker        check_debug_handle_metadata(graph_module)
1045*523fa7a6SAndroid Build Coastguard Worker        generate_missing_debug_handles(ep)
1046*523fa7a6SAndroid Build Coastguard Worker
1047*523fa7a6SAndroid Build Coastguard Worker        # Check debug handle still preserved after ScalarToTensorPass
1048*523fa7a6SAndroid Build Coastguard Worker        ScalarToTensorPass()(graph_module)
1049*523fa7a6SAndroid Build Coastguard Worker        check_debug_handle_metadata(graph_module)
1050*523fa7a6SAndroid Build Coastguard Worker
1051*523fa7a6SAndroid Build Coastguard Worker    def test_symint_conversion(self) -> None:
1052*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1053*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1054*523fa7a6SAndroid Build Coastguard Worker                return torch.add(x, x.shape[0] - 1)
1055*523fa7a6SAndroid Build Coastguard Worker
1056*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
1057*523fa7a6SAndroid Build Coastguard Worker
1058*523fa7a6SAndroid Build Coastguard Worker        dim_x = torch.export.Dim("dim_x", max=3)
1059*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(
1060*523fa7a6SAndroid Build Coastguard Worker            export(
1061*523fa7a6SAndroid Build Coastguard Worker                f,
1062*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2),),
1063*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes={"x": {0: dim_x}},
1064*523fa7a6SAndroid Build Coastguard Worker            ),
1065*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1066*523fa7a6SAndroid Build Coastguard Worker        )
1067*523fa7a6SAndroid Build Coastguard Worker        prog = prog.transform([SymToTensorPass()])
1068*523fa7a6SAndroid Build Coastguard Worker
1069*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("torch.ops.aten.scalar_tensor.default").run(
1070*523fa7a6SAndroid Build Coastguard Worker            prog.exported_program().graph_module.code
1071*523fa7a6SAndroid Build Coastguard Worker        )
1072*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1073*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
1074*523fa7a6SAndroid Build Coastguard Worker                f(torch.ones(3, 2)), prog.exported_program().module()(torch.ones(3, 2))
1075*523fa7a6SAndroid Build Coastguard Worker            )
1076*523fa7a6SAndroid Build Coastguard Worker        )
1077*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1078*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
1079*523fa7a6SAndroid Build Coastguard Worker                f(torch.zeros(3, 2)),
1080*523fa7a6SAndroid Build Coastguard Worker                prog.exported_program().module()(torch.zeros(3, 2)),
1081*523fa7a6SAndroid Build Coastguard Worker            )
1082*523fa7a6SAndroid Build Coastguard Worker        )
1083*523fa7a6SAndroid Build Coastguard Worker
1084*523fa7a6SAndroid Build Coastguard Worker    def test_remove_assert_pass(self) -> None:
1085*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1086*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1087*523fa7a6SAndroid Build Coastguard Worker                assert x.shape[0] == 5
1088*523fa7a6SAndroid Build Coastguard Worker                return x * x
1089*523fa7a6SAndroid Build Coastguard Worker
1090*523fa7a6SAndroid Build Coastguard Worker        f = Foo()
1091*523fa7a6SAndroid Build Coastguard Worker
1092*523fa7a6SAndroid Build Coastguard Worker        gm = to_edge(
1093*523fa7a6SAndroid Build Coastguard Worker            export(
1094*523fa7a6SAndroid Build Coastguard Worker                f,
1095*523fa7a6SAndroid Build Coastguard Worker                (torch.randn(5),),
1096*523fa7a6SAndroid Build Coastguard Worker            ),
1097*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1098*523fa7a6SAndroid Build Coastguard Worker        )
1099*523fa7a6SAndroid Build Coastguard Worker        new_gm = gm.transform([RemoveGraphAssertsPass()])
1100*523fa7a6SAndroid Build Coastguard Worker        num_asserts = [
1101*523fa7a6SAndroid Build Coastguard Worker            node
1102*523fa7a6SAndroid Build Coastguard Worker            for node in new_gm.exported_program().graph.nodes
1103*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function"
1104*523fa7a6SAndroid Build Coastguard Worker            and node.target == torch.ops.aten._assert_async.msg
1105*523fa7a6SAndroid Build Coastguard Worker        ]
1106*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(num_asserts), 0)
1107*523fa7a6SAndroid Build Coastguard Worker
1108*523fa7a6SAndroid Build Coastguard Worker    def test_arange(self) -> None:
1109*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1110*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1111*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1112*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.ones(2)
1113*523fa7a6SAndroid Build Coastguard Worker
1114*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1115*523fa7a6SAndroid Build Coastguard Worker                return torch.arange(start=0, end=2) + x
1116*523fa7a6SAndroid Build Coastguard Worker
1117*523fa7a6SAndroid Build Coastguard Worker        _ = to_edge(
1118*523fa7a6SAndroid Build Coastguard Worker            export(
1119*523fa7a6SAndroid Build Coastguard Worker                M(),
1120*523fa7a6SAndroid Build Coastguard Worker                (torch.randn(2),),
1121*523fa7a6SAndroid Build Coastguard Worker            )
1122*523fa7a6SAndroid Build Coastguard Worker        ).to_executorch()
1123*523fa7a6SAndroid Build Coastguard Worker
1124*523fa7a6SAndroid Build Coastguard Worker    def test_replace_slice(self) -> None:
1125*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1126*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1127*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1128*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.ones(10)
1129*523fa7a6SAndroid Build Coastguard Worker
1130*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1131*523fa7a6SAndroid Build Coastguard Worker                return self.a[:2] + x
1132*523fa7a6SAndroid Build Coastguard Worker
1133*523fa7a6SAndroid Build Coastguard Worker        gm = (
1134*523fa7a6SAndroid Build Coastguard Worker            to_edge(
1135*523fa7a6SAndroid Build Coastguard Worker                export(
1136*523fa7a6SAndroid Build Coastguard Worker                    M(),
1137*523fa7a6SAndroid Build Coastguard Worker                    (torch.randn(2),),
1138*523fa7a6SAndroid Build Coastguard Worker                )
1139*523fa7a6SAndroid Build Coastguard Worker            )
1140*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
1141*523fa7a6SAndroid Build Coastguard Worker            .graph_module
1142*523fa7a6SAndroid Build Coastguard Worker        )
1143*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check(
1144*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
1145*523fa7a6SAndroid Build Coastguard Worker        ).run(gm.code)
1146*523fa7a6SAndroid Build Coastguard Worker
1147*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_for_add(self) -> None:
1148*523fa7a6SAndroid Build Coastguard Worker        class Add(torch.nn.Module):
1149*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1150*523fa7a6SAndroid Build Coastguard Worker                return x + 3
1151*523fa7a6SAndroid Build Coastguard Worker
1152*523fa7a6SAndroid Build Coastguard Worker        add = Add()
1153*523fa7a6SAndroid Build Coastguard Worker
1154*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(
1155*523fa7a6SAndroid Build Coastguard Worker            export(add, (torch.ones(1),)),
1156*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(_skip_dim_order=False),
1157*523fa7a6SAndroid Build Coastguard Worker        )
1158*523fa7a6SAndroid Build Coastguard Worker        edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()])
1159*523fa7a6SAndroid Build Coastguard Worker        exported_program = lift_constant_tensor_pass(edge.exported_program())
1160*523fa7a6SAndroid Build Coastguard Worker
1161*523fa7a6SAndroid Build Coastguard Worker        # Check there is a lifted tensor followed by a to_copy node
1162*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("_lifted_tensor_constant0").check(
1163*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1164*523fa7a6SAndroid Build Coastguard Worker        ).run(exported_program.graph_module.code)
1165*523fa7a6SAndroid Build Coastguard Worker
1166*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(exported_program)
1167*523fa7a6SAndroid Build Coastguard Worker
1168*523fa7a6SAndroid Build Coastguard Worker        # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
1169*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_not("_lifted_tensor_constant").check(
1170*523fa7a6SAndroid Build Coastguard Worker            "_prop_tensor_constant0"
1171*523fa7a6SAndroid Build Coastguard Worker        ).check_not(
1172*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1173*523fa7a6SAndroid Build Coastguard Worker        ).run(
1174*523fa7a6SAndroid Build Coastguard Worker            new_ep.graph_module.code
1175*523fa7a6SAndroid Build Coastguard Worker        )
1176*523fa7a6SAndroid Build Coastguard Worker
1177*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_for_parameter(self) -> None:
1178*523fa7a6SAndroid Build Coastguard Worker        def count_additions(gm: torch.fx.GraphModule) -> int:
1179*523fa7a6SAndroid Build Coastguard Worker            return sum(
1180*523fa7a6SAndroid Build Coastguard Worker                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
1181*523fa7a6SAndroid Build Coastguard Worker            )
1182*523fa7a6SAndroid Build Coastguard Worker
1183*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1184*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1185*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1186*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))
1187*523fa7a6SAndroid Build Coastguard Worker
1188*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1189*523fa7a6SAndroid Build Coastguard Worker                b = self.a + self.a
1190*523fa7a6SAndroid Build Coastguard Worker                c = torch.cat([self.a, b])
1191*523fa7a6SAndroid Build Coastguard Worker                return (c + c) + x
1192*523fa7a6SAndroid Build Coastguard Worker
1193*523fa7a6SAndroid Build Coastguard Worker        aten = export(
1194*523fa7a6SAndroid Build Coastguard Worker            M(),
1195*523fa7a6SAndroid Build Coastguard Worker            (torch.zeros(2, 2, 3),),
1196*523fa7a6SAndroid Build Coastguard Worker        )
1197*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_additions(aten.graph_module), 3)
1198*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(aten)
1199*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_additions(new_ep.graph_module), 1)
1200*523fa7a6SAndroid Build Coastguard Worker
1201*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_graph_signature(self) -> None:
1202*523fa7a6SAndroid Build Coastguard Worker        def count_additions(gm: torch.fx.GraphModule) -> int:
1203*523fa7a6SAndroid Build Coastguard Worker            return sum(
1204*523fa7a6SAndroid Build Coastguard Worker                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
1205*523fa7a6SAndroid Build Coastguard Worker            )
1206*523fa7a6SAndroid Build Coastguard Worker
1207*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1208*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1209*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1210*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))
1211*523fa7a6SAndroid Build Coastguard Worker
1212*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1213*523fa7a6SAndroid Build Coastguard Worker                b = self.a + self.a
1214*523fa7a6SAndroid Build Coastguard Worker                c = torch.cat([self.a, b])
1215*523fa7a6SAndroid Build Coastguard Worker                return (c + c) + x
1216*523fa7a6SAndroid Build Coastguard Worker
1217*523fa7a6SAndroid Build Coastguard Worker        aten = export(
1218*523fa7a6SAndroid Build Coastguard Worker            M(),
1219*523fa7a6SAndroid Build Coastguard Worker            (torch.zeros(2, 2, 3),),
1220*523fa7a6SAndroid Build Coastguard Worker        )
1221*523fa7a6SAndroid Build Coastguard Worker        # Input signature will have two entries:
1222*523fa7a6SAndroid Build Coastguard Worker        # (1) parameter `a` and (2) user input `x`.
1223*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(aten.graph_signature.input_specs), 2)
1224*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(aten)
1225*523fa7a6SAndroid Build Coastguard Worker        # Check that there are exactly two propagated tensors - (1) propagated
1226*523fa7a6SAndroid Build Coastguard Worker        # constant and (2) user input.
1227*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
1228*523fa7a6SAndroid Build Coastguard Worker            new_ep.graph_signature.input_specs,
1229*523fa7a6SAndroid Build Coastguard Worker            [
1230*523fa7a6SAndroid Build Coastguard Worker                InputSpec(
1231*523fa7a6SAndroid Build Coastguard Worker                    kind=InputKind.CONSTANT_TENSOR,
1232*523fa7a6SAndroid Build Coastguard Worker                    arg=TensorArgument(name="_prop_tensor_constant0"),
1233*523fa7a6SAndroid Build Coastguard Worker                    target="_prop_tensor_constant0",
1234*523fa7a6SAndroid Build Coastguard Worker                    persistent=True,
1235*523fa7a6SAndroid Build Coastguard Worker                ),
1236*523fa7a6SAndroid Build Coastguard Worker                # User input graph signature.
1237*523fa7a6SAndroid Build Coastguard Worker                aten.graph_signature.input_specs[-1],
1238*523fa7a6SAndroid Build Coastguard Worker            ],
1239*523fa7a6SAndroid Build Coastguard Worker        )
1240*523fa7a6SAndroid Build Coastguard Worker
1241*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_for_parameter_slice(self) -> None:
1242*523fa7a6SAndroid Build Coastguard Worker        def count_slice(gm: torch.fx.GraphModule) -> int:
1243*523fa7a6SAndroid Build Coastguard Worker            return sum(
1244*523fa7a6SAndroid Build Coastguard Worker                (node.target == torch.ops.aten.slice_copy.Tensor)
1245*523fa7a6SAndroid Build Coastguard Worker                for node in gm.graph.nodes
1246*523fa7a6SAndroid Build Coastguard Worker            )
1247*523fa7a6SAndroid Build Coastguard Worker
1248*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1249*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1250*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1251*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.nn.Parameter(torch.ones(3, 2, 2))
1252*523fa7a6SAndroid Build Coastguard Worker
1253*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1254*523fa7a6SAndroid Build Coastguard Worker                # Create slice of shape (1, 2, 2)
1255*523fa7a6SAndroid Build Coastguard Worker                slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1)
1256*523fa7a6SAndroid Build Coastguard Worker                return torch.cat([x, slice_tensor])
1257*523fa7a6SAndroid Build Coastguard Worker
1258*523fa7a6SAndroid Build Coastguard Worker        aten = export(
1259*523fa7a6SAndroid Build Coastguard Worker            M(),
1260*523fa7a6SAndroid Build Coastguard Worker            (torch.zeros(2, 2, 2),),
1261*523fa7a6SAndroid Build Coastguard Worker        )
1262*523fa7a6SAndroid Build Coastguard Worker        self.assertIn("a", aten.state_dict)
1263*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_slice(aten.graph_module), 1)
1264*523fa7a6SAndroid Build Coastguard Worker
1265*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(aten)
1266*523fa7a6SAndroid Build Coastguard Worker        # Check there is a propagated tensor.
1267*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code)
1268*523fa7a6SAndroid Build Coastguard Worker        self.assertIn("_prop_tensor_constant0", new_ep.constants)
1269*523fa7a6SAndroid Build Coastguard Worker        self.assertNotIn("a", new_ep.state_dict)
1270*523fa7a6SAndroid Build Coastguard Worker        # No more slice copy.
1271*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_slice(new_ep.graph_module), 0)
1272*523fa7a6SAndroid Build Coastguard Worker
1273*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_no_propagate(self) -> None:
1274*523fa7a6SAndroid Build Coastguard Worker        def count_placeholder(gm: torch.fx.GraphModule) -> int:
1275*523fa7a6SAndroid Build Coastguard Worker            return sum((node.op == "placeholder") for node in gm.graph.nodes)
1276*523fa7a6SAndroid Build Coastguard Worker
1277*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1278*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1279*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1280*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.nn.Parameter(torch.ones(3, 2, 4))
1281*523fa7a6SAndroid Build Coastguard Worker
1282*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
1283*523fa7a6SAndroid Build Coastguard Worker                # y is unused.
1284*523fa7a6SAndroid Build Coastguard Worker                return x + self.a
1285*523fa7a6SAndroid Build Coastguard Worker
1286*523fa7a6SAndroid Build Coastguard Worker        aten = export(
1287*523fa7a6SAndroid Build Coastguard Worker            M(),
1288*523fa7a6SAndroid Build Coastguard Worker            (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)),
1289*523fa7a6SAndroid Build Coastguard Worker        )
1290*523fa7a6SAndroid Build Coastguard Worker        self.assertIn("a", aten.state_dict)
1291*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_placeholder(aten.graph_module), 3)
1292*523fa7a6SAndroid Build Coastguard Worker
1293*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(aten)
1294*523fa7a6SAndroid Build Coastguard Worker        # Check there is no propagated tensor.
1295*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code)
1296*523fa7a6SAndroid Build Coastguard Worker        self.assertNotIn("_prop_tensor_constant0", new_ep.constants)
1297*523fa7a6SAndroid Build Coastguard Worker        self.assertIn("a", new_ep.state_dict)
1298*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_placeholder(new_ep.graph_module), 3)
1299*523fa7a6SAndroid Build Coastguard Worker
1300*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_for_control_flow(self) -> None:
1301*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
1302*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1303*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1304*523fa7a6SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(3, 3)
1305*523fa7a6SAndroid Build Coastguard Worker
1306*523fa7a6SAndroid Build Coastguard Worker            def t(self, val):
1307*523fa7a6SAndroid Build Coastguard Worker                return val + 1
1308*523fa7a6SAndroid Build Coastguard Worker
1309*523fa7a6SAndroid Build Coastguard Worker            def f(self, val):
1310*523fa7a6SAndroid Build Coastguard Worker                return val - 1
1311*523fa7a6SAndroid Build Coastguard Worker
1312*523fa7a6SAndroid Build Coastguard Worker            def true_fn(self, val):
1313*523fa7a6SAndroid Build Coastguard Worker                return self.linear(val) + self.t(val)
1314*523fa7a6SAndroid Build Coastguard Worker
1315*523fa7a6SAndroid Build Coastguard Worker            def false_fn(self, val):
1316*523fa7a6SAndroid Build Coastguard Worker                return self.linear(val) - self.f(val)
1317*523fa7a6SAndroid Build Coastguard Worker
1318*523fa7a6SAndroid Build Coastguard Worker            def forward(self, pred, x):
1319*523fa7a6SAndroid Build Coastguard Worker                return torch.ops.higher_order.cond(
1320*523fa7a6SAndroid Build Coastguard Worker                    pred, self.true_fn, self.false_fn, [x]
1321*523fa7a6SAndroid Build Coastguard Worker                )
1322*523fa7a6SAndroid Build Coastguard Worker
1323*523fa7a6SAndroid Build Coastguard Worker        mod = Module()
1324*523fa7a6SAndroid Build Coastguard Worker        x = torch.randn([3, 3])
1325*523fa7a6SAndroid Build Coastguard Worker        pred = torch.tensor(x[0][0].item() < 0)
1326*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(
1327*523fa7a6SAndroid Build Coastguard Worker            export(mod, (pred, x)),
1328*523fa7a6SAndroid Build Coastguard Worker            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1329*523fa7a6SAndroid Build Coastguard Worker        )
1330*523fa7a6SAndroid Build Coastguard Worker        error_msg = r"constant_prop_pass for control flow is not supported yet."
1331*523fa7a6SAndroid Build Coastguard Worker
1332*523fa7a6SAndroid Build Coastguard Worker        # TODO(chenlai): enable constant prop pass for control flow
1333*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1334*523fa7a6SAndroid Build Coastguard Worker            RuntimeError,
1335*523fa7a6SAndroid Build Coastguard Worker            error_msg,
1336*523fa7a6SAndroid Build Coastguard Worker        ):
1337*523fa7a6SAndroid Build Coastguard Worker            _ = constant_prop_pass(edge.exported_program())
1338*523fa7a6SAndroid Build Coastguard Worker
1339*523fa7a6SAndroid Build Coastguard Worker    def test_mutable_buffers(self) -> None:
1340*523fa7a6SAndroid Build Coastguard Worker        def count_copies(gm: torch.fx.GraphModule) -> int:
1341*523fa7a6SAndroid Build Coastguard Worker            return sum(
1342*523fa7a6SAndroid Build Coastguard Worker                (node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes
1343*523fa7a6SAndroid Build Coastguard Worker            )
1344*523fa7a6SAndroid Build Coastguard Worker
1345*523fa7a6SAndroid Build Coastguard Worker        class MutableStateModule(torch.nn.Module):
1346*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1347*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1348*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("state", torch.zeros(1))
1349*523fa7a6SAndroid Build Coastguard Worker
1350*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1351*523fa7a6SAndroid Build Coastguard Worker                y = x + self.state
1352*523fa7a6SAndroid Build Coastguard Worker                self.state.add_(1)
1353*523fa7a6SAndroid Build Coastguard Worker                return y
1354*523fa7a6SAndroid Build Coastguard Worker
1355*523fa7a6SAndroid Build Coastguard Worker        model = to_edge(
1356*523fa7a6SAndroid Build Coastguard Worker            export(
1357*523fa7a6SAndroid Build Coastguard Worker                MutableStateModule(),
1358*523fa7a6SAndroid Build Coastguard Worker                (torch.zeros(1),),
1359*523fa7a6SAndroid Build Coastguard Worker            )
1360*523fa7a6SAndroid Build Coastguard Worker        )
1361*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_copies(model.exported_program().graph_module), 0)
1362*523fa7a6SAndroid Build Coastguard Worker        # Before
1363*523fa7a6SAndroid Build Coastguard Worker        # graph():
1364*523fa7a6SAndroid Build Coastguard Worker        #     %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1365*523fa7a6SAndroid Build Coastguard Worker        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1366*523fa7a6SAndroid Build Coastguard Worker        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1367*523fa7a6SAndroid Build Coastguard Worker        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1368*523fa7a6SAndroid Build Coastguard Worker        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1369*523fa7a6SAndroid Build Coastguard Worker        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1370*523fa7a6SAndroid Build Coastguard Worker        #     return (aten_add_tensor_1, aten_add_tensor)
1371*523fa7a6SAndroid Build Coastguard Worker        gm, _ = insert_write_back_for_buffers_pass(model.exported_program())
1372*523fa7a6SAndroid Build Coastguard Worker
1373*523fa7a6SAndroid Build Coastguard Worker        # After
1374*523fa7a6SAndroid Build Coastguard Worker        # graph():
1375*523fa7a6SAndroid Build Coastguard Worker        #     %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1376*523fa7a6SAndroid Build Coastguard Worker        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1377*523fa7a6SAndroid Build Coastguard Worker        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1378*523fa7a6SAndroid Build Coastguard Worker        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1379*523fa7a6SAndroid Build Coastguard Worker        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1380*523fa7a6SAndroid Build Coastguard Worker        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1381*523fa7a6SAndroid Build Coastguard Worker        #     %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1382*523fa7a6SAndroid Build Coastguard Worker        #     return (copy__default, aten_add_tensor)
1383*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(count_copies(gm), 1)
1384*523fa7a6SAndroid Build Coastguard Worker
1385*523fa7a6SAndroid Build Coastguard Worker    def test_remove_quantized_op_noop_pass(self) -> None:
1386*523fa7a6SAndroid Build Coastguard Worker        class TestAddSliceNoop(torch.nn.Module):
1387*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1388*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1389*523fa7a6SAndroid Build Coastguard Worker
1390*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1391*523fa7a6SAndroid Build Coastguard Worker                x = x + x
1392*523fa7a6SAndroid Build Coastguard Worker                x = x + x[:]
1393*523fa7a6SAndroid Build Coastguard Worker                return x
1394*523fa7a6SAndroid Build Coastguard Worker
1395*523fa7a6SAndroid Build Coastguard Worker        class TestAddSliceNotNoop(torch.nn.Module):
1396*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1397*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1398*523fa7a6SAndroid Build Coastguard Worker
1399*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1400*523fa7a6SAndroid Build Coastguard Worker                x = x + x
1401*523fa7a6SAndroid Build Coastguard Worker                x = x + x[:1]
1402*523fa7a6SAndroid Build Coastguard Worker                return x
1403*523fa7a6SAndroid Build Coastguard Worker
1404*523fa7a6SAndroid Build Coastguard Worker        def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
1405*523fa7a6SAndroid Build Coastguard Worker            return sum(
1406*523fa7a6SAndroid Build Coastguard Worker                (
1407*523fa7a6SAndroid Build Coastguard Worker                    node.target
1408*523fa7a6SAndroid Build Coastguard Worker                    in (
1409*523fa7a6SAndroid Build Coastguard Worker                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1410*523fa7a6SAndroid Build Coastguard Worker                        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1411*523fa7a6SAndroid Build Coastguard Worker                    )
1412*523fa7a6SAndroid Build Coastguard Worker                )
1413*523fa7a6SAndroid Build Coastguard Worker                for node in gm.graph.nodes
1414*523fa7a6SAndroid Build Coastguard Worker            )
1415*523fa7a6SAndroid Build Coastguard Worker
1416*523fa7a6SAndroid Build Coastguard Worker        def count_q_nodes(gm: torch.fx.GraphModule) -> int:
1417*523fa7a6SAndroid Build Coastguard Worker            return sum(
1418*523fa7a6SAndroid Build Coastguard Worker                (
1419*523fa7a6SAndroid Build Coastguard Worker                    node.target
1420*523fa7a6SAndroid Build Coastguard Worker                    in (
1421*523fa7a6SAndroid Build Coastguard Worker                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
1422*523fa7a6SAndroid Build Coastguard Worker                        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1423*523fa7a6SAndroid Build Coastguard Worker                    )
1424*523fa7a6SAndroid Build Coastguard Worker                )
1425*523fa7a6SAndroid Build Coastguard Worker                for node in gm.graph.nodes
1426*523fa7a6SAndroid Build Coastguard Worker            )
1427*523fa7a6SAndroid Build Coastguard Worker
1428*523fa7a6SAndroid Build Coastguard Worker        def quantize_model(
1429*523fa7a6SAndroid Build Coastguard Worker            m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
1430*523fa7a6SAndroid Build Coastguard Worker        ) -> Tuple[EdgeProgramManager, int, int]:
1431*523fa7a6SAndroid Build Coastguard Worker            # program capture
1432*523fa7a6SAndroid Build Coastguard Worker            m = torch.export.export_for_training(
1433*523fa7a6SAndroid Build Coastguard Worker                m_eager,
1434*523fa7a6SAndroid Build Coastguard Worker                example_inputs,
1435*523fa7a6SAndroid Build Coastguard Worker            ).module()
1436*523fa7a6SAndroid Build Coastguard Worker
1437*523fa7a6SAndroid Build Coastguard Worker            quantizer = XNNPACKQuantizer()
1438*523fa7a6SAndroid Build Coastguard Worker            quantization_config = get_symmetric_quantization_config()
1439*523fa7a6SAndroid Build Coastguard Worker            quantizer.set_global(quantization_config)
1440*523fa7a6SAndroid Build Coastguard Worker            m = prepare_pt2e(m, quantizer)  # pyre-fixme[6]
1441*523fa7a6SAndroid Build Coastguard Worker            m = convert_pt2e(m, fold_quantize=True)
1442*523fa7a6SAndroid Build Coastguard Worker            ep = torch.export.export(m, example_inputs)
1443*523fa7a6SAndroid Build Coastguard Worker            dq_nodes_pre = count_dq_nodes(ep.graph_module)
1444*523fa7a6SAndroid Build Coastguard Worker            q_nodes_pre = count_q_nodes(ep.graph_module)
1445*523fa7a6SAndroid Build Coastguard Worker            edge = to_edge(
1446*523fa7a6SAndroid Build Coastguard Worker                ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
1447*523fa7a6SAndroid Build Coastguard Worker            )
1448*523fa7a6SAndroid Build Coastguard Worker            return edge, dq_nodes_pre, q_nodes_pre
1449*523fa7a6SAndroid Build Coastguard Worker
1450*523fa7a6SAndroid Build Coastguard Worker        example_inputs = (torch.randn(9, 8),)
1451*523fa7a6SAndroid Build Coastguard Worker        model = TestAddSliceNoop()
1452*523fa7a6SAndroid Build Coastguard Worker        m_eager = model.eval()
1453*523fa7a6SAndroid Build Coastguard Worker        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1454*523fa7a6SAndroid Build Coastguard Worker
1455*523fa7a6SAndroid Build Coastguard Worker        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1456*523fa7a6SAndroid Build Coastguard Worker        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1457*523fa7a6SAndroid Build Coastguard Worker        # One dq and one q node around the slice copy should have been removed.
1458*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
1459*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(q_nodes_pre - q_nodes_post, 1)
1460*523fa7a6SAndroid Build Coastguard Worker
1461*523fa7a6SAndroid Build Coastguard Worker        # Check that the slice_copy is removed by the RemoveNoopPass.
1462*523fa7a6SAndroid Build Coastguard Worker        for node in edge.exported_program().graph_module.graph.nodes:
1463*523fa7a6SAndroid Build Coastguard Worker            self.assertFalse("slice" in str(node.target))
1464*523fa7a6SAndroid Build Coastguard Worker
1465*523fa7a6SAndroid Build Coastguard Worker        model = TestAddSliceNotNoop()
1466*523fa7a6SAndroid Build Coastguard Worker        m_eager = model.eval()
1467*523fa7a6SAndroid Build Coastguard Worker        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)
1468*523fa7a6SAndroid Build Coastguard Worker
1469*523fa7a6SAndroid Build Coastguard Worker        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
1470*523fa7a6SAndroid Build Coastguard Worker        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
1471*523fa7a6SAndroid Build Coastguard Worker        # One dq and one q node around the slice copy should have been removed.
1472*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(dq_nodes_pre, dq_nodes_post)
1473*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(q_nodes_pre, q_nodes_post)
1474*523fa7a6SAndroid Build Coastguard Worker
1475*523fa7a6SAndroid Build Coastguard Worker        # Check that the slice_copy is not removed by the RemoveNoopPass.
1476*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1477*523fa7a6SAndroid Build Coastguard Worker            any(
1478*523fa7a6SAndroid Build Coastguard Worker                "slice" in str(node.target)
1479*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1480*523fa7a6SAndroid Build Coastguard Worker            )
1481*523fa7a6SAndroid Build Coastguard Worker        )
1482*523fa7a6SAndroid Build Coastguard Worker
1483*523fa7a6SAndroid Build Coastguard Worker    def test_dq_q_no_op_pass(self) -> None:
1484*523fa7a6SAndroid Build Coastguard Worker        class TestDqQ(torch.nn.Module):
1485*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1486*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1487*523fa7a6SAndroid Build Coastguard Worker
1488*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1489*523fa7a6SAndroid Build Coastguard Worker                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1490*523fa7a6SAndroid Build Coastguard Worker                    x, 1.0, 0, -128, 127, torch.int8
1491*523fa7a6SAndroid Build Coastguard Worker                )
1492*523fa7a6SAndroid Build Coastguard Worker                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1493*523fa7a6SAndroid Build Coastguard Worker                    dq, 1.0, 0, -128, 127, torch.int8
1494*523fa7a6SAndroid Build Coastguard Worker                )
1495*523fa7a6SAndroid Build Coastguard Worker                return q
1496*523fa7a6SAndroid Build Coastguard Worker
1497*523fa7a6SAndroid Build Coastguard Worker        model = TestDqQ()
1498*523fa7a6SAndroid Build Coastguard Worker        m_eager = model.eval()
1499*523fa7a6SAndroid Build Coastguard Worker        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1500*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(ep)
1501*523fa7a6SAndroid Build Coastguard Worker        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
1502*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1503*523fa7a6SAndroid Build Coastguard Worker            any(
1504*523fa7a6SAndroid Build Coastguard Worker                "dequantize" in str(node.target)
1505*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1506*523fa7a6SAndroid Build Coastguard Worker            )
1507*523fa7a6SAndroid Build Coastguard Worker        )
1508*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1509*523fa7a6SAndroid Build Coastguard Worker            any(
1510*523fa7a6SAndroid Build Coastguard Worker                "quantize" in str(node.target)
1511*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1512*523fa7a6SAndroid Build Coastguard Worker            )
1513*523fa7a6SAndroid Build Coastguard Worker        )
1514*523fa7a6SAndroid Build Coastguard Worker
1515*523fa7a6SAndroid Build Coastguard Worker    def test_dq_q_different_qparams(self) -> None:
1516*523fa7a6SAndroid Build Coastguard Worker        class TestDqQDifferentQParam(torch.nn.Module):
1517*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1518*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1519*523fa7a6SAndroid Build Coastguard Worker
1520*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1521*523fa7a6SAndroid Build Coastguard Worker                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1522*523fa7a6SAndroid Build Coastguard Worker                    x, 1.0, 0, -128, 127, torch.int8
1523*523fa7a6SAndroid Build Coastguard Worker                )
1524*523fa7a6SAndroid Build Coastguard Worker                slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0)
1525*523fa7a6SAndroid Build Coastguard Worker                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
1526*523fa7a6SAndroid Build Coastguard Worker                    slice_copy_output, 1.0, 0, -127, 127, torch.int8
1527*523fa7a6SAndroid Build Coastguard Worker                )
1528*523fa7a6SAndroid Build Coastguard Worker                return q
1529*523fa7a6SAndroid Build Coastguard Worker
1530*523fa7a6SAndroid Build Coastguard Worker        model = TestDqQDifferentQParam()
1531*523fa7a6SAndroid Build Coastguard Worker        m_eager = model.eval()
1532*523fa7a6SAndroid Build Coastguard Worker        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
1533*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(ep)
1534*523fa7a6SAndroid Build Coastguard Worker        print(edge.exported_program().graph_module.graph)
1535*523fa7a6SAndroid Build Coastguard Worker        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
1536*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1537*523fa7a6SAndroid Build Coastguard Worker            any(
1538*523fa7a6SAndroid Build Coastguard Worker                "dequantize" in str(node.target)
1539*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1540*523fa7a6SAndroid Build Coastguard Worker            )
1541*523fa7a6SAndroid Build Coastguard Worker        )
1542*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1543*523fa7a6SAndroid Build Coastguard Worker            any(
1544*523fa7a6SAndroid Build Coastguard Worker                "quantize" in str(node.target)
1545*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1546*523fa7a6SAndroid Build Coastguard Worker            )
1547*523fa7a6SAndroid Build Coastguard Worker        )
1548*523fa7a6SAndroid Build Coastguard Worker        self.assertFalse(
1549*523fa7a6SAndroid Build Coastguard Worker            any(
1550*523fa7a6SAndroid Build Coastguard Worker                "slice" in str(node.target)
1551*523fa7a6SAndroid Build Coastguard Worker                for node in edge.exported_program().graph_module.graph.nodes
1552*523fa7a6SAndroid Build Coastguard Worker            )
1553*523fa7a6SAndroid Build Coastguard Worker        )
1554*523fa7a6SAndroid Build Coastguard Worker
1555*523fa7a6SAndroid Build Coastguard Worker    def test_normalize_view_copy_base_pass(self) -> None:
1556*523fa7a6SAndroid Build Coastguard Worker
1557*523fa7a6SAndroid Build Coastguard Worker        class ViewChain(torch.nn.Module):
1558*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1559*523fa7a6SAndroid Build Coastguard Worker                x = torch.ops.aten.view_copy.default(x, [30, 1])
1560*523fa7a6SAndroid Build Coastguard Worker                x = torch.ops.aten.view_copy.default(x, [5, 6])
1561*523fa7a6SAndroid Build Coastguard Worker                x = torch.ops.aten.view_copy.default(x, [2, 15])
1562*523fa7a6SAndroid Build Coastguard Worker                x = torch.ops.aten.view_copy.default(x, [3, -1])
1563*523fa7a6SAndroid Build Coastguard Worker                return x
1564*523fa7a6SAndroid Build Coastguard Worker
1565*523fa7a6SAndroid Build Coastguard Worker        def is_view_copy(node: torch.fx.Node) -> bool:
1566*523fa7a6SAndroid Build Coastguard Worker            return (
1567*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
1568*523fa7a6SAndroid Build Coastguard Worker                and node.target == torch.ops.aten.view_copy.default
1569*523fa7a6SAndroid Build Coastguard Worker            )
1570*523fa7a6SAndroid Build Coastguard Worker
1571*523fa7a6SAndroid Build Coastguard Worker        gm = export(ViewChain(), (torch.ones(30),)).graph_module
1572*523fa7a6SAndroid Build Coastguard Worker
1573*523fa7a6SAndroid Build Coastguard Worker        # Check before transformation
1574*523fa7a6SAndroid Build Coastguard Worker        n_view_copy_before = 0
1575*523fa7a6SAndroid Build Coastguard Worker        n_view_copy_bases_before = 0
1576*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1577*523fa7a6SAndroid Build Coastguard Worker            if is_view_copy(node):
1578*523fa7a6SAndroid Build Coastguard Worker                n_view_copy_before += 1
1579*523fa7a6SAndroid Build Coastguard Worker                base = node.args[0]
1580*523fa7a6SAndroid Build Coastguard Worker                if is_view_copy(base):
1581*523fa7a6SAndroid Build Coastguard Worker                    n_view_copy_bases_before += 1
1582*523fa7a6SAndroid Build Coastguard Worker
1583*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(n_view_copy_before, 4)
1584*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(n_view_copy_bases_before, 3)
1585*523fa7a6SAndroid Build Coastguard Worker
1586*523fa7a6SAndroid Build Coastguard Worker        # Do transformation
1587*523fa7a6SAndroid Build Coastguard Worker        p = NormalizeViewCopyBasePass()
1588*523fa7a6SAndroid Build Coastguard Worker        gm_res = p(gm)
1589*523fa7a6SAndroid Build Coastguard Worker        assert gm_res is not None
1590*523fa7a6SAndroid Build Coastguard Worker        gm = gm_res.graph_module
1591*523fa7a6SAndroid Build Coastguard Worker
1592*523fa7a6SAndroid Build Coastguard Worker        # Check after transformation
1593*523fa7a6SAndroid Build Coastguard Worker        n_view_copy_after = 0
1594*523fa7a6SAndroid Build Coastguard Worker        n_view_copy_bases_after = 0
1595*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1596*523fa7a6SAndroid Build Coastguard Worker            if is_view_copy(node):
1597*523fa7a6SAndroid Build Coastguard Worker                n_view_copy_after += 1
1598*523fa7a6SAndroid Build Coastguard Worker                base = node.args[0]
1599*523fa7a6SAndroid Build Coastguard Worker                if is_view_copy(base):
1600*523fa7a6SAndroid Build Coastguard Worker                    n_view_copy_bases_after += 1
1601*523fa7a6SAndroid Build Coastguard Worker
1602*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(n_view_copy_after, 4)
1603*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(n_view_copy_bases_after, 0)
1604*523fa7a6SAndroid Build Coastguard Worker
1605*523fa7a6SAndroid Build Coastguard Worker    def test_replace_view_copy_with_view_pass(self) -> None:  # noqa: C901
1606*523fa7a6SAndroid Build Coastguard Worker
1607*523fa7a6SAndroid Build Coastguard Worker        # Helper functions
1608*523fa7a6SAndroid Build Coastguard Worker        def is_view_copy(node: torch.fx.Node) -> bool:
1609*523fa7a6SAndroid Build Coastguard Worker            return (
1610*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
1611*523fa7a6SAndroid Build Coastguard Worker                and node.target == torch.ops.aten.view_copy.default
1612*523fa7a6SAndroid Build Coastguard Worker            )
1613*523fa7a6SAndroid Build Coastguard Worker
1614*523fa7a6SAndroid Build Coastguard Worker        def is_memory_view(node: torch.fx.Node) -> bool:
1615*523fa7a6SAndroid Build Coastguard Worker            return node.op == "call_function" and node.target == memory.view
1616*523fa7a6SAndroid Build Coastguard Worker
1617*523fa7a6SAndroid Build Coastguard Worker        # Test example set up
1618*523fa7a6SAndroid Build Coastguard Worker        class TestViewCopies(torch.nn.Module):
1619*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1620*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1621*523fa7a6SAndroid Build Coastguard Worker                self.parameter = torch.nn.Parameter(torch.ones(1))
1622*523fa7a6SAndroid Build Coastguard Worker
1623*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
1624*523fa7a6SAndroid Build Coastguard Worker                o1 = torch.ops.aten.view_copy.default(x, [1])
1625*523fa7a6SAndroid Build Coastguard Worker                o2 = torch.ops.aten.view_copy.default(self.parameter, [1])
1626*523fa7a6SAndroid Build Coastguard Worker                # view_copys at the end of a function are not replaced, so add
1627*523fa7a6SAndroid Build Coastguard Worker                # a computation before the end of the graph.
1628*523fa7a6SAndroid Build Coastguard Worker                return torch.ops.aten.add.Tensor(o1, o2)
1629*523fa7a6SAndroid Build Coastguard Worker
1630*523fa7a6SAndroid Build Coastguard Worker        ep = torch.export.export(
1631*523fa7a6SAndroid Build Coastguard Worker            TestViewCopies(),
1632*523fa7a6SAndroid Build Coastguard Worker            args=(torch.ones(1),),
1633*523fa7a6SAndroid Build Coastguard Worker        )
1634*523fa7a6SAndroid Build Coastguard Worker        for node in ep.graph.nodes:
1635*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder":
1636*523fa7a6SAndroid Build Coastguard Worker                node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1637*523fa7a6SAndroid Build Coastguard Worker                node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC
1638*523fa7a6SAndroid Build Coastguard Worker
1639*523fa7a6SAndroid Build Coastguard Worker        # Run tests
1640*523fa7a6SAndroid Build Coastguard Worker        gm = ep.graph_module
1641*523fa7a6SAndroid Build Coastguard Worker
1642*523fa7a6SAndroid Build Coastguard Worker        # Check before transformation
1643*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count(
1644*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.view_copy.default", 2, exactly=True
1645*523fa7a6SAndroid Build Coastguard Worker        ).run(gm.code)
1646*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run(
1647*523fa7a6SAndroid Build Coastguard Worker            gm.code
1648*523fa7a6SAndroid Build Coastguard Worker        )
1649*523fa7a6SAndroid Build Coastguard Worker
1650*523fa7a6SAndroid Build Coastguard Worker        # Do transformation
1651*523fa7a6SAndroid Build Coastguard Worker        p = ReplaceViewCopyWithViewPass()
1652*523fa7a6SAndroid Build Coastguard Worker        gm_res = p(gm)
1653*523fa7a6SAndroid Build Coastguard Worker        assert gm_res is not None
1654*523fa7a6SAndroid Build Coastguard Worker        gm = gm_res.graph_module
1655*523fa7a6SAndroid Build Coastguard Worker
1656*523fa7a6SAndroid Build Coastguard Worker        # Check after transformation
1657*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count(
1658*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.view_copy.default", 0, exactly=True
1659*523fa7a6SAndroid Build Coastguard Worker        ).run(gm.code)
1660*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
1661*523fa7a6SAndroid Build Coastguard Worker            gm.code
1662*523fa7a6SAndroid Build Coastguard Worker        )
1663*523fa7a6SAndroid Build Coastguard Worker
1664*523fa7a6SAndroid Build Coastguard Worker    def test_constant_prop_pass_for_no_grad(self) -> None:
1665*523fa7a6SAndroid Build Coastguard Worker        class LSTM(torch.nn.Module):
1666*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, input_size, hidden_size, num_layers):
1667*523fa7a6SAndroid Build Coastguard Worker                super(LSTM, self).__init__()
1668*523fa7a6SAndroid Build Coastguard Worker                self.hidden_size = hidden_size
1669*523fa7a6SAndroid Build Coastguard Worker                self.num_layers = num_layers
1670*523fa7a6SAndroid Build Coastguard Worker                self.lstm = torch.nn.LSTM(
1671*523fa7a6SAndroid Build Coastguard Worker                    input_size, hidden_size, num_layers, batch_first=True
1672*523fa7a6SAndroid Build Coastguard Worker                )
1673*523fa7a6SAndroid Build Coastguard Worker
1674*523fa7a6SAndroid Build Coastguard Worker            def forward(self, text_tokens):
1675*523fa7a6SAndroid Build Coastguard Worker                # input: (seq_len, batch, input_size)
1676*523fa7a6SAndroid Build Coastguard Worker                lstm_out, (new_hidden_state, new_cell_state) = self.lstm(
1677*523fa7a6SAndroid Build Coastguard Worker                    input=text_tokens, hx=None
1678*523fa7a6SAndroid Build Coastguard Worker                )
1679*523fa7a6SAndroid Build Coastguard Worker                return lstm_out
1680*523fa7a6SAndroid Build Coastguard Worker
1681*523fa7a6SAndroid Build Coastguard Worker        lstm = LSTM(input_size=200, hidden_size=203, num_layers=2)
1682*523fa7a6SAndroid Build Coastguard Worker        example_input = (torch.rand(2, 10, 200),)
1683*523fa7a6SAndroid Build Coastguard Worker
1684*523fa7a6SAndroid Build Coastguard Worker        aten = torch.export.export(lstm, example_input, strict=False)
1685*523fa7a6SAndroid Build Coastguard Worker        _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
1686*523fa7a6SAndroid Build Coastguard Worker            _check_ir_validity=True,
1687*523fa7a6SAndroid Build Coastguard Worker            _skip_dim_order=True,  # TODO(T189114319): Reuse dim order op after solving the ios oss issue
1688*523fa7a6SAndroid Build Coastguard Worker        )
1689*523fa7a6SAndroid Build Coastguard Worker
1690*523fa7a6SAndroid Build Coastguard Worker        edge_manager: EdgeProgramManager = to_edge(
1691*523fa7a6SAndroid Build Coastguard Worker            aten,
1692*523fa7a6SAndroid Build Coastguard Worker            compile_config=_EDGE_COMPILE_CONFIG,
1693*523fa7a6SAndroid Build Coastguard Worker        )
1694*523fa7a6SAndroid Build Coastguard Worker        new_ep = constant_prop_pass(edge_manager._edge_programs["forward"])
1695*523fa7a6SAndroid Build Coastguard Worker        _ = copy.deepcopy(new_ep.module_call_graph)
1696*523fa7a6SAndroid Build Coastguard Worker
1697*523fa7a6SAndroid Build Coastguard Worker    def test_dim_order_revert_pass(self) -> None:
1698*523fa7a6SAndroid Build Coastguard Worker        aten_op_str = "torch.ops.aten._to_copy.default"
1699*523fa7a6SAndroid Build Coastguard Worker        edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default"
1700*523fa7a6SAndroid Build Coastguard Worker        edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
1701*523fa7a6SAndroid Build Coastguard Worker
1702*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
1703*523fa7a6SAndroid Build Coastguard Worker            """
1704*523fa7a6SAndroid Build Coastguard Worker            A simple module that has a single to op that converts to channels last and then back to contiguous.
1705*523fa7a6SAndroid Build Coastguard Worker            Assuming contiguous input.
1706*523fa7a6SAndroid Build Coastguard Worker            """
1707*523fa7a6SAndroid Build Coastguard Worker
1708*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
1709*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
1710*523fa7a6SAndroid Build Coastguard Worker
1711*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1712*523fa7a6SAndroid Build Coastguard Worker                return x.to(memory_format=torch.channels_last).to(
1713*523fa7a6SAndroid Build Coastguard Worker                    memory_format=torch.contiguous_format
1714*523fa7a6SAndroid Build Coastguard Worker                ) + x.to(memory_format=torch.channels_last).to(
1715*523fa7a6SAndroid Build Coastguard Worker                    memory_format=torch.contiguous_format
1716*523fa7a6SAndroid Build Coastguard Worker                )
1717*523fa7a6SAndroid Build Coastguard Worker
1718*523fa7a6SAndroid Build Coastguard Worker            @staticmethod
1719*523fa7a6SAndroid Build Coastguard Worker            def to_copy_count():
1720*523fa7a6SAndroid Build Coastguard Worker                return 4
1721*523fa7a6SAndroid Build Coastguard Worker
1722*523fa7a6SAndroid Build Coastguard Worker        def _do_checks(
1723*523fa7a6SAndroid Build Coastguard Worker            test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str]
1724*523fa7a6SAndroid Build Coastguard Worker        ) -> None:
1725*523fa7a6SAndroid Build Coastguard Worker            for not_allowed in not_allowed_list:
1726*523fa7a6SAndroid Build Coastguard Worker                FileCheck().check_count(allowed, allowed_count, exactly=True).check_not(
1727*523fa7a6SAndroid Build Coastguard Worker                    not_allowed
1728*523fa7a6SAndroid Build Coastguard Worker                ).run(test_str)
1729*523fa7a6SAndroid Build Coastguard Worker
1730*523fa7a6SAndroid Build Coastguard Worker        m = Module()
1731*523fa7a6SAndroid Build Coastguard Worker        n = m.to_copy_count()
1732*523fa7a6SAndroid Build Coastguard Worker        input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format)
1733*523fa7a6SAndroid Build Coastguard Worker
1734*523fa7a6SAndroid Build Coastguard Worker        # 1. vanilla export, no edge ops
1735*523fa7a6SAndroid Build Coastguard Worker        ep = export(
1736*523fa7a6SAndroid Build Coastguard Worker            m,
1737*523fa7a6SAndroid Build Coastguard Worker            (input,),
1738*523fa7a6SAndroid Build Coastguard Worker        ).run_decompositions({})
1739*523fa7a6SAndroid Build Coastguard Worker        _do_checks(
1740*523fa7a6SAndroid Build Coastguard Worker            ep.graph_module.code,
1741*523fa7a6SAndroid Build Coastguard Worker            aten_op_str,
1742*523fa7a6SAndroid Build Coastguard Worker            n,
1743*523fa7a6SAndroid Build Coastguard Worker            [edge_aten_op_str, edge_dim_order_op_str],
1744*523fa7a6SAndroid Build Coastguard Worker        )
1745*523fa7a6SAndroid Build Coastguard Worker
1746*523fa7a6SAndroid Build Coastguard Worker        # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops
1747*523fa7a6SAndroid Build Coastguard Worker        edge_prog = to_edge(
1748*523fa7a6SAndroid Build Coastguard Worker            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True)
1749*523fa7a6SAndroid Build Coastguard Worker        )._edge_programs["forward"]
1750*523fa7a6SAndroid Build Coastguard Worker        _do_checks(
1751*523fa7a6SAndroid Build Coastguard Worker            edge_prog.graph_module.code,
1752*523fa7a6SAndroid Build Coastguard Worker            edge_aten_op_str,
1753*523fa7a6SAndroid Build Coastguard Worker            n,
1754*523fa7a6SAndroid Build Coastguard Worker            [aten_op_str, edge_dim_order_op_str],
1755*523fa7a6SAndroid Build Coastguard Worker        )
1756*523fa7a6SAndroid Build Coastguard Worker
1757*523fa7a6SAndroid Build Coastguard Worker        # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops
1758*523fa7a6SAndroid Build Coastguard Worker        new_res = DimOrderOpsRevertPass()(edge_prog.graph_module)
1759*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_res)
1760*523fa7a6SAndroid Build Coastguard Worker        _do_checks(
1761*523fa7a6SAndroid Build Coastguard Worker            new_res.graph_module.code,
1762*523fa7a6SAndroid Build Coastguard Worker            edge_aten_op_str,
1763*523fa7a6SAndroid Build Coastguard Worker            n,
1764*523fa7a6SAndroid Build Coastguard Worker            [aten_op_str, edge_dim_order_op_str],
1765*523fa7a6SAndroid Build Coastguard Worker        )
1766*523fa7a6SAndroid Build Coastguard Worker
1767*523fa7a6SAndroid Build Coastguard Worker        # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops
1768*523fa7a6SAndroid Build Coastguard Worker        edge_prog_dim_order = to_edge(
1769*523fa7a6SAndroid Build Coastguard Worker            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False)
1770*523fa7a6SAndroid Build Coastguard Worker        )._edge_programs["forward"]
1771*523fa7a6SAndroid Build Coastguard Worker        _do_checks(
1772*523fa7a6SAndroid Build Coastguard Worker            edge_prog_dim_order.graph_module.code,
1773*523fa7a6SAndroid Build Coastguard Worker            edge_dim_order_op_str,
1774*523fa7a6SAndroid Build Coastguard Worker            n,
1775*523fa7a6SAndroid Build Coastguard Worker            [aten_op_str, edge_aten_op_str],
1776*523fa7a6SAndroid Build Coastguard Worker        )
1777*523fa7a6SAndroid Build Coastguard Worker
1778*523fa7a6SAndroid Build Coastguard Worker        # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops
1779*523fa7a6SAndroid Build Coastguard Worker        new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module)
1780*523fa7a6SAndroid Build Coastguard Worker        self.assertIsNotNone(new_res_dim_order)
1781*523fa7a6SAndroid Build Coastguard Worker        _do_checks(
1782*523fa7a6SAndroid Build Coastguard Worker            new_res_dim_order.graph_module.code,
1783*523fa7a6SAndroid Build Coastguard Worker            edge_aten_op_str,
1784*523fa7a6SAndroid Build Coastguard Worker            n,
1785*523fa7a6SAndroid Build Coastguard Worker            [aten_op_str, edge_dim_order_op_str],
1786*523fa7a6SAndroid Build Coastguard Worker        )
1787*523fa7a6SAndroid Build Coastguard Worker
1788*523fa7a6SAndroid Build Coastguard Worker        output_no_dim_order = new_res.graph_module(input)
1789*523fa7a6SAndroid Build Coastguard Worker        output_no_dim_order_revert = new_res_dim_order.graph_module(input)
1790*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
1791*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
1792*523fa7a6SAndroid Build Coastguard Worker        )
1793