xref: /aosp_15_r20/external/executorch/test/end2end/test_end2end.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# flake8: noqa: F401
8*523fa7a6SAndroid Build Coastguard Workerimport functools
9*523fa7a6SAndroid Build Coastguard Workerimport inspect
10*523fa7a6SAndroid Build Coastguard Workerimport os
11*523fa7a6SAndroid Build Coastguard Workerimport random
12*523fa7a6SAndroid Build Coastguard Workerimport unittest
13*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Optional, Tuple, Type
14*523fa7a6SAndroid Build Coastguard Workerfrom unittest import skip, skipUnless
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.control_flow as control_flow
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker# @manual=//executorch/extension/pytree:pybindings
21*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as pytree
22*523fa7a6SAndroid Build Coastguard Workerimport torch
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import (
25*523fa7a6SAndroid Build Coastguard Worker    CaptureConfig,
26*523fa7a6SAndroid Build Coastguard Worker    EdgeCompileConfig,
27*523fa7a6SAndroid Build Coastguard Worker    ExecutorchBackendConfig,
28*523fa7a6SAndroid Build Coastguard Worker    memory,
29*523fa7a6SAndroid Build Coastguard Worker)
30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program
32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassManager
33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import (
34*523fa7a6SAndroid Build Coastguard Worker    DebugPass,
35*523fa7a6SAndroid Build Coastguard Worker    MemoryPlanningPass,
36*523fa7a6SAndroid Build Coastguard Worker    to_scratch_op_pass,
37*523fa7a6SAndroid Build Coastguard Worker    ToOutVarPass,
38*523fa7a6SAndroid Build Coastguard Worker)
39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import pretty_print, print_program
40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import make_tensor_value, TensorSpec
41*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.control_flow_models import (
42*523fa7a6SAndroid Build Coastguard Worker    FTCondBasic,
43*523fa7a6SAndroid Build Coastguard Worker    FTCondDynShape,
44*523fa7a6SAndroid Build Coastguard Worker    FTMapBasic,
45*523fa7a6SAndroid Build Coastguard Worker    FTMapDynShape,
46*523fa7a6SAndroid Build Coastguard Worker)
47*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.dynamic_shape_models import BatchNormModel
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.transformer import Transformer
50*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental.control_flow import cond
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Workerkernel_mode = None  # either aten mode or lean mode
53*523fa7a6SAndroid Build Coastguard Workertry:
54*523fa7a6SAndroid Build Coastguard Worker    from executorch.extension.pybindings.portable_lib import (
55*523fa7a6SAndroid Build Coastguard Worker        _load_bundled_program_from_buffer,
56*523fa7a6SAndroid Build Coastguard Worker        _load_for_executorch_from_buffer,
57*523fa7a6SAndroid Build Coastguard Worker        _load_for_executorch_from_bundled_program,
58*523fa7a6SAndroid Build Coastguard Worker    )
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    kernel_mode = "lean"
61*523fa7a6SAndroid Build Coastguard Workerexcept ImportError as e:
62*523fa7a6SAndroid Build Coastguard Worker    print(e)
63*523fa7a6SAndroid Build Coastguard Worker    pass
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Workertry:
66*523fa7a6SAndroid Build Coastguard Worker    from executorch.extension.pybindings.aten_lib import (
67*523fa7a6SAndroid Build Coastguard Worker        _load_bundled_program_from_buffer,
68*523fa7a6SAndroid Build Coastguard Worker        _load_for_executorch_from_buffer,
69*523fa7a6SAndroid Build Coastguard Worker        _load_for_executorch_from_bundled_program,
70*523fa7a6SAndroid Build Coastguard Worker    )
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    assert kernel_mode is None
73*523fa7a6SAndroid Build Coastguard Worker    kernel_mode = "aten"
74*523fa7a6SAndroid Build Coastguard Workerexcept ImportError as e:
75*523fa7a6SAndroid Build Coastguard Worker    print(e)
76*523fa7a6SAndroid Build Coastguard Worker    pass
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Workerassert kernel_mode is not None
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Workeris_aten_mode = kernel_mode == "aten"
81*523fa7a6SAndroid Build Coastguard Workeris_lean_mode = kernel_mode == "lean"
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
84*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as torch_pytree
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Workerfrom .exported_module import ExportedModule
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard WorkerRUN_SKIPPED = int(os.environ.get("RUN_SKIPPED", "0"))
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Workerclass ModuleBasic(nn.Module):
93*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
94*523fa7a6SAndroid Build Coastguard Worker        super(ModuleBasic, self).__init__()
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
97*523fa7a6SAndroid Build Coastguard Worker        return torch.sin(x).max()
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
100*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(100),)
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Workerclass ModuleOpsReturnMulti(nn.Module):
104*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
105*523fa7a6SAndroid Build Coastguard Worker        super(ModuleOpsReturnMulti, self).__init__()
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker    def forward(self, a, b):
108*523fa7a6SAndroid Build Coastguard Worker        x, y = torch.topk(a, 3)
109*523fa7a6SAndroid Build Coastguard Worker        return x * 2 + b
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
112*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(10), torch.randn(3))
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Workerclass ModuleAdd(nn.Module):
116*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
117*523fa7a6SAndroid Build Coastguard Worker        super(ModuleAdd, self).__init__()
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
120*523fa7a6SAndroid Build Coastguard Worker        return torch.add(x, y)
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
123*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(2, 2), torch.randn(2, 2))
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker
126*523fa7a6SAndroid Build Coastguard Workerclass ModuleFloatAddWithAlpha(nn.Module):
127*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
128*523fa7a6SAndroid Build Coastguard Worker        super(ModuleFloatAddWithAlpha, self).__init__()
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor, y: torch.Tensor, c: float):
131*523fa7a6SAndroid Build Coastguard Worker        return torch.add(x, y, alpha=c)
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
134*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(2, 2), torch.randn(2, 2), random.random())
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Workerclass ModuleIntAddWithAlpha(nn.Module):
138*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
139*523fa7a6SAndroid Build Coastguard Worker        super(ModuleIntAddWithAlpha, self).__init__()
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor, y: torch.Tensor, c: int):
142*523fa7a6SAndroid Build Coastguard Worker        return torch.add(x, y, alpha=c)
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
145*523fa7a6SAndroid Build Coastguard Worker        return (
146*523fa7a6SAndroid Build Coastguard Worker            torch.randint(0, 10, (2, 2)),
147*523fa7a6SAndroid Build Coastguard Worker            torch.randint(0, 10, (2, 2)),
148*523fa7a6SAndroid Build Coastguard Worker            random.randint(0, 10),
149*523fa7a6SAndroid Build Coastguard Worker        )
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Workerclass ModuleContainers(nn.Module):
153*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
154*523fa7a6SAndroid Build Coastguard Worker        super(ModuleContainers, self).__init__()
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    def forward(self, d):
157*523fa7a6SAndroid Build Coastguard Worker        a = d["a"]
158*523fa7a6SAndroid Build Coastguard Worker        b = d["b"]
159*523fa7a6SAndroid Build Coastguard Worker        return {"inputs": (a, b), "c": torch.add(a, b)}
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
162*523fa7a6SAndroid Build Coastguard Worker        return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Workerclass ToyModelForMemPlanning(nn.Module):
166*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
167*523fa7a6SAndroid Build Coastguard Worker        super(ToyModelForMemPlanning, self).__init__()
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker    def forward(self, a, b):
170*523fa7a6SAndroid Build Coastguard Worker        o = a
171*523fa7a6SAndroid Build Coastguard Worker        for i in range(3):
172*523fa7a6SAndroid Build Coastguard Worker            o = o * a
173*523fa7a6SAndroid Build Coastguard Worker            o = o + b
174*523fa7a6SAndroid Build Coastguard Worker        return o
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
177*523fa7a6SAndroid Build Coastguard Worker        return (
178*523fa7a6SAndroid Build Coastguard Worker            torch.randn(10),
179*523fa7a6SAndroid Build Coastguard Worker            torch.randn(10),
180*523fa7a6SAndroid Build Coastguard Worker        )
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Workerclass MemPlanningWithScratchTensor(nn.Module):
184*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
185*523fa7a6SAndroid Build Coastguard Worker        super(MemPlanningWithScratchTensor, self).__init__()
186*523fa7a6SAndroid Build Coastguard Worker        self.linear1 = nn.Linear(4, 2)
187*523fa7a6SAndroid Build Coastguard Worker        self.linear2 = nn.Linear(4, 2)
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker    def forward(self, a, b):
190*523fa7a6SAndroid Build Coastguard Worker        o1 = self.linear1(a)
191*523fa7a6SAndroid Build Coastguard Worker        o2 = self.linear2(b)
192*523fa7a6SAndroid Build Coastguard Worker        return o1 + o2
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
195*523fa7a6SAndroid Build Coastguard Worker        return (
196*523fa7a6SAndroid Build Coastguard Worker            torch.randn(10, 4),
197*523fa7a6SAndroid Build Coastguard Worker            torch.randn(10, 4),
198*523fa7a6SAndroid Build Coastguard Worker        )
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Workerclass ModuleOpsReturnTensorList(nn.Module):
202*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
203*523fa7a6SAndroid Build Coastguard Worker        super(ModuleOpsReturnTensorList, self).__init__()
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
206*523fa7a6SAndroid Build Coastguard Worker        split = torch.ops.aten.tensor_split.sections(x, 3)
207*523fa7a6SAndroid Build Coastguard Worker        return split[0]
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
210*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(100),)
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Workerclass ModuleReturnInput(nn.Module):
214*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
215*523fa7a6SAndroid Build Coastguard Worker        super(ModuleReturnInput, self).__init__()
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
218*523fa7a6SAndroid Build Coastguard Worker        return (x, x, {"x": x, "y": x}, [x, x, x])
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
221*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(1),)
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Worker
224*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfElse(nn.Module):
225*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
226*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Worker    def forward(self, c, x):
229*523fa7a6SAndroid Build Coastguard Worker        x = x * x
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker        def addloop(x, n):
232*523fa7a6SAndroid Build Coastguard Worker            out = x
233*523fa7a6SAndroid Build Coastguard Worker            for _ in range(n - 1):
234*523fa7a6SAndroid Build Coastguard Worker                out = out + x
235*523fa7a6SAndroid Build Coastguard Worker            return out
236*523fa7a6SAndroid Build Coastguard Worker
237*523fa7a6SAndroid Build Coastguard Worker        def true_branch(c, x):
238*523fa7a6SAndroid Build Coastguard Worker            return addloop(x, 3)
239*523fa7a6SAndroid Build Coastguard Worker
240*523fa7a6SAndroid Build Coastguard Worker        def false_branch(c, x):
241*523fa7a6SAndroid Build Coastguard Worker            return addloop(x, 4)
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker        y = cond(c, true_branch, false_branch, (c, x))
244*523fa7a6SAndroid Build Coastguard Worker        return y * y
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
247*523fa7a6SAndroid Build Coastguard Worker        return (torch.randint(2, [1]) == 0, torch.randn(10))
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfElseWithBoolInput(nn.Module):
251*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
252*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker    def forward(self, c: bool, x: torch.Tensor):
255*523fa7a6SAndroid Build Coastguard Worker        x = x * x
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Worker        def addloop(x, n):
258*523fa7a6SAndroid Build Coastguard Worker            out = x
259*523fa7a6SAndroid Build Coastguard Worker            for _ in range(n - 1):
260*523fa7a6SAndroid Build Coastguard Worker                out = out + x
261*523fa7a6SAndroid Build Coastguard Worker            return out
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker        def true_branch(c, x):
264*523fa7a6SAndroid Build Coastguard Worker            return addloop(x, 3)
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker        def false_branch(c, x):
267*523fa7a6SAndroid Build Coastguard Worker            return addloop(x, 4)
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker        y = cond(c, true_branch, false_branch, (c, x))
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Worker        return y * y
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
274*523fa7a6SAndroid Build Coastguard Worker        return (random.randint(0, 1) == 0, torch.randn(10))
275*523fa7a6SAndroid Build Coastguard Worker
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Workerclass ModuleWhileIf(nn.Module):
278*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
279*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker    def forward(self, accum, cnt):
282*523fa7a6SAndroid Build Coastguard Worker        @control_flow.tracing_context(
283*523fa7a6SAndroid Build Coastguard Worker            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
284*523fa7a6SAndroid Build Coastguard Worker        )
285*523fa7a6SAndroid Build Coastguard Worker        def loop_cond(accum, cnt):
286*523fa7a6SAndroid Build Coastguard Worker            return cnt != torch.zeros([1]).to(dtype=torch.long)
287*523fa7a6SAndroid Build Coastguard Worker
288*523fa7a6SAndroid Build Coastguard Worker        @control_flow.tracing_context(
289*523fa7a6SAndroid Build Coastguard Worker            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
290*523fa7a6SAndroid Build Coastguard Worker        )
291*523fa7a6SAndroid Build Coastguard Worker        def loop_body(accum, cnt):
292*523fa7a6SAndroid Build Coastguard Worker            # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
293*523fa7a6SAndroid Build Coastguard Worker            @control_flow.tracing_context(
294*523fa7a6SAndroid Build Coastguard Worker                inputs=(torch.zeros([1]).to(dtype=torch.long),)
295*523fa7a6SAndroid Build Coastguard Worker            )
296*523fa7a6SAndroid Build Coastguard Worker            def true_branch(cnt):
297*523fa7a6SAndroid Build Coastguard Worker                return cnt
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker            @control_flow.tracing_context(
300*523fa7a6SAndroid Build Coastguard Worker                inputs=(torch.zeros([1]).to(dtype=torch.long),)
301*523fa7a6SAndroid Build Coastguard Worker            )
302*523fa7a6SAndroid Build Coastguard Worker            def false_branch(cnt):
303*523fa7a6SAndroid Build Coastguard Worker                return torch.zeros([1], dtype=torch.long)
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker            accum = accum + cond(
306*523fa7a6SAndroid Build Coastguard Worker                torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
307*523fa7a6SAndroid Build Coastguard Worker            )
308*523fa7a6SAndroid Build Coastguard Worker            # 'cnt - 1' does not work yet since the runtime does not expect
309*523fa7a6SAndroid Build Coastguard Worker            # tensor to be mixed with scalar for sub op.
310*523fa7a6SAndroid Build Coastguard Worker            return accum, cnt - torch.ones([1]).to(dtype=torch.long)
311*523fa7a6SAndroid Build Coastguard Worker
312*523fa7a6SAndroid Build Coastguard Worker        y, _ = control_flow.while_loop(
313*523fa7a6SAndroid Build Coastguard Worker            loop_cond,
314*523fa7a6SAndroid Build Coastguard Worker            loop_body,
315*523fa7a6SAndroid Build Coastguard Worker            (accum, cnt),
316*523fa7a6SAndroid Build Coastguard Worker        )
317*523fa7a6SAndroid Build Coastguard Worker        return y
318*523fa7a6SAndroid Build Coastguard Worker
319*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
320*523fa7a6SAndroid Build Coastguard Worker        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
321*523fa7a6SAndroid Build Coastguard Worker
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfWhile(nn.Module):
324*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
325*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker    def forward(self, accum, cnt):
328*523fa7a6SAndroid Build Coastguard Worker        @control_flow.tracing_context(
329*523fa7a6SAndroid Build Coastguard Worker            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
330*523fa7a6SAndroid Build Coastguard Worker        )
331*523fa7a6SAndroid Build Coastguard Worker        def true_branch(accum, cnt):
332*523fa7a6SAndroid Build Coastguard Worker            @control_flow.tracing_context(
333*523fa7a6SAndroid Build Coastguard Worker                inputs=(
334*523fa7a6SAndroid Build Coastguard Worker                    torch.zeros([1]).to(dtype=torch.long),
335*523fa7a6SAndroid Build Coastguard Worker                    torch.randint(10, 100, [1]),
336*523fa7a6SAndroid Build Coastguard Worker                )
337*523fa7a6SAndroid Build Coastguard Worker            )
338*523fa7a6SAndroid Build Coastguard Worker            def loop_cond(accum, cnt):
339*523fa7a6SAndroid Build Coastguard Worker                return cnt != torch.zeros([1]).to(dtype=torch.long)
340*523fa7a6SAndroid Build Coastguard Worker
341*523fa7a6SAndroid Build Coastguard Worker            @control_flow.tracing_context(
342*523fa7a6SAndroid Build Coastguard Worker                inputs=(
343*523fa7a6SAndroid Build Coastguard Worker                    torch.zeros([1]).to(dtype=torch.long),
344*523fa7a6SAndroid Build Coastguard Worker                    torch.randint(10, 100, [1]),
345*523fa7a6SAndroid Build Coastguard Worker                )
346*523fa7a6SAndroid Build Coastguard Worker            )
347*523fa7a6SAndroid Build Coastguard Worker            def loop_body(accum, cnt):
348*523fa7a6SAndroid Build Coastguard Worker                return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
349*523fa7a6SAndroid Build Coastguard Worker
350*523fa7a6SAndroid Build Coastguard Worker            return control_flow.while_loop(loop_cond, loop_body, (accum, cnt))
351*523fa7a6SAndroid Build Coastguard Worker
352*523fa7a6SAndroid Build Coastguard Worker        @control_flow.tracing_context(
353*523fa7a6SAndroid Build Coastguard Worker            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
354*523fa7a6SAndroid Build Coastguard Worker        )
355*523fa7a6SAndroid Build Coastguard Worker        def false_branch(accum, cnt):
356*523fa7a6SAndroid Build Coastguard Worker            return accum, cnt
357*523fa7a6SAndroid Build Coastguard Worker
358*523fa7a6SAndroid Build Coastguard Worker        return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
359*523fa7a6SAndroid Build Coastguard Worker            0
360*523fa7a6SAndroid Build Coastguard Worker        ]
361*523fa7a6SAndroid Build Coastguard Worker
362*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
363*523fa7a6SAndroid Build Coastguard Worker        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
364*523fa7a6SAndroid Build Coastguard Worker
365*523fa7a6SAndroid Build Coastguard Worker
366*523fa7a6SAndroid Build Coastguard Workerclass ModuleContiguousTensor(nn.Module):
367*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
368*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
369*523fa7a6SAndroid Build Coastguard Worker        self.linear = nn.Linear(8, 32)
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker    def forward(self, arg):
372*523fa7a6SAndroid Build Coastguard Worker        return self.linear(arg)
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
375*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(3, 8),)
376*523fa7a6SAndroid Build Coastguard Worker
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Workerclass ModuleInputDynamicShape(nn.Module):
379*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
380*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
381*523fa7a6SAndroid Build Coastguard Worker
382*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
383*523fa7a6SAndroid Build Coastguard Worker        for i in range(4):
384*523fa7a6SAndroid Build Coastguard Worker            x = x + x
385*523fa7a6SAndroid Build Coastguard Worker            x = x * x
386*523fa7a6SAndroid Build Coastguard Worker        return x
387*523fa7a6SAndroid Build Coastguard Worker
388*523fa7a6SAndroid Build Coastguard Worker    def get_upper_bound_inputs(self):
389*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(10),)
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
392*523fa7a6SAndroid Build Coastguard Worker        n = random.randint(1, 10)
393*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(n),)
394*523fa7a6SAndroid Build Coastguard Worker
395*523fa7a6SAndroid Build Coastguard Worker
396*523fa7a6SAndroid Build Coastguard Workerclass ModuleIntermediateDynamicShape(nn.Module):
397*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
398*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
399*523fa7a6SAndroid Build Coastguard Worker
400*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
401*523fa7a6SAndroid Build Coastguard Worker        x = x * x
402*523fa7a6SAndroid Build Coastguard Worker
403*523fa7a6SAndroid Build Coastguard Worker        # We should use x[torch.nonzero(x)] ideally, but index op is not supported
404*523fa7a6SAndroid Build Coastguard Worker        # in the runtime so far.
405*523fa7a6SAndroid Build Coastguard Worker        x = torch.nonzero(x)
406*523fa7a6SAndroid Build Coastguard Worker        return x + x
407*523fa7a6SAndroid Build Coastguard Worker
408*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
409*523fa7a6SAndroid Build Coastguard Worker        return (torch.randint(0, 2, (10,), dtype=torch.float),)
410*523fa7a6SAndroid Build Coastguard Worker
411*523fa7a6SAndroid Build Coastguard Worker
412*523fa7a6SAndroid Build Coastguard Workerdef allclose(lhs, rhs, rtol=1e-5, atol=1e-8):
413*523fa7a6SAndroid Build Coastguard Worker    r"""
414*523fa7a6SAndroid Build Coastguard Worker    Unlike torch.allocse which only handles Tensor arguments, allclose handles
415*523fa7a6SAndroid Build Coastguard Worker    list, tuple, dict and nesting of these as well.
416*523fa7a6SAndroid Build Coastguard Worker    """
417*523fa7a6SAndroid Build Coastguard Worker    if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
418*523fa7a6SAndroid Build Coastguard Worker        return torch.allclose(lhs, rhs, rtol, atol)
419*523fa7a6SAndroid Build Coastguard Worker    if isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)):
420*523fa7a6SAndroid Build Coastguard Worker        return len(lhs) == len(rhs) and all(
421*523fa7a6SAndroid Build Coastguard Worker            allclose(a, b, rtol, atol) for a, b in zip(lhs, rhs)
422*523fa7a6SAndroid Build Coastguard Worker        )
423*523fa7a6SAndroid Build Coastguard Worker    if isinstance(lhs, dict) and isinstance(rhs, dict):
424*523fa7a6SAndroid Build Coastguard Worker        lhs_keys = set(lhs.keys())
425*523fa7a6SAndroid Build Coastguard Worker        rhs_keys = set(rhs.keys())
426*523fa7a6SAndroid Build Coastguard Worker        if lhs_keys != rhs_keys:
427*523fa7a6SAndroid Build Coastguard Worker            return False
428*523fa7a6SAndroid Build Coastguard Worker        return all(allclose(lhs[k], rhs[k], rtol, atol) for k in lhs)
429*523fa7a6SAndroid Build Coastguard Worker    else:
430*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
431*523fa7a6SAndroid Build Coastguard Worker            f"Unexpected types: lhs type {type(lhs)}, rhs type {type(rhs)}"
432*523fa7a6SAndroid Build Coastguard Worker        )
433*523fa7a6SAndroid Build Coastguard Worker
434*523fa7a6SAndroid Build Coastguard Worker
435*523fa7a6SAndroid Build Coastguard Workerdef validate_contiguous_tensors(program):
436*523fa7a6SAndroid Build Coastguard Worker    def _is_contiguous_tensor(tensor: exir.schema.Tensor):
437*523fa7a6SAndroid Build Coastguard Worker        """
438*523fa7a6SAndroid Build Coastguard Worker        Ensure the tensor is pytorch contigous (torch.memory_format=torch.contiguous)
439*523fa7a6SAndroid Build Coastguard Worker        since the runtime can not handle non-contiguous tensors so far.
440*523fa7a6SAndroid Build Coastguard Worker        """
441*523fa7a6SAndroid Build Coastguard Worker        sizes = tensor.sizes
442*523fa7a6SAndroid Build Coastguard Worker        dim_order = tensor.dim_order
443*523fa7a6SAndroid Build Coastguard Worker        assert len(sizes) == len(dim_order)
444*523fa7a6SAndroid Build Coastguard Worker        for i, val in enumerate(dim_order):
445*523fa7a6SAndroid Build Coastguard Worker            if i != val:
446*523fa7a6SAndroid Build Coastguard Worker                return False
447*523fa7a6SAndroid Build Coastguard Worker        return True
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker    for execution_plan in program.execution_plan:
450*523fa7a6SAndroid Build Coastguard Worker        for value in execution_plan.values:
451*523fa7a6SAndroid Build Coastguard Worker            if isinstance(value.val, exir.schema.Tensor):
452*523fa7a6SAndroid Build Coastguard Worker                assert _is_contiguous_tensor(
453*523fa7a6SAndroid Build Coastguard Worker                    value.val
454*523fa7a6SAndroid Build Coastguard Worker                ), f"Non-contiguous tensor found: size {value.val.sizes} stride {value.val.strides}. constant_buffer_idx {value.val.constant_buffer_idx}. allocation_info {value.val.allocation_info}."
455*523fa7a6SAndroid Build Coastguard Worker
456*523fa7a6SAndroid Build Coastguard Worker
457*523fa7a6SAndroid Build Coastguard Workerclass BoundMethod(object):
458*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, instance, callable):
459*523fa7a6SAndroid Build Coastguard Worker        self._instance = instance
460*523fa7a6SAndroid Build Coastguard Worker        self._callable = callable
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, *args, **kwargs):
463*523fa7a6SAndroid Build Coastguard Worker        return self._callable(self.instance, *args, **kwargs)
464*523fa7a6SAndroid Build Coastguard Worker
465*523fa7a6SAndroid Build Coastguard Worker
466*523fa7a6SAndroid Build Coastguard Workerdef maketest(
467*523fa7a6SAndroid Build Coastguard Worker    module_cls: Type[nn.Module],
468*523fa7a6SAndroid Build Coastguard Worker    niter: int = 10,
469*523fa7a6SAndroid Build Coastguard Worker    run_executor: bool = True,
470*523fa7a6SAndroid Build Coastguard Worker    do_tree_flatten: bool = False,
471*523fa7a6SAndroid Build Coastguard Worker    run_graph_module: bool = True,
472*523fa7a6SAndroid Build Coastguard Worker    atol: float = 1e-8,
473*523fa7a6SAndroid Build Coastguard Worker    rtol: float = 1e-5,
474*523fa7a6SAndroid Build Coastguard Worker    ignore_to_out_var_failure: bool = False,
475*523fa7a6SAndroid Build Coastguard Worker    allow_non_contiguous_tensor: bool = False,
476*523fa7a6SAndroid Build Coastguard Worker    method: str = "forward",
477*523fa7a6SAndroid Build Coastguard Worker    dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
478*523fa7a6SAndroid Build Coastguard Worker    capture_config=None,
479*523fa7a6SAndroid Build Coastguard Worker    verify_graph: Optional[Callable] = None,
480*523fa7a6SAndroid Build Coastguard Worker) -> Callable[[unittest.TestCase], None]:
481*523fa7a6SAndroid Build Coastguard Worker    r"""Returns a TestCase method to test the provided module class and method.
482*523fa7a6SAndroid Build Coastguard Worker
483*523fa7a6SAndroid Build Coastguard Worker    Args:
484*523fa7a6SAndroid Build Coastguard Worker        module_cls: The subclass of nn.Module to export.
485*523fa7a6SAndroid Build Coastguard Worker        niter: The number of random input data sets to test with.
486*523fa7a6SAndroid Build Coastguard Worker        run_executor: Whether to run the model on the executor. We may want to
487*523fa7a6SAndroid Build Coastguard Worker            skip running a model thru executor since some kernels are not
488*523fa7a6SAndroid Build Coastguard Worker            implemented.
489*523fa7a6SAndroid Build Coastguard Worker        do_tree_flatten: Whether to flatten input and unflatten output.
490*523fa7a6SAndroid Build Coastguard Worker        run_graph_module: Whether to run the traced and transformed GraphModule.
491*523fa7a6SAndroid Build Coastguard Worker            One may want to skip this if some custom ops do not have
492*523fa7a6SAndroid Build Coastguard Worker            implementation in torch.ops but is implemented in the executor.
493*523fa7a6SAndroid Build Coastguard Worker        atol: Absolute tolerance used in allclose and torch.allclose
494*523fa7a6SAndroid Build Coastguard Worker        rtol: Relative tolerance used in allclose and torch.allclose
495*523fa7a6SAndroid Build Coastguard Worker        ignore_to_out_var_failure: Whether to ignore the failue when a
496*523fa7a6SAndroid Build Coastguard Worker            functional op does not have an out variant.
497*523fa7a6SAndroid Build Coastguard Worker        allow_non_contiguous_tensor: If false, will validate that the emitted
498*523fa7a6SAndroid Build Coastguard Worker            program only contains contiguous tensors.
499*523fa7a6SAndroid Build Coastguard Worker        method: The name of the module_cls method to trace.
500*523fa7a6SAndroid Build Coastguard Worker        dynamic_memory_planning_mode: The dynamic memory planning mode to use.
501*523fa7a6SAndroid Build Coastguard Worker
502*523fa7a6SAndroid Build Coastguard Worker    Returns:
503*523fa7a6SAndroid Build Coastguard Worker        A TestCase method that tests the provided module class and method.
504*523fa7a6SAndroid Build Coastguard Worker    """
505*523fa7a6SAndroid Build Coastguard Worker
506*523fa7a6SAndroid Build Coastguard Worker    def wrapper(self: unittest.TestCase) -> None:
507*523fa7a6SAndroid Build Coastguard Worker        """A TestCase method that traces/exports/tests an nn.Module and method."""
508*523fa7a6SAndroid Build Coastguard Worker        module = ExportedModule.export(
509*523fa7a6SAndroid Build Coastguard Worker            module_class=module_cls,
510*523fa7a6SAndroid Build Coastguard Worker            # testend2end only supports modules with single methods defined
511*523fa7a6SAndroid Build Coastguard Worker            methods=(method,),
512*523fa7a6SAndroid Build Coastguard Worker            ignore_to_out_var_failure=ignore_to_out_var_failure,
513*523fa7a6SAndroid Build Coastguard Worker            dynamic_memory_planning_mode=dynamic_memory_planning_mode,
514*523fa7a6SAndroid Build Coastguard Worker            capture_config=capture_config,
515*523fa7a6SAndroid Build Coastguard Worker        )
516*523fa7a6SAndroid Build Coastguard Worker        if verify_graph:
517*523fa7a6SAndroid Build Coastguard Worker            verify_graph(self, module.exported_program.graph_module)
518*523fa7a6SAndroid Build Coastguard Worker        print(f"inputs for tracing: {module.trace_inputs}")
519*523fa7a6SAndroid Build Coastguard Worker
520*523fa7a6SAndroid Build Coastguard Worker        # compare the result between the eager module and graph module
521*523fa7a6SAndroid Build Coastguard Worker        inputs_list = [module.get_random_inputs() for _ in range(niter)]
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker        if run_graph_module:
524*523fa7a6SAndroid Build Coastguard Worker            for inputs in inputs_list:
525*523fa7a6SAndroid Build Coastguard Worker                with torch.no_grad():
526*523fa7a6SAndroid Build Coastguard Worker                    # only one method is supported so just grab that single method
527*523fa7a6SAndroid Build Coastguard Worker                    expected = getattr(module.eager_module, module.methods[0])(*inputs)
528*523fa7a6SAndroid Build Coastguard Worker                with torch.no_grad():
529*523fa7a6SAndroid Build Coastguard Worker                    result = module.exported_program.module()(*inputs)
530*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(allclose(expected, result, rtol, atol))
531*523fa7a6SAndroid Build Coastguard Worker
532*523fa7a6SAndroid Build Coastguard Worker        program = module.executorch_program.executorch_program
533*523fa7a6SAndroid Build Coastguard Worker        pretty_print(program)
534*523fa7a6SAndroid Build Coastguard Worker        print_program(program, show_meminfo=True, mark_dynamic_shape_tensor=True)
535*523fa7a6SAndroid Build Coastguard Worker        print(f"mem buffer sizes: {program.execution_plan[0].non_const_buffer_sizes}")
536*523fa7a6SAndroid Build Coastguard Worker        if not allow_non_contiguous_tensor:
537*523fa7a6SAndroid Build Coastguard Worker            validate_contiguous_tensors(program)
538*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(program.execution_plan[0].non_const_buffer_sizes) >= 2)
539*523fa7a6SAndroid Build Coastguard Worker        # We should not enable the following assertion since for some models
540*523fa7a6SAndroid Build Coastguard Worker        # that simply returning graph input, no mutable memory should be allocated
541*523fa7a6SAndroid Build Coastguard Worker        # self.assertTrue(all(s > 0 for s in program.program.execution_plan[0].non_const_buffer_sizes[1:]))
542*523fa7a6SAndroid Build Coastguard Worker
543*523fa7a6SAndroid Build Coastguard Worker        program.version = 0
544*523fa7a6SAndroid Build Coastguard Worker        buff = module.executorch_program.buffer
545*523fa7a6SAndroid Build Coastguard Worker        # Check that the magic version number is in the expected place, and
546*523fa7a6SAndroid Build Coastguard Worker        # follows the expected pattern.
547*523fa7a6SAndroid Build Coastguard Worker        self.assertRegex(buff[4:8].decode(errors="replace"), r"^ET[0-9][0-9]$")
548*523fa7a6SAndroid Build Coastguard Worker
549*523fa7a6SAndroid Build Coastguard Worker        if run_executor:
550*523fa7a6SAndroid Build Coastguard Worker            print("Running on the runtime")
551*523fa7a6SAndroid Build Coastguard Worker            executorch_module = _load_for_executorch_from_buffer(buff)
552*523fa7a6SAndroid Build Coastguard Worker            # compare the result between eager module and executor
553*523fa7a6SAndroid Build Coastguard Worker            for idx, inputs in enumerate(inputs_list):
554*523fa7a6SAndroid Build Coastguard Worker                with torch.no_grad():
555*523fa7a6SAndroid Build Coastguard Worker                    expected = getattr(module.eager_module, method)(*inputs)
556*523fa7a6SAndroid Build Coastguard Worker
557*523fa7a6SAndroid Build Coastguard Worker                if do_tree_flatten:
558*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
559*523fa7a6SAndroid Build Coastguard Worker                    flatten_inputs, inputs_spec = pytree.tree_flatten(*inputs)
560*523fa7a6SAndroid Build Coastguard Worker                    executorch_result = executorch_module.forward([*flatten_inputs])
561*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[16]: Module `pytree` has no attribute `TreeSpec`.
562*523fa7a6SAndroid Build Coastguard Worker                    executorch_result_unflatten = pytree.TreeSpec.from_str(
563*523fa7a6SAndroid Build Coastguard Worker                        program.execution_plan[0].container_meta_type.encoded_out_str
564*523fa7a6SAndroid Build Coastguard Worker                    ).tree_unflatten(executorch_result)
565*523fa7a6SAndroid Build Coastguard Worker                    actual = executorch_result_unflatten
566*523fa7a6SAndroid Build Coastguard Worker                else:
567*523fa7a6SAndroid Build Coastguard Worker                    actual = executorch_module.forward(inputs)[0]
568*523fa7a6SAndroid Build Coastguard Worker                is_close = allclose(expected, actual, rtol, atol)
569*523fa7a6SAndroid Build Coastguard Worker                if not is_close:
570*523fa7a6SAndroid Build Coastguard Worker                    print(f"Fail for {idx}th inputs: {inputs}")
571*523fa7a6SAndroid Build Coastguard Worker                    print(f"expected result: {expected}")
572*523fa7a6SAndroid Build Coastguard Worker                    print(f"actual result: {actual}")
573*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(is_close)
574*523fa7a6SAndroid Build Coastguard Worker
575*523fa7a6SAndroid Build Coastguard Worker    return wrapper
576*523fa7a6SAndroid Build Coastguard Worker
577*523fa7a6SAndroid Build Coastguard Worker
578*523fa7a6SAndroid Build Coastguard Workerclass E2ETest(unittest.TestCase):
579*523fa7a6SAndroid Build Coastguard Worker    r"""
580*523fa7a6SAndroid Build Coastguard Worker    When adding a new unittest, call maketest(ModuleName) if possible since
581*523fa7a6SAndroid Build Coastguard Worker    maketest handles all the boilterplate part. Ideally, we only need define
582*523fa7a6SAndroid Build Coastguard Worker    a new nn.Module and add one line to call maketest for new end2end test cases.
583*523fa7a6SAndroid Build Coastguard Worker    """
584*523fa7a6SAndroid Build Coastguard Worker
585*523fa7a6SAndroid Build Coastguard Worker    # don't run the model thru executor because aten::sin.out is not defined
586*523fa7a6SAndroid Build Coastguard Worker    # in the executor currently.
587*523fa7a6SAndroid Build Coastguard Worker    #
588*523fa7a6SAndroid Build Coastguard Worker    # aten::max.default does not have an out variant. Thus we need set
589*523fa7a6SAndroid Build Coastguard Worker    # ignore_to_out_var_failure to be True.
590*523fa7a6SAndroid Build Coastguard Worker    def test_basic(self):
591*523fa7a6SAndroid Build Coastguard Worker        maketest(ModuleBasic, run_executor=False, ignore_to_out_var_failure=True)(self)
592*523fa7a6SAndroid Build Coastguard Worker
593*523fa7a6SAndroid Build Coastguard Worker    # Make sure we can handle ops that return mutliple values. E.g. topk
594*523fa7a6SAndroid Build Coastguard Worker    # At one time we can not properly setup TensorSpec for an Fx node
595*523fa7a6SAndroid Build Coastguard Worker    # returning multiple tensors
596*523fa7a6SAndroid Build Coastguard Worker    #
597*523fa7a6SAndroid Build Coastguard Worker    # don't run the model thru executor because aten::topk.values is not defined
598*523fa7a6SAndroid Build Coastguard Worker    # in the executor currently
599*523fa7a6SAndroid Build Coastguard Worker    def test_ops_return_multi(self):
600*523fa7a6SAndroid Build Coastguard Worker        maketest(ModuleOpsReturnMulti, run_executor=False)(self)
601*523fa7a6SAndroid Build Coastguard Worker
602*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
603*523fa7a6SAndroid Build Coastguard Worker    def test_mem_planning_toy_model(self):
604*523fa7a6SAndroid Build Coastguard Worker        maketest(
605*523fa7a6SAndroid Build Coastguard Worker            ToyModelForMemPlanning,
606*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
607*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
608*523fa7a6SAndroid Build Coastguard Worker            ),
609*523fa7a6SAndroid Build Coastguard Worker        )(self)
610*523fa7a6SAndroid Build Coastguard Worker
611*523fa7a6SAndroid Build Coastguard Worker    # TODO: add ops implementations and turn on 'run_executor'
612*523fa7a6SAndroid Build Coastguard Worker    def test_mem_planning_scratch_tensor(self):
613*523fa7a6SAndroid Build Coastguard Worker        maketest(
614*523fa7a6SAndroid Build Coastguard Worker            MemPlanningWithScratchTensor,
615*523fa7a6SAndroid Build Coastguard Worker            run_graph_module=False,
616*523fa7a6SAndroid Build Coastguard Worker            run_executor=False,
617*523fa7a6SAndroid Build Coastguard Worker            atol=1e-5,
618*523fa7a6SAndroid Build Coastguard Worker        )(self)
619*523fa7a6SAndroid Build Coastguard Worker
620*523fa7a6SAndroid Build Coastguard Worker    def test_executorch_forward(self):
621*523fa7a6SAndroid Build Coastguard Worker        maketest(ModuleAdd)(self)
622*523fa7a6SAndroid Build Coastguard Worker
623*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
624*523fa7a6SAndroid Build Coastguard Worker    def test_containers(self):
625*523fa7a6SAndroid Build Coastguard Worker        maketest(
626*523fa7a6SAndroid Build Coastguard Worker            ModuleContainers,
627*523fa7a6SAndroid Build Coastguard Worker            do_tree_flatten=True,
628*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
629*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
630*523fa7a6SAndroid Build Coastguard Worker            ),
631*523fa7a6SAndroid Build Coastguard Worker        )(self)
632*523fa7a6SAndroid Build Coastguard Worker
633*523fa7a6SAndroid Build Coastguard Worker    # can not run the graph module since the out variance with tensor list out
634*523fa7a6SAndroid Build Coastguard Worker    # argument returns None rather than tensor list.
635*523fa7a6SAndroid Build Coastguard Worker    #
636*523fa7a6SAndroid Build Coastguard Worker    # Can not run in the executor since kernel for tensor splitting is not implemented..
637*523fa7a6SAndroid Build Coastguard Worker    def test_ops_return_tensorlist(self):
638*523fa7a6SAndroid Build Coastguard Worker        maketest(ModuleOpsReturnTensorList, run_graph_module=False, run_executor=False)(
639*523fa7a6SAndroid Build Coastguard Worker            self
640*523fa7a6SAndroid Build Coastguard Worker        )
641*523fa7a6SAndroid Build Coastguard Worker
642*523fa7a6SAndroid Build Coastguard Worker    # Failed to produce a graph during tracing w/ dynamo because there are no torch ops
643*523fa7a6SAndroid Build Coastguard Worker    # test_return_input = maketest(ModuleReturnInput, do_tree_flatten=True)
644*523fa7a6SAndroid Build Coastguard Worker
645*523fa7a6SAndroid Build Coastguard Worker    # can not run this on the executor because missing the following ops:
646*523fa7a6SAndroid Build Coastguard Worker    #   aten::select_copy.int_out, aten::eq.Scalar_out
647*523fa7a6SAndroid Build Coastguard Worker    # TODO(zhxchen17) re-enable these tests.
648*523fa7a6SAndroid Build Coastguard Worker    # test_control_flow_cond = maketest(ControlFlowCond, run_executor=False)
649*523fa7a6SAndroid Build Coastguard Worker    # fail to trace with functionalization enabled
650*523fa7a6SAndroid Build Coastguard Worker    # test_ifelse = maketest(ModuleIfElse)
651*523fa7a6SAndroid Build Coastguard Worker
652*523fa7a6SAndroid Build Coastguard Worker    # fail to trace with functionalization enabled
653*523fa7a6SAndroid Build Coastguard Worker    # Fail with error: Missing out variants: {'aten::select', 'aten::_shape_as_tensor', 'aten::tensor_split'}
654*523fa7a6SAndroid Build Coastguard Worker    # TODO(zhxchen17) re-enable these tests.
655*523fa7a6SAndroid Build Coastguard Worker    # test_while_0 = maketest(
656*523fa7a6SAndroid Build Coastguard Worker    #     ControlFlowWhile,
657*523fa7a6SAndroid Build Coastguard Worker    #     ignore_to_out_var_failure=True,
658*523fa7a6SAndroid Build Coastguard Worker    #     run_executor=False,
659*523fa7a6SAndroid Build Coastguard Worker    # )
660*523fa7a6SAndroid Build Coastguard Worker
661*523fa7a6SAndroid Build Coastguard Worker    # test_while = maketest(ModuleWhile)
662*523fa7a6SAndroid Build Coastguard Worker
663*523fa7a6SAndroid Build Coastguard Worker    # test_while_if = maketest(ModuleWhileIf)
664*523fa7a6SAndroid Build Coastguard Worker    # test_if_while = maketest(ModuleIfWhile)
665*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job")
666*523fa7a6SAndroid Build Coastguard Worker    def test_contiguous_tensor(self):
667*523fa7a6SAndroid Build Coastguard Worker        maketest(ModuleContiguousTensor, run_executor=False)(self)
668*523fa7a6SAndroid Build Coastguard Worker
669*523fa7a6SAndroid Build Coastguard Worker
670*523fa7a6SAndroid Build Coastguard Workerclass DynamicModelE2ETest(unittest.TestCase):
671*523fa7a6SAndroid Build Coastguard Worker    """
672*523fa7a6SAndroid Build Coastguard Worker    End2end tests for dynamic models. For dynamic models we mean models with
673*523fa7a6SAndroid Build Coastguard Worker    control flow or dynamic shape.
674*523fa7a6SAndroid Build Coastguard Worker    """
675*523fa7a6SAndroid Build Coastguard Worker
676*523fa7a6SAndroid Build Coastguard Worker    @skip("Revisit when unbacked symint is ready")
677*523fa7a6SAndroid Build Coastguard Worker    def test_intermediate_dynamic_shape(self):
678*523fa7a6SAndroid Build Coastguard Worker        maketest(
679*523fa7a6SAndroid Build Coastguard Worker            ModuleIntermediateDynamicShape,
680*523fa7a6SAndroid Build Coastguard Worker            run_graph_module=False,
681*523fa7a6SAndroid Build Coastguard Worker            allow_non_contiguous_tensor=True,
682*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
683*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
684*523fa7a6SAndroid Build Coastguard Worker            ),
685*523fa7a6SAndroid Build Coastguard Worker        )(self)
686*523fa7a6SAndroid Build Coastguard Worker
687*523fa7a6SAndroid Build Coastguard Worker    # TODO(shunting): some non constant tensors for transformer are non-contiguous.
688*523fa7a6SAndroid Build Coastguard Worker    # Ignore for now. Will debug more.
689*523fa7a6SAndroid Build Coastguard Worker    # NOTE: can not run on runtime since missing these ops: P535190636
690*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job")
691*523fa7a6SAndroid Build Coastguard Worker    def test_transformer_encode(self):
692*523fa7a6SAndroid Build Coastguard Worker        maketest(
693*523fa7a6SAndroid Build Coastguard Worker            Transformer,
694*523fa7a6SAndroid Build Coastguard Worker            method="encode",
695*523fa7a6SAndroid Build Coastguard Worker            allow_non_contiguous_tensor=True,
696*523fa7a6SAndroid Build Coastguard Worker            run_executor=False,
697*523fa7a6SAndroid Build Coastguard Worker        )(self)
698*523fa7a6SAndroid Build Coastguard Worker
699*523fa7a6SAndroid Build Coastguard Worker    # basic test for functorch torch.ops.higher_order.cond
700*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
701*523fa7a6SAndroid Build Coastguard Worker    def test_ft_cond_basic(self):
702*523fa7a6SAndroid Build Coastguard Worker        maketest(
703*523fa7a6SAndroid Build Coastguard Worker            FTCondBasic,
704*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
705*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
706*523fa7a6SAndroid Build Coastguard Worker                enable_functionalization=False,  # TODO enable functionalization
707*523fa7a6SAndroid Build Coastguard Worker            ),
708*523fa7a6SAndroid Build Coastguard Worker        )(self)
709*523fa7a6SAndroid Build Coastguard Worker
710*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "Emitter is not ready yet")
711*523fa7a6SAndroid Build Coastguard Worker    def test_ft_map_basic(self):
712*523fa7a6SAndroid Build Coastguard Worker        maketest(
713*523fa7a6SAndroid Build Coastguard Worker            FTMapBasic,
714*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
715*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
716*523fa7a6SAndroid Build Coastguard Worker                enable_functionalization=False,  # TODO enable functionalization
717*523fa7a6SAndroid Build Coastguard Worker            ),
718*523fa7a6SAndroid Build Coastguard Worker        )(self)
719*523fa7a6SAndroid Build Coastguard Worker
720*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
721*523fa7a6SAndroid Build Coastguard Worker    def test_ft_cond_dynshape(self):
722*523fa7a6SAndroid Build Coastguard Worker        maketest(
723*523fa7a6SAndroid Build Coastguard Worker            FTCondDynShape,
724*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
725*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
726*523fa7a6SAndroid Build Coastguard Worker                enable_functionalization=False,  # TODO enable functionalization
727*523fa7a6SAndroid Build Coastguard Worker            ),
728*523fa7a6SAndroid Build Coastguard Worker        )(self)
729*523fa7a6SAndroid Build Coastguard Worker
730*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "Emitter is not ready yet")
731*523fa7a6SAndroid Build Coastguard Worker    def test_ft_map_dynshape(self):
732*523fa7a6SAndroid Build Coastguard Worker        maketest(
733*523fa7a6SAndroid Build Coastguard Worker            FTMapDynShape,
734*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
735*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
736*523fa7a6SAndroid Build Coastguard Worker                enable_functionalization=False,  # TODO enable functionalization
737*523fa7a6SAndroid Build Coastguard Worker            ),
738*523fa7a6SAndroid Build Coastguard Worker        )(self)
739*523fa7a6SAndroid Build Coastguard Worker
740*523fa7a6SAndroid Build Coastguard Worker    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
741*523fa7a6SAndroid Build Coastguard Worker    def test_batch_norm(self):
742*523fa7a6SAndroid Build Coastguard Worker        maketest(
743*523fa7a6SAndroid Build Coastguard Worker            BatchNormModel,
744*523fa7a6SAndroid Build Coastguard Worker            capture_config=exir.CaptureConfig(
745*523fa7a6SAndroid Build Coastguard Worker                enable_dynamic_shape=True,
746*523fa7a6SAndroid Build Coastguard Worker            ),
747*523fa7a6SAndroid Build Coastguard Worker            verify_graph=BatchNormModel.verify_graph,
748*523fa7a6SAndroid Build Coastguard Worker            # TODO: lean mode does not have native_batch_norm.out implemented
749*523fa7a6SAndroid Build Coastguard Worker            # run this on aten mode.
750*523fa7a6SAndroid Build Coastguard Worker            run_executor=is_aten_mode,
751*523fa7a6SAndroid Build Coastguard Worker        )(self)
752