xref: /aosp_15_r20/external/pytorch/test/export/test_export.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2# flake8: noqa
3import copy
4import dataclasses
5import io
6import logging
7import operator
8import re
9import unittest
10import warnings
11from contextlib import contextmanager
12from dataclasses import dataclass
13from re import escape
14from typing import Dict, List
15
16import torch
17import torch._dynamo as torchdynamo
18import torch.nn.functional as F
19from functorch.experimental.control_flow import cond, map
20from torch import Tensor
21from torch._decomp import get_decompositions
22from torch._dynamo.test_case import TestCase
23from torch._dynamo.testing import normalize_gm
24from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
25from torch._export.utils import (
26    get_buffer,
27    get_param,
28    is_buffer,
29    is_param,
30    register_dataclass_as_pytree_node,
31)
32from torch._higher_order_ops.hints_wrap import hints_wrapper
33from torch._inductor.compile_fx import split_const_gm
34from torch._subclasses import FakeTensorMode
35from torch.export import Dim, export, unflatten
36from torch.export._trace import (
37    _export,
38    _export_to_torch_ir,
39    DEFAULT_EXPORT_DYNAMO_CONFIG,
40)
41from torch.export.graph_signature import (
42    ExportGraphSignature,
43    InputKind,
44    OutputKind,
45    OutputSpec,
46    TensorArgument,
47)
48from torch.fx.experimental.proxy_tensor import make_fx
49from torch.fx.experimental.symbolic_shapes import ShapeEnv
50from torch.testing import FileCheck
51from torch.testing._internal.common_cuda import (
52    PLATFORM_SUPPORTS_FLASH_ATTENTION,
53    SM90OrLater,
54)
55from torch.testing._internal.common_device_type import onlyCPU, onlyCUDA
56from torch.testing._internal.common_utils import (
57    find_library_location,
58    IS_FBCODE,
59    IS_MACOS,
60    IS_SANDCASTLE,
61    IS_WINDOWS,
62    run_tests,
63    TEST_TRANSFORMERS,
64    TestCase as TorchTestCase,
65)
66from torch.utils._pytree import (
67    LeafSpec,
68    tree_flatten,
69    tree_map,
70    tree_unflatten,
71    TreeSpec,
72    treespec_dumps,
73    treespec_loads,
74)
75
76
77try:
78    from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
79
80    HAS_TORCHREC = True
81except ImportError:
82    HAS_TORCHREC = False
83
84try:
85    from . import testing
86except ImportError:
87    import testing
88# The following import pattern matters as `test_export.export` is patched
89# in other files (like test_export_nonstrict.py). `torch.export.export`
90# will invalidate the patch.
91from torch.export import export
92
93
94torch.library.define("testlib::returns_tensor_symint", "(Tensor x) -> (Tensor, SymInt)")
95torch.library.define(
96    "testlib::foo",
97    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
98    tags=torch.Tag.pt2_compliant_tag,
99)
100torch.library.define(
101    "testlib::foo_mutated",
102    "(Tensor(a!) x) -> (Tensor, Tensor)",
103    tags=torch.Tag.pt2_compliant_tag,
104)
105torch.library.define(
106    "testlib::foo_functional",
107    "(Tensor x) -> (Tensor)",
108    tags=torch.Tag.pt2_compliant_tag,
109)
110torch.library.define(
111    "testlib::foo_unbacked",
112    "(Scalar x) -> (Tensor)",
113    tags=torch.Tag.pt2_compliant_tag,
114)
115
116
117@torch.library.impl("testlib::returns_tensor_symint", "cpu")
118@torch.library.impl_abstract("testlib::returns_tensor_symint")
119def returns_tensor_symint_impl(x):
120    return x, x.shape[0]
121
122
123@torch.library.impl("testlib::foo", "cpu")
124@torch._dynamo.disable
125def foo_impl(x, z):
126    x.add_(5)
127    z.add_(5)
128    return x, z, x + z
129
130
131@torch.library.impl_abstract("testlib::foo")
132def foo_abstract(x, z):
133    return x, z, x + z
134
135
136@torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd")
137def foo_mutated(x):
138    a, b, c = torch.ops.testlib.foo(x, x.cos())
139    return a, a.cos()
140
141
142@torch.library.impl("testlib::foo_functional", "CompositeImplicitAutograd")
143def foo_functional(x):
144    a, b, c = torch.ops.testlib.foo(x.cos(), x.cos())
145    return a.cos()
146
147
148@torch.library.impl("testlib::foo_unbacked", "CompositeImplicitAutograd")
149def foo_unbacked(x):
150    if x > 2:
151        return torch.ones(4, 4)
152    if x < 6:
153        return torch.ones(4, 4)
154    return torch.ones(4, 4)
155
156
157@dataclass
158class Inp:
159    x: Tensor
160    y: List[Tensor]
161    z: Dict[str, Tensor]
162
163
164NON_STRICT_SUFFIX = "_non_strict"
165RETRACEABILITY_SUFFIX = "_retraceability"
166SERDES_SUFFIX = "_serdes"
167PREDISPATCH_SUFFIX = "_pre_dispatch"
168TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
169TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
170
171
172def is_non_strict_test(test_name):
173    return test_name.endswith(NON_STRICT_SUFFIX)
174
175
176def is_retracebility_test(test_name):
177    return test_name.endswith(RETRACEABILITY_SUFFIX)
178
179
180def is_serdes_test(test_name):
181    return test_name.endswith(SERDES_SUFFIX)
182
183
184def is_training_ir_test(test_name):
185    return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) or test_name.endswith(
186        TRAINING_IR_DECOMP_NON_STRICT_SUFFIX
187    )
188
189
190def get_hop_schema(ep: torch.export.ExportedProgram):
191    hop_node = next(
192        node
193        for node in ep.graph.nodes
194        if isinstance(node.target, torch._ops.HigherOrderOperator)
195    )
196    return torch._library.utils.hop_schema_from_fx_node(hop_node)
197
198
199@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
200class TestDynamismExpression(TestCase):
201    def test_export_inline_constraints(self):
202        class Module(torch.nn.Module):
203            def forward(self, x):
204                b = x.item()
205                torch._check_is_size(b)
206                return torch.full((b, 1), 1)
207
208        f = Module()
209        inp = (torch.tensor([3]),)
210        ref = f(*inp)
211
212        gm = export(f, inp)
213        res = gm.module()(*inp)
214
215        self.assertTrue(torchdynamo.utils.same(ref, res))
216
217        gm = make_fx(f, tracing_mode="symbolic")(*inp)
218        res = gm(*inp)
219        self.assertTrue(torchdynamo.utils.same(ref, res))
220
221    def test_export_constraints_error_not_in_range(self):
222        class InvalidInputConflictWithInputConstraints(torch.nn.Module):
223            def forward(self, x):
224                return x + 1
225
226        inp = torch.zeros([3])
227        dim_x = torch.export.Dim("dim_x", min=6)
228        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "not in range"):
229            torch.export.export(
230                InvalidInputConflictWithInputConstraints(),
231                (inp,),
232                dynamic_shapes={"x": {0: dim_x}},
233            )
234
235    def test_export_slice_maxsize(self):
236        class Slice(torch.nn.Module):
237            def forward(self, *args):
238                return torch.ops.aten.slice.Tensor(*args)
239
240        inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
241        dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
242        torch.export.export(
243            Slice(),
244            inp,
245            dynamic_shapes=dynamic_shapes,
246        )
247
248    def test_export_constraints_error(self):
249        class ConflictingConstraints(torch.nn.Module):
250            def forward(self, x):
251                b = x.item()
252                torch._check_is_size(b)
253                torch._check(b >= 4)
254                torch._check(b <= 5)
255                torch._check(b <= 5)
256                torch._check(True)
257                return torch.full((b, 1), 1)
258
259        inp = (torch.tensor([3]),)
260        ep = export(ConflictingConstraints(), inp)
261
262        with self.assertRaisesRegex(
263            RuntimeError, r"Runtime assertion failed for expression u[\d+] \>\= 4"
264        ):
265            ep.module()(torch.tensor([3]))
266
267    def test_export_assume_static_by_default(self):
268        class Module(torch.nn.Module):
269            def forward(self, x: torch.Tensor):
270                if x.shape[0] == 4:
271                    return x + 1
272                else:
273                    return x
274
275        branch_on_shape = Module()
276        inp = (torch.rand(4, 5),)
277
278        # Being able to export means shape is preserved as static
279        export(branch_on_shape, inp)
280
281
282@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
283@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
284class TestExport(TestCase):
285    def _test_export_same_as_eager(self, f, args, kwargs=None):
286        kwargs = kwargs or {}
287        exported_program = export(f, args, kwargs)
288        self.assertEqual(exported_program.module()(*args, **kwargs), f(*args, **kwargs))
289        # this is not supported by .module()
290        # reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
291        # self.assertEqual(
292        #     exported_program.module()(*args, **reversed_kwargs), f(*args, **reversed_kwargs)
293        # )
294
295    def _check_dynamic_shapes_specs_and_shapes(
296        self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False
297    ):
298        from torch._export.serde.dynamic_shapes import (
299            _dump_dynamic_shapes,
300            _load_dynamic_shapes,
301        )
302        from torch.utils._pytree import tree_map
303
304        def _construct_inputs(shapes):
305            def _is_tensor_leaf(x):
306                return isinstance(x, tuple) and all(isinstance(y, int) for y in x)
307
308            return tree_map(
309                lambda x: torch.randn(*x) if _is_tensor_leaf(x) else x,
310                shapes,
311                is_leaf=_is_tensor_leaf,
312            )
313
314        # exports with a list of equivalent dynamic shapes specs,
315        # then tests for pass/fail on list of shapes
316        for _specs in specs:
317            ep = export(model, inputs, dynamic_shapes=_specs)
318            eps = [ep]
319            if test_serdes:
320                # test dynamic shapes serialization
321                # test that behavior remains the same when exporting with ser/des specs:
322                # serialize + deserialize original specs, and export.
323                ep_serdes = export(
324                    model,
325                    inputs,
326                    dynamic_shapes=_load_dynamic_shapes(
327                        _dump_dynamic_shapes(_specs, inputs)
328                    ),
329                )
330                eps.append(ep_serdes)
331
332            for ep in eps:
333                for shapes in passing_shapes:
334                    test_inputs = _construct_inputs(shapes)
335                    ep.module()(*test_inputs)
336                for shapes in failing_shapes:
337                    test_inputs = _construct_inputs(shapes)
338                    with self.assertRaises(RuntimeError):
339                        ep.module()(*test_inputs)
340
341    def test_basic(self):
342        class Module(torch.nn.Module):
343            def forward(self, x, y):
344                return x[0] + y
345
346        f = Module()
347        inp = ([torch.ones(1, 3)], torch.ones(1, 3))
348        self._test_export_same_as_eager(f, inp)
349
350    def test_no_tensor_computation(self):
351        class Module(torch.nn.Module):
352            def forward(self, x, y):
353                return y
354
355        f = Module()
356        inp = ([torch.ones(1, 3)], 1)
357        ep = export(f, inp)
358        self.assertEqual(ep.module()(*inp), f(*inp))
359        self.assertExpectedInline(
360            str(ep.graph).strip(),
361            """\
362graph():
363    %x_0 : [num_users=0] = placeholder[target=x_0]
364    %y : [num_users=0] = placeholder[target=y]
365    return (1,)""",
366        )
367
368    def test_no_tensor_computation_2(self):
369        class Module(torch.nn.Module):
370            def forward(self, x, y):
371                return x
372
373        f = Module()
374        inp = (torch.randn(3), 1)
375        ep = export(f, inp)
376        self.assertEqual(ep.module()(*inp), f(*inp))
377        self.assertExpectedInline(
378            str(ep.graph).strip(),
379            """\
380graph():
381    %x : [num_users=1] = placeholder[target=x]
382    %y : [num_users=0] = placeholder[target=y]
383    return (x,)""",
384        )
385
386    def test_no_tensor_computation_3(self):
387        class Module(torch.nn.Module):
388            def forward(self, x, y):
389                return 5
390
391        f = Module()
392        inp = (2, 1)
393        ep = export(f, inp)
394        self.assertEqual(ep.module()(*inp), f(*inp))
395        self.assertExpectedInline(
396            str(ep.graph).strip(),
397            """\
398graph():
399    %x : [num_users=0] = placeholder[target=x]
400    %y : [num_users=0] = placeholder[target=y]
401    return (5,)""",
402        )
403
404    def test_no_tensor_computation_4(self):
405        class Module(torch.nn.Module):
406            def forward(self, x, y):
407                return x
408
409        f = Module()
410        inp = ([torch.randn(3)], 1)
411        ep = export(f, inp)
412        self.assertEqual(ep.module()(*inp), f(*inp))
413        self.assertExpectedInline(
414            str(ep.graph).strip(),
415            """\
416graph():
417    %x_0 : [num_users=1] = placeholder[target=x_0]
418    %y : [num_users=0] = placeholder[target=y]
419    return (x_0,)""",
420        )
421
422    def test_not_registered_parameter(self):
423        class Basic(torch.nn.Module):
424            def __init__(self):
425                super().__init__()
426                self.params = {"foo": torch.nn.Parameter(torch.ones(3, 3))}
427
428            def forward(self, x):
429                return x + self.params["foo"]
430
431        f = Basic()
432        args = (torch.randn(1, 3),)
433        # strict-mode will error out because foo is registered as parameter
434        # in dynamo (a behavior that's different from eager). We decided to
435        # follow eager behavior.
436        ep = export(f, args, strict=False)
437        gm = ep.module()
438        self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 1)
439        self.assertEqual(len(ep.graph_signature.parameters), 0)
440        # check foo is not a parameter in the final graph
441        self.assertEqual(len(list(gm.named_parameters())), 0)
442        self.assertEqual(gm(*args), f(*args))
443        self.assertExpectedInline(
444            str(gm.graph).strip(),
445            """\
446graph():
447    %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
448    %x : [num_users=1] = placeholder[target=x]
449    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_0), kwargs = {})
450    return (add,)""",
451        )
452
453    def test_external_call_non_strict_real_tensor(self):
454        class ExternalMethod:
455            def add(self, x):
456                return x + x
457
458        class Basic(torch.nn.Module):
459            def __init__(self) -> None:
460                super().__init__()
461                self.external_add = ExternalMethod().add
462
463            def forward(self, x):
464                return self.external_add(x)
465
466        f = Basic()
467        args = (torch.randn(1, 3),)
468        ep = export(f, args, strict=False)
469        self.assertEqual(ep.module()(*args), f(*args))
470
471    def test_colon_parameter(self):
472        class M(torch.nn.Module):
473            def __init__(self) -> None:
474                super().__init__()
475                self.register_parameter("foo:bar", torch.nn.Parameter(torch.ones(3, 3)))
476
477            def forward(self, x):
478                return x + getattr(self, "foo:bar")
479
480        ep = export(M(), (torch.randn(3, 3),))
481        x = torch.randn(3, 3)
482        self.assertEqual(ep.module()(x), M()(x))
483
484    def test_conv_dynamic(self):
485        # Simple module for demonstration
486        class M(torch.nn.Module):
487            def __init__(self) -> None:
488                super().__init__()
489                self.conv = torch.nn.Conv2d(
490                    in_channels=3, out_channels=32, kernel_size=3, padding=1
491                )
492                self.relu = torch.nn.ReLU()
493                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
494
495            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
496                a = self.conv(x)
497                a.add_(y)
498                return self.maxpool(self.relu(a))
499
500        example_args = (torch.randn(2, 3, 256, 256), torch.ones(2, 32, 256, 256))
501        dynamic_shapes = {"x": {0: Dim("batch")}, "y": {0: Dim("batch")}}
502        m = M()
503        exported_program: torch.export.ExportedProgram = export(
504            m, args=example_args, dynamic_shapes=dynamic_shapes
505        )
506
507        args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256))
508        self.assertEqual(exported_program.module()(*args), m(*args))
509        args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
510        self.assertEqual(exported_program.module()(*args), m(*args))
511
512        from torch._export import capture_pre_autograd_graph
513
514        gm: torch.fx.GraphModule = capture_pre_autograd_graph(
515            m, args=example_args, dynamic_shapes=dynamic_shapes
516        )
517
518        args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256))
519        self.assertEqual(gm(*args), m(*args))
520        args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
521        self.assertEqual(gm(*args), m(*args))
522
523    def test_masked_select_dynamic(self):
524        class M(torch.nn.Module):
525            def __init__(self) -> None:
526                super().__init__()
527
528            def forward(self, x: torch.Tensor) -> torch.Tensor:
529                mask = x.ge(0.5)
530                return torch.masked_select(x, mask)
531
532        example_args = (torch.randn(3, 4, 5),)
533        dim0_x_max, dim1_x_max = 100, 7
534        dynamic_shapes = {
535            "x": {
536                0: Dim("dim0_x", max=dim0_x_max),
537                1: Dim("dim1_x_max", max=dim1_x_max),
538            }
539        }
540        m = M()
541        exported_program: torch.export.ExportedProgram = export(
542            m, args=example_args, dynamic_shapes=dynamic_shapes
543        )
544
545        # Test that the expected upper bound is among the range constraints.
546        expected_upper_bound = dim0_x_max * dim1_x_max * 5
547        vr_upper_bounds = [
548            vr.upper for vr in exported_program.range_constraints.values()
549        ]
550        self.assertTrue(expected_upper_bound in set(vr_upper_bounds))
551        # Test that none of the upper bounds are larger.
552        for vr_upper in vr_upper_bounds:
553            self.assertTrue(vr_upper <= expected_upper_bound)
554
555    def test_setgrad_lifted_tensor(self):
556        class M(torch.nn.Module):
557            def forward(self, x, y):
558                with torch.enable_grad():
559                    c = torch.tensor(4)
560                    z = c + x + y
561
562                return z * z
563
564        m = M()
565        x = torch.randn(4)
566        y = torch.randn(4)
567        # Need to surround export with no_grad to bypass AutogradStateOpsFailSafeguard.
568        with torch.no_grad():
569            ep = export(m, (x, y))
570        self.assertEqual(ep.module()(x, y), m(x, y))
571
572    def test_basic_non_strict_real_tensor(self):
573        class Basic(torch.nn.Module):
574            def __init__(self) -> None:
575                super().__init__()
576                self.param = torch.nn.Parameter(torch.randn(1, 3))
577
578            def forward(self, x, y):
579                return x[0] + y - self.param
580
581        f = Basic()
582        args = ([torch.randn(1, 3)], torch.randn(1, 3))
583        ep = export(f, args, strict=False)
584        self.assertEqual(ep.module()(*args), f(*args))
585
586    def test_basic_non_strict_fake_tensor(self):
587        class Basic(torch.nn.Module):
588            def __init__(self) -> None:
589                super().__init__()
590                self.param = torch.nn.Parameter(torch.randn(3, 2))
591
592            def forward(self, x, y):
593                return x[0] + y - self.param
594
595        fake_mode = FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))
596        f = Basic()
597        with fake_mode:
598            args = ([torch.empty(3, 2)], torch.empty(3, 2))
599        ep = export(f, args, strict=False)
600        inputs = ([torch.randn(3, 2)], torch.randn(3, 2))
601        self.assertEqual(ep.module()(*inputs), f(*inputs))
602
603    def test_non_strict_dynamic_shapes(self):
604        class Foo(torch.nn.Module):
605            def __init__(self) -> None:
606                super().__init__()
607                self.u = torch.nn.Buffer(torch.ones(1))
608                self.v = torch.nn.Buffer(torch.ones(1))
609
610            def forward(self, x, ys, zs, c):
611                y = ys[0] + ys[1] + zs["a"] + zs["b"]
612                self.v.add_(3)
613                w = self.u - self.v
614                if x.shape[0] < 3 and c.shape[0] != 4:
615                    return x + w, x + y
616                else:
617                    return x - w, x - y
618
619        foo = Foo()
620
621        inp = (
622            torch.ones(5),
623            [torch.zeros(5), torch.ones(5)],
624            {"a": torch.zeros(5), "b": torch.ones(5)},
625            torch.ones(4),
626        )
627        dim = torch.export.Dim("dim", min=3)
628        dynamic_shapes = (
629            {0: dim},
630            [{0: dim}, {0: dim}],
631            {"a": {0: dim}, "b": {0: dim}},
632            None,
633        )
634
635        ep_ns = torch.export.export(
636            foo, inp, dynamic_shapes=dynamic_shapes, strict=False
637        )
638
639        bad_runtime_inp1 = (
640            torch.ones(6),
641            [torch.zeros(5), torch.ones(5)],
642            {"a": torch.zeros(5), "b": torch.ones(5)},
643            torch.ones(4),
644        )
645        with self.assertRaisesRegex(
646            RuntimeError,
647            escape(
648                "Expected input at *args[1][0].shape[0] to be equal to 6, but got 5"
649            ),
650        ):
651            ep_ns.module()(*bad_runtime_inp1)
652
653        bad_runtime_inp2 = (
654            torch.ones(5),
655            [torch.zeros(5), torch.ones(5)],
656            {"a": torch.zeros(5), "b": torch.ones(5)},
657            torch.ones(6),
658        )
659        with self.assertRaisesRegex(
660            RuntimeError,
661            escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"),
662        ):
663            ep_ns.module()(*bad_runtime_inp2)
664
665        good_runtime_inp = (
666            torch.ones(7),
667            [torch.zeros(7), torch.ones(7)],
668            {"a": torch.zeros(7), "b": torch.ones(7)},
669            torch.ones(4),
670        )
671        ep_ns.module()(*good_runtime_inp)
672
673        bad_example_inp = (
674            torch.ones(2),
675            [torch.zeros(2), torch.ones(2)],
676            {"a": torch.zeros(2), "b": torch.ones(2)},
677            torch.ones(4),
678        )
679        with self.assertRaisesRegex(
680            torch.fx.experimental.symbolic_shapes.ConstraintViolationError,
681            "2 not in range.*3,",
682        ):
683            ep_ns = torch.export.export(
684                foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
685            )
686
687    def test_non_strict_dynamic_shapes_suggested_fixes(self):
688        class Foo(torch.nn.Module):
689            def forward(self, x, c):
690                if x.shape[0] <= 6:
691                    return x + 1, c + 2
692                else:
693                    return x - 1, c - 2
694
695        foo = Foo()
696
697        bad_example_inp = (
698            torch.ones(5),
699            torch.ones(4),
700        )
701        dim = torch.export.Dim("dim", min=3)
702        dynamic_shapes = (
703            {0: dim},
704            None,
705        )
706
707        with self.assertRaisesRegex(
708            torch._dynamo.exc.UserError,
709            "Constraints violated \\(dim\\)!(.*\n)*.*"
710            "Not all values of dim.*satisfy the generated guard(.*\n)*.*"
711            "Suggested fixes:(.*\n)*.*"
712            "dim = Dim\\('dim', min=3, max=6\\)",
713        ):
714            torch.export.export(
715                foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
716            )
717
718    def test_unbacked_to_cond(self):
719        class M(torch.nn.Module):
720            def forward(self, a):
721                az = a.nonzero()
722
723                def true_fn(x):
724                    return (x + 1).sum()
725
726                def false_fn(x):
727                    return (x + 3).sum()
728
729                r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
730                return r * 2
731
732        M()(torch.randn(7))
733        torch.export.export(M(), (torch.randn(7),))
734
735    def test_unbacked_to_cond_passthrough(self):
736        class M(torch.nn.Module):
737            def forward(self, a):
738                az = a.nonzero()
739
740                def true_fn(x):
741                    return x + 1
742
743                def false_fn(x):
744                    return x + 3
745
746                r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
747                return r * 2
748
749        M()(torch.randn(7))
750        torch.export.export(M(), (torch.randn(7),))
751
752    @torch._dynamo.config.patch(capture_scalar_outputs=True)
753    def test_cond_contains_unbacked_no_escape(self):
754        class M(torch.nn.Module):
755            def forward(self, a, b1, b2, c):
756                def true_fn(x):
757                    return x * b1.item()
758
759                def false_fn(x):
760                    return x * b2.item()
761
762                r = torch.cond(a, true_fn, false_fn, (c,))
763                return r * 2
764
765        args = (
766            torch.tensor(True),
767            torch.tensor([4]),
768            torch.tensor([4]),
769            torch.randn(10, requires_grad=True),
770        )
771        torch.export.export(M(), args)
772
773    def test_state_tensors(self):
774        class M(torch.nn.Module):  # simple with register buffer
775            def __init__(self) -> None:
776                super().__init__()
777                self.buf = torch.nn.Buffer(torch.ones(2, 3), persistent=False)
778
779            def forward(self, x):
780                # x = 2
781                y = self.buf
782                # y = 1
783                w1 = self.buf + 3
784                w2 = self.buf + 4
785                w3 = self.buf + 5
786                self.buf = w1
787                z = self.buf
788                self.buf = w3
789                # z = 4
790                return x + y + z + w2
791
792        ep = torch.export.export(M(), (torch.randn(2, 3),), strict=False)
793        self.assertEqual(ep.graph_signature.buffers_to_mutate, {"add_2": "buf"})
794        self.assertTrue(
795            torch.allclose(ep.module()(torch.ones(2, 3) + 1), torch.ones(2, 3) * 12)
796        )
797
798        class M(torch.nn.Module):  # simple without register buffer
799            def __init__(self) -> None:
800                super().__init__()
801                self.buf = torch.ones(2, 3)
802
803            def forward(self, x):
804                # x = 2
805                y = self.buf
806                # y = 1
807                self.buf = self.buf + 3
808                z = self.buf
809                # z = 3
810                return x + y + z
811
812        with self.assertRaisesRegex(
813            ValueError,
814            "The tensor attribute self.buf was assigned during export",
815        ):
816            torch.export.export(M(), (torch.randn(2, 3),), strict=False)
817
818        class M(torch.nn.Module):  # complex with register buffer
819            def __init__(self) -> None:
820                super().__init__()
821                tensors = [torch.ones(2, 3), torch.ones(2, 3)]
822                for i, tensor in enumerate(tensors):
823                    self.register_buffer(f"buf_{i}", tensor, persistent=False)
824
825            def get_tensor(self, i):
826                return getattr(self, f"buf_{i}")
827
828            def set_tensor(self, i, val):
829                setattr(self, f"buf_{i}", val)
830
831            def forward(self, x):
832                # x = 2
833                y = self.get_tensor(0) + self.get_tensor(1)
834                # y = 1 + 1
835                self.set_tensor(0, torch.ones(2, 3) + 2)
836                self.set_tensor(1, torch.ones(2, 3) + 2)
837                z = self.get_tensor(0) + self.get_tensor(1)
838                # z = 3 + 3
839                return x + y + z
840
841        ep = torch.export.export(M(), (torch.randn(2, 3),), strict=False)
842        self.assertEqual(
843            ep.graph_signature.buffers_to_mutate, {"add_1": "buf_0", "add_2": "buf_1"}
844        )
845        self.assertTrue(
846            torch.allclose(ep.module()(torch.ones(2, 3) + 1), torch.ones(2, 3) * 10)
847        )
848
849        class M(torch.nn.Module):  # complex without register buffer
850            def __init__(self) -> None:
851                super().__init__()
852                self.tensors = [torch.ones(2, 3), torch.ones(2, 3)]
853
854            def get_tensor(self, i):
855                return self.tensors[i]
856
857            def set_tensor(self, i, val):
858                self.tensors[i] = val
859
860            def forward(self, x):
861                # x = 2
862                y = self.get_tensor(0) + self.get_tensor(1)
863                # y = 1 + 1
864                self.set_tensor(0, torch.ones(2, 3) + 2)
865                self.set_tensor(1, torch.ones(2, 3) + 2)
866                z = self.get_tensor(0) + self.get_tensor(1)
867                # z = 3 + 3
868                return x + y + z
869
870        with self.assertRaisesRegex(
871            ValueError,
872            "The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
873        ):
874            torch.export.export(M(), (torch.randn(2, 3),), strict=False)
875
876    def test_state_primitives(self):
877        class M(torch.nn.Module):
878            def __init__(self) -> None:
879                super().__init__()
880                self.x = 1
881                self.y = {"k": 2}
882                self.z = (3,)
883
884            def forward(self, x):
885                self.x = self.x + 4
886                self.y["k"] = self.y["k"] + 5
887                self.z = (self.z[0] + 6,)
888                return x + self.x + self.y["k"] + self.z[0]
889
890        ep = export(M(), (torch.randn(2, 3),))
891        self.assertTrue(
892            torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21)
893        )
894
895    def test_export_script_module(self):
896        class Foo(torch.nn.Module):
897            def forward(self, rv: torch.Tensor, t: torch.Tensor):
898                i = t.item()
899                return rv + i
900
901        foo = Foo()
902        foo_script = torch.jit.script(foo)
903        inp = (torch.zeros(3, 4), torch.tensor(7))
904
905        with self.assertRaisesRegex(
906            ValueError, "Exporting a ScriptModule is not supported"
907        ):
908            export(foo_script, inp)
909
910        from torch._export.converter import TS2EPConverter
911
912        TS2EPConverter(foo_script, inp).convert()
913
914    def test_torch_fn(self):
915        class M1(torch.nn.Module):
916            def __init__(self) -> None:
917                super().__init__()
918                self.linear = torch.nn.Linear(3, 3)
919                self.relu = torch.nn.ReLU()
920
921            def forward(self, x):
922                x = self.linear(x)
923                x = self.linear(x)
924                x = self.relu(x)
925                x = x + x
926                return x
927
928        ep1 = export(M1(), (torch.randn(3, 3),)).run_decompositions()
929        expected_result = [
930            ("linear_1", "builtin_function_or_method.linear"),
931            ("linear_1", "builtin_function_or_method.linear"),
932            ("linear_2", "builtin_function_or_method.linear"),
933            ("linear_2", "builtin_function_or_method.linear"),
934            ("relu_1", "function.relu"),
935            ("add_1", "method_descriptor.add"),
936        ]
937        actual_result = []
938        for i, node in enumerate(ep1.graph.nodes):
939            if node.op == "call_function":
940                actual_result.append(node.meta.get("torch_fn"))
941        self.assertEqual(actual_result, expected_result)
942
943        class M2(torch.nn.Module):
944            def __init__(self) -> None:
945                super().__init__()
946
947            def forward(self, x, weight, bias):
948                x = torch.nn.functional.linear(x, weight, bias)
949                x = torch.nn.functional.relu(x)
950                x = torch.add(x, x)
951                return x
952
953        ep2 = export(
954            M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3))
955        ).run_decompositions()
956        expected_result = [
957            ("linear_1", "builtin_function_or_method.linear"),
958            ("linear_1", "builtin_function_or_method.linear"),
959            ("relu_1", "function.relu"),
960            ("add_1", "builtin_function_or_method.add"),
961        ]
962        actual_result = []
963        for i, node in enumerate(ep2.graph.nodes):
964            if node.op == "call_function":
965                actual_result.append(node.meta.get("torch_fn"))
966        self.assertEqual(actual_result, expected_result)
967
968    @testing.expectedFailureSerDer  # failed serializing SymInt nodes in subgraph (known issue)
969    def test_hoo_inline_users_issue(self):
970        # This came from an issue where replace_with_hop passes would inline subgraphs,
971        # and mess up node.users for nodes present in multiple subgraphs (e.g. _x in SetGradCase
972        # below, since it's used in both set_grad_enabled HOO modules).
973        # This checks that node.users and node.args are in correspondence.
974        def check_users_for_graph(graph):
975            def _tuple_contains(_tuple, val):
976                # check nested, since output node args have format ((x, y, ...),)
977                return any(
978                    _tuple_contains(x, val) if isinstance(x, tuple) else x == val
979                    for x in _tuple
980                )
981
982            for node in graph.nodes:
983                # check node.users
984                for user in node.users.keys():
985                    assert _tuple_contains(user.args, node)
986                # check node.args
987                for arg in node.args:
988                    if isinstance(arg, torch.fx.Node):
989                        assert _tuple_contains(arg.users, node)
990
991        # check set grad enabled
992        class SetGradCase(torch.nn.Module):
993            def forward(self, x):
994                _x = x.shape[0] + 2
995                _xx = _x + 2
996                with torch.no_grad():
997                    y = _x * 4
998                return _xx, y
999
1000        ep = export(
1001            SetGradCase(),
1002            (torch.randn(6),),
1003            dynamic_shapes={"x": (Dim("dx"),)},
1004            strict=False,
1005        )
1006        check_users_for_graph(ep.graph)
1007
1008    def test_export_predispatch_custom_ops_warnings(self):
1009        @torch.library.custom_op("mylib::foo", mutates_args={})
1010        def foo(x: torch.Tensor) -> torch.Tensor:
1011            return x.sin()
1012
1013        @foo.register_fake
1014        def _(x):
1015            return torch.empty_like(x)
1016
1017        class Foo(torch.nn.Module):
1018            def forward(self, x):
1019                return foo(x)
1020
1021        x = torch.randn(3)
1022
1023        # Assert no warnings
1024        with warnings.catch_warnings():
1025            warnings.simplefilter("error")
1026            torch.export.export(Foo(), (x,))
1027
1028        # Assert warning for CompositeImplictAutograd op
1029        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1030            lib.define("foo123(Tensor x) -> Tensor")
1031            lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd")
1032
1033            class Bar(torch.nn.Module):
1034                def forward(self, x):
1035                    return torch.ops.mylib.foo123(x)
1036
1037            with self.assertWarnsRegex(
1038                UserWarning, "CompositeImplicitAutograd and have functional schema"
1039            ):
1040                with warnings.catch_warnings():
1041                    warnings.simplefilter("always")
1042                    torch.export.export(Bar(), (x,))
1043
1044    def test_export_preserve_linear_at_aot_level(self):
1045        class Foo(torch.nn.Module):
1046            def __init__(self) -> None:
1047                super().__init__()
1048                self.linear = torch.nn.Linear(3, 3)
1049
1050            def forward(self, x):
1051                x = self.linear(x)
1052                return torch.ops.aten.chunk.default(x, 3, 0)
1053
1054        gm = (
1055            torch.export.export(
1056                Foo(),
1057                (torch.randn(3, 3),),
1058            )
1059            .run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,))
1060            .graph_module
1061        )
1062        # linear is CompositeImplicitAutograd functional op so we should preserve it
1063        # chunk is CompositeImplicitAutograd non-functional op we decompose.
1064        self.assertExpectedInline(
1065            str(gm.code).strip(),
1066            """\
1067def forward(self, p_linear_weight, p_linear_bias, x):
1068    linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
1069    split = torch.ops.aten.split.Tensor(linear, 1);  linear = None
1070    getitem = split[0]
1071    getitem_1 = split[1]
1072    getitem_2 = split[2];  split = None
1073    return (getitem, getitem_1, getitem_2)""",
1074        )
1075
1076    def test_export_cond_preserve_torch_fn_for_subgraphs(self):
1077        class MySubModule(torch.nn.Module):
1078            def foo(self, x):
1079                return x.cos()
1080
1081            def forward(self, x):
1082                return self.foo(x)
1083
1084        class CondBranchClassMethod(torch.nn.Module):
1085            def __init__(self) -> None:
1086                super().__init__()
1087                self.subm = MySubModule()
1088
1089            def bar(self, x):
1090                return x.sin()
1091
1092            def forward(self, x):
1093                return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
1094
1095        example_inputs = (torch.randn(1, 3, 3, 3),)
1096        m = CondBranchClassMethod()
1097        m.eval()
1098        gm = export(m, example_inputs).module()
1099
1100        actual_torch_fns = []
1101        for mod in gm.modules():
1102            for node in mod.graph.nodes:
1103                if node.name in {"sin", "cos"}:
1104                    torch_fn = node.meta.get("torch_fn")
1105                    print(torch_fn)
1106                    actual_torch_fns.append(torch_fn)
1107        exp_torch_fns = [
1108            ("cos_1", "method_descriptor.cos"),
1109            ("sin_1", "method_descriptor.sin"),
1110        ]
1111        self.assertEqual(actual_torch_fns, exp_torch_fns)
1112
1113    def test_derived_dim_basic(self):
1114        class Foo(torch.nn.Module):
1115            def forward(self, x, y):
1116                return x + y[1:]
1117
1118        foo = Foo()
1119
1120        x, y = torch.randn(5), torch.randn(6)
1121        dimx = torch.export.Dim("dimx", min=3, max=6)
1122
1123        dimy = torch.export.Dim("dimy", min=4, max=7)  # doesn't work
1124        with self.assertRaisesRegex(
1125            torch._dynamo.exc.UserError,
1126            (
1127                "Constraints violated \\(dimy\\)!(.*\n)*.*"
1128                "The values of dimy.*must always be related to the values of dimx.*by.*(.*\n)*.*"
1129                "Suggested fixes:(.*\n)*.*"
1130                "dimy = dimx \\+ 1"
1131            ),
1132        ):
1133            export(
1134                foo,
1135                (x, y),
1136                dynamic_shapes=({0: dimx}, {0: dimy}),
1137            )
1138
1139        dimy = dimx * 2  # doesn't work
1140        with self.assertRaisesRegex(
1141            torch._dynamo.exc.UserError,
1142            "Expected input.*size.* to be equal to 2\\*dimx, where dimx = 5, but got 6",
1143        ):
1144            export(
1145                foo,
1146                (x, y),
1147                dynamic_shapes=({0: dimx}, {0: dimy}),
1148            )
1149
1150        dimy = dimx + 1  # works
1151        ep = export(
1152            foo,
1153            (x, y),
1154            dynamic_shapes=({0: dimx}, {0: dimy}),
1155        )
1156        with self.assertRaisesRegex(
1157            RuntimeError,
1158            "Expected input.*shape.*to be equal to 5, but got 6",
1159        ):
1160            ep.module()(torch.randn(4), torch.randn(6))
1161
1162        self.assertEqual(ep.module()(torch.randn(4), torch.randn(5)).size()[0], 4)
1163
1164    def test_derived_dim_nested(self):
1165        class Foo(torch.nn.Module):
1166            def forward(self, x, y):
1167                return x + y[1::2]
1168
1169        foo = Foo()
1170
1171        x, y = torch.randn(5), torch.randn(11)
1172        dimx = torch.export.Dim("dimx", min=3, max=6)
1173        dimy = dimx * 2 + 1  # works
1174        ep = export(
1175            foo,
1176            (x, y),
1177            dynamic_shapes=({0: dimx}, {0: dimy}),
1178        )
1179        self.assertEqual(ep.module()(torch.randn(4), torch.randn(9)).size()[0], 4)
1180
1181        class Foo(torch.nn.Module):
1182            def forward(self, z, y):
1183                return z[1:] + y[1::2]
1184
1185        foo = Foo()
1186
1187        z, y = torch.randn(6), torch.randn(11)
1188
1189        dimz = dimx
1190        dimy = dimx * 2 - 1  # works
1191        ep = export(
1192            foo,
1193            (z, y),
1194            dynamic_shapes=({0: dimz}, {0: dimy}),
1195        )
1196        self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4)
1197
1198        dimz = dimx + 1
1199        dimy = dimx * 2 - 1  # doesn't work
1200
1201        with self.assertRaisesRegex(
1202            torch._dynamo.exc.UserError,
1203            "Expected input.*size.*to be equal to 2\\*dimx - 1, where dimx = 5, but got 11",
1204        ):
1205            export(
1206                foo,
1207                (z, y),
1208                dynamic_shapes=({0: dimz}, {0: dimy}),
1209            )
1210
1211        dimy = dimx * 2 + 1  # works
1212        ep = export(
1213            foo,
1214            (z, y),
1215            dynamic_shapes=({0: dimz}, {0: dimy}),
1216        )
1217        with self.assertRaisesRegex(
1218            RuntimeError, "Expected input.*shape.*to be <= 7, but got 8"
1219        ):
1220            ep.module()(torch.randn(8), torch.randn(15))
1221        with self.assertRaisesRegex(
1222            RuntimeError,
1223            "Expected input.*shape.*to be equal to 9, but got 8",
1224        ):
1225            ep.module()(torch.randn(5), torch.randn(8))
1226
1227        self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4)
1228
1229    def test_derived_dim_integer(self):
1230        class Foo(torch.nn.Module):
1231            def forward(self, w):
1232                if w.shape[0] % 2 == 0:
1233                    return w[::2]
1234                else:
1235                    return w[1:-1:2]
1236
1237        foo = Foo()
1238
1239        w = torch.randn(10)
1240        dimx = torch.export.Dim("dimx", min=3, max=6)
1241        dimw = dimx * 2 + 1  # doesn't work
1242        with self.assertRaisesRegex(
1243            torch._dynamo.exc.UserError,
1244            "Expected shape.*= 10 of input Tensor to be "
1245            "of the form 2\\*dimx \\+ 1, where dimx is an integer",
1246        ):
1247            export(
1248                foo,
1249                (w,),
1250                dynamic_shapes=({0: dimw},),
1251            )
1252
1253        dimw = dimx * 2  # works
1254        ep = export(
1255            foo,
1256            (w,),
1257            dynamic_shapes=({0: dimw},),
1258        )
1259        with self.assertRaisesRegex(
1260            RuntimeError,
1261            "Expected input.*shape.*= 9 to be "
1262            "of the form 2\\*s1, where s1 is an integer",
1263        ):
1264            ep.module()(torch.randn(9))
1265
1266        self.assertEqual(ep.module()(torch.randn(8)).size()[0], 4)
1267        with self.assertRaisesRegex(
1268            RuntimeError,
1269            "Expected input.*shape.*to be <= 12, but got 14",
1270        ):
1271            ep.module()(torch.randn(14))
1272
1273    def test_derived_dim_repeat_derived(self):
1274        class Foo(torch.nn.Module):
1275            def forward(self, u, v):
1276                return u[::2] + v[::2]
1277
1278        foo = Foo()
1279
1280        u, v = torch.randn(10), torch.randn(10)
1281        dimx = torch.export.Dim("dimx", min=3, max=6)
1282        dimw = dimx * 2  # works
1283        ep = export(
1284            foo,
1285            (u, v),
1286            dynamic_shapes=({0: dimw}, {0: dimw}),
1287        )
1288        self.assertEqual(ep.module()(torch.randn(8), torch.randn(8)).size()[0], 4)
1289
1290    def test_derived_dim_out_of_order(self):
1291        dimy = torch.export.Dim("dimy", min=5, max=7)
1292        dimx = dimy - 1  # out of order, effectively dimy = dimx + 1
1293        dimz = dimy + 1  # out of order, effectively dimz = dimx + 2
1294
1295        class Foo(torch.nn.Module):
1296            def forward(self, x, y, z):
1297                return x + y[1:] + z[2:]
1298
1299        foo = Foo()
1300
1301        u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
1302        ep = export(
1303            foo,
1304            (u, v, w),
1305            dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
1306        )
1307        with self.assertRaisesRegex(
1308            RuntimeError,
1309            "Expected input.*shape.*to be equal to 8, but got 5",
1310        ):
1311            ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
1312
1313        self.assertEqual(
1314            ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6
1315        )
1316
1317    def test_derived_dim_out_of_order_repeat_derived(self):
1318        dimy = torch.export.Dim("dimy", min=5, max=7)
1319        dimx = dimy - 1  # out of order, effectively dimy = dimx + 1
1320        dimz = dimy + 1  # out of order, effectively dimz = dimx + 2
1321        dimx1 = dimx
1322        dimx2 = dimz - 2  # works, effectively = dimx
1323
1324        class Foo(torch.nn.Module):
1325            def forward(self, x, y, z, x1, x2):
1326                return x + y[1:] + z[2:] + x1 + x2
1327
1328        foo = Foo()
1329
1330        u, v, w, u1, u2 = (
1331            torch.randn(5),
1332            torch.randn(6),
1333            torch.randn(7),
1334            torch.randn(5),
1335            torch.randn(5),
1336        )
1337        ep = export(
1338            foo,
1339            (u, v, w, u1, u2),
1340            dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
1341        )
1342        with self.assertRaisesRegex(
1343            RuntimeError,
1344            "Expected input.*shape.*to be equal to 6, but got 5",
1345        ):
1346            ep.module()(
1347                torch.randn(6),
1348                torch.randn(7),
1349                torch.randn(8),
1350                torch.randn(6),
1351                torch.randn(5),
1352            )
1353
1354        self.assertEqual(
1355            ep.module()(
1356                torch.randn(6),
1357                torch.randn(7),
1358                torch.randn(8),
1359                torch.randn(6),
1360                torch.randn(6),
1361            ).size()[0],
1362            6,
1363        )
1364
1365        ep = export(
1366            foo,
1367            (u, v, w, u, u),  # reused inputs
1368            dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
1369        )
1370        with self.assertRaisesRegex(
1371            RuntimeError,
1372            "Expected input.*shape.*to be equal to 6, but got 5",
1373        ):
1374            ep.module()(
1375                torch.randn(6),
1376                torch.randn(7),
1377                torch.randn(8),
1378                torch.randn(6),
1379                torch.randn(5),
1380            )
1381
1382        self.assertEqual(
1383            ep.module()(
1384                torch.randn(6),
1385                torch.randn(7),
1386                torch.randn(8),
1387                torch.randn(6),
1388                torch.randn(6),
1389            ).size()[0],
1390            6,
1391        )
1392
1393    def test_specialize_derived_dim_roots(self):
1394        # dim & derived dim both specialize
1395        class Foo(torch.nn.Module):
1396            def forward(self, x, y):
1397                return x.reshape([-1]) + y
1398
1399        dy = Dim("dy", min=6)
1400        x, y = torch.randn(6, 2), torch.randn(12)
1401        dynamic_shapes = {
1402            "x": (dy - 6, 2),
1403            "y": (dy,),
1404        }
1405        try:
1406            export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
1407            raise Exception(
1408                "export() call should have failed with dynamic shapes error."
1409            )
1410        except torch._dynamo.exc.UserError as exc:
1411            expected_error_msg = (
1412                "Specializations unexpectedly required \(dy\)!(.*\n)*.*"
1413                ".*solving the guards generated for dy - 6.*resulted in a specialized value of 6(.*\n)*.*"
1414                "Suggested fixes(.*\n)*.*"
1415                ".*dy = 12(.*\n)*.*"
1416            )
1417            self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
1418            self.assertTrue(
1419                "dy - 6 = 6" not in exc.args[0]
1420            )  # don't suggest fix for non-root dim
1421
1422    def test_keep_composite_ops_invalid(self):
1423        class Foo(torch.nn.Module):
1424            def __init__(self) -> None:
1425                super().__init__()
1426                self.linear = torch.nn.Linear(3, 3)
1427
1428            def forward(self, x):
1429                x = self.linear(x)
1430                return torch.ops.aten.chunk.default(x, 3, 0)
1431
1432        with self.assertRaisesRegex(
1433            RuntimeError, "aten.chunk.default is a mutating/aliasing op"
1434        ):
1435            _ = torch.export.export(
1436                Foo(),
1437                (torch.randn(3, 3),),
1438            ).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,))
1439
1440        with self.assertRaisesRegex(
1441            RuntimeError, "aten.sym_size.default is a metadata query function"
1442        ):
1443            _ = torch.export.export(
1444                Foo(),
1445                (torch.randn(3, 3),),
1446            ).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,))
1447
1448        with self.assertRaisesRegex(
1449            RuntimeError,
1450            "We can't detect aten.native_batch_norm.default as a functional op statically",
1451        ):
1452            _ = torch.export.export(
1453                Foo(),
1454                (torch.randn(3, 3),),
1455            ).run_decompositions(
1456                {}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,)
1457            )
1458
1459    def test_keep_composite_ops_linear_convd(self):
1460        class MyLinear(torch.nn.Module):
1461            def __init__(self) -> None:
1462                super().__init__()
1463                self.weight = torch.randn(20, 98)
1464                self.bias = torch.randn(20)
1465
1466            def forward(self, x):
1467                return torch.nn.functional.linear(x, self.weight, self.bias)
1468
1469        class Foo(torch.nn.Module):
1470            def __init__(self) -> None:
1471                super().__init__()
1472                self.conv = torch.nn.Conv2d(16, 33, 3)
1473                self.conv1d = torch.nn.Conv1d(16, 33, 3)
1474                self.linear = MyLinear()
1475
1476            def forward(self, x, y):
1477                x_conv = self.conv(x)
1478                y_conv_1d = self.conv1d(y)
1479                x_linear = self.linear(x_conv)
1480                return x_linear.cos() + y_conv_1d.sum()
1481
1482        ep = torch.export.export(
1483            Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
1484        )
1485        ep_has_linear_convd = ep.run_decompositions(
1486            decomp_table={},
1487            _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
1488        )
1489        self.assertExpectedInline(
1490            str(ep_has_linear_convd.graph_module.code).strip(),
1491            """\
1492def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
1493    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1494    conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias);  y = p_conv1d_weight = p_conv1d_bias = None
1495    linear = torch.ops.aten.linear.default(conv2d, c_linear_weight, c_linear_bias);  conv2d = c_linear_weight = c_linear_bias = None
1496    cos = torch.ops.aten.cos.default(linear);  linear = None
1497    sum_1 = torch.ops.aten.sum.default(conv1d);  conv1d = None
1498    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1499    return (add,)""",
1500        )
1501
1502        ep_has_convd = ep.run_decompositions(
1503            decomp_table=None,
1504            _preserve_ops=[
1505                torch.ops.aten.conv2d.default,
1506                torch.ops.aten.conv1d.default,
1507            ],
1508        )
1509        self.assertExpectedInline(
1510            str(ep_has_convd.graph_module.code).strip(),
1511            """\
1512def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
1513    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1514    conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias);  y = p_conv1d_weight = p_conv1d_bias = None
1515    view = torch.ops.aten.view.default(conv2d, [31680, 98]);  conv2d = None
1516    permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]);  c_linear_weight = None
1517    addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute);  c_linear_bias = view = permute = None
1518    view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]);  addmm = None
1519    cos = torch.ops.aten.cos.default(view_1);  view_1 = None
1520    sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []);  conv1d = None
1521    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1522    return (add,)""",
1523        )
1524
1525        ep_has_convd = ep_has_convd.run_decompositions(
1526            decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default]
1527        )
1528        self.assertExpectedInline(
1529            str(ep_has_convd.graph_module.code).strip(),
1530            """\
1531def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
1532    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1533    convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1);  y = p_conv1d_weight = p_conv1d_bias = None
1534    view = torch.ops.aten.view.default(conv2d, [31680, 98]);  conv2d = None
1535    permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]);  c_linear_weight = None
1536    addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute);  c_linear_bias = view = permute = None
1537    view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]);  addmm = None
1538    cos = torch.ops.aten.cos.default(view_1);  view_1 = None
1539    sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []);  convolution = None
1540    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1541    return (add,)""",
1542        )
1543
1544    def test_keep_composite_ops_linear_convd_for_training_ir(self):
1545        class MyLinear(torch.nn.Module):
1546            def __init__(self) -> None:
1547                super().__init__()
1548                self.weight = torch.nn.Buffer(torch.randn(20, 98))
1549                self.bias = torch.nn.Buffer(torch.randn(20))
1550
1551            def forward(self, x):
1552                return torch.nn.functional.linear(x, self.weight, self.bias)
1553
1554        class Foo(torch.nn.Module):
1555            def __init__(self) -> None:
1556                super().__init__()
1557                self.conv = torch.nn.Conv2d(16, 33, 3)
1558                self.conv1d = torch.nn.Conv1d(16, 33, 3)
1559                self.linear = MyLinear()
1560
1561            def forward(self, x, y):
1562                x_conv = self.conv(x)
1563                y_conv_1d = self.conv1d(y)
1564                x_linear = self.linear(x_conv)
1565                return x_linear.cos() + y_conv_1d.sum()
1566
1567        ep = torch.export.export_for_training(
1568            Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
1569        )
1570        ep_has_linear_convd = ep.run_decompositions(
1571            decomp_table={},
1572            _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
1573        )
1574
1575        self.assertExpectedInline(
1576            str(ep_has_linear_convd.graph_module.code).strip(),
1577            """\
1578def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y):
1579    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1580    conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias);  y = p_conv1d_weight = p_conv1d_bias = None
1581    linear = torch.ops.aten.linear.default(conv2d, b_linear_weight, b_linear_bias);  conv2d = b_linear_weight = b_linear_bias = None
1582    cos = torch.ops.aten.cos.default(linear);  linear = None
1583    sum_1 = torch.ops.aten.sum.default(conv1d);  conv1d = None
1584    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1585    return (add,)""",
1586        )
1587
1588        ep_has_convd = ep.run_decompositions(
1589            decomp_table=None,
1590            _preserve_ops=[
1591                torch.ops.aten.conv2d.default,
1592                torch.ops.aten.conv1d.default,
1593            ],
1594        )
1595
1596        self.assertExpectedInline(
1597            str(ep_has_convd.graph_module.code).strip(),
1598            """\
1599def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y):
1600    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1601    conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias);  y = p_conv1d_weight = p_conv1d_bias = None
1602    view = torch.ops.aten.view.default(conv2d, [31680, 98]);  conv2d = None
1603    permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]);  b_linear_weight = None
1604    addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute);  b_linear_bias = view = permute = None
1605    view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]);  addmm = None
1606    cos = torch.ops.aten.cos.default(view_1);  view_1 = None
1607    sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []);  conv1d = None
1608    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1609    return (add,)""",
1610        )
1611
1612        ep_has_convd = ep_has_convd.run_decompositions(
1613            decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default]
1614        )
1615
1616        self.assertExpectedInline(
1617            str(ep_has_convd.graph_module.code).strip(),
1618            """\
1619def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y):
1620    conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
1621    convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1);  y = p_conv1d_weight = p_conv1d_bias = None
1622    view = torch.ops.aten.view.default(conv2d, [31680, 98]);  conv2d = None
1623    permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]);  b_linear_weight = None
1624    addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute);  b_linear_bias = view = permute = None
1625    view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]);  addmm = None
1626    cos = torch.ops.aten.cos.default(view_1);  view_1 = None
1627    sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []);  convolution = None
1628    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
1629    return (add,)""",
1630        )
1631
1632    def test_set_grad_empty(self):
1633        class M(torch.nn.Module):
1634            def forward(self, x):
1635                with torch.no_grad():
1636                    x = x + 1
1637                    return x, None
1638
1639        ep = export(M(), (torch.ones(3, 3),))
1640        inp = torch.randn(3, 3)
1641        self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1))
1642
1643    def test_derived_dim_out_of_order_simplified(self):
1644        _dimz = torch.export.Dim("_dimz", min=6, max=8)
1645        dimy = _dimz - 1
1646        dimx = dimy - 1
1647        dimz = torch.export.Dim("dimz", min=6, max=8)  # doesn't work, should be = _dimz
1648
1649        class Foo(torch.nn.Module):
1650            def forward(self, x, y, z):
1651                return x + y[1:] + z[2:]
1652
1653        foo = Foo()
1654        u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
1655        try:
1656            export(
1657                foo,
1658                (u, v, w),
1659                dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
1660            )
1661        except torch._dynamo.exc.UserError as exc:
1662            expected_error_msg = (
1663                "Constraints violated \(dimz\)!(.*\n)*.*"
1664                "The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*"
1665                "Suggested fixes:(.*\n)*.*"
1666                "dimz = _dimz"
1667            )
1668            self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None)
1669            # don't suggest fix for non-root dims, and no need to update root here
1670            self.assertTrue("_dimz - 2 = Dim(" not in exc.args[0])
1671            self.assertTrue("_dimz - 1 = _dimz - 1" not in exc.args[0])
1672            self.assertTrue("_dimz = Dim(" not in exc.args[0])
1673
1674        dimz = dimx + 2  # works, effectively = _dimz
1675        ep = export(
1676            foo,
1677            (u, v, w),
1678            dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
1679        )
1680        with self.assertRaisesRegex(
1681            RuntimeError,
1682            "Expected input.*shape.*to be equal to 8, but got 5",
1683        ):
1684            ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
1685
1686        self.assertEqual(
1687            ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6
1688        )
1689
1690    def test_simple_export_for_training(self):
1691        class Foo(torch.nn.Module):
1692            def __init__(self) -> None:
1693                super().__init__()
1694                self.linear = torch.nn.Linear(2, 2)
1695
1696            def forward(self, x):
1697                return self.linear(x)
1698
1699        eager_model = Foo()
1700        ep_for_training = torch.export.export_for_training(
1701            eager_model, (torch.ones(2, 2),)
1702        )
1703        self.assertExpectedInline(
1704            str(ep_for_training.graph_module.code).strip(),
1705            """\
1706def forward(self, p_linear_weight, p_linear_bias, x):
1707    linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
1708    return (linear,)""",
1709        )
1710        gm = ep_for_training.module()
1711        self.assertExpectedInline(
1712            str(gm.code).strip(),
1713            """\
1714def forward(self, x):
1715    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
1716    linear_weight = self.linear.weight
1717    linear_bias = self.linear.bias
1718    linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias);  x = linear_weight = linear_bias = None
1719    return pytree.tree_unflatten((linear,), self._out_spec)""",
1720        )
1721
1722        self.assertTrue(
1723            torch.allclose(gm(torch.ones(2, 2)), eager_model(torch.ones(2, 2)))
1724        )
1725
1726    def test_export_for_training_with_mutation(self):
1727        class Foo(torch.nn.Module):
1728            def __init__(self) -> None:
1729                super().__init__()
1730                self.buffer = torch.nn.Buffer(torch.ones(4, 4))
1731
1732            def forward(self, x):
1733                x.add_(5)
1734                self.buffer.add_(5)
1735                return x + self.buffer
1736
1737        eager_model_for_export = Foo()
1738        eager_model_for_testing = Foo()
1739        ep_for_training = torch.export.export_for_training(
1740            eager_model_for_export, (torch.ones(4, 4),)
1741        )
1742        self.assertExpectedInline(
1743            str(ep_for_training.graph_module.code).strip(),
1744            """\
1745def forward(self, b_buffer, x):
1746    add_ = torch.ops.aten.add_.Tensor(x, 5);  x = None
1747    add__1 = torch.ops.aten.add_.Tensor(b_buffer, 5);  b_buffer = None
1748    add = torch.ops.aten.add.Tensor(add_, add__1);  add_ = add__1 = None
1749    return (add,)""",
1750        )
1751        gm = ep_for_training.module()
1752        self.assertExpectedInline(
1753            str(gm.code).strip(),
1754            """\
1755def forward(self, x):
1756    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
1757    buffer = self.buffer
1758    add_ = torch.ops.aten.add_.Tensor(x, 5);  x = None
1759    add__1 = torch.ops.aten.add_.Tensor(buffer, 5);  buffer = None
1760    add = torch.ops.aten.add.Tensor(add_, add__1);  add_ = add__1 = None
1761    return pytree.tree_unflatten((add,), self._out_spec)""",
1762        )
1763
1764        self.assertTrue(
1765            torch.allclose(
1766                gm(torch.ones(4, 4)), eager_model_for_testing(torch.ones(4, 4))
1767            )
1768        )
1769
1770    def test_export_for_training_with_dynamic_shapes(self):
1771        class Foo(torch.nn.Module):
1772            def __init__(self) -> None:
1773                super().__init__()
1774                self.buffer = torch.nn.Buffer(torch.ones(4, 4))
1775
1776            def forward(self, x):
1777                x.add_(5)
1778                self.buffer.add_(5)
1779                return x + self.buffer.sum()
1780
1781        eager_model_for_export_training = Foo()
1782        eager_model_for_export_inference = Foo()
1783        eager_model_for_testing = Foo()
1784        ep_for_training = torch.export.export_for_training(
1785            eager_model_for_export_training,
1786            (torch.ones(4, 4),),
1787            dynamic_shapes=({0: Dim("x")},),
1788        )
1789
1790        self.assertTrue(
1791            torch.allclose(
1792                ep_for_training.module()(torch.ones(2, 4)),
1793                eager_model_for_testing(torch.ones(2, 4)),
1794            )
1795        )
1796
1797        ep_for_real = export(
1798            eager_model_for_export_inference,
1799            (torch.ones(4, 4),),
1800            dynamic_shapes=({0: Dim("x")},),
1801        )
1802
1803        self.assertEqual(
1804            str(ep_for_training.range_constraints), str(ep_for_real.range_constraints)
1805        )
1806
1807    def test_export_for_training_with_container_type(self):
1808        class Foo(torch.nn.Module):
1809            def __init__(self) -> None:
1810                super().__init__()
1811                self.buffer = torch.nn.Buffer(torch.ones(4, 4))
1812
1813            def forward(self, container):
1814                x = container[0][0]
1815                y = container[0][1]
1816                x.add_(5)
1817                y.add_(5)
1818                return x + y + self.buffer.sum()
1819
1820        eager_model = Foo()
1821        ep_for_training = torch.export.export_for_training(
1822            eager_model,
1823            ([torch.ones(4, 4), torch.ones(4, 4)],),
1824        )
1825
1826        self.assertTrue(
1827            torch.allclose(
1828                ep_for_training.module()(
1829                    ([torch.ones(4, 4), torch.ones(4, 4)]),
1830                ),
1831                eager_model(([torch.ones(4, 4), torch.ones(4, 4)])),
1832            )
1833        )
1834
1835    def test_export_for_training_run_decomp(self):
1836        class Foo(torch.nn.Module):
1837            def __init__(self) -> None:
1838                super().__init__()
1839                self.buffer = torch.nn.Buffer(torch.ones(2, 2))
1840                self.linear = torch.nn.Linear(2, 2)
1841
1842            def forward(self, x):
1843                self.buffer.add_(5)
1844                return self.linear(x) + self.buffer.sum()
1845
1846        eager_model = Foo()
1847        ep_for_training = torch.export.export_for_training(
1848            eager_model,
1849            (torch.ones(2, 2),),
1850        )
1851        ep_for_inference = ep_for_training.run_decompositions()
1852        self.assertExpectedInline(
1853            str(ep_for_inference.graph_module.code).strip(),
1854            """\
1855def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
1856    add = torch.ops.aten.add.Tensor(b_buffer, 5);  b_buffer = None
1857    permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]);  p_linear_weight = None
1858    addmm = torch.ops.aten.addmm.default(p_linear_bias, x, permute);  p_linear_bias = x = permute = None
1859    sum_1 = torch.ops.aten.sum.dim_IntList(add, [])
1860    add_1 = torch.ops.aten.add.Tensor(addmm, sum_1);  addmm = sum_1 = None
1861    return (add, add_1)""",
1862        )
1863
1864    def test_derived_dim_out_of_order_simplified_repeat_non_derived(self):
1865        class Foo(torch.nn.Module):
1866            def forward(self, x, y, y1, z):
1867                return x + y[1:] + y1[1:] + z[2:]
1868
1869        foo = Foo()
1870
1871        u, v, v1, w = torch.randn(5), torch.randn(6), torch.randn(6), torch.randn(7)
1872        _dimz = torch.export.Dim("_dimz", min=6, max=8)
1873        dimy = _dimz - 1
1874        dimx = dimy - 1
1875        dimz = dimx + 2  # works, effectively = _dimz
1876        ep = export(
1877            foo,
1878            (u, v, v1, w),
1879            dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimy}, {0: dimz}),
1880        )
1881        with self.assertRaisesRegex(
1882            RuntimeError,
1883            "Expected input.*shape.*to be equal to 7, but got 5",
1884        ):
1885            ep.module()(
1886                torch.randn(6),
1887                torch.randn(7),
1888                torch.randn(5),
1889                torch.randn(8),
1890            )
1891
1892        self.assertEqual(
1893            ep.module()(
1894                torch.randn(6),
1895                torch.randn(7),
1896                torch.randn(7),
1897                torch.randn(8),
1898            ).size()[0],
1899            6,
1900        )
1901
1902    def test_static_dim_constraints(self):
1903        class Foo(torch.nn.Module):
1904            def __init__(self) -> None:
1905                super().__init__()
1906                self.l = torch.nn.Linear(6, 4)
1907
1908            def forward(self, x, y, z):
1909                x0 = self.l(x) + y[1:]
1910                return x0, z * 2.0
1911
1912        foo = Foo()
1913        inputs = (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 3))
1914        dx = Dim("dx", min=3, max=6)
1915        dy = dx + 1
1916        dz = Dim("dz", min=3, max=6)
1917
1918        # test that tweaking shapes fails
1919        wrong_shape_inputs = [
1920            (torch.randn(4, 7), torch.randn(5, 4), torch.randn(3, 3)),
1921            (torch.randn(4, 6), torch.randn(5, 5), torch.randn(3, 3)),
1922            (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 4)),
1923        ]
1924
1925        # all of these should be fine
1926        for dynamic_shapes in [
1927            ({0: dx, 1: 6}, {0: dy, 1: 4}, {0: dz, 1: 3}),
1928            ((dx, None), (dy, 4), (dz, 3)),
1929            ((None, 6), (5, None), (None, None)),
1930            ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}),
1931            (None, None, (Dim.STATIC, Dim.STATIC)),
1932        ]:
1933            ep = export(foo, inputs, dynamic_shapes=dynamic_shapes)
1934            self.assertEqual(foo(*inputs), ep.module()(*inputs))
1935            for wrong_inputs in wrong_shape_inputs:
1936                with self.assertRaises(RuntimeError):
1937                    ep.module()(*wrong_inputs)
1938
1939        # check range_constraints - static dims shouldn't be present
1940        ep = export(foo, inputs, dynamic_shapes=((dx, None), (dy, 4), (dz, 3)))
1941        self.assertEqual(len(ep.range_constraints), 3)
1942        for vr in ep.range_constraints.values():
1943            self.assertTrue(vr.lower < vr.upper)
1944
1945        # check raised errors
1946        with self.assertRaisesRegex(
1947            (
1948                torch.fx.experimental.symbolic_shapes.ConstraintViolationError,
1949                torch._dynamo.exc.UserError,
1950            ),
1951            "Static shape constraint of 5 does not match input size of 4, for .*",
1952        ):
1953            _ = export(foo, inputs, dynamic_shapes=((5, None), None, None))
1954        with self.assertRaisesRegex(
1955            (
1956                torch.fx.experimental.symbolic_shapes.ConstraintViolationError,
1957                torch._dynamo.exc.UserError,
1958            ),
1959            "Static shape constraint of 9 does not match input size of 6, for .*",
1960        ):
1961            _ = export(foo, inputs, dynamic_shapes=((dx, 9), (dy, 4), (3, 3)))
1962
1963    def test_dim_1_2(self):
1964        class Foo(torch.nn.Module):
1965            def forward(self, x):
1966                return x * 2
1967
1968        dx = Dim("dx", min=1, max=2)
1969        ep = export(Foo(), (torch.randn(2, 2),), dynamic_shapes=({0: dx, 1: None},))
1970        ep.module()(torch.randn(1, 2))
1971        ep.module()(torch.randn(2, 2))
1972        with self.assertRaisesRegex(
1973            RuntimeError, "Expected input at .* to be <= 2, but got 3"
1974        ):
1975            ep.module()(torch.randn(3, 2))
1976        vr = list(ep.range_constraints.values())[0]
1977        self.assertEqual(vr.lower, 1)
1978        self.assertEqual(vr.upper, 2)
1979
1980    def test_derived_dim_1_2(self):
1981        class Bar(torch.nn.Module):
1982            def forward(self, x, y):
1983                return x + y[1:]
1984
1985        dx = Dim("dx", min=1, max=2)
1986        ep = export(
1987            Bar(),
1988            (torch.randn(2, 2), torch.randn(3, 2)),
1989            dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}),
1990        )
1991        ep.module()(torch.randn(1, 2), torch.randn(2, 2))
1992        range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values())
1993        range_upper_bounds = sorted(vr.upper for vr in ep.range_constraints.values())
1994        self.assertEqual(range_lower_bounds, [1, 2])
1995        self.assertEqual(range_upper_bounds, [2, 3])
1996
1997    def test_dynamic_shapes_builder_basic(self):
1998        class M(torch.nn.Module):
1999            def forward(self, x, y, z):
2000                return x + y[0] + z["k"]
2001
2002        m = M()
2003
2004        x = torch.randn(4)
2005        y = [torch.randn(4)]
2006        z = {"k": torch.randn(4)}
2007        args = (x, y, z)
2008
2009        shapes_collection = torch.export.ShapesCollection()
2010        dim = torch.export.Dim("dim", max=10)
2011        shapes_collection[x] = (dim,)
2012        shapes_collection[y[0]] = (dim,)
2013        shapes_collection[z["k"]] = (dim,)
2014
2015        ep = export(m, args, dynamic_shapes=shapes_collection)
2016        sym = next(iter(ep.range_constraints.keys()))
2017        for node in ep.graph.nodes:
2018            if node.op == "placeholder":
2019                self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
2020
2021    def test_dynamic_shapes_builder_kwargs(self):
2022        class M(torch.nn.Module):
2023            def forward(self, x, y, z):
2024                return x + y[0] + z["k"]
2025
2026        m = M()
2027
2028        x = torch.randn(4)
2029        y = [torch.randn(4)]
2030        z = {"k": torch.randn(4)}
2031        args = (x,)
2032        kwargs = {"z": z, "y": y}
2033
2034        shapes_collection = torch.export.ShapesCollection()
2035        dim = torch.export.Dim("dim", max=10)
2036        shapes_collection[x] = (dim,)
2037        shapes_collection[y[0]] = (dim,)
2038        shapes_collection[z["k"]] = (dim,)
2039
2040        ep = export(m, args, kwargs=kwargs, dynamic_shapes=shapes_collection)
2041        sym = next(iter(ep.range_constraints.keys()))
2042        for node in ep.graph.nodes:
2043            if node.op == "placeholder":
2044                self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
2045
2046    # retracing doesn't seem to like dataclass registration,
2047    # raising a dynamo error in fx_pytree.tree_flatten_spec
2048    @testing.expectedFailureRetraceability
2049    def test_dynamic_shapes_builder_pytree(self):
2050        torch.export.register_dataclass(
2051            Inp,
2052            serialized_type_name="test_dynamic_shapes_builder_pytree.Inp",
2053        )
2054
2055        class M(torch.nn.Module):
2056            def forward(self, inp: Inp):
2057                return inp.x + inp.y[0] + inp.z["k"]
2058
2059        m = M()
2060        x = torch.randn(4)
2061        y = [torch.randn(4)]
2062        z = {"k": torch.randn(4)}
2063        args = (Inp(x, y, z),)
2064
2065        shapes_collection = torch.export.ShapesCollection()
2066        dim = torch.export.Dim("dim", max=10)
2067        shapes_collection[x] = (dim,)
2068        shapes_collection[y[0]] = (dim,)
2069        shapes_collection[z["k"]] = (dim,)
2070
2071        ep = export(m, args, dynamic_shapes=shapes_collection.dynamic_shapes(m, args))
2072        sym = next(iter(ep.range_constraints.keys()))
2073        for node in ep.graph.nodes:
2074            if node.op == "placeholder":
2075                self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
2076
2077    def test_mismatched_dynamic_shapes(self):
2078        AUTO, STATIC = Dim.AUTO, Dim.STATIC
2079
2080        class M(torch.nn.Module):
2081            def forward(self, x):
2082                return x["k"]["k"][0] + x["k"]["k"][1]
2083
2084        inputs = ({"k": {"k": [torch.rand(4), torch.rand(4)]}},)
2085        dim = torch.export.Dim("dim")
2086
2087        dynamic_shapes = {
2088            "k": {"k": [dim, dim]}
2089        }  # ValueError: Node keys mismatch; missing key(s): {'x'}; extra key(s): {'k'}.
2090        with self.assertRaisesRegex(
2091            torch._dynamo.exc.UserError,
2092            re.escape(
2093                "When `dynamic_shapes` is specified as a dict, its top-level keys "
2094                "must be the arg names ['x'] of `inputs`, but here they are ['k']. "
2095                "Since here `inputs` is a list/tuple enclosing a single dict, "
2096                "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
2097            ),
2098        ):
2099            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2100
2101        dynamic_shapes = (
2102            {"k": {"k": [dim, dim]}},
2103        )  # torch._dynamo.exc.UserError: Unexpected dynamic_shape .*dim.* of Tensor, try None instead
2104        with self.assertRaisesRegex(
2105            torch._dynamo.exc.UserError,
2106            "Unexpected input tensor shape .*dim.* "
2107            + re.escape(
2108                "specified at `dynamic_shapes[0]['k']['k'][0]` "
2109                "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
2110                " where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)"
2111            ),
2112        ):
2113            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2114
2115        dynamic_shapes = (
2116            {"k": {"k": (dim, dim)}},
2117        )  # ValueError: Node type mismatch; expected <class 'list'>, but got <class 'tuple'>.
2118        with self.assertRaisesRegex(
2119            torch._dynamo.exc.UserError,
2120            re.escape(
2121                "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: "
2122                "`inputs[0]['k']['k']` is a <class 'list'>, but `dynamic_shapes[0]['k']['k']` is a <class 'tuple'>"
2123            ),
2124        ):
2125            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2126
2127        dynamic_shapes = ({"k": {"k": [(dim,), (dim,)]}},)  # ok
2128        export(M(), inputs, dynamic_shapes=dynamic_shapes)
2129
2130        dynamic_shapes = (
2131            {"k": {"k": dim}},
2132        )  # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*.
2133        with self.assertRaisesRegex(
2134            torch._dynamo.exc.UserError,
2135            re.escape(
2136                "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: "
2137                "`inputs[0]['k']['k']` is a <class 'list'>, but `dynamic_shapes[0]['k']['k']` is not"
2138            ),
2139        ):
2140            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2141
2142        dynamic_shapes = {
2143            "x": {"k": [(dim,), (dim,)]},
2144            "k": {"k": [(dim,), (dim,)]},
2145        }  # ValueError: Node arity mismatch; expected 1, but got 2.
2146        with self.assertRaisesRegex(
2147            torch._dynamo.exc.UserError,
2148            re.escape(
2149                "When `dynamic_shapes` is specified as a dict, its top-level keys "
2150                "must be the arg names ['x'] of `inputs`, but here they are ['x', 'k']. "
2151                "Alternatively, you could also ignore arg names entirely "
2152                "and specify `dynamic_shapes` as a list/tuple matching `inputs`."
2153            ),
2154        ):
2155            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2156
2157        dynamic_shapes = (
2158            {"k": {"k": [(dim,), (dim,), (dim,)]}},
2159        )  # ValueError: Node arity mismatch; expected 2, but got 3.
2160        with self.assertRaisesRegex(
2161            torch._dynamo.exc.UserError,
2162            re.escape(
2163                "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: "
2164                "`inputs[0]['k']['k']` has 2 elements, but `dynamic_shapes[0]['k']['k']` has 3 elements"
2165            ),
2166        ):
2167            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2168
2169        dynamic_shapes = (
2170            {"k": {"K": [(dim,), (dim,), (dim,)]}},
2171        )  # ValueError: Node keys mismatch; missing key(s): {'k'}; extra key(s): {'K'}.
2172        with self.assertRaisesRegex(
2173            torch._dynamo.exc.UserError,
2174            re.escape(
2175                "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: "
2176                "`inputs[0]['k']` has keys ['k'], but `dynamic_shapes[0]['k']` has keys ['K']"
2177            ),
2178        ):
2179            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2180
2181        dynamic_shapes = {
2182            "x": {"k": {"k": [(dim,), (AUTO,)]}}
2183        }  # mixing AUTO and Dims is not well supported.
2184        with self.assertRaisesRegex(
2185            torch._dynamo.exc.UserError,
2186            re.escape(
2187                "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
2188                "and can easily lead to constraint violation errors or obscure errors in torch.export."
2189            ),
2190        ):
2191            export(M(), inputs, dynamic_shapes=dynamic_shapes)
2192
2193        class N(torch.nn.Module):
2194            def forward(self, x):
2195                return x["k"]["k1"][0] + x["k"]["k2"][0]
2196
2197        inputs = ({"k": {"k1": [torch.rand(4)], "k2": [torch.rand(4)]}},)
2198        dim = torch.export.Dim("dim")
2199
2200        dynamic_shapes = ({"k": {"k2": [(dim,)], "k1": [(dim,)]}},)  # ok
2201        export(N(), inputs, dynamic_shapes=dynamic_shapes)
2202
2203    def test_torch_check_eq_commutativity(self):
2204        class M1(torch.nn.Module):
2205            def forward(self, x1, x2, x3, y):
2206                z1 = x1.item()
2207                z2 = x2.item()
2208                z3 = x3.item()
2209                # instead of: torch._check((z2 + z3) == z1)
2210                torch._check(z1 == (z2 + z3))
2211                if z2 + z3 == z1:
2212                    return y * 2
2213                else:
2214                    return y + 3
2215
2216        export(
2217            M1(),
2218            (torch.tensor(6), torch.tensor(3), torch.tensor(3), torch.randn(1)),
2219        )
2220
2221        class M2(torch.nn.Module):
2222            def forward(self, x1, x2, x3, y):
2223                z1 = x1.item()
2224                z2 = x2.item()
2225                z3 = x3.item()
2226                # instead of: torch._check((z2 + z3) != z1)
2227                torch._check(z1 != (z2 + z3))
2228                if z2 + z3 == z1:
2229                    return y * 2
2230                else:
2231                    return y + 3
2232
2233        export(
2234            M2(),
2235            (torch.tensor(6), torch.tensor(6), torch.tensor(6), torch.randn(1)),
2236        )
2237
2238    def test_raise_user_error_when_guard_on_data_dependent_operation(self):
2239        class M(torch.nn.Module):
2240            def forward(self, x):
2241                y = x.nonzero()
2242                z = y.shape[0]
2243                if z > 2:
2244                    return x.cos()
2245                else:
2246                    return x.sin()
2247
2248        with self.assertRaisesRegex(
2249            (
2250                torchdynamo.exc.UserError,
2251                torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
2252            ),
2253            "Could not guard on data-dependent expression",
2254        ):
2255            _ = export(M(), (torch.tensor([2, 3, 5]),))
2256
2257    def test_suggested_fixes_for_data_dependent_errors_basic(self):
2258        # suggested fixes for data-dependent errors only work in non-strict mode
2259        strict = False
2260        error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
2261
2262        # Just to introduce some indirection: N is a top-level module N that calls
2263        # module M, defined next.
2264        class N(torch.nn.Module):
2265            def __init__(self) -> None:
2266                super().__init__()
2267                self.m = M()
2268
2269            def forward(self, t):
2270                return self.m(t) + 1
2271
2272        # example input
2273        t = torch.tensor([1, 4, 4], dtype=torch.int32)
2274
2275        # We define a series of versions of M() below. Each version has
2276        # raises a data-dependent error that the next version fixes, by
2277        # copy-pasting a suggested fix in the error message. The fix is
2278        # always a torch.check() on an unresolved condition (or its negation)
2279        # on unbacked symints mentioned in the error message.
2280        # Note that the suggested fixes are in terms of local variables
2281        # near the location of error that "contain" the unbacked symints
2282        # in the unresolved condition (either directly or indirectly, e.g.,
2283        # inside a list or inside the shape of a tensor).
2284
2285        class M_v0(torch.nn.Module):
2286            def forward(self, t):
2287                items = [t[i].item() for i in range(t.numel())]
2288                r = torch.randn([items[0], items[1]])
2289                # Could not guard on data-dependent expression Eq(u2, -1)
2290                return r.view(items[0], items[2])
2291
2292        M = M_v0
2293        with self.assertRaisesRegex(
2294            error_type,
2295            "The following call raised this error(.*\n)+"
2296            f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
2297            "To fix the error, insert one of the following checks before this call.*:\n"
2298            f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
2299            f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
2300            f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
2301        ):
2302            export(N(), (t,), strict=strict)
2303
2304        class M_v1(torch.nn.Module):
2305            def forward(self, t):
2306                items = [t[i].item() for i in range(t.numel())]
2307                r = torch.randn([items[0], items[1]])
2308                # Could not guard on data-dependent expression Eq(u2, -1)
2309                torch._check(items[2] != -1)
2310                # Could not guard on data-dependent expression u2 >= 0
2311                return r.view(items[0], items[2])
2312
2313        M = M_v1
2314        with self.assertRaisesRegex(
2315            error_type,
2316            "The following call raised this error(.*\n)+"
2317            f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
2318            "To fix the error, insert one of the following checks before this call.*:\n"
2319            f".*{re.escape('torch._check(items[2] >= 0)')}.*\n"
2320            f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
2321            f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
2322        ):
2323            export(N(), (t,), strict=strict)
2324
2325        class M_v2(torch.nn.Module):
2326            def forward(self, t):
2327                items = [t[i].item() for i in range(t.numel())]
2328                r = torch.randn([items[0], items[1]])
2329                # Could not guard on data-dependent expression Eq(u2, -1)
2330                torch._check(items[2] != -1)
2331                # Could not guard on data-dependent expression u2 >= 0
2332                torch._check(items[2] >= 0)
2333                # Could not guard on data-dependent expression Eq(u1, u2)
2334                return r.view(items[0], items[2])
2335
2336        M = M_v2
2337        with self.assertRaisesRegex(
2338            error_type,
2339            "The following call raised this error(.*\n)+"
2340            f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
2341            "To fix the error, insert one of the following checks before this call.*:\n"
2342            f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
2343            f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
2344            f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
2345        ):
2346            export(N(), (t,), strict=strict)
2347
2348        class M_v3(torch.nn.Module):
2349            def forward(self, t):
2350                items = [t[i].item() for i in range(t.numel())]
2351                r = torch.randn([items[0], items[1]])
2352                # Could not guard on data-dependent expression Eq(u2, -1)
2353                torch._check(items[2] != -1)
2354                # Could not guard on data-dependent expression u2 >= 0
2355                torch._check(items[2] >= 0)
2356                # Could not guard on data-dependent expression Eq(u1, u2)
2357                torch._check(items[2] == r.shape[1])
2358                return r.view(items[0], items[2])
2359
2360        M = M_v3
2361        export(N(), (t,), strict=strict)
2362
2363    @testing.expectedFailureSerDer  # T195866111
2364    def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
2365        # suggested fixes for data-dependent errors only work in non-strict mode
2366        strict = False
2367        error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
2368
2369        def retry_export(m, inp, fixes):
2370            # API that applies a series of fixes, retrying export after applying each fix,
2371            # and asserting the applied fix was suggested in the previous try.
2372            # Using this API avoids the need to define multiple versions of the same test
2373            # module, as in `test_suggested_fixes_for_data_dependent_errors_basic` above.
2374            def code(snippets):
2375                return f"[{', '.join(snippets)}]"
2376
2377            for i in range(len(fixes)):
2378                with self.assertRaisesRegex(error_type, re.escape(fixes[i])):
2379                    export(m, (*inp, code(fixes[:i])), strict=strict)
2380            export(m, (*inp, code(fixes)), strict=strict)
2381
2382        # The following examples are lifted from @ezyang's "Data-dependent shape puzzlers"
2383        # notebook at https://www.internalfb.com/intern/anp/view/?id=5330476
2384
2385        # These test modules are written in a way that works well with retry_export above.
2386        # Specifically, they take an extra `fixes` argument and `eval` it at the location
2387        # that is expected to raise errors.
2388
2389        class cf_implicitsize(torch.nn.Module):
2390            def forward(self, x, y, fixes):
2391                i = x.item()
2392                eval(fixes)
2393                # instead of y[i]
2394                return y.narrow(0, i, 1).squeeze()
2395
2396        retry_export(
2397            cf_implicitsize(),
2398            (torch.tensor(2), torch.randn(10)),
2399            fixes=[
2400                # Could not guard on data-dependent expression u0 < 0
2401                "torch._check(i >= 0)",
2402            ],
2403        )
2404
2405        class cf_nomemo(torch.nn.Module):
2406            def forward(self, x, y, fixes):
2407                i = y[0].item()
2408                eval(fixes)
2409                return x.unsqueeze(1).expand(-1, i)
2410
2411        retry_export(
2412            cf_nomemo(),
2413            (torch.randn(8), torch.tensor([2])),
2414            fixes=[
2415                # Could not guard on data-dependent expression Eq(u0, 1)
2416                "torch._check(i != 1)",
2417                # Could not guard on data-dependent expression Ne(u0, -1)
2418                "torch._check(i != (-1))",
2419            ],
2420        )
2421
2422        class cf_changevar(torch.nn.Module):
2423            def forward(self, x, fixes):
2424                i = x.item()
2425                eval(fixes)
2426                r = torch.arange(i // 2)
2427                return r + r
2428
2429        retry_export(
2430            cf_changevar(),
2431            (torch.tensor(20),),
2432            fixes=[
2433                # Could not guard on data-dependent expression Eq((u0//2), 0)
2434                "torch._check(((i//2)) != 0)",
2435                # Could not guard on data-dependent expression Eq((u0//2), 1)
2436                "torch._check(((i//2)) != 1)",
2437            ],
2438        )
2439
2440        class cf_stacklist(torch.nn.Module):
2441            def forward(self, xs, y, fixes):
2442                i = y.item()
2443                eval(fixes)
2444                # instead of xs[i]
2445                return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
2446
2447        retry_export(
2448            cf_stacklist(),
2449            ([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
2450            fixes=[
2451                # Could not guard on data-dependent expression u0 < 0
2452                "torch._check(i >= 0)",
2453            ],
2454        )
2455
2456        class cf_tensorsplit(torch.nn.Module):
2457            def forward(self, x, offsets_t, fixes):
2458                lengths = torch.diff(offsets_t).tolist()
2459                rs = []
2460                start = 0
2461                for length in lengths:
2462                    eval(fixes)
2463                    rs.append(x.narrow(0, start, length))
2464                    start += length
2465                return rs
2466
2467        retry_export(
2468            cf_tensorsplit(),
2469            (torch.arange(10), torch.tensor([0, 2, 5, 7, 10])),
2470            fixes=[],  # nothing to fix!
2471        )
2472
2473    def test_no_suggested_fixes_for_data_dependent_errors(self):
2474        # suggested fixes for data-dependent errors only work in non-strict mode
2475        strict = False
2476        error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
2477
2478        class cf_stacklist(torch.nn.Module):
2479            def forward(self, xs, y):
2480                # y.item() is not a local, so we can't suggest a fix
2481                return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
2482
2483        with self.assertRaisesRegex(
2484            error_type,
2485            "Could not guard on data-dependent expression u0 < 0",
2486        ):
2487            export(
2488                cf_stacklist(),
2489                ([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
2490                strict=strict,
2491            )
2492
2493    def test_tolist(self):
2494        class M(torch.nn.Module):
2495            def forward(self, x):
2496                return x.tolist()
2497
2498        ep = export(M(), (torch.ones(3, dtype=torch.int),))
2499        self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3])
2500
2501    def test_if_functional(self):
2502        class Module(torch.nn.Module):
2503            def forward(self, x):
2504                z = x + 4
2505                z.add_(4)
2506                y = z.view(x.shape)
2507                return x.cos() + y.cos()
2508
2509        foo = Module()
2510        gm = export(foo, (torch.tensor([2, 3, 5]),))
2511
2512        view_count = 0
2513        for node in gm.graph.nodes:
2514            if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor:
2515                # No more inplace mutation
2516                self.assertNotEqual(
2517                    node.target,
2518                    torch.ops.aten.add_.Tensor,
2519                    "There shouldn't be any inplace mutation node in the graph.",
2520                )
2521            if (
2522                node.op == "call_function"
2523                and node.target == torch.ops.aten.view.default
2524            ):
2525                view_count += 1
2526
2527        # There should be nonzero view nodes in the graph
2528        self.assertTrue(view_count > 0)
2529
2530    def test_solver_unsupported_sympy_function(self):
2531        # repro of https://github.com/pytorch/pytorch/issues/131897
2532
2533        class MyModule(torch.nn.Module):
2534            def __init__(self):
2535                super().__init__()
2536
2537            def forward(self, x, y):
2538                x = torch.nn.functional.interpolate(
2539                    x, scale_factor=0.5, mode="bilinear"
2540                )
2541                x = torch.nn.functional.interpolate(
2542                    x, scale_factor=2.0, mode="bilinear"
2543                )
2544                x = x + y
2545                return x
2546
2547        model = MyModule().eval()
2548
2549        inputs = (
2550            torch.rand((1, 1, 32, 32)),
2551            torch.rand((1, 1, 32, 32)),
2552        )
2553
2554        dim = torch.export.Dim("Dim", min=16, max=64)
2555        dynamic_shapes = {"x": {2: dim, 3: dim}, "y": {2: dim, 3: dim}}
2556
2557        exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes)
2558        self.assertEqual(exported_program.module()(*inputs), model(*inputs))
2559
2560    def test_export_mod_constraints(self):
2561        class BasicDynamiShapeModel(torch.nn.Module):
2562            def forward(self, x: torch.Tensor) -> torch.Tensor:
2563                return x.view(x.shape[0] - 1, -1)
2564
2565        m = BasicDynamiShapeModel()
2566        a = torch.randn(3, 4)
2567        dim0_x = torch.export.Dim("dim0_x", min=3)
2568        dim1_x = torch.export.Dim("dim1_x", max=8000)
2569        dynamic_shapes = {"x": (dim0_x, dim1_x)}
2570        em = torch.export._trace._export(
2571            m,
2572            (a,),
2573            dynamic_shapes=dynamic_shapes,
2574            allow_complex_guards_as_runtime_asserts=True,
2575        )
2576        em.module()(torch.randn(4, 3))
2577        with self.assertRaisesRegex(
2578            RuntimeError,
2579            r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)",
2580        ):
2581            em.module()(torch.randn(4, 5))
2582
2583        dim0_x = None
2584        dim1_x = 2 * torch.export.Dim("_dim1_x", max=4000)
2585        dynamic_shapes = {"x": (dim0_x, dim1_x)}
2586        em = torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
2587        x = torch.randn(3, 5)
2588        with self.assertRaisesRegex(
2589            RuntimeError,
2590            "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer",
2591        ):
2592            em.module()(x)
2593
2594    def test_mark_and_auto_dynamic(self):
2595        # for this use case, mark_dynamic() and AUTO should have same effect.
2596        # check that same symbol gets allocated to both dims without raising constraint violation.
2597        AUTO, STATIC = Dim.AUTO, Dim.STATIC
2598
2599        class Foo(torch.nn.Module):
2600            def forward(self, x, y):
2601                torch._check(x.shape[0] == y.shape[0])
2602                torch._check(x.shape[0] <= 64)
2603                return x + 2, y + 2
2604
2605        inputs = (torch.randn(4, 4), torch.randn(4, 4))
2606        ep_auto = torch.export.export(
2607            Foo(), inputs, dynamic_shapes={"x": (AUTO, None), "y": (AUTO, None)}
2608        )
2609        torch._dynamo.mark_dynamic(inputs[0], 0)
2610        torch._dynamo.mark_dynamic(inputs[1], 0)
2611        ep_dynamic = torch.export.export(Foo(), inputs)
2612
2613        # test both programs have same effect
2614        for ep in [ep_auto, ep_dynamic]:
2615            gm = ep.module()
2616            gm(torch.randn(32, 4), torch.randn(32, 4))
2617            gm(torch.randn(1, 4), torch.randn(1, 4))
2618            with self.assertRaises(RuntimeError):
2619                gm(torch.randn(33, 4), torch.randn(32, 4))
2620                gm(torch.randn(128, 4), torch.randn(128, 4))
2621
2622    def test_dont_duck_size_for_auto_dynamic(self):
2623        # for this use case, mark_dynamic() and AUTO should have same effect.
2624        # check that same symbol gets allocated to both dims without raising constraint violation.
2625        AUTO, STATIC = Dim.AUTO, Dim.STATIC
2626
2627        class Foo(torch.nn.Module):
2628            def forward(self, x, y):
2629                # x: [s0, s1], y: [s0 + 1, 4]
2630                assert y.shape[1] == 4
2631                assert x.shape[0] == y.shape[0] - 1
2632                return x * 2, y * 2
2633
2634        # duck sizing would make all static based on these sample inputs
2635        inputs = (torch.randn(4, 4), torch.randn(5, 4))
2636        shapes = {
2637            "x": (AUTO, AUTO),
2638            "y": (AUTO, AUTO),
2639        }
2640        ep = export(Foo(), inputs, dynamic_shapes=shapes)
2641        ep.module()(torch.randn(6, 3), torch.randn(7, 4))
2642
2643    @testing.expectedFailureRetraceability  # T183144629
2644    def test_map(self):
2645        class Module(torch.nn.Module):
2646            def forward(self, xs, y, z):
2647                def body(x, y, z):
2648                    return x + y + z
2649
2650                return map(body, xs, y, z)
2651
2652        list_tensor_map = Module()
2653        inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
2654        self._test_export_same_as_eager(list_tensor_map, inps)
2655
2656    @unittest.expectedFailure
2657    def test_crop_like(self):
2658        # https://fb.workplace.com/groups/1405155842844877/posts/8195050017188725/
2659
2660        # Minimal crop code copied from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional
2661        class CropLike(torch.nn.Module):
2662            def forward(self, image, crop_height, crop_width):
2663                c, image_height, image_width = image.shape
2664                crop_top = int(round((image_height - crop_height) / 2.0))
2665                crop_left = int(round((image_width - crop_width) / 2.0))
2666                return image[
2667                    ...,
2668                    crop_top : crop_top + crop_height,
2669                    crop_left : crop_left + crop_width,
2670                ]
2671
2672        crop = CropLike()
2673        imagew = Dim("width")
2674        imageh = Dim("height")
2675        dynamic_dims = {
2676            "image": {0: None, 1: imageh, 2: imagew},
2677            "crop_height": None,
2678            "crop_width": None,
2679        }
2680        args = (torch.rand(3, 512, 512), 150, 150)
2681        ecrop = export(crop, args=args, dynamic_shapes=dynamic_dims)
2682
2683        args = (torch.rand(3, 700, 700), 150, 150)
2684        self.assertEqual(ecrop.module()(*args), ecrop(*args))
2685
2686    def test_export_func_with_kwargs(self):
2687        class Module(torch.nn.Module):
2688            def forward(self, arg1, arg2, kw1, kw2):
2689                return arg1 + arg2, kw1 + kw2
2690
2691        kw_func = Module()
2692        args = (torch.ones(6, 4), torch.ones(1, 1))
2693        kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}
2694        self._test_export_same_as_eager(kw_func, args, kwargs)
2695
2696    def test_export_func_with_pytree_kwargs(self):
2697        class Module(torch.nn.Module):
2698            def forward(self, arg1, arg2, a, b):
2699                return arg1 + a["kw1"] + b[0], arg2 + a["kw2"] + b[1]
2700
2701        kw_func = Module()
2702        args = (torch.ones(2, 3), torch.ones(3, 4))
2703        kwargs = {
2704            "a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)},
2705            "b": [torch.ones(2, 3), torch.ones(3, 4)],
2706        }
2707        self._test_export_same_as_eager(kw_func, args, kwargs)
2708
2709    def test_export_func_with_default_kwargs(self):
2710        class Module(torch.nn.Module):
2711            def forward(self, arg1, arg2, a, b=1):
2712                return arg1 + arg2, a["kw1"] + a["kw2"] + b
2713
2714        kw_func = Module()
2715
2716        class Module2(torch.nn.Module):
2717            def forward(self, arg1, arg2, a=1, b=2):
2718                return arg1 + a, arg2 + b
2719
2720        kw_func2 = Module2()
2721
2722        args = (torch.ones(6, 4), torch.ones(1, 1))
2723        kwargs1 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}}
2724        kwargs2 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}, "b": 2}
2725        self._test_export_same_as_eager(kw_func, args, kwargs1)
2726        self._test_export_same_as_eager(kw_func, args, kwargs2)
2727        kwargs3 = {"b": 1}
2728        self._test_export_same_as_eager(kw_func2, args, kwargs3)
2729
2730    def test_export_func_with_var_postional_args(self):
2731        class Module(torch.nn.Module):
2732            def forward(self, arg1, arg2, *args):
2733                return arg1 + args[0], arg2 + args[1]
2734
2735        kw_func = Module()
2736        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
2737        self._test_export_same_as_eager(kw_func, args)
2738
2739    def test_export_func_with_keyword_only_args(self):
2740        class Module(torch.nn.Module):
2741            def forward(self, arg1, arg2, *args, kw1, kw2):
2742                return arg1 + args[0] + kw1, arg2 + args[1] + kw2
2743
2744        kw_func = Module()
2745        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
2746        kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}
2747        self._test_export_same_as_eager(kw_func, args, kwargs)
2748
2749    def test_export_func_with_var_keyword_args(self):
2750        class Module(torch.nn.Module):
2751            def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
2752                return (
2753                    arg1 + args[0] + kw1 + kwargs["kw3"],
2754                    arg2 + args[1] + kw2 + kwargs["kw4"],
2755                )
2756
2757        kw_func = Module()
2758        args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
2759        kwargs = {
2760            "kw1": torch.ones(2, 3),
2761            "kw2": torch.ones(3, 4),
2762            "kw3": torch.ones(2, 3),
2763            "kw4": torch.ones(3, 4),
2764        }
2765        self._test_export_same_as_eager(kw_func, args, kwargs)
2766
2767    def test_unbacked_slice(self):
2768        class M(torch.nn.Module):
2769            def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
2770                valid_mask = scores > score_thr
2771                scores = scores[valid_mask]
2772                valid_idxs = torch.nonzero(valid_mask).to(scores.device)
2773
2774                num_topk = torch.minimum(topk, torch.tensor(valid_idxs.shape[0])).item()
2775                torch._check_is_size(num_topk)
2776                torch._check(scores.shape[0] >= num_topk)
2777                scores, idxs = scores.sort(descending=True)
2778                scores = scores[:num_topk]
2779                topk_idxs = valid_idxs[idxs[:num_topk]]
2780                keep_idxs, labels = topk_idxs.unbind(dim=1)
2781
2782                return scores, labels, keep_idxs
2783
2784        score = torch.tensor(
2785            [[0.1, 0.3, 0.2], [0.12, 0.7, 0.9], [0.02, 0.8, 0.08], [0.4, 0.1, 0.08]]
2786        )
2787        bbox_pred = torch.tensor([[0.2, 0.3], [0.4, 0.7], [0.1, 0.1], [0.5, 0.1]])
2788        score_thr = 0.15
2789        nms_pre = torch.tensor(4)
2790        inputs = (score, score_thr, nms_pre, dict(bbox_pred=bbox_pred))
2791
2792        ep = torch.export.export(M(), inputs)
2793        orig_res = M()(*inputs)
2794        ep_res = ep.module()(*inputs)
2795        self.assertTrue(torch.allclose(orig_res[0], ep_res[0]))
2796        self.assertTrue(torch.allclose(orig_res[1], ep_res[1]))
2797        self.assertTrue(torch.allclose(orig_res[2], ep_res[2]))
2798
2799    def test_unflatten_asserts(self):
2800        # TODO: strict-export fails
2801        class M1(torch.nn.Module):
2802            def forward(self, x, y):
2803                b = x.item()
2804
2805                torch._check_is_size(b)
2806                torch._check(b < y.size(0))
2807                return y[:b]
2808
2809        class M3(torch.nn.Module):
2810            def forward(self, x, y):
2811                b = x.item()
2812
2813                torch._check_is_size(b)
2814                torch._check(b < y.size(0) * 2)
2815                return y[:b]
2816
2817        class M2(torch.nn.Module):
2818            def __init__(self) -> None:
2819                super().__init__()
2820                self.m1 = M1()
2821                self.m3 = M3()
2822
2823            def forward(self, x, y):
2824                return self.m1(x, y) + self.m3(x, y)
2825
2826        inputs = (torch.tensor(3), torch.randn(10))
2827
2828        ep = torch.export.export(
2829            M2(), inputs, dynamic_shapes={"x": None, "y": (Dim("moo"),)}, strict=False
2830        )
2831        orig_res = M2()(*inputs)
2832        ep_res = ep.module()(*inputs)
2833        self.assertTrue(torch.allclose(orig_res[0], ep_res[0]))
2834        self.assertTrue(torch.allclose(orig_res[1], ep_res[1]))
2835        self.assertTrue(torch.allclose(orig_res[2], ep_res[2]))
2836
2837        unflattened = torch.export.unflatten(ep)
2838        ep_res = unflattened(*inputs)
2839        self.assertTrue(torch.allclose(orig_res[0], ep_res[0]))
2840        self.assertTrue(torch.allclose(orig_res[1], ep_res[1]))
2841        self.assertTrue(torch.allclose(orig_res[2], ep_res[2]))
2842
2843    def test_export_func_with_var_keyword_pytree_args(self):
2844        class Module(torch.nn.Module):
2845            def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
2846                return (
2847                    arg1 + arg2[0][0] + args[0] + kw1[0] + kwargs["kw3"][0],
2848                    arg2[1] + args[1] + kw2 + kwargs["kw4"],
2849                )
2850
2851        kw_func = Module()
2852        args = (
2853            torch.ones(2, 3),
2854            [(torch.ones(2, 3),), torch.ones(3, 4)],
2855            torch.ones(2, 3),
2856            torch.ones(3, 4),
2857        )
2858        kwargs = {
2859            "kw1": (torch.ones(2, 3),),
2860            "kw2": torch.ones(3, 4),
2861            "kw3": (torch.ones(2, 3), torch.ones(3, 4)),
2862            "kw4": torch.ones(3, 4),
2863        }
2864        self._test_export_same_as_eager(kw_func, args, kwargs)
2865
2866    @testing.expectedFailureSerDer  # we don't save placeholder metadata
2867    @testing.expectedFailureNonStrict
2868    @testing.expectedFailureTrainingIRToRunDecompNonStrict  # source_fn_stack failure
2869    def test_linear_conv(self):
2870        class MyLinear(torch.nn.Module):
2871            def __init__(self) -> None:
2872                super().__init__()
2873                self.weight = torch.randn(20, 98)
2874                self.bias = torch.randn(20)
2875
2876            def forward(self, x):
2877                return torch.nn.functional.linear(x, self.weight, self.bias)
2878
2879        class Foo(torch.nn.Module):
2880            def __init__(self) -> None:
2881                super().__init__()
2882                self.conv = torch.nn.Conv2d(16, 33, 3)
2883                self.linear = MyLinear()
2884
2885            def forward(self, x):
2886                x_conv = self.conv(x)
2887                x_linear = self.linear(x_conv)
2888                return x_linear.cos()
2889
2890        ep = export(Foo(), (torch.randn(20, 16, 50, 100),))
2891        for node in ep.graph.nodes:
2892            if (
2893                node.op == "placeholder"
2894                and node.name in ep.graph_signature.inputs_to_buffers
2895                or node.name in ep.graph_signature.inputs_to_parameters
2896            ):
2897                self.assertTrue("source_fn_stack" in node.meta)
2898
2899    def test_export_api_with_dynamic_shapes(self):
2900        from torch.export import Dim, dims, export
2901
2902        # pass dynamic shapes of inputs [args]
2903        class Foo(torch.nn.Module):
2904            def forward(self, x, y):
2905                return torch.matmul(x, y)
2906
2907        foo = Foo()
2908        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
2909        batch = Dim("batch")
2910        efoo = export(
2911            foo,
2912            inputs,
2913            dynamic_shapes={k: {0: batch} for k in ["x", "y"]},
2914        )
2915        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
2916
2917        foo = Foo()
2918        inputs = (torch.randn(10, 2, 3),)
2919        kwinputs = {"y": torch.randn(10, 3, 4)}
2920        batch = Dim("batch")
2921        efoo = export(
2922            foo, inputs, kwinputs, dynamic_shapes={k: {0: batch} for k in ["x", "y"]}
2923        )
2924        self.assertEqual(
2925            efoo.module()(*inputs, **kwinputs).shape, foo(*inputs, **kwinputs).shape
2926        )
2927
2928        # pass dynamic shapes of inputs [partial, error]
2929        foo = Foo()
2930        inputs = (torch.randn(10, 2, 3),)
2931        kwinputs = {"y": torch.randn(10, 3, 4)}
2932        batch = Dim("batch")
2933        with self.assertRaisesRegex(
2934            torch._dynamo.exc.UserError,
2935            (
2936                "Constraints violated \\(batch\\)!(.*\n)*.*"
2937                "batch was inferred to be a constant(.*\n)*.*"
2938                "Suggested fixes:(.*\n)*.*"
2939                "batch = 10"
2940            ),
2941        ):
2942            export(
2943                foo,
2944                inputs,
2945                kwinputs,
2946                dynamic_shapes={"x": {0: batch}, "y": None},
2947            )
2948
2949        # pass dynamic shapes of inputs [module]
2950        foo = Foo()
2951        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
2952        batch = Dim("batch")
2953        efoo = export(
2954            foo,
2955            inputs,
2956            dynamic_shapes={"x": {0: batch}, "y": {0: batch}},
2957        )
2958        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
2959
2960        # pass dynamic shapes of inputs [bounds, mostly shared]
2961        foo = Foo()
2962        inputs = (torch.randn(10, 3, 3), torch.randn(10, 3, 3))
2963        batch = Dim("batch", min=8, max=64)
2964        size = Dim("size")
2965        efoo = export(
2966            foo,
2967            inputs,
2968            dynamic_shapes={
2969                "x": (batch, size, size),
2970                "y": (batch, size, size),
2971            },
2972        )
2973        self.assertEqual(
2974            [
2975                str(node.meta["val"].shape)
2976                for node in efoo.graph_module.graph.nodes
2977                if node.op == "placeholder"
2978            ],
2979            ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
2980        )
2981        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
2982
2983        # pass dynamic shapes of inputs [multiple, mostly distinct]
2984        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
2985        batch, M, K, N = dims("batch", "M", "K", "N")
2986        efoo = export(
2987            Foo(),
2988            inputs,
2989            dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
2990        )
2991        self.assertEqual(
2992            [
2993                str(node.meta["val"].shape)
2994                for node in efoo.graph_module.graph.nodes
2995                if node.op == "placeholder"
2996            ],
2997            ["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
2998        )
2999        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
3000
3001        # pass dynamic shapes of inputs [dict]
3002        class Foo(torch.nn.Module):
3003            def forward(self, inputs):
3004                return torch.matmul(inputs["x"], inputs["y"])
3005
3006        foo = Foo()
3007        inputs = ({"x": torch.randn(10, 2, 3), "y": torch.randn(10, 3, 4)},)
3008        batch = Dim("batch")
3009        efoo = export(
3010            foo, inputs, dynamic_shapes={"inputs": {k: {0: batch} for k in ["x", "y"]}}
3011        )
3012        self.assertEqual(
3013            [
3014                str(node.meta["val"].shape)
3015                for node in efoo.graph_module.graph.nodes
3016                if node.op == "placeholder"
3017            ],
3018            ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
3019        )
3020        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
3021
3022        # pass dynamic shapes of inputs [list]
3023        class Foo(torch.nn.Module):
3024            def forward(self, inputs):
3025                return torch.matmul(inputs[0], inputs[1])
3026
3027        foo = Foo()
3028        inputs = ([torch.randn(10, 2, 3), torch.randn(10, 3, 4)],)
3029        batch = Dim("batch")
3030        efoo = export(
3031            foo, inputs, dynamic_shapes={"inputs": [{0: batch} for _ in range(2)]}
3032        )
3033        self.assertEqual(
3034            [
3035                str(node.meta["val"].shape)
3036                for node in efoo.graph_module.graph.nodes
3037                if node.op == "placeholder"
3038            ],
3039            ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
3040        )
3041        self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
3042
3043        # pass dynamic shapes of inputs [dataclass]
3044
3045        # TODO(avik): This part of the test should have failed both serde and retracing
3046        # but these failures are hidden because of the local import of `export` in this test.
3047        # The serde failure is benign, and easily avoided by moving the dataclass definition
3048        # to the top-level. OTOH the retracing failure needs further investigation.
3049        @dataclass
3050        class DataClass:
3051            a: Tensor
3052            b: Tensor
3053
3054        register_dataclass_as_pytree_node(
3055            DataClass,
3056            serialized_type_name="test_export_api_with_dynamic_shapes.DataClass",
3057        )
3058
3059        class Foo(torch.nn.Module):
3060            def forward(self, inputs):
3061                return torch.matmul(inputs.a, inputs.b)
3062
3063        foo = Foo()
3064        inputs = (DataClass(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
3065        batch = Dim("batch")
3066        efoo = export(
3067            foo,
3068            inputs,
3069            dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
3070        )
3071        self.assertEqual(
3072            [
3073                str(node.meta["val"].shape)
3074                for node in efoo.graph_module.graph.nodes
3075                if node.op == "placeholder"
3076            ],
3077            ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
3078        )
3079
3080        # pass dynamic shapes of inputs [pytree-registered classes]
3081        if HAS_TORCHREC:
3082            # skipping tests if torchrec not available
3083            class Foo(torch.nn.Module):
3084                def forward(self, kjt) -> torch.Tensor:
3085                    return kjt.values() + 0, kjt.offsets() + 0
3086
3087            foo = Foo()
3088            kjt = KeyedJaggedTensor(
3089                values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
3090                keys=["index_0", "index_1"],
3091                lengths=torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]),
3092                offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]),
3093            )
3094            inputs = (kjt,)
3095            dim = Dim("dim")
3096            dim_plus_one = Dim("dim_plus_one")
3097            efoo = torch.export.export(
3098                foo,
3099                inputs,
3100                dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]},
3101            )
3102            self.assertEqual(
3103                [out.shape for out in efoo.module()(*inputs)],
3104                [out.shape for out in foo(*inputs)],
3105            )
3106
3107        # pass dynamic shapes of inputs [distinct, error]
3108        class Foo(torch.nn.Module):
3109            def forward(self, x, y):
3110                return torch.matmul(x, y)
3111
3112        foo = Foo()
3113        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
3114        batch, M, K1, K2, N = dims("batch", "M", "K1", "K2", "N")
3115        with self.assertRaisesRegex(
3116            torch._dynamo.exc.UserError,
3117            (
3118                "Constraints violated \\(K2\\)!(.*\n)*.*"
3119                "K2.*and.*K1.*must always be equal(.*\n)*.*"
3120                "Suggested fixes:(.*\n)*.*"
3121                "K2 = K1"
3122            ),
3123        ):
3124            export(
3125                foo,
3126                inputs,
3127                dynamic_shapes={"x": (batch, M, K1), "y": (batch, K2, N)},
3128            )
3129
3130        # pass dynamic shapes of inputs [specialized, error]
3131        foo = Foo()
3132        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
3133        batch, M, K1, N = dims("batch", "M", "K1", "N")
3134        with self.assertRaisesRegex(
3135            torch._dynamo.exc.UserError,
3136            (
3137                "Constraints violated \\(K1\\)!(.*\n)*.*"
3138                "K1 was inferred to be a constant(.*\n)*.*"
3139                "Suggested fixes:(.*\n)*.*"
3140                "K1 = 3"
3141            ),
3142        ):
3143            export(
3144                foo,
3145                inputs,
3146                dynamic_shapes={"x": (batch, M, K1), "y": (batch, None, N)},
3147            )
3148
3149        # pass dynamic shapes of inputs [guards, error]
3150        class Foo(torch.nn.Module):
3151            def forward(self, x, y):
3152                if x.shape[0] < 16 and y.shape[1] % 3 == 0:
3153                    return torch.matmul(x, y)
3154                else:
3155                    return x + y
3156
3157        foo = Foo()
3158        inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
3159        batch, M, K, N = dims("batch", "M", "K", "N")
3160        with self.assertRaisesRegex(
3161            torch._dynamo.exc.UserError,
3162            (
3163                "Constraints violated.*!(.*\n)*.*"
3164                "Not all values of K.*satisfy the generated guard(.*\n)*.*"
3165                "Not all values of batch.*satisfy the generated guard(.*\n)*.*"
3166                "Suggested fixes:(.*\n)*.*"
3167                "batch = Dim\\('batch', max=15\\)(.*\n)*.*"
3168                "K = 3\\*_K"
3169            ),
3170        ):
3171            export(
3172                foo,
3173                inputs,
3174                dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
3175            )
3176
3177    def test_suggested_fixes_new_roots(self):
3178        from torch.export import dims
3179
3180        # suggested fixes should introduce new root dim for modulo guard
3181        class Foo(torch.nn.Module):
3182            def forward(self, x, y, z):
3183                # dy = 3 * _dx
3184                # dx = 3 * _dx - 1
3185                # dz = 3 * _dx + 2
3186                # suggested fixes results will look something like
3187                # {"dx": {"eq": 3*_dx-1, "min": 5, "max": 36}, "dy": {"eq": dx+1}, ...}
3188                if x.shape[0] >= 5 and x.shape[0] <= 36 and y.shape[0] % 3 == 0:
3189                    return x + y[1:] + z[3:]
3190
3191        foo = Foo()
3192        inputs = (
3193            torch.randn(
3194                11,
3195            ),
3196            torch.randn(
3197                12,
3198            ),
3199            torch.randn(
3200                14,
3201            ),
3202        )
3203        dx, dy, dz = dims("dx", "dy", "dz")
3204        dynamic_shapes = {
3205            "x": (dx,),
3206            "y": (dy,),
3207            "z": (dz,),
3208        }
3209        with self.assertRaisesRegex(  # figure out regex later
3210            torch._dynamo.exc.UserError,
3211            (
3212                "Constraints violated.*!(.*\n)*.*"
3213                "Suggested fixes(.*\n)*.*"
3214                "_dx = Dim\(\\'_dx\\', max=12\)(.*\n)*.*"
3215                "dx = 3\*_dx - 1(.*\n)*.*"
3216                "dy = 3\*_dx(.*\n)*.*"
3217                "dz = 3\*_dx \+ 2"
3218            ),
3219        ):
3220            export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
3221        # retry export
3222        _dx = Dim("_dx", min=2, max=12)
3223        dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)}
3224        export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
3225
3226    def test_refine_dynamic_shapes_from_suggested_fixes(self):
3227        from torch.export.dynamic_shapes import (
3228            refine_dynamic_shapes_from_suggested_fixes,
3229        )
3230
3231        def helper(model, inputs, dynamic_shapes):
3232            # export, fail, parse & refine suggested fixes, re-export
3233            try:
3234                export(Foo(), inps, dynamic_shapes=dynamic_shapes)
3235                raise Exception("should have raised constraint violation error")
3236            except torch._dynamo.exc.UserError as exc:
3237                new_shapes = refine_dynamic_shapes_from_suggested_fixes(
3238                    exc.msg, dynamic_shapes
3239                )
3240                export(Foo(), inps, dynamic_shapes=new_shapes)
3241                return new_shapes
3242
3243        # specialize dims + derived dims
3244        class Foo(torch.nn.Module):
3245            def forward(self, x, y, z):
3246                x0 = x + y[1:] + z[2:]
3247                x1 = x @ torch.randn(4, 4)
3248                return x0, x1
3249
3250        inps = (
3251            torch.randn(
3252                4,
3253            ),
3254            torch.randn(
3255                5,
3256            ),
3257            torch.randn(
3258                6,
3259            ),
3260        )
3261        dx = Dim("dx", max=16)
3262        dynamic_shapes = {"x": (dx,), "y": (dx + 1,), "z": (dx + 2,)}
3263        new_shapes = helper(Foo(), inps, dynamic_shapes)
3264        self.assertEqual(new_shapes["x"][0], 4)
3265        self.assertEqual(new_shapes["z"][0], 6)
3266
3267        # refine lower, upper bound
3268        class Foo(torch.nn.Module):
3269            def forward(self, x, y):
3270                if x.shape[0] >= 6 and y.shape[0] <= 16:
3271                    return x * 2.0, y + 1
3272
3273        inps = (torch.randn(16), torch.randn(12))
3274        dynamic_shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
3275        new_shapes = helper(Foo(), inps, dynamic_shapes)
3276        self.assertEqual(new_shapes["x"][0].min, 6)
3277        self.assertEqual(new_shapes["y"][0].max, 16)
3278
3279        # divisiblity, will introduce new root
3280        class Foo(torch.nn.Module):
3281            def forward(self, x):
3282                if x.shape[0] >= 9:
3283                    return x.reshape([-1, 3])
3284
3285        inps = (
3286            torch.randn(
3287                15,
3288            ),
3289        )
3290        dynamic_shapes = ((Dim("dx"),),)
3291        new_shapes = helper(Foo(), inps, dynamic_shapes)
3292        dim = new_shapes[0][0]
3293        root = dim.root
3294        self.assertEqual(dim.fn(2), 6)
3295        self.assertEqual(root.min, 3)
3296
3297        # turn dim into derived dim/relation
3298        class Foo(torch.nn.Module):
3299            def forward(self, x, y):
3300                return x + y[4:]
3301
3302        inps = (torch.randn(6, 4), torch.randn(10, 4))
3303        dynamic_shapes = {
3304            "x": (Dim("dx0"), Dim("dx1")),
3305            "y": (Dim("dy0"), Dim("dy1")),
3306        }
3307        new_shapes = helper(Foo(), inps, dynamic_shapes)
3308        self.assertEqual(new_shapes["x"][0], new_shapes["y"][0].root)  # dy0 = dx0 + 4
3309        self.assertEqual(new_shapes["y"][0].fn(5), 9)
3310        self.assertEqual(new_shapes["x"][1], new_shapes["y"][1])  # dx1 = dy1
3311
3312        # nested dynamic shapes spec
3313        class Foo(torch.nn.Module):
3314            def forward(self, x, y):
3315                x0 = x[0]["data"] + x[1] + x[2][2:]
3316                x1 = y["a"] @ torch.randn(4, 4)
3317                x2 = y["b"] @ torch.randn(6, 6)
3318                return x0, x1, x2
3319
3320        inps = (
3321            [
3322                {"data": torch.randn(4, 4)},
3323                torch.randn(4, 4),
3324                torch.randn(6, 4),
3325            ],
3326            {
3327                "a": torch.randn(8, 4),
3328                "b": torch.randn(9, 6),
3329            },
3330        )
3331        dynamic_shapes = {
3332            "x": [
3333                {"data": (Dim("dx00"), Dim("dx01"))},
3334                (Dim("dx10"), Dim("dx11")),
3335                (Dim("dx20"), Dim("dx21")),
3336            ],
3337            "y": {
3338                "a": (Dim("dya0"), Dim("dya1")),
3339                "b": (Dim("dyb0"), Dim("dyb1")),
3340            },
3341        }
3342        new_shapes = helper(Foo(), inps, dynamic_shapes)
3343        self.assertEqual(
3344            new_shapes["x"][0]["data"][0], new_shapes["x"][1][0]
3345        )  # dx10 = dx00
3346        self.assertEqual(
3347            new_shapes["x"][2][0].root, new_shapes["x"][0]["data"][0]
3348        )  # dx20 = dx00 + 2
3349        self.assertEqual(new_shapes["x"][2][0].fn(10), 12)
3350        self.assertEqual(
3351            new_shapes["x"][0]["data"][1], new_shapes["x"][1][1]
3352        )  # dx11 = dx01
3353        self.assertEqual(new_shapes["y"]["a"][1], 4)
3354        self.assertEqual(new_shapes["y"]["b"][1], 6)
3355        self.assertEqual(new_shapes["y"]["b"][0].__name__, "dyb0")  # unchanged
3356
3357    def test_dynamic_shapes_spec_with_pytree(self):
3358        from torch.export import Dim, export
3359        from torch.utils._pytree import tree_map
3360
3361        inputs = {
3362            "tensor": torch.randn(3),
3363            "dict_of_tensors": {k: torch.randn(3) for k in ["A", "B", "C", "D"]},
3364            "list_of_tensors": [torch.randn(3) for _ in range(4)],
3365        }
3366
3367        batch = Dim("batch")
3368        # uniformly specify dynamic shapes for all inputs
3369        spec = tree_map(lambda x: {0: batch}, inputs)
3370
3371        class Foo(torch.nn.Module):
3372            def forward(self, inputs):
3373                return (
3374                    inputs["tensor"]
3375                    + inputs["dict_of_tensors"]["A"]
3376                    + inputs["list_of_tensors"][0]
3377                )
3378
3379        ep = export(Foo(), (inputs,), dynamic_shapes={"inputs": spec})
3380        input_shapes = [
3381            str(node.meta["val"].shape)
3382            for node in ep.graph_module.graph.nodes
3383            if node.op == "placeholder"
3384        ]
3385        self.assertEqual(len(input_shapes), 9)
3386        self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes))
3387
3388    def test_error_does_not_reference_eager_fallback(self):
3389        class Module(torch.nn.Module):
3390            def forward(self, x):
3391                y = x.nonzero()
3392                z = y.shape[0]
3393                if z > 2:
3394                    return x.cos()
3395                else:
3396                    return x.sin()
3397
3398        fn_ddo = Module()
3399        if is_non_strict_test(self._testMethodName):
3400            error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
3401            error_msg = r"Could not guard on data-dependent expression"
3402        else:
3403            error = torchdynamo.exc.UserError
3404            error_msg = r"^(?!.*fall back to eager).*"
3405        with self.assertRaisesRegex(error, error_msg):
3406            _ = export(fn_ddo, (torch.tensor([2, 3, 5]),))
3407
3408    def test_pytree_register_data_class(self):
3409        @dataclass
3410        class MyDataClass:
3411            x: int
3412            y: int
3413            z: int = None
3414
3415        dt = MyDataClass(x=3, y=4)
3416        flat, spec = tree_flatten(dt)
3417        self.assertTrue(spec, LeafSpec())
3418        self.assertTrue(len(flat) == 1)
3419
3420        register_dataclass_as_pytree_node(
3421            MyDataClass,
3422            serialized_type_name="test_pytree_register_data_class.MyDataClass",
3423        )
3424
3425        flat, spec = tree_flatten(dt)
3426        self.assertEqual(
3427            spec,
3428            TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
3429        )
3430        self.assertEqual(flat, [3, 4])
3431
3432        orig_dt = tree_unflatten(flat, spec)
3433        self.assertTrue(isinstance(orig_dt, MyDataClass))
3434        self.assertEqual(orig_dt.x, 3)
3435        self.assertEqual(orig_dt.y, 4)
3436        self.assertEqual(orig_dt.z, None)
3437
3438        roundtrip_spec = treespec_loads(treespec_dumps(spec))
3439        self.assertEqual(roundtrip_spec, spec)
3440
3441        @dataclass
3442        class MyOtherDataClass:  # the pytree registration don't allow registering the same class twice
3443            x: int
3444            y: int
3445            z: int = None
3446
3447        # Override the registration with keep none fields
3448        register_dataclass_as_pytree_node(
3449            MyOtherDataClass,
3450            return_none_fields=True,
3451            serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass",
3452        )
3453
3454        dt = MyOtherDataClass(x=3, y=4)
3455        flat, spec = tree_flatten(dt)
3456        self.assertEqual(
3457            spec,
3458            TreeSpec(
3459                MyOtherDataClass,
3460                [["x", "y", "z"], []],
3461                [LeafSpec(), LeafSpec(), LeafSpec()],
3462            ),
3463        )
3464        self.assertEqual(flat, [3, 4, None])
3465
3466        orig_dt = tree_unflatten(flat, spec)
3467        self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
3468        self.assertEqual(orig_dt.x, 3)
3469        self.assertEqual(orig_dt.y, 4)
3470        self.assertEqual(orig_dt.z, None)
3471
3472        roundtrip_spec = treespec_loads(treespec_dumps(spec))
3473        self.assertEqual(roundtrip_spec, spec)
3474
3475    def test_pytree_register_nested_data_class(self):
3476        @dataclass
3477        class Inner:
3478            x: int
3479            y: int
3480
3481        @dataclass
3482        class Outer:
3483            xy: Inner
3484            ab: Inner
3485
3486        xy = Inner(1, 2)
3487        ab = Inner(3, 4)
3488        dt = Outer(xy, ab)
3489        inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}
3490
3491        register_dataclass_as_pytree_node(
3492            Inner, serialized_type_name="test_pytree_register_nested_data_class.Inner"
3493        )
3494        register_dataclass_as_pytree_node(
3495            Outer, serialized_type_name="test_pytree_register_nested_data_class.Outer"
3496        )
3497
3498        flat, spec = tree_flatten(inp)
3499        self.assertEqual(flat, [1, 2, 3, 4, torch.ones(1), 1, 2, 3, 4])
3500
3501        unflat = tree_unflatten(flat, spec)
3502        self.assertEqual(unflat, inp)
3503
3504        roundtrip_spec = treespec_loads(treespec_dumps(spec))
3505        self.assertEqual(roundtrip_spec, spec)
3506
3507    def test_param_util(self):
3508        class Basic(torch.nn.Module):
3509            def __init__(self) -> None:
3510                super().__init__()
3511                self.lin = torch.nn.Linear(10, 1)
3512
3513            def forward(self, x):
3514                return self.lin(x)
3515
3516        ep = export(Basic(), (torch.randn(5, 10),))
3517        num_params = 0
3518        params = []
3519        for node in ep.graph.nodes:
3520            if is_param(ep, node):
3521                num_params += 1
3522                params.append(get_param(ep, node))
3523        self.assertEqual(num_params, 2)
3524        self.assertEqual(params[0].shape, [1, 10])  # weight
3525        self.assertEqual(params[1].shape, [1])  # bias
3526
3527    def test_buffer_util(self):
3528        ep = export(
3529            torch.nn.BatchNorm2d(100, affine=False), (torch.ones(20, 100, 35, 45),)
3530        )
3531        num_buffer = 0
3532        buffer = []
3533
3534        for node in ep.graph.nodes:
3535            if is_buffer(ep, node):
3536                num_buffer += 1
3537                buffer.append(get_buffer(ep, node))
3538        self.assertEqual(num_buffer, 3)
3539
3540        self.assertEqual(buffer[0].shape, torch.Size([100]))  # running_mean
3541        self.assertEqual(buffer[1].shape, torch.Size([100]))  # running_var
3542        self.assertEqual(buffer[2].shape, torch.Size([]))  # num_batches_tracked
3543
3544    def test_export_dynamo_config(self):
3545        class MyModule(torch.nn.Module):
3546            def __init__(self) -> None:
3547                super().__init__()
3548                self.lstm = torch.nn.LSTM(input_size=4, hidden_size=5, num_layers=1)
3549
3550            def forward(self, inputs: torch.Tensor) -> torch.Tensor:
3551                return self.lstm(inputs)
3552
3553        config = DEFAULT_EXPORT_DYNAMO_CONFIG
3554        mod = MyModule()
3555
3556        @contextmanager
3557        def _patch_config(kwargs):
3558            orig_config_dict = dataclasses.asdict(config)
3559
3560            try:
3561                for k, v in kwargs.items():
3562                    setattr(config, k, v)
3563                yield
3564            finally:
3565                for k, v in orig_config_dict.items():
3566                    setattr(config, k, v)
3567
3568        inp = (torch.rand(5, 4),)
3569        exported_program = export(mod, inp, strict=True)
3570
3571        with _patch_config({"allow_rnn": False}):
3572            with self.assertRaisesRegex(
3573                torch._dynamo.exc.Unsupported,
3574                "TorchDynamo purposely graph breaks on RNN, GRU, LSTMs",
3575            ):
3576                _ = export(mod, inp, strict=True)
3577
3578    def test_device_to_static(self):
3579        class Module(torch.nn.Module):
3580            def forward(self, x):
3581                return x.to("cpu")
3582
3583        ep = export(Module(), (torch.tensor(1, device="cpu"),))
3584        ops = []
3585        for node in ep.graph.nodes:
3586            if node.op == "call_function":
3587                ops.append(node.target)
3588        self.assertGreater(len(ops), 0)
3589        for op in ops:
3590            self.assertIn(op, (torch.ops.aten._to_copy.default,))
3591
3592    def test_device_to_dynamic(self):
3593        class Module(torch.nn.Module):
3594            def forward(self, x):
3595                return x.to("cpu")
3596
3597        ep = export(
3598            Module(),
3599            (torch.tensor([1, 2], device="cpu"),),
3600            dynamic_shapes={"x": {0: Dim("i")}},
3601        )
3602        ops = []
3603        for node in ep.graph.nodes:
3604            if node.op == "call_function":
3605                ops.append(node.target)
3606        self.assertGreater(len(ops), 0)
3607        for op in ops:
3608            self.assertIn(op, (torch.ops.aten._to_copy.default,))
3609
3610    def test_device_to_mutation(self):
3611        class Module(torch.nn.Module):
3612            def forward(self, x):
3613                y = x.to("cpu")
3614                y.add_(1)
3615                return y, x
3616
3617        with self.assertRaisesRegex(
3618            RuntimeError, "cannot mutate tensors with frozen storage"
3619        ):
3620            export(Module(), (torch.tensor(1, device="cpu"),))
3621
3622    def test_float_conversion(self):
3623        class Module(torch.nn.Module):
3624            def forward(self, x):
3625                return x.float()
3626
3627        ep = export(Module(), (torch.tensor(1, dtype=torch.float),))
3628        ops = []
3629        for node in ep.graph.nodes:
3630            if node.op == "call_function":
3631                ops.append(node.target)
3632        self.assertGreater(len(ops), 0)
3633        for op in ops:
3634            self.assertIn(op, (torch.ops.aten._to_copy.default,))
3635
3636    def test_device_to_mutation_float(self):
3637        class Module(torch.nn.Module):
3638            def forward(self, x):
3639                y = x.float()
3640                y.add_(1)
3641                return y, x
3642
3643        with self.assertRaisesRegex(
3644            RuntimeError, "cannot mutate tensors with frozen storage"
3645        ):
3646            export(Module(), (torch.tensor(1, dtype=torch.float),))
3647
3648    def test_module(self):
3649        class MyLinear(torch.nn.Module):
3650            def __init__(self) -> None:
3651                super().__init__()
3652                self.weight = torch.randn(20, 98)
3653                self.bias = torch.randn(20)
3654
3655            def forward(self, x):
3656                return torch.nn.functional.linear(x, self.weight, self.bias)
3657
3658        class Foo(torch.nn.Module):
3659            def __init__(self) -> None:
3660                super().__init__()
3661                self.conv = torch.nn.Conv2d(16, 33, 3)
3662                self.linear = MyLinear()
3663
3664            def forward(self, x):
3665                a, b = x
3666                a_conv = self.conv(a)
3667                a_linear = self.linear(a_conv)
3668                b_conv = self.conv(b)
3669                b_linear = self.linear(b_conv)
3670                return (
3671                    a_linear.cos() + b_linear.sin(),
3672                    a_linear.sin() + b_linear.cos(),
3673                )
3674
3675        inp_container = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),)
3676
3677        ep = export(Foo(), inp_container)
3678        ep_rexported = export(ep.module(), inp_container)
3679
3680        inp_test = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),)
3681
3682        self.assertTrue(
3683            torch.allclose(
3684                ep.module()(*inp_test)[0], ep_rexported.module()(*inp_test)[0]
3685            )
3686        )
3687        self.assertTrue(
3688            torch.allclose(
3689                ep.module()(*inp_test)[1], ep_rexported.module()(*inp_test)[1]
3690            )
3691        )
3692
3693    def test_use_embedding_twice(self):
3694        class Foo(torch.nn.Module):
3695            def __init__(self):
3696                super().__init__()
3697                self.embed = torch.nn.Embedding(4, 4)
3698
3699            def forward(self, x):
3700                return self.embed(x) + self.embed.weight[x]
3701
3702        inputs = (torch.tensor([0, 1, 2, 3]),)
3703        ep = export(Foo(), inputs)
3704
3705    def test_module_with_dict_container_inp_out(self):
3706        class MyLinear(torch.nn.Module):
3707            def __init__(self) -> None:
3708                super().__init__()
3709                self.weight = torch.randn(20, 98)
3710                self.bias = torch.randn(20)
3711
3712            def forward(self, x):
3713                return torch.nn.functional.linear(x, self.weight, self.bias)
3714
3715        class Foo(torch.nn.Module):
3716            def __init__(self) -> None:
3717                super().__init__()
3718                self.conv = torch.nn.Conv2d(16, 33, 3)
3719                self.linear = MyLinear()
3720
3721            def forward(self, x):
3722                a1, a2 = x["a"]
3723                b = x["b"]
3724                a1_conv = self.conv(a1)
3725                a1_linear = self.linear(a1_conv)
3726                a2_conv = self.conv(a2)
3727                a2_linear = self.linear(a2_conv)
3728                b_conv = self.conv(b)
3729                b_linear = self.linear(b_conv)
3730                return {
3731                    "a": a1_linear.cos() + b_linear.sin(),
3732                    "b": a2_linear.sin() + b_linear.cos(),
3733                }
3734
3735        inp_container = (
3736            {
3737                "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
3738                "b": torch.randn(20, 16, 50, 100),
3739            },
3740        )
3741
3742        ep = export(Foo(), inp_container)
3743        ep_rexported = export(ep.module(), inp_container)
3744
3745        inp_test = (
3746            {
3747                "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
3748                "b": torch.randn(20, 16, 50, 100),
3749            },
3750        )
3751
3752        self.assertTrue(
3753            torch.allclose(
3754                ep.module()(*inp_test)["a"], ep_rexported.module()(*inp_test)["a"]
3755            )
3756        )
3757        self.assertTrue(
3758            torch.allclose(
3759                ep.module()(*inp_test)["b"], ep_rexported.module()(*inp_test)["b"]
3760            )
3761        )
3762
3763    def test_args_type_checked(self):
3764        class M(torch.nn.Module):
3765            def forward(self, x):
3766                return x + 1
3767
3768        inp = torch.rand(2, 2)
3769        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "to be a tuple"):
3770            # Intentionally not wrapping `inp` in a tuple to trigger the error
3771            _ = export(M(), inp)
3772
3773    def test_decomp_batch_norm_functional_predispatch(self):
3774        class ConvBatchnorm(torch.nn.Module):
3775            def __init__(self) -> None:
3776                super().__init__()
3777                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
3778                self.bn = torch.nn.BatchNorm2d(3)
3779
3780            def forward(self, x):
3781                x = self.conv(x)
3782                x = self.bn(x)
3783                return (x,)
3784
3785        mod = ConvBatchnorm()
3786        mod.eval()
3787        inp = torch.randn(1, 1, 3, 3)
3788
3789        gm = torch.export._trace._export(mod, (inp,), pre_dispatch=True).module()
3790        self.assertExpectedInline(
3791            str(gm.code).strip(),
3792            """\
3793def forward(self, x):
3794    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
3795    conv_weight = self.conv.weight
3796    conv_bias = self.conv.bias
3797    bn_weight = self.bn.weight
3798    bn_bias = self.bn.bias
3799    bn_running_mean = self.bn.running_mean
3800    bn_running_var = self.bn.running_var
3801    bn_num_batches_tracked = self.bn.num_batches_tracked;  bn_num_batches_tracked = None
3802    conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias);  x = conv_weight = conv_bias = None
3803    _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05);  conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
3804    getitem = _native_batch_norm_legit_no_training[0];  _native_batch_norm_legit_no_training = None
3805    return pytree.tree_unflatten((getitem,), self._out_spec)""",
3806        )
3807
3808        mod.train()
3809        gm_train = _export(mod, (inp,), pre_dispatch=True).module()
3810        self.assertExpectedInline(
3811            str(gm_train.code).strip(),
3812            """\
3813def forward(self, x):
3814    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
3815    conv_weight = self.conv.weight
3816    conv_bias = self.conv.bias
3817    bn_weight = self.bn.weight
3818    bn_bias = self.bn.bias
3819    bn_running_mean = self.bn.running_mean
3820    bn_running_var = self.bn.running_var
3821    bn_num_batches_tracked = self.bn.num_batches_tracked
3822    conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias);  x = conv_weight = conv_bias = None
3823    add = torch.ops.aten.add.Tensor(bn_num_batches_tracked, 1)
3824    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05);  conv2d = bn_weight = bn_bias = None
3825    getitem = _native_batch_norm_legit_functional[0]
3826    getitem_3 = _native_batch_norm_legit_functional[3]
3827    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
3828    copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3);  bn_running_mean = getitem_3 = copy__default = None
3829    copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4);  bn_running_var = getitem_4 = copy__default_1 = None
3830    copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add);  bn_num_batches_tracked = add = copy__default_2 = None
3831    return pytree.tree_unflatten((getitem,), self._out_spec)""",
3832        )
3833
3834    def test_constrain_size_in_eager(self):
3835        class Module(torch.nn.Module):
3836            def forward(self, x, y):
3837                n = x.max().item()
3838                torch._check_is_size(n)
3839                return y + n
3840
3841        fn = Module()
3842        ep = export(
3843            fn,
3844            (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))),
3845        )
3846        test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
3847        self.assertTrue(torch.allclose(ep.module()(*test_inp), fn(*test_inp)))
3848
3849    def test_constrain_size_with_constrain_value(self):
3850        class Module(torch.nn.Module):
3851            def forward(self, x, y):
3852                n = x.max().item()
3853                torch._check(n >= 2)
3854                torch._check(n <= 10)
3855                torch._check_is_size(n)
3856                return y + n
3857
3858        fn = Module()
3859        with self.assertRaisesRegex(
3860            RuntimeError, r"Expected cond to be True, but got False"
3861        ):
3862            _ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
3863
3864        ep = export(
3865            fn,
3866            (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3))),
3867        )
3868        with self.assertRaisesRegex(
3869            RuntimeError, r"Runtime assertion failed for expression u[\d+] \>\= 2"
3870        ):
3871            test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
3872            _ = ep.module()(*test_inp)
3873
3874    def test_constrain_size_with_various_cases(self):
3875        class Module1(torch.nn.Module):
3876            def forward(self, x, y):
3877                n = x.item()
3878                torch._check_is_size(n)
3879                torch._check(n >= 0)
3880                return y.sum() + torch.ones(n, 5).sum()
3881
3882        case1 = Module1()
3883
3884        class Module2(torch.nn.Module):
3885            def forward(self, x, y):
3886                n = x.item()
3887                torch._check_is_size(n)
3888                torch._check(n >= 0)
3889                torch._check(n <= 6)
3890                return y.sum() + torch.ones(n, 5).sum()
3891
3892        case2 = Module2()
3893
3894        class Module3(torch.nn.Module):
3895            def forward(self, x, y):
3896                n = x.item()
3897                torch._check_is_size(n)
3898                torch._check(n >= 0)
3899                torch._check(n <= 1)
3900                return y.sum() + torch.ones(n, 5).sum()
3901
3902        case3 = Module3()
3903
3904        class Module4(torch.nn.Module):
3905            def forward(self, x, y):
3906                n = x.item()
3907                torch._check_is_size(n)
3908                torch._check(n >= 2)
3909                return y.sum() + torch.ones(n, 5).sum()
3910
3911        case4 = Module4()
3912
3913        class Module5(torch.nn.Module):
3914            def forward(self, x, y):
3915                n = x.item()
3916                torch._check_is_size(n)
3917                torch._check(n >= 1)
3918                return y.sum() + torch.ones(n, 5).sum()
3919
3920        case5 = Module5()
3921
3922        ep = export(case1, (torch.tensor(1), torch.ones(4, 5)))
3923
3924        with self.assertRaisesRegex(
3925            RuntimeError, r"Expected cond to be True, but got False"
3926        ):
3927            _ = case1(torch.tensor(-1), torch.randn(4, 5))
3928
3929        self.assertTrue(
3930            torch.allclose(
3931                ep.module()(torch.tensor(1), torch.ones(4, 5)),
3932                case1(torch.tensor(1), torch.ones(4, 5)),
3933            )
3934        )
3935
3936        ep = export(case2, (torch.tensor(5), torch.randn(4, 5)))
3937
3938        with self.assertRaisesRegex(
3939            RuntimeError,
3940            r"Expected cond to be True, but got False",
3941        ):
3942            _ = case2(torch.tensor(7), torch.randn(4, 5))
3943
3944        with self.assertRaisesRegex(
3945            RuntimeError,
3946            r"Expected cond to be True, but got False",
3947        ):
3948            _ = case2(torch.tensor(9), torch.randn(4, 5))
3949
3950        self.assertTrue(
3951            torch.allclose(
3952                ep.module()(torch.tensor(5), torch.ones(4, 5)),
3953                case2(torch.tensor(5), torch.ones(4, 5)),
3954            )
3955        )
3956
3957        _ = case3(torch.tensor(1), torch.randn(4, 5))
3958
3959        with self.assertRaisesRegex(
3960            RuntimeError,
3961            r"Expected cond to be True, but got False",
3962        ):
3963            _ = case4(torch.tensor(1), torch.randn(4, 5))
3964
3965        ep = export(case4, (torch.tensor(5), torch.randn(4, 5)))
3966
3967        with self.assertRaisesRegex(
3968            RuntimeError,
3969            r"Expected cond to be True, but got False",
3970        ):
3971            _ = case4(torch.tensor(1), torch.randn(4, 5))
3972
3973        self.assertTrue(
3974            torch.allclose(
3975                ep.module()(torch.tensor(5), torch.ones(4, 5)),
3976                case4(torch.tensor(5), torch.ones(4, 5)),
3977            )
3978        )
3979
3980        ep = export(case5, (torch.tensor(5), torch.randn(4, 5)))
3981
3982        with self.assertRaisesRegex(
3983            RuntimeError,
3984            r"Expected cond to be True, but got False",
3985        ):
3986            _ = case5(torch.tensor(0), torch.randn(4, 5))
3987
3988        self.assertTrue(
3989            torch.allclose(
3990                ep.module()(torch.tensor(5), torch.ones(4, 5)),
3991                case5(torch.tensor(5), torch.ones(4, 5)),
3992            )
3993        )
3994
3995    def test_automatic_constrain_size(self):
3996        class M(torch.nn.Module):
3997            def forward(self, x, y):
3998                n = x.item()
3999                return y.sum() + torch.ones(n, 5).sum()
4000
4001        ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))
4002
4003        # This is because we insert sym_constrain_range in the graph now
4004        error_msg = r"Invalid value range for -1 between"
4005        with self.assertRaisesRegex(RuntimeError, error_msg):
4006            _ = ep.module()(torch.tensor(-1), torch.randn(4, 5))
4007
4008        self.assertTrue(
4009            torch.allclose(
4010                ep.module()(torch.tensor(1), torch.ones(4, 5)),
4011                M()(torch.tensor(1), torch.ones(4, 5)),
4012            )
4013        )
4014
4015    def test_constrain_decomp(self) -> None:
4016        class M(torch.nn.Module):
4017            def __init__(self) -> None:
4018                super().__init__()
4019                self.freq = torch.ones(5, 5)
4020
4021            def forward(self, start_pos: torch.Tensor):
4022                pos = start_pos.item()
4023                torch._check_is_size(pos)
4024                torch._check(pos >= 0)
4025                torch._check(pos <= 4)
4026                return self.freq[pos] * self.freq[pos]
4027
4028        ep = torch.export.export(M(), (torch.tensor(1),))
4029        FileCheck().check_count(
4030            "torch.ops.aten._assert_scalar.default", 2, exactly=True
4031        ).run(ep.graph_module.code)
4032        FileCheck().check_count(
4033            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4034        ).run(ep.graph_module.code)
4035
4036        decompose_ep = ep.run_decompositions()
4037        FileCheck().check_count(
4038            "torch.ops.aten._assert_scalar.default", 2, exactly=True
4039        ).run(ep.graph_module.code)
4040        FileCheck().check_count(
4041            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4042        ).run(ep.graph_module.code)
4043
4044    def test_mixed_input(self):
4045        class Module(torch.nn.Module):
4046            def forward(self, a, b, alpha: int):
4047                return torch.add(a, b, alpha=alpha)
4048
4049        func = Module()
4050
4051        a = torch.rand(1, 2)
4052        b = torch.rand(1, 2)
4053        alpha = 10
4054
4055        exported = export(func, (a, b, alpha))
4056        for node in exported.graph_module.graph.nodes:
4057            if node.op == "placeholder":
4058                self.assertTrue(isinstance(node.meta["val"], (Tensor, int)))
4059
4060    def test_export_with_inline_constraints(self):
4061        class Module(torch.nn.Module):
4062            def forward(self, x):
4063                a = x.item()
4064                torch._check(a >= 4)
4065                torch._check(a <= 7)
4066                return torch.empty((a, 4))
4067
4068        f = Module()
4069        ep = export(f, (torch.tensor([5]),))
4070        self.assertEqual(ep.module()(torch.tensor([6])).shape, (6, 4))
4071
4072        FileCheck().check_count(
4073            "torch.ops.aten._assert_scalar.default", 2, exactly=True
4074        ).run(ep.graph_module.code)
4075        FileCheck().check_count(
4076            "torch.ops.aten.sym_constrain_range.default", 0, exactly=True
4077        ).run(ep.graph_module.code)
4078        FileCheck().check_count(
4079            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4080        ).run(ep.graph_module.code)
4081
4082        with self.assertRaisesRegex(
4083            RuntimeError,
4084            r"Runtime assertion failed for expression u[\d+] \<\= 7",
4085        ) as cm:
4086            ep.module()(torch.tensor([30]))
4087
4088    def test_export_with_inline_constraints_complex(self):
4089        class Module(torch.nn.Module):
4090            def forward(self, x):
4091                a = x.item()
4092                torch._check(a >= 4)
4093                torch._check(a <= 7)
4094                empty = torch.empty((a, 4))
4095
4096                return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
4097
4098        f = Module()
4099        ep = export(f, (torch.tensor([6]),))
4100        self.assertEqual(ep.module()(torch.tensor([5])).shape, (10, 5))
4101        FileCheck().check_count(
4102            "torch.ops.aten._assert_scalar.default", 2, exactly=True
4103        ).run(ep.graph_module.code)
4104        FileCheck().check_count(
4105            "torch.ops.aten.sym_constrain_range.default", 0, exactly=True
4106        ).run(ep.graph_module.code)
4107        FileCheck().check_count(
4108            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4109        ).run(ep.graph_module.code)
4110
4111    def test_to_module_with_mutated_buffer(self):
4112        class Foo(torch.nn.Module):
4113            def __init__(self) -> None:
4114                super().__init__()
4115                self.buf = torch.nn.Buffer(torch.zeros(1))
4116
4117            def forward(self, x):
4118                self.buf.add_(1)
4119                return x.sum() + self.buf.sum()
4120
4121        exported = export(Foo(), (torch.ones(5, 5),))
4122        stateful_gm = exported.module()
4123        export_return_val = stateful_gm(torch.ones(5, 5))
4124        eager = Foo()
4125        eager_return_val = eager(torch.ones(5, 5))
4126        self.assertTrue(torch.allclose(eager_return_val, export_return_val))
4127
4128        for name, buffer in stateful_gm.named_buffers():
4129            self.assertTrue(torch.allclose(torch.ones(1), buffer))
4130
4131        changed = stateful_gm.graph.eliminate_dead_code()
4132        self.assertFalse(changed)
4133        self.assertTrue(
4134            torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5)))
4135        )
4136
4137        for name, buffer in stateful_gm.named_buffers():
4138            self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
4139
4140    def test_to_module_with_mutated_buffer_multiple(self):
4141        class Bar(torch.nn.Module):
4142            def __init__(self) -> None:
4143                super().__init__()
4144                self.buf = torch.nn.Buffer(torch.ones(1))
4145
4146            def forward(self, x):
4147                self.buf.add_(1)
4148                return x.sum() + self.buf.sum()
4149
4150        class Foo(torch.nn.Module):
4151            def __init__(self) -> None:
4152                super().__init__()
4153                self.buf = torch.nn.Buffer(torch.zeros(1))
4154                self.bar = Bar()
4155
4156            def forward(self, x):
4157                self.buf.add_(1)
4158                self.bar.buf.add_(2)
4159                bar = self.bar(x)
4160                return bar.sum() + self.buf.sum()
4161
4162        exported = export(Foo(), (torch.ones(5, 5),))
4163        stateful_gm = exported.module()
4164        export_return_val = stateful_gm(torch.ones(5, 5))
4165        eager = Foo()
4166        eager_return_val = eager(torch.ones(5, 5))
4167        self.assertTrue(torch.allclose(eager_return_val, export_return_val))
4168
4169        for name, buffer in stateful_gm.named_buffers():
4170            if name == "L__self___buf":
4171                self.assertTrue(torch.allclose(torch.ones(1), buffer))
4172            if name == "L__self___bar_buf":
4173                self.assertTrue(
4174                    torch.allclose(torch.tensor(4, dtype=torch.float), buffer)
4175                )
4176
4177        changed = stateful_gm.graph.eliminate_dead_code()
4178        self.assertFalse(changed)
4179        self.assertTrue(
4180            torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5)))
4181        )
4182
4183        for name, buffer in stateful_gm.named_buffers():
4184            if name == "L__self___buf":
4185                self.assertTrue(
4186                    torch.allclose(torch.tensor(2, dtype=torch.float), buffer)
4187                )
4188            if name == "L__self___bar_buf":
4189                self.assertTrue(
4190                    torch.allclose(torch.tensor(7, dtype=torch.float), buffer)
4191                )
4192
4193    def test_runtime_assert_for_prim(self):
4194        class Foo(torch.nn.Module):
4195            def forward(self, x, y):
4196                return x + y
4197
4198        foo = Foo()
4199        tensor_inp = torch.ones(7, 5)
4200        dim0_x = torch.export.Dim("dim0_x", min=6)
4201        dynamic_shapes = {"x": {0: dim0_x}, "y": None}
4202        exported = torch.export.export(
4203            foo, (tensor_inp, 5), dynamic_shapes=dynamic_shapes
4204        )
4205        self.assertTrue(
4206            torch.allclose(
4207                exported.module()(torch.ones(8, 5), 5), foo(torch.ones(8, 5), 5)
4208            )
4209        )
4210        with self.assertRaisesRegex(
4211            RuntimeError,
4212            escape("Expected input at *args[1] to be equal to 5, but got 6"),
4213        ):
4214            _ = exported.module()(torch.ones(8, 5), 6)
4215
4216        exported = torch.export.export(
4217            foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes
4218        )
4219        with self.assertRaisesRegex(
4220            RuntimeError,
4221            escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"),
4222        ):
4223            _ = exported.module()(torch.ones(7, 5), 6.0)
4224
4225    def test_runtime_assert_for_prm_str(self):
4226        class Foo(torch.nn.Module):
4227            def forward(self, a, b, mode):
4228                return torch.div(a, b, rounding_mode=mode)
4229
4230        foo = Foo()
4231        inps = (torch.randn(4, 4), torch.randn(4), "trunc")
4232        exported = export(foo, inps)
4233        with self.assertRaisesRegex(
4234            RuntimeError, "to be equal to trunc, but got floor"
4235        ):
4236            _ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor")
4237        self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
4238
4239    def test_redundant_assert_max_upper_bound(self):
4240        class M(torch.nn.Module):
4241            def forward(self, x):
4242                b = x.nonzero()
4243                torch._check(b.shape[0] >= 3)
4244                return b
4245
4246        m = M()
4247        inp = (torch.tensor([1, 1, 1, 0, 1]),)
4248        dim = torch.export.Dim("dim")
4249        ep = export(m, inp, dynamic_shapes=((dim,),))
4250        FileCheck().check_count(
4251            "torch.ops.aten._assert_scalar.default", 1, exactly=True
4252        ).run(ep.graph_module.code)
4253
4254    def test_to_module_with_mutated_buffer_multiple_update_sub_later(self):
4255        class Bar(torch.nn.Module):
4256            def __init__(self) -> None:
4257                super().__init__()
4258                self.buf = torch.nn.Buffer(torch.ones(1))
4259
4260            def forward(self, x):
4261                self.buf.add_(1)
4262                return x.sum() + self.buf.sum()
4263
4264        class Foo(torch.nn.Module):
4265            def __init__(self) -> None:
4266                super().__init__()
4267                self.buf = torch.nn.Buffer(torch.zeros(1))
4268                self.bar = Bar()
4269
4270            def forward(self, x):
4271                self.buf.add_(1)
4272                bar = self.bar(x)
4273                self.bar.buf.add_(2)
4274                return bar.sum() + self.buf.sum()
4275
4276        exported = export(Foo(), (torch.ones(5, 5),))
4277        stateful_gm = exported.module()
4278        export_return_val = stateful_gm(torch.ones(5, 5))
4279        eager = Foo()
4280        eager_return_val = eager(torch.ones(5, 5))
4281        self.assertTrue(torch.allclose(eager_return_val, export_return_val))
4282
4283        for name, buffer in stateful_gm.named_buffers():
4284            if name == "L__self___buf":
4285                self.assertTrue(torch.allclose(torch.ones(1), buffer))
4286            if name == "L__self___bar_buf":
4287                self.assertTrue(
4288                    torch.allclose(torch.tensor(4, dtype=torch.float), buffer)
4289                )
4290
4291        changed = stateful_gm.graph.eliminate_dead_code()
4292        self.assertFalse(changed)
4293        self.assertTrue(
4294            torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5)))
4295        )
4296
4297        for name, buffer in stateful_gm.named_buffers():
4298            if name == "L__self___buf":
4299                self.assertTrue(
4300                    torch.allclose(torch.tensor(2, dtype=torch.float), buffer)
4301                )
4302            if name == "L__self___bar_buf":
4303                self.assertTrue(
4304                    torch.allclose(torch.tensor(7, dtype=torch.float), buffer)
4305                )
4306
4307    def test_retracable_ep(self):
4308        class Bar(torch.nn.Module):
4309            def __init__(self) -> None:
4310                super().__init__()
4311                self.buf = torch.nn.Buffer(torch.ones(1))
4312
4313            def forward(self, x):
4314                self.buf.add_(1)
4315                return x.sum() + self.buf.sum()
4316
4317        class Foo(torch.nn.Module):
4318            def __init__(self) -> None:
4319                super().__init__()
4320                self.buf = torch.nn.Buffer(torch.zeros(1))
4321                self.bar = Bar()
4322
4323            def forward(self, x):
4324                self.buf.add_(1)
4325                bar = self.bar(x)
4326                self.bar.buf.add_(2)
4327                return bar.sum() + self.buf.sum()
4328
4329        inp = torch.ones(5, 5)
4330        exported = torch.export.export(Foo(), (inp,))
4331        reexported = torch.export.export(exported.module(), (inp,))
4332
4333        self.assertTrue(torch.allclose(Foo()(inp), reexported.module()(inp)))
4334
4335        dim0_x = torch.export.Dim("dim0_x")
4336        exported = torch.export.export(Foo(), (inp,), dynamic_shapes=({0: dim0_x},))
4337        reexported = torch.export.export(exported.module(), (inp,))
4338        with self.assertRaisesRegex(
4339            RuntimeError, "shape\[0\] to be equal to 5, but got 7"
4340        ):
4341            reexported.module()(torch.ones(7, 5))
4342
4343        reexported = torch.export.export(
4344            exported.module(), (inp,), dynamic_shapes=({0: dim0_x},)
4345        )
4346        self.assertTrue(
4347            torch.allclose(
4348                Foo()(torch.ones(7, 5)), reexported.module()(torch.ones(7, 5))
4349            )
4350        )
4351
4352        # can't retrace with invalid inputs with respect to the original ExportedProgram
4353        dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
4354        exported_v2 = torch.export.export(
4355            Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}
4356        )
4357        with self.assertRaisesRegex(
4358            RuntimeError,
4359            escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"),
4360        ):
4361            torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
4362
4363    def test_export_cond_symbool_pred(self):
4364        class A(torch.nn.Module):
4365            def __init__(self) -> None:
4366                super().__init__()
4367                self.buffer = torch.nn.Buffer(torch.ones(6, 4))
4368
4369            def forward(self):
4370                return self.buffer.cos()
4371
4372        class Foo(torch.nn.Module):
4373            def __init__(self) -> None:
4374                super().__init__()
4375                self.a = A()
4376
4377            def forward(self, x):
4378                def true_fn(x):
4379                    return x.cos() + self.a().sum()
4380
4381                def false_fn(x):
4382                    return x.sin()
4383
4384                return cond(x.shape[0] > 4, true_fn, false_fn, [x])
4385
4386        dim0 = torch.export.Dim("dim0", min=3)
4387        inp = torch.ones(6, 4)
4388        ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
4389        schema = get_hop_schema(ep)
4390        self.assertExpectedInline(
4391            str(schema),
4392            """cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""",
4393        )
4394        self.assertExpectedInline(
4395            ep.graph_module.code.strip(),
4396            """\
4397def forward(self, b_a_buffer, x):
4398    sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
4399    gt = sym_size_int_1 > 4;  sym_size_int_1 = None
4400    true_graph_0 = self.true_graph_0
4401    false_graph_0 = self.false_graph_0
4402    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]);  gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
4403    getitem = cond[0];  cond = None
4404    return (getitem,)""",
4405        )
4406        self.assertTrue(
4407            torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
4408        )
4409
4410    def test_aten_lift_fresh_copy(self):
4411        class M(torch.nn.Module):
4412            def forward(self, x):
4413                return torch.ops.aten.lift_fresh_copy(x)
4414
4415        ep = export(M(), (torch.ones(6, 4),))
4416        found = False
4417
4418        op = "torch.ops.aten.clone.default"
4419        FileCheck().check_count(op, 1, exactly=True).run(ep.graph_module.code)
4420
4421    def test_cond_buffers(self):
4422        class M(torch.nn.Module):
4423            def __init__(self) -> None:
4424                super().__init__()
4425                self.register_parameter(
4426                    "param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
4427                )
4428                self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1)
4429
4430            def true_fn(self, x):
4431                return x + self.param
4432
4433            def false_fn(self, x):
4434                return x + self.buffer
4435
4436            def forward(self, x):
4437                return cond(x.shape[0] == 4, self.true_fn, self.false_fn, [x])
4438
4439        inp = torch.ones(2, 3)
4440        ep = torch.export.export(M(), (inp,))
4441        inp = torch.randn(2, 3)
4442        epm = ep.module()
4443        self.assertTrue(torch.allclose(epm(inp), M()(inp)))
4444
4445        for gm in epm.named_modules():
4446            if not isinstance(gm, torch.fx.GraphModule):
4447                continue
4448            self.assertEqual(
4449                len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1
4450            )
4451
4452    # map_fn references module outside the module hierarchy
4453    @unittest.expectedFailure
4454    def test_map_buffers(self):
4455        class M1(torch.nn.Module):
4456            def __init__(self) -> None:
4457                super().__init__()
4458                self.register_parameter(
4459                    "param", torch.nn.Parameter(torch.tensor(5), requires_grad=False)
4460                )
4461                self.buffer = torch.nn.Buffer(torch.tensor(6) + 1)
4462
4463        m1 = M1()
4464
4465        def map_fn(x, y):
4466            z = x + y + m1.param + m1.buffer
4467            z.add_(4)
4468            return z
4469
4470        class M(torch.nn.Module):
4471            def forward(self, xs, y):
4472                return map(map_fn, xs, y)
4473
4474        example_inputs = (torch.ones(3, 2), torch.tensor(3))
4475        ep = torch.export.export(M(), example_inputs)
4476        example_inputs = (torch.randn(3, 2), torch.tensor(3))
4477        epm = ep.module()
4478        self.assertTrue(torch.allclose(epm(*example_inputs), M()(*example_inputs)))
4479
4480        for gm in epm.named_modules():
4481            if not isinstance(gm, torch.fx.GraphModule):
4482                continue
4483            self.assertEqual(
4484                len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
4485            )
4486
4487    def test_check_is_size_error(self):
4488        class Module(torch.nn.Module):
4489            def forward(self, x):
4490                a = x.item()
4491                # We cannot automatically infer a is a size here because view
4492                # accepts -1
4493                return torch.randn(24).view(a, 4)
4494
4495        f = Module()
4496        if is_non_strict_test(self._testMethodName):
4497            error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
4498        else:
4499            error = torch._dynamo.exc.UserError
4500        error_msg = r"Could not guard on data-dependent expression"
4501        with self.assertRaisesRegex(error, error_msg):
4502            _ = export(f, (torch.tensor(6),))
4503
4504    def test_train_eval_on_exported_preautograd_module(self):
4505        class Foo(torch.nn.Module):
4506            def __init__(self) -> None:
4507                super().__init__()
4508
4509            def forward(self, x):
4510                if x.shape[0] > 4:
4511                    return x.cos()
4512                return x.sin()
4513
4514        graph_module = _export(Foo(), (torch.ones(7, 5),), pre_dispatch=True).module()
4515        with self.assertRaisesRegex(
4516            NotImplementedError, r"Calling train\(\) is not supported yet."
4517        ):
4518            graph_module.train()
4519
4520        with self.assertRaisesRegex(
4521            NotImplementedError, r"Calling eval\(\) is not supported yet."
4522        ):
4523            graph_module.eval()
4524
4525    def test_lifted_constants(self) -> None:
4526        class Module(torch.nn.Module):
4527            def forward(self, x):
4528                return x + torch.tensor(3)
4529
4530        f = Module()
4531        ep = export(f, (torch.tensor(1),))
4532
4533        self.assertEqual(len(ep.graph_signature.input_specs), 2)
4534        self.assertEqual(len(ep.constants), 1)
4535
4536        class Foo(torch.nn.Module):
4537            def __init__(self) -> None:
4538                super().__init__()
4539                self.a = torch.tensor(3)
4540
4541            def forward(self, x):
4542                list_tensor = [torch.tensor(3), torch.tensor(4)]
4543                return x + self.a + list_tensor[0] + list_tensor[1]
4544
4545        ep = export(Foo(), (torch.tensor(1),))
4546
4547        self.assertEqual(len(ep.graph_signature.input_specs), 4)
4548        self.assertEqual(len(ep.state_dict), 0)
4549        self.assertEqual(len(ep.constants), 3)
4550
4551        inp = (torch.tensor(5),)
4552        self.assertTrue(torch.allclose(ep.module()(*inp), Foo()(*inp)))
4553
4554        transform = ep.run_decompositions()
4555        self.assertEqual(len(ep.graph_signature.input_specs), 4)
4556        self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp)))
4557
4558    def test_tensor_attribute_zero_args(self):
4559        class Foo(torch.nn.Module):
4560            def __init__(self, value):
4561                super().__init__()
4562                self.x = torch.tensor(value)
4563
4564            def forward(self):
4565                return self.x.clone()
4566
4567        m = Foo([1, 2])
4568        ep = export(m, ())
4569        self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
4570
4571    def test_preserve_shape_dynamism_for_unused_inputs(self):
4572        @dataclass
4573        class Input:
4574            f: torch.Tensor
4575            p: torch.Tensor
4576
4577        torch._export.utils.register_dataclass_as_pytree_node(
4578            Input,
4579            serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input",
4580        )
4581
4582        class Module(torch.nn.Module):
4583            def forward(self, x: Input):
4584                return x.f + 1
4585
4586        mod = Module()
4587        example_inputs = (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)
4588        ep_static = torch.export.export(mod, example_inputs)
4589        for node in ep_static.graph.nodes:
4590            if node.op == "placeholder":
4591                for s in node.meta["val"].shape:
4592                    self.assertIsInstance(s, int)
4593
4594        dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p")
4595        dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]}
4596        ep_dynamic = torch.export.export(
4597            mod, example_inputs, dynamic_shapes=dynamic_shapes
4598        )
4599        for node in ep_dynamic.graph.nodes:
4600            if node.op == "placeholder":
4601                for i, s in enumerate(node.meta["val"].shape):
4602                    if i == 0:
4603                        self.assertIsInstance(s, torch.SymInt)
4604                    else:
4605                        self.assertIsInstance(s, int)
4606
4607    def test_multiple_definitions_same_name_dim(self):
4608        class Foo(torch.nn.Module):
4609            def forward(self, x, y):
4610                return torch.matmul(x, y)
4611
4612        A = torch.export.Dim("C", min=3)
4613        B = torch.export.Dim("C", max=12)
4614        with self.assertRaisesRegex(
4615            torch._dynamo.exc.UserError,
4616            "Found different definitions Dim\\(.*min=3\\) and Dim\\(.*max=12\\) "
4617            "for the same symbolic dimension",
4618        ):
4619            torch.export.export(
4620                Foo(),
4621                (torch.randn(10, 10), torch.randn(10, 10)),
4622                dynamic_shapes={"x": (A, B), "y": (B, A)},
4623            )
4624
4625    def test_export_with_wrong_inputs(self):
4626        class MyModule(torch.nn.Module):
4627            def forward(self, x):
4628                return x + x
4629
4630        exported_program = export(MyModule(), (torch.rand(2, 3),), {})
4631        with self.assertRaisesRegex(ValueError, "Trying to flatten user inputs"):
4632            exported_program.module()(torch.rand(2, 3), torch.rand(2, 3))
4633
4634    def test_export_decomps_simple(self):
4635        class M(torch.nn.Module):
4636            def __init__(self) -> None:
4637                super().__init__()
4638                self.lin = torch.nn.Linear(10, 1)
4639
4640            def forward(self, x):
4641                return self.lin(x)
4642
4643        inp = (torch.randn(5, 10),)
4644        m = M()
4645        ep = export(m, inp)
4646        state_dict = ep.state_dict
4647
4648        self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
4649
4650        core_aten_ep = ep.run_decompositions()
4651        FileCheck().check_count("torch.ops.aten.permute.default", 1, exactly=True).run(
4652            core_aten_ep.graph_module.code
4653        )
4654        FileCheck().check_count("torch.ops.aten.t.default", 0, exactly=True).run(
4655            core_aten_ep.graph_module.code
4656        )
4657        self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
4658        self.assertEqual(id(state_dict), id(ep.state_dict))
4659
4660    def test_export_decomps_dynamic(self):
4661        class M(torch.nn.Module):
4662            def __init__(self) -> None:
4663                super().__init__()
4664                self.lin = torch.nn.Linear(10, 1)
4665
4666            def forward(self, x):
4667                return self.lin(x)
4668
4669        inp = (torch.randn(5, 10),)
4670        m = M()
4671        ep = export(m, inp, dynamic_shapes={"x": {0: Dim("batch")}})
4672
4673        core_aten_ep = ep.run_decompositions()
4674
4675        input_node = [
4676            node for node in core_aten_ep.graph.nodes if node.op == "placeholder"
4677        ][-1]
4678        self.assertTrue(isinstance(input_node.meta["val"].shape[0], torch.SymInt))
4679
4680        FileCheck().check_count("torch.ops.aten.permute.default", 1, exactly=True).run(
4681            core_aten_ep.graph_module.code
4682        )
4683        FileCheck().check_count("torch.ops.aten.t.default", 0, exactly=True).run(
4684            core_aten_ep.graph_module.code
4685        )
4686        self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
4687
4688    def test_nonzero_2(self):
4689        class Module(torch.nn.Module):
4690            def forward(self, x):
4691                return torch.nonzero(x)
4692
4693        f = Module()
4694        ep = export(f, (torch.ones(2),))
4695        inp = torch.randn(2)
4696        self.assertTrue(torch.allclose(ep.module()(inp), torch.nonzero(inp)))
4697
4698    def test_redundant_asserts(self):
4699        class Foo(torch.nn.Module):
4700            def forward(self, x):
4701                y = x.item()
4702                torch._check_is_size(y)
4703                return torch.zeros(y)
4704
4705        f = Foo()
4706
4707        ep = export(f, (torch.tensor([3]),))
4708
4709        FileCheck().check_count(
4710            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4711        ).run(ep.graph_module.code)
4712        FileCheck().check_count(
4713            "torch.ops.aten._assert_scalar.default", 1, exactly=True
4714        ).run(ep.graph_module.code)
4715
4716        ep = ep.run_decompositions()
4717
4718        FileCheck().check_count(
4719            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
4720        ).run(ep.graph_module.code)
4721        FileCheck().check_count(
4722            "torch.ops.aten._assert_scalar.default", 1, exactly=True
4723        ).run(ep.graph_module.code)
4724
4725    def test_non_arg_name_dynamic_shapes_api(self):
4726        class Foo(torch.nn.Module):
4727            def forward(self, a, b):
4728                return a.sum() + b.sum()
4729
4730        foo = Foo()
4731        dim = torch.export.Dim("dim")
4732        ep = torch.export.export(
4733            foo,
4734            (torch.randn(4, 4), torch.randn(4, 4)),
4735            dynamic_shapes=(None, {0: dim}),
4736        )
4737
4738        test_inp = (torch.randn(4, 4), torch.randn(7, 4))
4739        self.assertEqual(ep.module()(*test_inp), foo(*test_inp))
4740
4741        ep_v2 = torch.export.export(
4742            foo,
4743            (torch.randn(4, 4), torch.randn(4, 4)),
4744            dynamic_shapes=(None, None),
4745        )
4746        with self.assertRaisesRegex(
4747            RuntimeError, "shape\[0\] to be equal to 4, but got 7"
4748        ):
4749            ep_v2.module()(*test_inp)
4750
4751    def test_constant_output(self):
4752        class ModuleConstant(torch.nn.Module):
4753            def __init__(self) -> None:
4754                super().__init__()
4755                self.b = torch.randn(3, 2)
4756
4757            def forward(self):
4758                return self.b
4759
4760        class ModuleNestedConstant(torch.nn.Module):
4761            def __init__(self) -> None:
4762                super().__init__()
4763                self.bff = torch.randn(3, 2)
4764
4765            def forward(self, x, y):
4766                return {"prediction": (x + y, self.bff)}
4767
4768        mod = ModuleConstant()
4769        ep = torch.export.export(mod, ())
4770        self.assertEqual(ep.module()(), mod())
4771
4772        args = (torch.randn(3, 2), torch.randn(3, 2))
4773        mod = ModuleNestedConstant()
4774        ep = torch.export.export(mod, args)
4775        self.assertEqual(ep.module()(*args), mod(*args))
4776
4777    def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
4778        class Foo(torch.nn.Module):
4779            def forward(self, a, b, kw1, kw2):
4780                return a.sum() + b.sum() + kw1.sum() - kw2.sum()
4781
4782        foo = Foo()
4783        dim = torch.export.Dim("dim")
4784        dim_for_kw1 = torch.export.Dim("dim_for_kw1")
4785        ep = torch.export.export(
4786            foo,
4787            (torch.randn(4, 4), torch.randn(4, 4)),
4788            {"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)},
4789            # We are specifying dynamism on the first kwarg even though user passed in
4790            # different order
4791            dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None),
4792        )
4793
4794        test_inp = (torch.randn(4, 4), torch.randn(7, 4))
4795        test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)}
4796        # This should work even if the kwarg order are flipped.
4797        self.assertEqual(
4798            ep.module()(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs)
4799        )
4800
4801    def test_non_arg_name_dynamic_shapes_api_with_container_type(self):
4802        class Foo(torch.nn.Module):
4803            def forward(self, a, b):
4804                return a[0].sum() + a[1].sum() + b.sum()
4805
4806        inp_a = (torch.randn(4, 4), torch.randn(4, 4))
4807        inp_b = torch.randn(4, 4)
4808        inp = (inp_a, inp_b)
4809
4810        count = 0
4811
4812        def dynamify_inp(x):
4813            # Mark the second input a[1] dynamic
4814            nonlocal count
4815            if count == 1:
4816                dim = torch.export.Dim("dim", min=3)
4817                count += 1
4818                return {0: dim}
4819            count += 1
4820            return None
4821
4822        dynamic_shapes = tree_map(dynamify_inp, inp)
4823
4824        foo = Foo()
4825        ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
4826
4827        test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
4828        with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"):
4829            ep.module()(*test_inp)
4830
4831    def test_nested_module(self):
4832        class M1(torch.nn.Module):
4833            def forward(self, x):
4834                return x + x
4835
4836        class M2(torch.nn.Module):
4837            def forward(self, x):
4838                m = M1()
4839                return m(x) * x
4840
4841        inps = (torch.randn(3, 3),)
4842        ep = export(M2(), inps)
4843        self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
4844
4845        add_nodes = [
4846            node
4847            for node in ep.graph.nodes
4848            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor
4849        ]
4850        self.assertEqual(len(add_nodes), 1)
4851        add_node = add_nodes[0]
4852        self.assertEqual(len(add_node.meta["nn_module_stack"]), 1)
4853        self.assertTrue("M2" in list(add_node.meta["nn_module_stack"].values())[0][1])
4854
4855        self.assertExpectedInline(
4856            str(ep.graph).strip(),
4857            """\
4858graph():
4859    %x : [num_users=2] = placeholder[target=x]
4860    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
4861    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
4862    return (mul,)""",
4863        )
4864
4865        unflattened = unflatten(ep)
4866        self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
4867
4868    def test_nested_module_with_init_buffer(self):
4869        class M1(torch.nn.Module):
4870            def __init__(self) -> None:
4871                super().__init__()
4872                self.b = torch.ones(3, 3)
4873
4874            def forward(self, x):
4875                return x + self.b
4876
4877        class M2(torch.nn.Module):
4878            def forward(self, x):
4879                m = M1()
4880                return m(x) * x
4881
4882        inps = (torch.randn(3, 3),)
4883        ep = export(M2(), inps)
4884        self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
4885
4886        self.assertEqual(len(ep.state_dict), 0)
4887        self.assertEqual(len(ep.constants), 0)
4888
4889        self.assertExpectedInline(
4890            str(ep.graph).strip(),
4891            """\
4892graph():
4893    %x : [num_users=2] = placeholder[target=x]
4894    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
4895    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {})
4896    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
4897    return (mul,)""",
4898        )
4899
4900        unflattened = unflatten(ep)
4901        self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
4902
4903    @testing.expectedFailureRetraceability  # Retracing tensor constants results in buffers
4904    def test_nested_module_with_constant_buffer(self):
4905        class M1(torch.nn.Module):
4906            def __init__(self) -> None:
4907                super().__init__()
4908                self.b = torch.tensor(5)
4909
4910            def forward(self, x):
4911                return x + self.b
4912
4913        class M2(torch.nn.Module):
4914            def forward(self, x):
4915                m = M1()
4916                return m(x) * x
4917
4918        inps = (torch.randn(3, 3),)
4919        ep = export(M2(), inps)
4920        self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
4921
4922        self.assertEqual(len(ep.state_dict), 0)
4923        self.assertEqual(len(ep.constants), 1)
4924
4925        if is_training_ir_test(self._testMethodName):
4926            self.assertExpectedInline(
4927                str(ep.graph).strip(),
4928                """\
4929graph():
4930    %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4931    %x : [num_users=2] = placeholder[target=x]
4932    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4933    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lift_fresh_copy), kwargs = {})
4934    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
4935    return (mul,)""",
4936            )
4937        else:
4938            self.assertExpectedInline(
4939                str(ep.graph).strip(),
4940                """\
4941graph():
4942    %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4943    %x : [num_users=2] = placeholder[target=x]
4944    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4945    %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
4946    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {})
4947    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
4948    return (mul,)""",
4949            )
4950
4951        unflattened = unflatten(ep)
4952        self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
4953
4954    def test_nested_module_with_parameter(self):
4955        class M1(torch.nn.Module):
4956            def __init__(self) -> None:
4957                super().__init__()
4958                self.a = torch.nn.Parameter(torch.ones(3, 3))
4959                self.b = torch.nn.Parameter(torch.tensor(5.0))
4960
4961            def forward(self, x):
4962                return x + self.a * self.b
4963
4964        class M2(torch.nn.Module):
4965            def forward(self, x):
4966                m = M1()
4967                return m(x) * x
4968
4969        inps = (torch.randn(3, 3),)
4970        # Strict export segfaults (Issue #128109)
4971        ep = torch.export.export(M2(), inps, strict=False)
4972        self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
4973
4974        self.assertEqual(len(ep.state_dict), 0)
4975        self.assertEqual(len(ep.constants), 1)
4976
4977        self.assertExpectedInline(
4978            str(ep.graph).strip(),
4979            """\
4980graph():
4981    %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
4982    %x : [num_users=2] = placeholder[target=x]
4983    %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
4984    %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
4985    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
4986    %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
4987    %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
4988    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {})
4989    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
4990    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
4991    return (mul_1,)""",
4992        )
4993
4994        unflattened = unflatten(ep)
4995        self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
4996
4997    def test_lazy_module_kwargs(self):
4998        class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
4999            def initialize_parameters(self, *args, **kwargs):
5000                pass
5001
5002            def forward(self, x, y):
5003                return x + y
5004
5005        m = LazyModule()
5006        ep = torch.export.export(
5007            m, (), {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
5008        )
5009        inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}
5010        self.assertEqual(ep.module()(**inputs), m(**inputs))
5011
5012    def test_retrace_pre_autograd(self):
5013        class Foo(torch.nn.Module):
5014            def __init__(self) -> None:
5015                super().__init__()
5016                self.buffer = torch.nn.Buffer(torch.ones(4, 4))
5017
5018            def forward(self, x):
5019                self.buffer.add_(4)
5020                return x.sum() + self.buffer.sum()
5021
5022        inp = torch.randn(4, 4)
5023        gm = _export(
5024            Foo(),
5025            (inp,),
5026            dynamic_shapes=({0: torch.export.Dim("dim", min=3)},),
5027            pre_dispatch=True,
5028        ).module()
5029
5030        with self.assertRaisesRegex(
5031            RuntimeError, escape("Expected input at *args[0].shape[0]")
5032        ):
5033            gm(torch.randn(2, 2))
5034
5035        with self.assertRaisesRegex(
5036            RuntimeError, escape("Expected input at *args[0].shape[0]")
5037        ):
5038            torch.export.export(gm, (torch.randn(2, 2),))
5039
5040        ep = torch.export.export(
5041            gm,
5042            (torch.randn(5, 4),),
5043            dynamic_shapes=({0: torch.export.Dim("dim", min=3)},),
5044        )
5045
5046        test_inp = torch.ones(8, 4)
5047        self.assertTrue(torch.allclose(ep.module()(test_inp), Foo().forward(test_inp)))
5048
5049    def test_runtime_assert_with_size(self):
5050        class M(torch.nn.Module):
5051            def forward(self, x, y):
5052                a = x.item()
5053                torch._check_is_size(a)
5054                torch._check(a <= y.size(0))
5055                return y[:a]
5056
5057        ep = export(
5058            M(),
5059            (torch.tensor(5), torch.ones(10)),
5060            dynamic_shapes={"x": None, "y": {0: torch.export.Dim("t")}},
5061        )
5062        inp = (torch.tensor(6), torch.randn(13))
5063        self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp)))
5064
5065    @unittest.skip("Test is only supposed to work with non-strict mode")
5066    def test_issue_113041(self):
5067        class TestModule(torch.nn.Module):
5068            def __init__(self) -> None:
5069                super().__init__()
5070                self.a = torch.tensor(1.0)
5071
5072            def forward(self, x: torch.Tensor) -> torch.Tensor:
5073                return x + self.a
5074
5075        def forward_hook(module: torch.nn.Module, inputs, output) -> torch.Tensor:
5076            return 2 * output
5077
5078        seq = torch.nn.Sequential(TestModule()).eval()
5079        seq.b = torch.tensor(2)
5080        handle = seq.register_forward_hook(forward_hook)
5081
5082        class M(torch.nn.Module):
5083            def __init__(self) -> None:
5084                super().__init__()
5085                self.seq = seq
5086
5087            def forward(self, x):
5088                return self.seq(x) + self.seq.b
5089
5090        inp = (torch.randn(2, 8),)
5091        ep = export(M(), inp)  # This errors because dynamo adds an extra input
5092
5093    def test_export_with_fake_tensor_inputs(self):
5094        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
5095
5096        class Model(torch.nn.Module):
5097            def __init__(self) -> None:
5098                super().__init__()
5099                self.linear = torch.nn.Linear(2, 2)
5100
5101            def forward(self, x):
5102                out = self.linear(x)
5103                return out
5104
5105        # Put the inputs on a device
5106        with fake_mode, torch.device("meta"):
5107            x = torch.rand(5, 2, 2)
5108            model = Model()
5109
5110            exported_program = torch.export.export(model, (x,))
5111            export_res = exported_program.module()(x)
5112            exp_res = model(x)
5113            all_meta_val = [
5114                node.meta["val"]
5115                for node in exported_program.graph_module.graph.nodes
5116                if "val" in node.meta
5117            ]
5118            self.assertTrue(export_res.size() == exp_res.size())
5119            self.assertTrue(all(val.device == x.device for val in all_meta_val))
5120            self.assertTrue(
5121                all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val)
5122            )
5123            decomposed_ep = exported_program.run_decompositions()
5124            export_res = decomposed_ep.module()(x)
5125            self.assertTrue(export_res.size() == exp_res.size())
5126
5127    def test_export_with_fake_tensor_inputs_on_cuda_devices(self):
5128        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
5129
5130        class Model(torch.nn.Module):
5131            def __init__(self) -> None:
5132                super().__init__()
5133                self.linear = torch.nn.Linear(2, 2)
5134
5135            def forward(self, x):
5136                out = self.linear(x)
5137                return out
5138
5139        # Put the inputs on a device
5140        with fake_mode, torch.device("meta"):
5141            x = torch.rand(5, 2, 2)
5142            model = Model()
5143
5144        # Manualy set the fake_device of fake tensors.
5145        x.fake_device = torch.device("cuda:0")
5146        for n, p in model.named_parameters():
5147            p.fake_device = torch.device("cuda:0")
5148
5149        # Need to set all the requires_grad of tensors to False, because fake_tensor with CUDA device
5150        # doesn't quite work well with aot_autograd right now due to some logic fails
5151        # the check in call getDeviceGuardImpl in InputMetadata.
5152        x.requires_grad = False
5153        for n, p in model.named_parameters():
5154            p.requires_grad = False
5155
5156        def check_device_and_fake_mode():
5157            exported_program = torch.export.export(model, (x,))
5158            export_res = exported_program.module()(x)
5159            exp_res = model(x)
5160            all_meta_val = [
5161                node.meta["val"]
5162                for node in exported_program.graph_module.graph.nodes
5163                if "val" in node.meta
5164            ]
5165            self.assertTrue(export_res.size() == exp_res.size())
5166            self.assertTrue(all(val.device == x.device for val in all_meta_val))
5167            self.assertTrue(
5168                all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val)
5169            )
5170
5171        check_device_and_fake_mode()
5172
5173    def test_run_decomposition_supports_user_input_mutation(self):
5174        class SingleOp(torch.nn.Module):
5175            def __init__(self) -> None:
5176                super().__init__()
5177                self.op = torch.ops.aten.native_batch_norm
5178
5179            def forward(
5180                self,
5181                input,
5182                weight,
5183                bias,
5184                running_mean,
5185                running_var,
5186                training,
5187                momentum,
5188                eps,
5189                **kwargs,
5190            ):
5191                return self.op(
5192                    input,
5193                    weight,
5194                    bias,
5195                    running_mean,
5196                    running_var,
5197                    training,
5198                    momentum,
5199                    eps,
5200                    **kwargs,
5201                )
5202
5203        input = torch.randn(5, 5, 5)
5204        weight = torch.randn(5)
5205        bias = torch.randn(5)
5206        running_mean = torch.randn(5)
5207        running_var = torch.randn(5)
5208        training = True
5209        momentum = 0.5
5210        eps = 0.6
5211
5212        model = SingleOp()
5213        output = model(
5214            input, weight, bias, running_mean, running_var, training, momentum, eps
5215        )
5216
5217        ep = torch.export.export(
5218            model,
5219            args=(
5220                input,
5221                weight,
5222                bias,
5223                running_mean,
5224                running_var,
5225                training,
5226                momentum,
5227                eps,
5228            ),
5229        )
5230        ep.run_decompositions(decomp_table=torch._decomp.decomposition_table)
5231        self.assertEqual(
5232            ep.module()(
5233                input, weight, bias, running_mean, running_var, training, momentum, eps
5234            ),
5235            output,
5236        )
5237
5238    def test_export_graph_with_no_inputs(self):
5239        # We saw this pattern when users want to export
5240        # a graph that initlizes the states of a model.
5241        class Module(torch.nn.Module):
5242            def forward(self):
5243                return torch.randn(3, 4), torch.randn(3, 4)
5244
5245        f = Module()
5246        ep = torch.export.export(f, ())
5247        a, b = ep.module()()
5248        self.assertEqual(a.size(), torch.Size([3, 4]))
5249        self.assertEqual(b.size(), torch.Size([3, 4]))
5250
5251        # Contains unbacked symint
5252        class M(torch.nn.Module):
5253            def forward(self):
5254                full = torch.full((), 11)
5255                i0 = full.item()
5256                return (torch.full((i0,), 0.0),)
5257
5258        f = M()
5259        ep = torch.export.export(f, ())
5260        a = ep.module()()[0]
5261        self.assertEqual(a.size(), torch.Size([11]))
5262        self.assertEqual(a, torch.zeros(11))
5263
5264    def test_pad_sequence(self):
5265        class Module(torch.nn.Module):
5266            def forward(self, x):
5267                return torch._C._nn.pad_sequence([x])
5268
5269        m0 = Module()
5270        inputs = (torch.randn(3, 2),)
5271        ep = torch.export.export(
5272            m0, inputs, dynamic_shapes={"x": {0: Dim("batch_size")}}
5273        )
5274        self.assertEqual(ep.module()(*inputs), m0(*inputs))
5275
5276        class ModuleBatchFirst(torch.nn.Module):
5277            def forward(self, x):
5278                return torch._C._nn.pad_sequence([x], batch_first=True)
5279
5280        m1 = ModuleBatchFirst()
5281        inputs = (torch.randn(3, 2),)
5282        ep = torch.export.export(
5283            m1, inputs, dynamic_shapes={"x": {0: Dim("batch_size")}}
5284        )
5285        self.assertEqual(ep.module()(*inputs), m1(*inputs))
5286
5287        class ModuleMulti(torch.nn.Module):
5288            def forward(self, x, y, z):
5289                return torch._C._nn.pad_sequence([x, y, z])
5290
5291        m2 = ModuleMulti()
5292        inputs = (torch.randn(5, 2), torch.randn(4, 2), torch.randn(3, 2))
5293        ep = torch.export.export(
5294            m2,
5295            inputs,
5296            dynamic_shapes={
5297                "x": {0: Dim("batch_size")},
5298                "y": {0: Dim("y")},
5299                "z": {0: Dim("z")},
5300            },
5301        )
5302        self.assertEqual(ep.module()(*inputs), m2(*inputs))
5303
5304        class ModuleMultiBatchFirst(torch.nn.Module):
5305            def forward(self, x, y, z):
5306                return torch._C._nn.pad_sequence([x, y, z], batch_first=True)
5307
5308        m3 = ModuleMulti()
5309        inputs = (torch.randn(5, 2), torch.randn(4, 2), torch.randn(3, 2))
5310        ep = torch.export.export(
5311            m2,
5312            inputs,
5313            dynamic_shapes={
5314                "x": {0: Dim("batch_size")},
5315                "y": {0: Dim("y")},
5316                "z": {0: Dim("z")},
5317            },
5318        )
5319        self.assertEqual(ep.module()(*inputs), m3(*inputs))
5320
5321    def test_export_then_compile_tensor_ctor(self):
5322        class M(torch.nn.Module):
5323            def forward(self, scores, mask):
5324                scores = scores.masked_fill(
5325                    mask, torch.tensor(torch.finfo(scores.dtype).min)
5326                )  # (bs, n_heads, q_length, k_length)
5327                return scores
5328
5329        tensor_cpu = torch.randn(2, 4)
5330        mask_cpu = torch.BoolTensor(
5331            [[False, True, False, False], [False, False, False, False]]
5332        )
5333
5334        m = M().eval()
5335        # res_ref = m(tensor_cpu, mask_cpu)
5336        # print("res_ref is: {}".format(res_ref), flush=True)
5337
5338        exported_model = _export(m, (tensor_cpu, mask_cpu), pre_dispatch=True).module()
5339        optimized_model = torch.compile(exported_model)
5340        optimized_model(tensor_cpu, mask_cpu)
5341
5342    def test_export_input_mutation_static_shape(self):
5343        class MutationModel(torch.nn.Module):
5344            def forward(self, x, y):
5345                x.view(3, 2, -1).add_(y)
5346                return x
5347
5348        inputs = (torch.randn(12), torch.tensor(2))
5349        model = MutationModel()
5350        ep = export(model, inputs)
5351        inputs_export = copy.deepcopy(inputs)
5352        inputs_model = copy.deepcopy(inputs)
5353        self.assertEqual(ep.module()(*inputs_export), model(*inputs_model))
5354        self.assertEqual(inputs[0] + torch.tensor(2), inputs_model[0])
5355        self.assertEqual(inputs[0] + torch.tensor(2), inputs_export[0])
5356
5357    def test_export_input_mutation_dynamic_shape(self):
5358        class MutationModel(torch.nn.Module):
5359            def forward(self, x, y):
5360                x[0].mul_(y)
5361                return x
5362
5363        inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0)
5364        model = MutationModel()
5365        ep = torch.export.export(
5366            model,
5367            inputs,
5368            dynamic_shapes={"x": ({0: torch.export.Dim("dim")}, None), "y": None},
5369        )
5370        nodes = list(ep.graph.nodes)
5371        self.assertEqual(nodes[0].op, "placeholder")
5372        self.assertIsInstance(nodes[0].meta["val"], torch.Tensor)
5373        self.assertIsInstance(nodes[0].meta["val"].shape[0], torch.SymInt)
5374
5375        inputs_export = copy.deepcopy(inputs)
5376        inputs_model = copy.deepcopy(inputs)
5377        self.assertEqual(ep.module()(*inputs_export), model(*inputs_model))
5378        self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0])
5379        self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0])
5380
5381    def test_export_input_mutation_bug(self):
5382        class M(torch.nn.Module):
5383            def forward(self, x):
5384                x[:, :2, :] = x[:, :2, :] + 1
5385                return x
5386
5387        inputs = (torch.ones(4, 4, 4),)
5388        ep = torch.export.export(M(), inputs)
5389        m = ep.module()
5390
5391        # Make the name conflict with a placeholder name that we get from
5392        # aot_export
5393        for i, node in enumerate(m.graph.nodes):
5394            if node.op == "placeholder":
5395                node.name = f"arg0_{i + 1}"
5396        m.recompile()
5397
5398        ep = torch.export.export(m, inputs)
5399
5400        inputs = (torch.randn(4, 4, 4),)
5401        self.assertEqual(
5402            ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs))
5403        )
5404
5405    def test__scaled_dot_product_flash_attention(self):
5406        class Module(torch.nn.Module):
5407            def forward(self, q, k, v):
5408                res = torch.nn.functional.scaled_dot_product_attention(q, k, v)
5409                return res[0]
5410
5411        m = Module()
5412        inputs = (
5413            torch.randn(5, 4, 3, 2),
5414            torch.randn(5, 4, 3, 2),
5415            torch.randn(5, 4, 3, 2),
5416        )
5417        ep = export(m, inputs)
5418        self.assertEqual(ep.module()(*inputs), m(*inputs))
5419
5420    @testing.expectedFailureSerDer  # symfloat nyi
5421    def test_sym_sqrt(self):
5422        import math
5423
5424        class M(torch.nn.Module):
5425            def forward(self, x):
5426                return x / torch.sym_sqrt(x.shape[0])
5427
5428        ep = export(M(), (torch.ones(16, 4),), dynamic_shapes={"x": {0: Dim("dim")}})
5429        _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
5430        FileCheck().check_count("torch._sym_sqrt", 1, exactly=True).run(
5431            ep.graph_module.code
5432        )
5433
5434    def test_check_specialized_int(self):
5435        class SingleOp(torch.nn.Module):
5436            def __init__(self) -> None:
5437                super().__init__()
5438                self.op = torch.ops.aten.scatter_add
5439
5440            def forward(self, t, dim, index, src, **kwargs):
5441                return self.op(t, dim, index, src, **kwargs)
5442
5443        t = torch.randn(10, 5)
5444        dim = -1
5445        index = torch.tensor(
5446            [
5447                [2, 4, 3, 1, 0],
5448                [0, 2, 1, 4, 3],
5449                [3, 1, 4, 2, 0],
5450                [4, 0, 3, 1, 2],
5451                [3, 0, 4, 1, 2],
5452            ]
5453        )
5454        src = torch.randn(5, 5)
5455
5456        model = SingleOp()
5457        output = model(t, dim, index, src)
5458
5459        ep = torch.export.export(model, args=(t, dim, index, src))
5460        ep.run_decompositions(decomp_table=torch._decomp.decomposition_table)
5461        self.assertEqual(ep.module()(t, dim, index, src), output)
5462
5463    def test_fqn(self):
5464        class NestedChild(torch.nn.Module):
5465            def forward(self, x):
5466                return x / x
5467
5468        class Child1(torch.nn.Module):
5469            def __init__(self) -> None:
5470                super().__init__()
5471                self.nested = NestedChild()
5472                self.register_parameter(
5473                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
5474                )
5475
5476            def forward(self, x):
5477                x = self.nested(x)
5478                return x + self.child1param
5479
5480        class Child2(torch.nn.Module):
5481            def __init__(self) -> None:
5482                super().__init__()
5483                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
5484
5485            def forward(self, x):
5486                return x - self.child2buffer
5487
5488        class MyModule(torch.nn.Module):
5489            def __init__(self) -> None:
5490                super().__init__()
5491                self.foo = Child1()
5492                self.bar = Child2()
5493                self.register_parameter(
5494                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
5495                )
5496
5497            def forward(self, x):
5498                x = x * self.rootparam
5499                x = self.foo(x)
5500                x = self.bar(x)
5501                return x
5502
5503        orig_eager = MyModule()
5504        test_inp = torch.randn(2, 3)
5505
5506        torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {})
5507        for k, v in orig_eager.state_dict().items():
5508            normalized_k = k.replace(".", "_")
5509            self.assertIn(normalized_k, torch_gm.state_dict())
5510            self.assertEqual(v, torch_gm.state_dict()[normalized_k])
5511        self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp)))
5512
5513        pre_autograd_gm = torch.export._trace._export(
5514            orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True
5515        ).module()
5516        for k, v in orig_eager.state_dict().items():
5517            normalized_k = k.replace(".", "_")
5518            self.assertIn(k, pre_autograd_gm.state_dict())
5519            self.assertEqual(v, pre_autograd_gm.state_dict()[k])
5520        self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp)))
5521
5522        ep = export(orig_eager, (torch.rand(2, 3),), {})
5523        for k, v in orig_eager.state_dict().items():
5524            # We do not need to normalize the key here because exported
5525            # program's state dict is able to contain the module information.
5526            self.assertIn(k, ep.state_dict)
5527            self.assertEqual(v, ep.state_dict[k])
5528        self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp)))
5529
5530    def test_nn_module_stack(self):
5531        class Leaf(torch.nn.Module):
5532            def __init__(self) -> None:
5533                super().__init__()
5534                self.linear = torch.nn.Linear(4, 4)
5535
5536            def forward(self, x):
5537                return self.linear(x)
5538
5539        class Bar(torch.nn.Module):
5540            def __init__(self) -> None:
5541                super().__init__()
5542                self.leaf = Leaf()
5543                self.buffer = torch.nn.Buffer(torch.randn(4, 4))
5544
5545            def forward(self, x):
5546                return self.buffer.sum() + self.leaf(x).sum()
5547
5548        class Foo(torch.nn.Module):
5549            def __init__(self) -> None:
5550                super().__init__()
5551                self.bar = Bar()
5552
5553            def forward(self, x):
5554                y = self.bar.buffer + x
5555                return (self.bar(x) + y.sum(),)
5556
5557        inp = (torch.randn(4, 4),)
5558        mod = Foo()
5559        ep_strict = torch.export.export(mod, inp).run_decompositions()
5560        ep_non_strict = torch.export.export(mod, inp, strict=False).run_decompositions()
5561
5562        gm_unflat_non_strict = unflatten(ep_non_strict)
5563        self.assertTrue(hasattr(gm_unflat_non_strict, "bar"))
5564        self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer"))
5565        self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf"))
5566
5567        gm_unflat_strict = unflatten(ep_strict)
5568
5569        self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp))
5570        self.assertExpectedInline(
5571            str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(),
5572            """\
5573graph():
5574    %x : [num_users=1] = placeholder[target=x]
5575    %weight : [num_users=1] = get_attr[target=weight]
5576    %bias : [num_users=1] = get_attr[target=bias]
5577    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%weight, [1, 0]), kwargs = {})
5578    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %permute), kwargs = {})
5579    return addmm""",
5580        )
5581
5582        gm_flat_non_strict = ep_non_strict.module()
5583        gm_flat_strict = ep_strict.module()
5584
5585        self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))
5586
5587    def test_nn_module_stack_shared_submodule(self):
5588        class Leaf(torch.nn.Module):
5589            def __init__(self) -> None:
5590                super().__init__()
5591                self.linear = torch.nn.Linear(4, 4)
5592
5593            def forward(self, x):
5594                return self.linear(x)
5595
5596        class Bar(torch.nn.Module):
5597            def __init__(self) -> None:
5598                super().__init__()
5599                self.leaf = Leaf()
5600                self.buffer = torch.nn.Buffer(torch.randn(4, 4))
5601
5602            def forward(self, x):
5603                return self.buffer.sum() + self.leaf(x).sum()
5604
5605        class BarDifferent(torch.nn.Module):
5606            def __init__(self) -> None:
5607                super().__init__()
5608                self.leaf = Leaf()
5609
5610            def forward(self, x):
5611                a = self.leaf(x).sum()
5612                b = self.leaf(x).sum()
5613                return a + b
5614
5615        class Foo(torch.nn.Module):
5616            def __init__(self) -> None:
5617                super().__init__()
5618                self.bar = Bar()
5619                self.bar_different = BarDifferent()
5620
5621            def forward(self, x):
5622                y = self.bar.buffer + x
5623                return (
5624                    self.bar(x) + self.bar_different(x + 2),
5625                    y.sum(),
5626                )
5627
5628        inp = (torch.randn(4, 4),)
5629        mod = Foo()
5630        ep_strict = torch.export.export(mod, inp)
5631        ep_non_strict = torch.export.export(mod, inp, strict=False)
5632
5633        gm_unflat_non_strict = unflatten(ep_non_strict)
5634        self.assertTrue(hasattr(gm_unflat_non_strict, "bar"))
5635        self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer"))
5636        self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf"))
5637        self.assertTrue(hasattr(gm_unflat_non_strict.bar_different, "leaf"))
5638
5639        gm_unflat_strict = unflatten(ep_strict)
5640
5641        self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp))
5642        self.assertExpectedInline(
5643            str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(),
5644            """\
5645graph():
5646    %x : [num_users=1] = placeholder[target=x]
5647    %weight : [num_users=1] = get_attr[target=weight]
5648    %bias : [num_users=1] = get_attr[target=bias]
5649    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %weight, %bias), kwargs = {})
5650    return linear""",
5651        )
5652        self.assertExpectedInline(
5653            str(gm_unflat_non_strict.bar_different.leaf.linear.graph).strip(),
5654            """\
5655graph():
5656    %add_2 : [num_users=1] = placeholder[target=add_2]
5657    %weight : [num_users=1] = get_attr[target=weight]
5658    %bias : [num_users=1] = get_attr[target=bias]
5659    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%add_2, %weight, %bias), kwargs = {})
5660    return linear_1""",
5661        )
5662
5663        gm_flat_non_strict = ep_non_strict.module()
5664        gm_flat_strict = ep_strict.module()
5665
5666        self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))
5667
5668    def test_stack_trace(self):
5669        class Foo(torch.nn.Module):
5670            def __init__(self) -> None:
5671                super().__init__()
5672                self.linear = torch.nn.Linear(4, 4)
5673
5674            def forward(self, x):
5675                x = self.linear(x)
5676                x *= 2.0
5677                return x
5678
5679        ep = export(
5680            Foo(),
5681            (torch.randn(4, 4),),
5682        )
5683        # check correct lines are in stack trace
5684        trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get(
5685            "stack_trace", ""
5686        )
5687        self.assertTrue(
5688            re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul)
5689        )
5690        trace_addmm = [
5691            node for node in ep.graph.nodes if node.name in ["addmm", "linear"]
5692        ][0].meta.get("stack_trace", "")
5693        self.assertTrue(
5694            re.search(
5695                r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm
5696            )
5697        )
5698
5699    def test_cond_with_module_stack_export_with(self):
5700        class Bar(torch.nn.Module):
5701            def __init__(self) -> None:
5702                super().__init__()
5703                self.linear = torch.nn.Linear(4, 4)
5704
5705            def forward(self, x):
5706                def true_fn(x):
5707                    return self.linear(x).cos()
5708
5709                def false_fn(x):
5710                    return self.linear(x).sin()
5711
5712                return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
5713
5714        class CondExport(torch.nn.Module):
5715            def __init__(self) -> None:
5716                super().__init__()
5717                self.bar = Bar()
5718
5719            def forward(self, x):
5720                return x.cos() + self.bar(x)
5721
5722        inp = (torch.randn(4, 4),)
5723        ep = torch.export.export(CondExport(), inp, strict=False)
5724        self.assertExpectedInline(
5725            ep.graph_module.code.strip(),
5726            """\
5727def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
5728    cos = torch.ops.aten.cos.default(x)
5729    sum_1 = torch.ops.aten.sum.default(x)
5730    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
5731    true_graph_0 = self.true_graph_0
5732    false_graph_0 = self.false_graph_0
5733    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]);  gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
5734    getitem = cond[0];  cond = None
5735    add = torch.ops.aten.add.Tensor(cos, getitem);  cos = getitem = None
5736    return (add,)""",
5737        )
5738        schema = get_hop_schema(ep)
5739        self.assertExpectedInline(
5740            str(schema),
5741            """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
5742        )
5743
5744        cond_top_level_nn_module_stack = [
5745            node.meta["nn_module_stack"]
5746            for node in ep.graph.nodes
5747            if node.name == "true_graph_0"
5748        ][0]
5749
5750        self.assertTrue(
5751            "test_cond_with_module_stack_export_with.<locals>.Bar"
5752            in str(cond_top_level_nn_module_stack)
5753        )
5754
5755    # TODO: See https://github.com/pytorch/pytorch/issues/115790
5756    @unittest.expectedFailure
5757    def test_cond_with_module_stack_export_with_unflatten(self):
5758        class Bar(torch.nn.Module):
5759            def __init__(self) -> None:
5760                super().__init__()
5761                self.linear = torch.nn.Linear(4, 4)
5762
5763            def forward(self, x):
5764                def true_fn(x):
5765                    return self.linear(x).cos()
5766
5767                def false_fn(x):
5768                    return self.linear(x).sin()
5769
5770                return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
5771
5772        class CondExport(torch.nn.Module):
5773            def __init__(self) -> None:
5774                super().__init__()
5775                self.bar = Bar()
5776
5777            def forward(self, x):
5778                return x.cos() + self.bar(x)
5779
5780        inp = (torch.randn(4, 4),)
5781        ep = torch.export.export(CondExport(), inp, strict=False)
5782
5783        cond_top_level_nn_module_stack = [
5784            node.meta["nn_module_stack"]
5785            for node in ep.graph.nodes
5786            if node.name == "true_graph_0"
5787        ][0]
5788
5789        # we can't preserve nn_module_stack for the subgraphs for now.
5790        for node in ep.graph_module.true_graph_0.graph.nodes:
5791            self.assertEqual(
5792                node.meta["nn_module_stack"], cond_top_level_nn_module_stack
5793            )
5794
5795        # this doesn't work today
5796        gm_unflat_strict = unflatten(ep)
5797
5798    def test_predispatch_cond(self):
5799        class Model(torch.nn.Module):
5800            def __init__(self) -> None:
5801                super().__init__()
5802                self.pred = torch.nn.Buffer(torch.tensor(False))
5803                self.t = torch.nn.Buffer(torch.tensor(10))
5804
5805            def forward(self, x, y):
5806                def true_fn(x, y):
5807                    with torch.enable_grad():
5808                        return x - 1 + self.t + y
5809
5810                return torch.cond(
5811                    self.pred,
5812                    true_fn,
5813                    lambda x, y: x + 1 - self.t + y,
5814                    [x, y],
5815                )
5816
5817        model = Model()
5818        with torch.no_grad():
5819            exported_program = torch.export._trace._export(
5820                model,
5821                (torch.tensor(10), torch.tensor(12)),
5822                {},
5823                dynamic_shapes=None,
5824                pre_dispatch=True,
5825                strict=False,
5826            )
5827
5828        schema = get_hop_schema(exported_program)
5829        self.assertExpectedInline(
5830            str(schema),
5831            """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",  # noqa: B950
5832        )
5833
5834        self.assertExpectedInline(
5835            str(exported_program.graph_module.code.strip()),
5836            """\
5837def forward(self, b_pred, b_t, x, y):
5838    true_graph_0 = self.true_graph_0
5839    false_graph_0 = self.false_graph_0
5840    cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]);  b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
5841    getitem = cond[0];  cond = None
5842    return (getitem,)""",
5843        )  # noqa: B950
5844
5845        self.assertExpectedInline(
5846            str(exported_program.graph_module.true_graph_0.code.strip()),
5847            """\
5848def forward(self, b_t, x, y):
5849    submod_3 = self.submod_1
5850    add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y);  submod_3 = x = b_t = y = None
5851    getitem = add_1[0];  add_1 = None
5852    return (getitem,)""",
5853        )
5854
5855        self.assertExpectedInline(
5856            str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
5857            """\
5858def forward(self, x, b_t, y):
5859    sub = torch.ops.aten.sub.Tensor(x, 1);  x = None
5860    add = torch.ops.aten.add.Tensor(sub, b_t);  sub = b_t = None
5861    add_1 = torch.ops.aten.add.Tensor(add, y);  add = y = None
5862    return (add_1,)""",
5863        )
5864
5865    def test_predispatch_grad_wrappers(self):
5866        class Model(torch.nn.Module):
5867            def forward(self, x, y):
5868                with torch.enable_grad():
5869                    x = x - y
5870                with torch.no_grad():
5871                    x = x + y
5872                return x
5873
5874        # no grad
5875        model = Model()
5876        with torch.no_grad():
5877            ep_nograd = torch.export._trace._export(
5878                model,
5879                (torch.tensor(10), torch.tensor(12)),
5880                {},
5881                dynamic_shapes=None,
5882                pre_dispatch=True,
5883                strict=False,
5884            )
5885        # check that only sub op is wrapped with grad_enabled
5886        getattr_nodes = [
5887            node for node in ep_nograd.graph.nodes if node.op == "get_attr"
5888        ]
5889        self.assertEqual(len(getattr_nodes), 1)
5890        grad_subgraph = getattr(ep_nograd.graph_module, getattr_nodes[0].target)
5891        op_node = [
5892            node for node in grad_subgraph.graph.nodes if node.op == "call_function"
5893        ][0]
5894        self.assertEqual(op_node.target._name, "aten::sub.Tensor")
5895
5896        # enable grad
5897        model = Model()
5898        ep_grad = torch.export._trace._export(
5899            model,
5900            (torch.tensor(10), torch.tensor(12)),
5901            {},
5902            dynamic_shapes=None,
5903            pre_dispatch=True,
5904            strict=False,
5905        )
5906        # check that only add op is wrapped with grad_enabled
5907        getattr_nodes = [node for node in ep_grad.graph.nodes if node.op == "get_attr"]
5908        self.assertEqual(len(getattr_nodes), 1)
5909        grad_subgraph = getattr(ep_grad.graph_module, getattr_nodes[0].target)
5910        op_node = [
5911            node for node in grad_subgraph.graph.nodes if node.op == "call_function"
5912        ][0]
5913        self.assertEqual(op_node.target._name, "aten::add.Tensor")
5914
5915    @testing.expectedFailureRetraceability
5916    def test_layer_sharing(self):
5917        N, C, H, W = 1, 2, 2, 3
5918
5919        class Module(torch.nn.Module):
5920            def __init__(self) -> None:
5921                super().__init__()
5922                layer = torch.nn.LayerNorm([C, H, W])
5923                self.norms = torch.nn.ModuleList(
5924                    [
5925                        layer,
5926                        layer,
5927                    ]
5928                )
5929
5930            def forward(self, x):
5931                for norm in self.norms:
5932                    x = norm(x)
5933                return x
5934
5935        m = Module()
5936        copied_m = copy.deepcopy(m)
5937        ep = export(copied_m, (torch.randn(N, C, H, W),))
5938        self.assertEqual(copied_m.state_dict(), m.state_dict())
5939        self.assertEqual(ep.state_dict, m.state_dict())
5940
5941    def test_non_persistent_buffer(self):
5942        class MyModule(torch.nn.Module):
5943            def __init__(self) -> None:
5944                super().__init__()
5945                self.foo = torch.nn.Buffer(torch.rand(2, 3), persistent=False)
5946
5947            def forward(self, x):
5948                return self.foo + x
5949
5950        class MyOuterModule(torch.nn.Module):
5951            def __init__(self) -> None:
5952                super().__init__()
5953                self.inner = MyModule()
5954
5955            def forward(self, x):
5956                return self.inner(x)
5957
5958        inp = torch.rand(2, 3)
5959
5960        def _test(m, non_persistent_buffer):
5961            ep = export(m, (inp,), {})
5962
5963            self.assertEqual(ep.module()(inp), m(inp))
5964            # Non-persistent buffers should not show up in the state dict
5965            self.assertNotIn(non_persistent_buffer, ep.state_dict)
5966            named_buffers = {name: buffer for (name, buffer) in ep.named_buffers()}
5967            # But they should show up in named_buffers()
5968            self.assertIn(non_persistent_buffer, named_buffers)
5969            self.assertIn(non_persistent_buffer, ep.constants)
5970            self.assertEqual(len(ep.constants), 1)
5971
5972            # Check the same properties of the unlifted module
5973            mod = ep.module()
5974            self.assertNotIn(non_persistent_buffer, mod.state_dict())
5975            mod_named_buffers = {name: buffer for (name, buffer) in mod.named_buffers()}
5976            self.assertIn(non_persistent_buffer, mod_named_buffers)
5977            self.assertIn(non_persistent_buffer, ep.constants)
5978            self.assertEqual(len(ep.constants), 1)
5979            self.assertEqual(mod(inp), m(inp))
5980
5981        _test(MyModule(), "foo")
5982        _test(MyOuterModule(), "inner.foo")
5983
5984    def test_export_with_set_grad_enabled(self):
5985        class Model(torch.nn.Module):
5986            def __init__(self) -> None:
5987                super().__init__()
5988                self.linear = torch.nn.Linear(4, 4)
5989
5990            def forward(self, x):
5991                with torch.no_grad():
5992                    return self.linear(x)
5993
5994        model = Model()
5995        ep = export(model, (torch.randn(4, 4),), {})
5996        # _export_for_traininig is using pre_dispatch=False
5997        # Therefore the set_grad calls are not replaced with a hop.
5998        if not is_training_ir_test(self._testMethodName):
5999            self.assertIn(
6000                "torch.ops.higher_order.wrap_with_set_grad_enabled",
6001                ep.graph_module.code,
6002            )
6003
6004    def test_export_as_backend(self):
6005        def f(x, y):
6006            return x + y
6007
6008        def my_custom_backend(gm, example_inputs):
6009            gm = (
6010                torch.export.export(gm, tuple(example_inputs), strict=False)
6011                .run_decompositions()
6012                .module()
6013            )
6014            return gm
6015
6016        inp = (torch.randn(3, 3), torch.randn(3, 3))
6017        new_res = torch.compile(f, backend=my_custom_backend)(*inp)
6018        self.assertTrue(torch.allclose(f(*inp), new_res))
6019
6020    def test_nonstrict_retrace_preserves_metadata(self):
6021        class MyModule(torch.nn.Module):
6022            def __init__(self) -> None:
6023                super().__init__()
6024                self.linear = torch.nn.Linear(4, 4)
6025
6026            def forward(self, x):
6027                return self.linear(x)
6028
6029        inp = torch.randn(4, 4)
6030        m = MyModule()
6031        ep = torch.export.export(m, (inp,), {}, strict=False)
6032        # retrace
6033        ep2 = torch.export.export(ep.module(), (inp,), {}, strict=False)
6034
6035        for n1, n2 in zip(list(ep.graph.nodes), list(ep2.graph.nodes)):
6036            self.assertEqual(n1.meta.get("stack_trace"), n2.meta.get("stack_trace"))
6037
6038    def test_fake_weights(self):
6039        class MyModule(torch.nn.Module):
6040            def __init__(self) -> None:
6041                super().__init__()
6042                self.foo = torch.nn.Parameter(torch.randn(4, 4))
6043                self.bar = torch.nn.Buffer(torch.randn(4, 4), persistent=False)
6044                self.baz = torch.nn.Buffer(torch.randn(4, 4), persistent=True)
6045
6046            def forward(self, x):
6047                return self.foo + x + self.bar + self.baz
6048
6049        fake_mode = torch._subclasses.FakeTensorMode(
6050            shape_env=ShapeEnv(tracked_fakes=[])
6051        )
6052        with fake_mode:
6053            m = MyModule()
6054        inp = torch.randn(4, 4)
6055        ep = export(m, (inp,))
6056        # Can't compare outputs because the module has fake weights.
6057
6058    def test_fake_inputs(self):
6059        class MyModule(torch.nn.Module):
6060            def __init__(self) -> None:
6061                super().__init__()
6062                self.foo = torch.nn.Parameter(torch.randn(4, 4))
6063
6064            def forward(self, x):
6065                return self.foo + x
6066
6067        fake_mode = torch._subclasses.FakeTensorMode(
6068            shape_env=ShapeEnv(tracked_fakes=[])
6069        )
6070        m = MyModule()
6071        with fake_mode:
6072            inp = torch.randn(4, 4)
6073
6074        ep = export(m, (inp,))
6075        self.assertEqual(ep.module()(torch.ones(4, 4)), m(torch.ones(4, 4)))
6076
6077    def test_trace_under_fake(self):
6078        class MyModule(torch.nn.Module):
6079            def __init__(self) -> None:
6080                super().__init__()
6081                self.foo = torch.nn.Parameter(torch.randn(4, 4))
6082
6083            def forward(self, x):
6084                return self.foo + x
6085
6086        fake_mode = torch._subclasses.FakeTensorMode(
6087            shape_env=ShapeEnv(tracked_fakes=[])
6088        )
6089        with fake_mode:
6090            m = MyModule()
6091            inp = torch.randn(4, 4)
6092            # Can't use unqualified export() as it will attempt to deserialize
6093            # under a new FakeTensorMode.
6094            ep = torch.export.export(m, (inp,))
6095
6096    def test_compiling_state(self):
6097        class TestModule1(torch.nn.Module):
6098            def forward(self, x):
6099                if torch._dynamo.is_compiling():
6100                    return x * 2
6101                else:
6102                    return x * 3
6103
6104        class TestModule2(torch.nn.Module):
6105            def forward(self, x):
6106                if torch._utils.is_compiling():
6107                    return x * 2
6108                else:
6109                    return x * 3
6110
6111        class TestModule3(torch.nn.Module):
6112            def forward(self, x):
6113                if torch.compiler.is_compiling():
6114                    return x * 2
6115                else:
6116                    return x * 3
6117
6118        for m in [TestModule1(), TestModule2(), TestModule3()]:
6119            input = torch.randn(5)
6120            ep_strict = export(m, (input,), strict=True)
6121            ep_non_strict = export(m, (input,), strict=False)
6122
6123            self.assertTrue(torch.allclose(input * 3, m(input)))
6124            self.assertTrue(torch.allclose(input * 2, ep_strict.module()(input)))
6125            self.assertTrue(torch.allclose(input * 2, ep_non_strict.module()(input)))
6126
6127    def test_user_input_and_buffer_mutation(self):
6128        class MyModule(torch.nn.Module):
6129            def __init__(self) -> None:
6130                super().__init__()
6131                self.foo = torch.nn.Buffer(torch.randn(4, 4))
6132
6133            def forward(self, x):
6134                self.foo.add_(1)
6135                x.add_(1)
6136                return self.foo + x
6137
6138        mod = MyModule()
6139        mod_copy = copy.deepcopy(mod)
6140        ep = export(mod_copy, (torch.rand(4, 4),))
6141
6142        self.assertEqual(mod.foo, ep.module().foo)
6143        self.assertEqual(mod(torch.ones(4, 4)), ep.module()(torch.ones(4, 4)))
6144
6145    def test_symint_tensor_return(self):
6146        class Module(torch.nn.Module):
6147            def forward(self, x):
6148                return torch.ops.testlib.returns_tensor_symint(x)[0]
6149
6150        self._test_export_same_as_eager(Module(), (torch.randn(4, 4),))
6151
6152    def test_custom_op_auto_functionalize(self):
6153        class M(torch.nn.Module):
6154            def __init__(self) -> None:
6155                super().__init__()
6156
6157            def forward(self, x, z):
6158                return torch.ops.testlib.foo(x, z)
6159
6160        inps = (torch.ones(5), torch.ones(5))
6161        inps_for_export = (torch.ones(5), torch.ones(5))
6162        inps_for_export_with_decomp = (torch.ones(5), torch.ones(5))
6163
6164        ep = torch.export.export(M(), inps_for_export)
6165        x_new_eager, z_new_eager, legit_eager = M()(*inps)
6166        x_new_export, z_new_export, legit_export = ep.module()(*inps_for_export)
6167        self.assertTrue(torch.allclose(x_new_eager, x_new_export))
6168        self.assertTrue(torch.allclose(z_new_eager, z_new_export))
6169        self.assertTrue(torch.allclose(legit_eager, legit_export))
6170
6171        ep = ep.run_decompositions()
6172        x_new_export, z_new_export, legit_export = ep.module()(
6173            *inps_for_export_with_decomp
6174        )
6175        self.assertTrue(torch.allclose(x_new_eager, x_new_export))
6176        self.assertTrue(torch.allclose(z_new_eager, z_new_export))
6177        self.assertTrue(torch.allclose(legit_eager, legit_export))
6178
6179    def test_custom_op_auto_functionalize_pre_dispatch(self):
6180        class M(torch.nn.Module):
6181            def __init__(self) -> None:
6182                super().__init__()
6183
6184            def forward(self, x):
6185                return torch.ops.testlib.foo_mutated(x)
6186
6187        inps = (torch.ones(5),)
6188
6189        ep = torch.export.export(M(), inps)
6190        self.assertExpectedInline(
6191            str(ep.graph_module.code.strip()),
6192            """\
6193def forward(self, x):
6194    cos = torch.ops.aten.cos.default(x)
6195    auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos);  x = cos = None
6196    getitem_3 = auto_functionalized[3];  auto_functionalized = None
6197    cos_1 = torch.ops.aten.cos.default(getitem_3)
6198    return (getitem_3, getitem_3, cos_1)""",
6199        )
6200
6201        ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
6202        self.assertExpectedInline(
6203            str(ep.graph_module.code.strip()),
6204            """\
6205def forward(self, x):
6206    cos = torch.ops.aten.cos.default(x)
6207    auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos);  x = cos = None
6208    getitem_3 = auto_functionalized[3];  auto_functionalized = None
6209    cos_1 = torch.ops.aten.cos.default(getitem_3)
6210    return (getitem_3, getitem_3, cos_1)""",
6211        )
6212
6213    def test_custom_op_auto_warn_pre_dispatch(self):
6214        class M(torch.nn.Module):
6215            def __init__(self) -> None:
6216                super().__init__()
6217
6218            def forward(self, x):
6219                return torch.ops.testlib.foo_functional(x)
6220
6221        inps = (torch.ones(5),)
6222
6223        ep = torch.export.export(M(), inps).run_decompositions()
6224        self.assertExpectedInline(
6225            str(ep.graph_module.code.strip()),
6226            """\
6227def forward(self, x):
6228    cos = torch.ops.aten.cos.default(x)
6229    cos_1 = torch.ops.aten.cos.default(x);  x = None
6230    auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1);  cos = cos_1 = None
6231    getitem_3 = auto_functionalized[3];  auto_functionalized = None
6232    cos_2 = torch.ops.aten.cos.default(getitem_3);  getitem_3 = None
6233    return (cos_2,)""",
6234        )
6235
6236        ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
6237        self.assertExpectedInline(
6238            str(ep.graph_module.code.strip()),
6239            """\
6240def forward(self, x):
6241    foo_functional = torch.ops.testlib.foo_functional.default(x);  x = None
6242    return (foo_functional,)""",
6243        )
6244
6245    def test_placeholder_naming_collisions(self):
6246        # test collisions between nested user inputs
6247        class Foo(torch.nn.Module):
6248            def forward(self, x, x_foo, x_foo_0):
6249                return x["foo"][0] + x_foo[0] + x_foo_0
6250
6251        inputs = (
6252            {"foo": [torch.randn(4, 4)]},
6253            (torch.randn(4, 4),),
6254            torch.randn(4, 4),
6255        )
6256        ep = export(Foo(), inputs)
6257        expected_names = ["x_foo_0", "x_foo_0_1", "x_foo_0_2"]
6258        real_names = [spec.arg.name for spec in ep.graph_signature.input_specs]
6259        self.assertEqual(expected_names, real_names)
6260
6261        # test collisions between user inputs and params, buffers, constants
6262        class Foo(torch.nn.Module):
6263            def __init__(self) -> None:
6264                super().__init__()
6265                self.param = torch.nn.Parameter(torch.randn(4))
6266                self.alpha = torch.nn.Buffer(torch.randn(4), persistent=True)
6267                self.beta = torch.nn.Buffer(torch.randn(4), persistent=False)
6268                self.gamma = torch.randn(4)
6269
6270            def forward(self, p, b_alpha, b, c_gamma):
6271                p = p["param"] + self.param
6272                b = self.alpha + self.beta + b_alpha + b["beta"]
6273                c = self.gamma + c_gamma
6274                return p, b, c
6275
6276        inputs = (
6277            {"param": torch.randn(4)},
6278            torch.randn(4),
6279            {"beta": torch.randn(4)},
6280            torch.randn(4),
6281        )
6282        ep = export(Foo(), inputs)
6283        expected_names = [  # user inputs should be prioritized, unprefixed
6284            ("p_param_1", InputKind.PARAMETER),
6285            ("b_alpha_1", InputKind.BUFFER),
6286            ("b_beta_1", InputKind.BUFFER),
6287            ("c_gamma_1", InputKind.CONSTANT_TENSOR),
6288            ("p_param", InputKind.USER_INPUT),
6289            ("b_alpha", InputKind.USER_INPUT),
6290            ("b_beta", InputKind.USER_INPUT),
6291            ("c_gamma", InputKind.USER_INPUT),
6292        ]
6293        real_names = [
6294            (spec.arg.name, spec.kind) for spec in ep.graph_signature.input_specs
6295        ]
6296        self.assertEqual(expected_names, real_names)
6297
6298        # test collisions between user inputs & call_function nodes
6299        class Foo(torch.nn.Module):
6300            def forward(self, mul, add, add_1):
6301                return mul * mul + add * add_1
6302
6303        ep = export(Foo(), (torch.randn(4, 4), torch.randn(4, 4), torch.randn(4, 4)))
6304        expected_names_and_ops = [
6305            ("mul", "placeholder"),
6306            ("add", "placeholder"),
6307            ("add_1", "placeholder"),
6308            ("mul_1", "call_function"),
6309            ("mul_2", "call_function"),
6310            ("add_2", "call_function"),
6311            ("output", "output"),
6312        ]
6313        real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
6314        self.assertEqual(expected_names_and_ops, real_names_and_ops)
6315
6316    def test_placeholder_naming_collisions_hoo_subgraphs(self):
6317        # test collisions between user inputs, top-level nodes, and HOO subgraph nodes
6318        class Foo(torch.nn.Module):
6319            def forward(self, x, mul, mul_1):
6320                _mul = x * x
6321                y = cond(
6322                    _mul.sum() > 0,
6323                    lambda x, y, z: x * y * z,
6324                    lambda x, y, z: x + y + z,
6325                    [_mul, mul, mul_1],
6326                )
6327                with torch.enable_grad():
6328                    y = y * y
6329                return y
6330
6331        with torch.no_grad():
6332            ep = torch.export._trace._export(
6333                Foo(),
6334                (torch.randn(4), torch.randn(4), torch.randn(4)),
6335                pre_dispatch=True,
6336            )
6337
6338        schema = get_hop_schema(ep)
6339        self.assertExpectedInline(
6340            str(schema),
6341            """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""",
6342        )
6343        # test cond subgraph
6344        expected_names_and_ops = [
6345            ("mul_2", "placeholder"),
6346            ("mul", "placeholder"),
6347            ("mul_1", "placeholder"),
6348            ("mul_3", "call_function"),
6349            ("mul_4", "call_function"),
6350            ("output", "output"),
6351        ]
6352        real_names_and_ops = [
6353            (node.name, node.op) for node in ep.graph_module.true_graph_0.graph.nodes
6354        ]
6355        self.assertEqual(expected_names_and_ops, real_names_and_ops)
6356        # test set_grad_enabled subgraph
6357        expected_names_and_ops = [
6358            ("getitem", "placeholder"),
6359            ("mul_1", "call_function"),
6360            ("output", "output"),
6361        ]
6362        real_names_and_ops = [
6363            (node.name, node.op) for node in ep.graph_module.submod_1.graph.nodes
6364        ]
6365        self.assertEqual(expected_names_and_ops, real_names_and_ops)
6366
6367        # test collisions between user inputs & higher order op subgraphs
6368        # (please never do this)
6369        class Foo(torch.nn.Module):
6370            def forward(self, input, true_graph, body_graph):
6371                x = input + true_graph[0] + true_graph[1]
6372                x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x])
6373                x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x])
6374                return x
6375
6376        inputs = (
6377            torch.randn(10, 4),
6378            (torch.randn(4), torch.randn(4)),
6379            (torch.randn(4),),
6380        )
6381        ep = export(Foo(), inputs)
6382        expected_getattr_names = [
6383            "true_graph_2",
6384            "false_graph_0",
6385            "true_graph_3",
6386            "false_graph_1",
6387        ]
6388        real_getattr_names = [
6389            node.name for node in ep.graph.nodes if node.op == "get_attr"
6390        ]
6391        self.assertEqual(expected_getattr_names, real_getattr_names)
6392
6393    def test_constant_input_naming(self):
6394        class Foo(torch.nn.Module):
6395            def forward(self, x, y, div="floor"):
6396                return torch.div(x, y, rounding_mode=div)
6397
6398        f = Foo()
6399        inputs = (torch.randn(4), torch.randn(4), "floor")
6400        ep = export(f, inputs)
6401        div_spec = ep.graph_signature.input_specs[2]
6402        self.assertEqual(div_spec.arg.name, "div")
6403        self.assertEqual(div_spec.arg.value, "floor")
6404
6405    def test_unbacked_deferred_runtime_retrace(self):
6406        class Foo(torch.nn.Module):
6407            def forward(self, x, y):
6408                y_sum = y.sin().sum()
6409                with torch.no_grad():
6410                    a = x.item()
6411                    torch._check_is_size(a)
6412                    torch._check(a > 2)
6413                    torch._check(a < 6)
6414                    unbacked_shape = torch.ops.testlib.foo_unbacked(a)
6415                return y + y_sum + unbacked_shape.sum()
6416
6417        inps = (torch.tensor(4), torch.randn(5, 5))
6418        from torch.export import _trace
6419
6420        ep_pre = _trace._export(Foo(), inps, pre_dispatch=True, strict=False)
6421        self.assertExpectedInline(
6422            str(ep_pre.graph_module.submod_1.code).strip(),
6423            """\
6424def forward(self, x):
6425    item = torch.ops.aten.item.default(x);  x = None
6426    sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None
6427    ge_1 = item >= 3
6428    _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
6429    le = item <= 5
6430    _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'");  le = _assert_scalar_default_1 = None
6431    gt_1 = item > 2
6432    _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'");  gt_1 = _assert_scalar_default_2 = None
6433    lt_1 = item < 6
6434    _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'");  lt_1 = _assert_scalar_default_3 = None
6435    foo_unbacked = torch.ops.testlib.foo_unbacked.default(item);  item = None
6436    return (foo_unbacked,)""",
6437        )
6438        ep_aot = ep_pre.run_decompositions()
6439        self.assertExpectedInline(
6440            str(ep_aot.graph_module.code).strip(),
6441            """\
6442def forward(self, x, y):
6443    sin = torch.ops.aten.sin.default(y)
6444    sum_1 = torch.ops.aten.sum.dim_IntList(sin, []);  sin = None
6445    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x);  x = None
6446    sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense);  sym_constrain_range_for_size_default = None
6447    ge_1 = _local_scalar_dense >= 3
6448    _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 3 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
6449    le_1 = _local_scalar_dense <= 5;  _local_scalar_dense = None
6450    _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None
6451    full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
6452    add = torch.ops.aten.add.Tensor(y, sum_1);  y = sum_1 = None
6453    sum_2 = torch.ops.aten.sum.dim_IntList(full, []);  full = None
6454    add_1 = torch.ops.aten.add.Tensor(add, sum_2);  add = sum_2 = None
6455    return (add_1,)""",
6456        )
6457
6458    def test_nested_dynamic_shapes_spec(self):
6459        class Foo(torch.nn.Module):
6460            def forward(self, x):
6461                (a0, a1), (b0, b1), (c0, c1, c2) = x
6462                return a0 + a1 + b0 + b1 + c0 + c1 + c2
6463
6464        f = Foo()
6465        inputs = (
6466            (1, 2),
6467            (
6468                torch.randn(4, 4),
6469                torch.randn(4, 4),
6470            ),
6471            (
6472                torch.randn(4, 4),
6473                torch.randn(4, 4),
6474                torch.randn(4, 4),
6475            ),
6476        )
6477        # make sure this gets parsed correctly as 7 individual inputs, not 3 tensors
6478        dynamic_shapes = {
6479            "x": (
6480                (None, None),
6481                (None, None),
6482                (None, None, None),
6483            )
6484        }
6485        export(f, (inputs,), dynamic_shapes=dynamic_shapes)
6486
6487    def test_disable_forced_specializations_ok(self):
6488        # check that we don't force specialization, and defer to runtime asserts
6489        # with allow_complex_guards_as_runtime_asserts=True to successfully export
6490        # case 1: modulo guards
6491        from torch.export import dims
6492
6493        class Mod4Reshape(torch.nn.Module):
6494            def forward(self, x):
6495                return x.reshape(x.shape[0] - 1, 4, -1)  # Mod(s0*s1, 4*(s0-1)) = 0
6496
6497        inputs = (torch.randn(10, 72),)
6498        dx, dy = dims("dx", "dy")
6499        ep = torch.export._trace._export(
6500            Mod4Reshape(),
6501            inputs,
6502            dynamic_shapes={"x": (dx, dy)},
6503            allow_complex_guards_as_runtime_asserts=True,
6504        )
6505        out1 = ep.module()(torch.randn(8, 7))
6506        self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
6507        out2 = ep.module()(torch.randn(12, 11))
6508        self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
6509        with self.assertRaisesRegex(
6510            RuntimeError,
6511            r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'",
6512        ):
6513            ep.module()(torch.randn(8, 8))  # fail
6514
6515        # case 2: 2d reshape
6516        class FreeReshape(torch.nn.Module):
6517            def forward(self, x, y, z):
6518                return x.reshape([-1]) + y.reshape([-1]) + z  # s0*s1 = s2*s3 = s4
6519
6520        inputs = (
6521            torch.randn(6, 8),
6522            torch.randn(3, 16),
6523            torch.randn(48),
6524        )
6525        dynamic_shapes = {
6526            "x": [Dim(f"dx{i}", min=2) for i in range(2)],
6527            "y": [Dim(f"dy{i}", min=2) for i in range(2)],
6528            "z": [Dim(f"dz{i}", min=4) for i in range(1)],
6529        }
6530        ep = torch.export._trace._export(
6531            FreeReshape(),
6532            inputs,
6533            dynamic_shapes=dynamic_shapes,
6534            allow_complex_guards_as_runtime_asserts=True,
6535        )
6536        ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
6537        out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
6538        self.assertEqual(out1.shape, torch.ones(48).shape)
6539        out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
6540        self.assertEqual(out2.shape, torch.ones(40).shape)
6541        with self.assertRaisesRegex(
6542            RuntimeError,
6543            r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'",
6544        ):  # fail only at runtime
6545            ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30))  # fail
6546
6547        # case 3: 3d reshape (previously failing with different issue)
6548        class Reshape3d(torch.nn.Module):
6549            def forward(self, x, y):
6550                return x.reshape([-1]) + y  # s0*s1*s2 = s3
6551
6552        inputs = (
6553            torch.randn(4, 3, 2),
6554            torch.randn(24),
6555        )
6556        dynamic_shapes = {
6557            "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
6558            "y": (Dim("dy", min=8),),
6559        }
6560        ep = torch.export._trace._export(
6561            Reshape3d(),
6562            inputs,
6563            dynamic_shapes=dynamic_shapes,
6564            allow_complex_guards_as_runtime_asserts=True,
6565        )
6566        out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
6567        self.assertEqual(out1.shape, torch.ones(126).shape)
6568        with self.assertRaisesRegex(
6569            RuntimeError,
6570            r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'",
6571        ):  # fail only at runtime
6572            ep.module()(torch.randn(4, 3, 2), torch.randn(10))  # fail
6573
6574    def test_disable_forced_specializations_errors(self):
6575        # check error messages with hybrid symints
6576        class Foo(torch.nn.Module):
6577            def forward(self, w, x, y, z):
6578                return w.reshape([-1]) + x, y + z  # simple: s0*s1 = s2, s3 = s4
6579
6580        inputs = (
6581            torch.randn(3, 4),
6582            torch.randn(12),
6583            torch.randn(4),
6584            torch.randn(4),
6585        )
6586        dynamic_shapes = {
6587            "w": [Dim(f"dw{i}") for i in range(2)],
6588            "x": [Dim(f"dx{i}") for i in range(1)],
6589            "y": [Dim("dy")],  # y & z incorrect, export is supposed to fail.
6590            "z": [Dim("dz")],  # suggested fix should be to match these up.
6591        }
6592        with self.assertRaisesRegex(  # if disable=True, suggested fixes should not specialize.
6593            torch._dynamo.exc.UserError,
6594            r".*Constraints violated(.*\n)*"
6595            r"Suggested fixes:(.*\n)*"
6596            r".*dz = dy(.*\n)*",
6597        ) as msg:
6598            export(
6599                Foo(),
6600                inputs,
6601                dynamic_shapes=dynamic_shapes,
6602                strict=False,
6603            )
6604
6605    # TODO requires_grad doesn't seem to work with serialization.
6606    @testing.expectedFailureSerDer
6607    def test_preserve_requires_grad_placeholders(self):
6608        class Module(torch.nn.Module):
6609            def __init__(self) -> None:
6610                super().__init__()
6611                self.p = torch.nn.Parameter(torch.randn(3, 3))
6612
6613            def forward(self, x, y):
6614                return self.p + x + y
6615
6616        m = Module()
6617        ep = export(m, (torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)))
6618        placeholders = [
6619            node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
6620        ]
6621        self.assertTrue(placeholders[0].meta["val"].requires_grad)
6622        self.assertFalse(placeholders[1].meta["val"].requires_grad)
6623        self.assertTrue(placeholders[2].meta["val"].requires_grad)
6624
6625    def test_reshape_view_helper(self):
6626        # see: https://github.com/pytorch/pytorch/issues/126607
6627        class Model(torch.nn.Module):
6628            def __init__(self) -> None:
6629                super().__init__()
6630
6631            def forward(self, x):
6632                x = x.view(x.size(1), -1)
6633                # torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?)
6634                # Ne(s0, 20), so that reshape isn't no-op
6635                # Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16]
6636                # then split_dim -> [20, s0, 16]
6637                # check that these show up in graph
6638                return torch.nn.functional.softmax(
6639                    x, dim=0
6640                )  # don't think softmax actually creates any issues, just part of original test
6641
6642        model = Model()
6643        x = torch.rand(1024, 20, 16)
6644        dynamic_shapes = {"x": {0: Dim("batch")}}
6645        ep = torch.export._trace._export(
6646            model,
6647            (x,),
6648            dynamic_shapes=dynamic_shapes,
6649            allow_complex_guards_as_runtime_asserts=True,
6650        )
6651        with self.assertRaisesRegex(
6652            RuntimeError,
6653            r"Runtime assertion failed for expression Ne\(s0, 20\)",
6654        ):
6655            ep.module()(torch.randn(20, 20, 16))
6656        with self.assertRaisesRegex(
6657            RuntimeError,
6658            r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)",
6659        ):
6660            ep.module()(torch.randn(400, 20, 16))
6661        ep.module()(torch.randn(42, 20, 16))
6662
6663    def test_allow_explicit_guards_as_runtime_asserts(self):
6664        # check that explicit guards are treated as runtime assertions
6665        class Foo(torch.nn.Module):
6666            def forward(self, x, y):
6667                # check that negation of first guard also shows up as runtime assertion
6668                if x.shape[0] == y.shape[0]:  # False
6669                    return x + y
6670                elif x.shape[0] == y.shape[0] ** 3:  # False
6671                    return x + 2, y + 3
6672                elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
6673                    return x * 2.0, y * 3.0
6674
6675        inputs = (torch.randn(6), torch.randn(12))
6676        dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
6677        ep = torch.export._trace._export(
6678            Foo(),
6679            inputs,
6680            dynamic_shapes=dynamic_shapes,
6681            allow_complex_guards_as_runtime_asserts=True,
6682        )
6683        # check forward pass
6684        out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
6685        self.assertEqual(out0.shape, torch.ones(9).shape)
6686        self.assertEqual(out1.shape, torch.ones(27).shape)
6687        with self.assertRaisesRegex(
6688            RuntimeError,
6689            r"Runtime assertion failed for expression Ne\(s0, s1\)",
6690        ):  # fail only at runtime
6691            ep.module()(torch.randn(4), torch.randn(4))  # fail
6692        with self.assertRaisesRegex(
6693            RuntimeError,
6694            r"Runtime assertion failed for expression Ne\(s0, s1\**3\)",
6695        ):
6696            ep.module()(torch.randn(64), torch.randn(4))  # fail
6697        with self.assertRaisesRegex(
6698            RuntimeError,
6699            r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)",
6700        ):
6701            ep.module()(torch.randn(10), torch.randn(9))  # fail
6702
6703        # this should be set with command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1,
6704        # but dynamo checks that at torch import time, so setting os.environ makes no difference
6705        # instead, manually patch dynamo config and test.
6706        # test that setting this flag removes runtime asserts
6707        from torch._dynamo import config as _dynamo_config
6708
6709        with _dynamo_config.patch(
6710            do_not_emit_runtime_asserts=True,
6711        ):
6712            ep = torch.export._trace._export(
6713                Foo(),
6714                inputs,
6715                dynamic_shapes=dynamic_shapes,
6716                allow_complex_guards_as_runtime_asserts=True,
6717            ).run_decompositions()
6718
6719        self.assertEqual(
6720            [
6721                node.target == torch.ops.aten._assert_scalar.default
6722                for node in ep.graph.nodes
6723            ].count(True),
6724            0,
6725        )
6726
6727    def test_constant_aliasing(self):
6728        class M1(torch.nn.Module):
6729            def __init__(self, m2, foo):
6730                super().__init__()
6731                self.m2 = m2
6732                self.foo = foo
6733
6734            def forward(self, x):
6735                return x + self.foo + self.m2(x)
6736
6737        class M2(torch.nn.Module):
6738            def __init__(self) -> None:
6739                super().__init__()
6740                self.foo = torch.ones(3, 3)
6741
6742            def forward(self, x):
6743                return x + self.foo
6744
6745        m2 = M2()
6746        m1 = M1(m2, m2.foo)
6747        inps = (torch.ones(3, 3),)
6748        ep = torch.export.export(m1, inps, strict=False)
6749        # check both constants appear in list
6750        self.assertEqual(sorted(list(ep.constants)), ["foo", "m2.foo"])
6751        # check only one input spec exists
6752        num_constant_inputs = [
6753            spec.kind == InputKind.CONSTANT_TENSOR
6754            for spec in ep.graph_signature.input_specs
6755        ].count(True)
6756        self.assertEqual(num_constant_inputs, 1)
6757        # unflatten
6758        unflattened = unflatten(ep)
6759        self.assertTrue(torch.allclose(m1(*inps), unflattened(*inps)))
6760
6761    @testing.expectedFailureRetraceability
6762    def test_unused_aliases(self):
6763        class Foo(torch.nn.Module):
6764            def __init__(self) -> None:
6765                super().__init__()
6766                # param
6767                self.alpha = torch.nn.Parameter(torch.randn(4))
6768                self.beta = self.alpha
6769                self.gamma = self.alpha
6770
6771            def forward(self, x):
6772                return x + self.gamma
6773
6774        inps = (torch.randn(4),)
6775        ep = export(Foo(), inps)
6776        # placeholder nodes will be deduplicated in strict-mode,
6777        # but check that all params still appear in state dict
6778        for param in ["alpha", "beta", "gamma"]:
6779            self.assertTrue(param in ep.state_dict)
6780
6781        # check that they also appear in unflattened state dict
6782        unep = unflatten(ep)
6783        for param in ["alpha", "beta", "gamma"]:
6784            self.assertTrue(param in unep.state_dict())
6785
6786    def test_intermediate_shape_comp(self):
6787        class Foo(torch.nn.Module):
6788            def forward(self, x, y):
6789                z = torch.cat([x, x], dim=0)
6790                w = z.repeat(y.shape[0])
6791                return w.shape[0] + x.shape[0]
6792
6793        inputs = (torch.randn(6), torch.randn(4))
6794        shapes = {
6795            "x": (Dim("dx0"),),
6796            "y": (Dim("dy"),),
6797        }
6798        ep = export(
6799            Foo(),
6800            inputs,
6801            dynamic_shapes=shapes,
6802        )
6803        # test that shape is from size compute, not sym_size call
6804        add_node = [node for node in ep.graph.nodes if node.target == operator.add][0]
6805        self.assertTrue(add_node.args[0].target == operator.mul)
6806        # test sym_size calls only happen on placeholders
6807        sym_size_nodes = [
6808            node
6809            for node in ep.graph.nodes
6810            if node.target == torch.ops.aten.sym_size.int
6811        ]
6812        self.assertEqual(len(sym_size_nodes), 2)
6813        self.assertTrue(
6814            all(node.args[0].op == "placeholder" for node in sym_size_nodes)
6815        )
6816        # dynamo will DCE the repeat node, AOTAutograd will leave it
6817        # training IR will also DCE due to retracing
6818        repeat_nodes = [
6819            node
6820            for node in ep.graph.nodes
6821            if node.target == torch.ops.aten.repeat.default
6822        ]
6823        self.assertEqual(
6824            len(repeat_nodes),
6825            1
6826            if is_non_strict_test(self._testMethodName)
6827            and not is_training_ir_test(self._testMethodName)
6828            else 0,
6829        )
6830
6831    def test_checks_to_constrain_range(self):
6832        class Foo(torch.nn.Module):
6833            def forward(self, x, y):
6834                n = y.item()
6835                m = y.item()
6836                torch._check_is_size(n)
6837                torch._check(m >= 0)
6838                torch._check(n >= 3)
6839                torch._check(-m >= -9)  # m <= 9
6840                torch._check(n <= 6)
6841                # n has range [3, 9]
6842                return x[:n]
6843
6844        inputs = (torch.randn(10), torch.tensor(6))
6845        ep = export(Foo(), inputs)
6846        FileCheck().check_count(
6847            "torch.ops.aten._assert_scalar.default", 2, exactly=True
6848        ).run(ep.graph_module.code)
6849        FileCheck().check_count(
6850            "torch.ops.aten.sym_constrain_range.default", 0, exactly=True
6851        ).run(ep.graph_module.code)
6852        FileCheck().check_count(
6853            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
6854        ).run(ep.graph_module.code)
6855
6856        ep = ep.run_decompositions()
6857        FileCheck().check_count(
6858            "torch.ops.aten._assert_scalar.default", 2, exactly=True
6859        ).run(ep.graph_module.code)
6860        FileCheck().check_count(
6861            "torch.ops.aten.sym_constrain_range.default", 0, exactly=True
6862        ).run(ep.graph_module.code)
6863        FileCheck().check_count(
6864            "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True
6865        ).run(ep.graph_module.code)
6866
6867        # check runtime
6868        ep.module()(torch.randn(10), torch.tensor(5))
6869        with self.assertRaisesRegex(
6870            RuntimeError,
6871            r"Runtime assertion failed for expression u[\d+] \>\= 3",
6872        ):
6873            ep.module()(torch.randn(10), torch.tensor(2))
6874
6875    def test_cse_for_symint(self):
6876        class Foo(torch.nn.Module):
6877            # check sym ops only get computed once
6878            def forward(self, x, y):
6879                if (
6880                    x.shape[0] ** 2 - y.shape[0] ** 2 >= 4  # 16
6881                    and x.shape[0] ** 2 - y.shape[0] ** 2 <= 20
6882                    and x.shape[0] ** 2 - y.shape[0] ** 2 != 15
6883                ):
6884                    return x * 2, y * 2
6885
6886        inputs = (torch.randn(5), torch.randn(3))
6887        shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
6888        ep = torch.export._trace._export(
6889            Foo(),
6890            inputs,
6891            dynamic_shapes=shapes,
6892            allow_complex_guards_as_runtime_asserts=True,
6893        )
6894        # count 2 pow nodes, 2 sym_size.int nodes
6895        self.assertEqual(
6896            [node.target for node in ep.graph.nodes].count(
6897                operator.pow,
6898            ),
6899            2,
6900        )
6901        FileCheck().check_count("torch.ops.aten.sym_size.int", 2, exactly=True).run(
6902            ep.graph_module.code
6903        )
6904
6905        ep = ep.run_decompositions()
6906        self.assertEqual(
6907            [node.target for node in ep.graph.nodes].count(
6908                operator.pow,
6909            ),
6910            2,
6911        )
6912        FileCheck().check_count("torch.ops.aten.sym_size.int", 2, exactly=True).run(
6913            ep.graph_module.code
6914        )
6915
6916    def test_slice_with_floordiv(self):
6917        # slice operation emits runtime assert s0//2 <= s1
6918        class M1(torch.nn.Module):
6919            def forward(self, x, y):
6920                d = x.size(0) // 2
6921                return y[d:]
6922
6923        class M(torch.nn.Module):
6924            def __init__(self) -> None:
6925                super().__init__()
6926                self.m1 = M1()
6927
6928            def forward(self, x, y):
6929                d = x.size(0) // 2
6930                m1_res = self.m1(x, y)
6931                return y[d:] + m1_res
6932
6933        inputs = (torch.ones(10), torch.ones(10))
6934        d0 = torch.export.Dim("d0", max=2048)
6935        d1 = torch.export.Dim("d1", max=2048)
6936        ep = export(
6937            M(),
6938            inputs,
6939            dynamic_shapes=((d0,), (d1,)),
6940        )
6941        ep.module()(torch.ones(8), torch.ones(4))
6942        ep.module()(torch.ones(8), torch.ones(5))
6943        with self.assertRaisesRegex(
6944            RuntimeError,
6945            r"Runtime assertion failed for expression \(s0//2\) \<\= s1",
6946        ):
6947            ep.module()(torch.ones(10), torch.ones(4))
6948
6949    def test_split_const_gm_with_lifted_constants(self):
6950        class Model(torch.nn.Module):
6951            def __init__(self) -> None:
6952                super().__init__()
6953                self.w_pre = torch.randn(4, 4)
6954                self.b = torch.randn(4)
6955
6956            def forward(self, x):
6957                w_transpose = torch.transpose(self.w_pre, 0, 1)
6958                w_relu = torch.nn.functional.relu(w_transpose)
6959                w = w_relu + self.b
6960                return torch.matmul(x, w)
6961
6962        example_inputs = (torch.randn(4, 4),)
6963        mod = Model()
6964        ep = torch.export.export(mod, example_inputs)
6965        new_gm = copy.deepcopy(ep.graph_module)
6966        new_sig = copy.deepcopy(ep.graph_signature)
6967        placeholder_nodes = [
6968            node for node in new_gm.graph.nodes if node.op == "placeholder"
6969        ]
6970        constants = {**ep.state_dict, **ep.constants}
6971        lifted_constants = {
6972            n.name: constants[spec.target]
6973            for n, spec in zip(placeholder_nodes, new_sig.input_specs)
6974            if spec.target is not None
6975        }
6976        const_gm, _ = split_const_gm(new_gm, lifted_constants)
6977        counter = 0
6978        for node in const_gm.graph.nodes:
6979            if node.op == "call_function":
6980                counter += 1
6981        self.assertTrue(counter > 0)
6982        test_input = torch.randn(4, 4)
6983        expected = new_gm(None, None, test_input)[0]
6984        actual = mod(test_input)
6985        self.assertEqual(actual, expected)
6986        const_gm, _ = split_const_gm(ep.graph_module, lifted_constants, lambda x: True)
6987        counter = 0
6988        for node in const_gm.graph.nodes:
6989            if node.op == "call_function":
6990                self.assertTrue(False)
6991
6992    @testing.expectedFailureTrainingIRToRunDecomp  # T200904004
6993    @testing.expectedFailureTrainingIRToRunDecompNonStrict
6994    def test_istft_op(self):
6995        class istft_class(torch.nn.Module):
6996            def forward(self, spec):
6997                window = torch.hann_window(1024).type(torch.FloatTensor)
6998                return torch.istft(
6999                    spec,
7000                    n_fft=1024,
7001                    hop_length=512,
7002                    window=window,
7003                    length=144000,
7004                )
7005
7006        model = istft_class()
7007        real_part = torch.randn(1, 513, 282, dtype=torch.float32)
7008        imaginary_part = torch.randn(1, 513, 282, dtype=torch.float32)
7009        spec = torch.complex(real_part, imaginary_part)
7010        export(model, (spec,))
7011
7012    def test_automatic_dynamic_shapes_simple_equality(self):
7013        # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism
7014        # leads to replacement symbols being set for equalities, and inferred relationships being checked
7015        # with runtime asserts. Check that we specialize to static values when the program says so.
7016        AUTO, STATIC = Dim.AUTO, Dim.STATIC
7017
7018        # case 1: direct equality between symbols
7019        class SimpleEquality(torch.nn.Module):
7020            def forward(self, x, y, z):
7021                # all inputs should have shape [s0, s1]
7022                return x + y + z
7023
7024        inputs = tuple(torch.randn(6, 3) for _ in range(3))
7025        # fully dynamic
7026        self._check_dynamic_shapes_specs_and_shapes(
7027            SimpleEquality(),
7028            inputs,
7029            specs=[
7030                ((AUTO, AUTO), (AUTO, AUTO), (AUTO, AUTO)),
7031                [[AUTO, AUTO], [AUTO, AUTO], [AUTO, AUTO]],
7032                {"x": (AUTO, AUTO), "y": (AUTO, AUTO), "z": (AUTO, AUTO)},
7033            ],
7034            passing_shapes=[
7035                ((4, 4), (4, 4), (4, 4)),
7036                ((1, 1), (1, 1), (1, 1)),
7037                ((0, 9), (0, 9), (0, 9)),
7038            ],
7039            failing_shapes=[
7040                ((4, 4), (4, 4), (4, 3)),
7041                ((4, 4), (5, 4), (4, 5)),
7042            ],
7043            test_serdes=True,
7044        )
7045        # static s1
7046        self._check_dynamic_shapes_specs_and_shapes(
7047            # specifying just one dimension as static should be enough to specialize all s1
7048            SimpleEquality(),
7049            inputs,
7050            specs=[
7051                [{0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, (AUTO, None)],
7052                {"x": (AUTO, AUTO), "y": (AUTO, AUTO), "z": (AUTO, None)},
7053            ],
7054            passing_shapes=[
7055                ((4, 3), (4, 3), (4, 3)),
7056                ((1, 3), (1, 3), (1, 3)),
7057                ((0, 3), (0, 3), (0, 3)),
7058            ],
7059            failing_shapes=[
7060                ((4, 4), (4, 4), (4, 4)),
7061                ((1, 1), (1, 1), (1, 1)),
7062                ((0, 9), (0, 9), (0, 9)),
7063            ],
7064            test_serdes=True,
7065        )
7066        # fully static
7067        self._check_dynamic_shapes_specs_and_shapes(
7068            # this should specialize all
7069            SimpleEquality(),
7070            inputs,
7071            specs=[{"x": (None, AUTO), "y": (AUTO, AUTO), "z": (AUTO, None)}],
7072            passing_shapes=[
7073                ((6, 3), (6, 3), (6, 3)),
7074            ],
7075            failing_shapes=[
7076                ((6, 4), (6, 4), (6, 4)),
7077                ((1, 3), (1, 3), (1, 3)),
7078                ((0, 9), (0, 9), (0, 9)),
7079            ],
7080            test_serdes=True,
7081        )
7082
7083    def test_automatic_dynamic_shapes_constant_relation(self):
7084        AUTO, STATIC = Dim.AUTO, Dim.STATIC
7085
7086        # case 2: related by constant: s0 + 4 = s1
7087        class OffBy4(torch.nn.Module):
7088            def forward(self, x, y):
7089                return x + y[4:]
7090
7091        inputs = (torch.randn(6), torch.randn(10))
7092        # fully dynamic
7093        self._check_dynamic_shapes_specs_and_shapes(
7094            OffBy4(),
7095            inputs,
7096            specs=[
7097                ((AUTO,), (AUTO,)),
7098                {"x": (AUTO,), "y": (AUTO,)},
7099            ],
7100            passing_shapes=[
7101                ((10,), (14,)),
7102                ((3,), (7,)),
7103                ((2,), (6,)),
7104            ],
7105            failing_shapes=[
7106                ((10,), (13,)),
7107            ],
7108            test_serdes=True,
7109        )
7110        # static s1 should specialize s0
7111        self._check_dynamic_shapes_specs_and_shapes(
7112            OffBy4(),
7113            inputs,
7114            specs=[
7115                {"x": (AUTO,), "y": (None,)},
7116            ],
7117            passing_shapes=[
7118                ((6,), (10,)),
7119            ],
7120            failing_shapes=[
7121                ((10,), (14,)),
7122                ((3,), (7,)),
7123                ((2,), (6,)),
7124            ],
7125            test_serdes=True,
7126        )
7127
7128    def test_automatic_dynamic_shapes_linear_relation(self):
7129        AUTO, STATIC = Dim.AUTO, Dim.STATIC
7130
7131        # case 3: linear relation
7132        class LinearRel(torch.nn.Module):
7133            def forward(self, x, y):
7134                # x: [s0], y: [s1]
7135                # relation seems to be (s0 + 2) // 4 == s1
7136                return x[1::4] + y
7137
7138        inputs = (torch.randn(21), torch.randn(5))
7139
7140        # fully dynamic
7141        self._check_dynamic_shapes_specs_and_shapes(
7142            LinearRel(),
7143            inputs,
7144            specs=[
7145                ((AUTO,), (AUTO,)),
7146                {"x": (AUTO,), "y": (AUTO,)},
7147            ],
7148            passing_shapes=[
7149                ((33,), (8,)),
7150                ((32,), (8,)),
7151                ((31,), (8,)),
7152                ((30,), (8,)),
7153            ],
7154            failing_shapes=[
7155                ((34,), (8,)),
7156                ((22,), (5,)),
7157            ],
7158            test_serdes=False,
7159        )
7160        # static s1 shouldn't actually specialize s0 (guard: (s0 + 2) // 4 == 5)
7161        self._check_dynamic_shapes_specs_and_shapes(
7162            LinearRel(),
7163            inputs,
7164            specs=[
7165                ((AUTO,), None),
7166                {"x": (AUTO,), "y": None},
7167            ],
7168            passing_shapes=[
7169                ((21,), (5,)),
7170                ((20,), (5,)),
7171                ((19,), (5,)),
7172                ((18,), (5,)),
7173            ],
7174            failing_shapes=[
7175                ((33,), (8,)),
7176            ],
7177            test_serdes=False,
7178        )
7179        # but static s0 will definitely specialize s1 (guard: (21 + 2) // 4 == s1 -> 5 == s1)
7180        self._check_dynamic_shapes_specs_and_shapes(
7181            LinearRel(),
7182            inputs,
7183            specs=[
7184                (None, (AUTO,)),
7185            ],
7186            passing_shapes=[
7187                ((21,), (5,)),
7188            ],
7189            failing_shapes=[
7190                ((22,), (5,)),
7191            ],
7192            test_serdes=True,
7193        )
7194
7195    def test_dynamic_shapes_serdes_generic(self):
7196        from torch._export.serde.dynamic_shapes import (
7197            _dump_dynamic_shapes,
7198            _load_dynamic_shapes,
7199        )
7200
7201        class Foo(torch.nn.Module):
7202            def forward(self, a, b, c, d):
7203                if d == "hello":
7204                    x = a[0] + a[1][1:]
7205                    b = torch.cat([b, b], dim=0).reshape([-1, 1])
7206                    return x + b, c * 2
7207
7208        # test de/serialization on some generic specs
7209        dz = Dim("dz", min=4, max=16)
7210        dx = 2 * dz
7211        dy = dx + 1
7212        inputs = (
7213            [
7214                torch.randn(8, 4),
7215                torch.randn(9, 4),
7216            ],
7217            torch.randn(4),
7218            torch.randn(4, 4),
7219            "hello",
7220        )
7221        dynamic_shapes = {
7222            "a": [
7223                (dx, 4),
7224                (dy, 4),
7225            ],
7226            "b": (dz,),
7227            "c": None,
7228            "d": None,
7229        }
7230        ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
7231        self._check_dynamic_shapes_specs_and_shapes(
7232            Foo(),
7233            inputs,
7234            [dynamic_shapes],
7235            [
7236                ([(16, 4), (17, 4)], (8,), (4, 4), "hello"),
7237                ([(24, 4), (25, 4)], (12,), (4, 4), "hello"),
7238            ],
7239            [
7240                ([(16, 4), (17, 4)], (8,), (5, 5), "hello"),
7241            ],
7242            test_serdes=True,
7243        )
7244        self.assertExpectedInline(
7245            _dump_dynamic_shapes(dynamic_shapes, inputs),
7246            """DynamicShapesSpec(dynamic_shapes=([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), dims={'dz': RootDim(min=4, max=16, derived=['2*dz', '2*dz + 1'])})""",
7247        )
7248        self.assertExpectedInline(
7249            _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True),
7250            """{'dynamic_shapes': ([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), 'dims': {'dz': {'min': 4, 'max': 16, 'derived': ['2*dz', '2*dz + 1']}}}""",
7251        )
7252        ((dx, _), (dy, _)), (dz,), (_, _), _ = _load_dynamic_shapes(
7253            _dump_dynamic_shapes(dynamic_shapes, inputs)
7254        )
7255        self.assertEqual(dx.root, dz)
7256        self.assertEqual(dy.root, dz)
7257
7258    def test_dynamic_shapes_serdes_various(self):
7259        # serialization for dataclass inputs, Dim.AUTO/STATIC, and kwargs
7260        from torch._export.serde.dynamic_shapes import (
7261            _dump_dynamic_shapes,
7262            _load_dynamic_shapes,
7263        )
7264
7265        auto, static = Dim.AUTO, Dim.STATIC
7266
7267        @dataclass
7268        class Input:
7269            a: Tensor
7270            b: Tensor
7271
7272        register_dataclass_as_pytree_node(
7273            Input,
7274            serialized_type_name="test_dynamic_shapes_serdes_various.Input",
7275        )
7276
7277        class Foo(torch.nn.Module):
7278            def forward(self, x, y, z):
7279                return x - torch.randn(4), y.a + y.b + z[1:]
7280
7281        args = (torch.randn(4, 4),)
7282        kwargs = {
7283            "y": Input(a=torch.randn(8, 8), b=torch.randn(8, 8)),
7284            "z": torch.randn(9, 8),
7285        }
7286        dynamic_shapes = {
7287            "x": (auto, static),
7288            "y": [(auto, auto), (auto, auto)],
7289            "z": (auto, 8),
7290        }
7291
7292        # dump dynamic_shapes
7293        self.assertExpectedInline(
7294            _dump_dynamic_shapes(dynamic_shapes, args, kwargs),
7295            """DynamicShapesSpec(dynamic_shapes=(['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), dims={})""",
7296        )
7297        self.assertExpectedInline(
7298            _dump_dynamic_shapes(dynamic_shapes, args, kwargs, to_dict=True),
7299            """{'dynamic_shapes': (['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), 'dims': {}}""",
7300        )
7301
7302    def test_dynamic_shapes_serdes_user_errors(self):
7303        # check error messages for dynamic shapes de/serialization
7304        from torch._export.serde.dynamic_shapes import (
7305            _dump_dynamic_shapes,
7306            _load_dynamic_shapes,
7307            DynamicShapesSpec,
7308            RootDim,
7309        )
7310        from torch._export.serde.serialize import _dataclass_to_dict
7311
7312        # this stuff should be well tested in `test_mismatched_dynamic_shapes`
7313        with self.assertRaisesRegex(
7314            torch._dynamo.exc.UserError,
7315            re.escape(
7316                "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]['k']` "
7317                "is a <class 'list'>, but `dynamic_shapes[0]['k']` is a <class 'tuple'>"
7318            ),
7319        ):
7320            dynamic_shapes = {"x": {"k": (Dim("dx"), Dim("dy"))}}
7321            _dump_dynamic_shapes(dynamic_shapes, ({"k": [torch.randn(4, 4)]},))
7322
7323        # loading with from_dict=True/False
7324        spec = DynamicShapesSpec(
7325            dynamic_shapes=[["dx"]],
7326            dims={"dx": RootDim(min=4, max=16, derived=[])},
7327        )
7328        spec_dict = _dataclass_to_dict(spec)
7329        with self.assertRaisesRegex(
7330            torch._dynamo.exc.UserError,
7331            re.escape(
7332                "With from_dict=True, expected `spec` to be a dict, "
7333                "got <class 'torch._export.serde.dynamic_shapes.DynamicShapesSpec'>"
7334            ),
7335        ):
7336            _load_dynamic_shapes(spec, from_dict=True)
7337
7338        with self.assertRaisesRegex(
7339            torch._dynamo.exc.UserError,
7340            re.escape("Expected `spec` to be a DynamicShapesSpec, got <class 'dict'>"),
7341        ):
7342            _load_dynamic_shapes(spec_dict, from_dict=False)
7343
7344        self.assertExpectedInline(
7345            _load_dynamic_shapes(spec, from_dict=False),
7346            """[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""",
7347        )
7348
7349        # check incorrect info in dims
7350        with self.assertRaisesRegex(
7351            torch._dynamo.exc.UserError,
7352            re.escape(
7353                "Expected dims in `spec['dims']` to map `min` to an int, got dx: None"
7354            ),
7355        ):
7356            spec = {
7357                "dynamic_shapes": [["dx"]],
7358                "dims": {
7359                    "dx": {
7360                        "min": None,
7361                        "max": 4,
7362                        "derived": [],
7363                    },
7364                },
7365            }
7366            _load_dynamic_shapes(spec, from_dict=True)
7367
7368        with self.assertRaisesRegex(
7369            torch._dynamo.exc.UserError,
7370            re.escape(
7371                "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
7372                "got dx which is not in dict_keys(['dy'])"
7373            ),
7374        ):
7375            spec = {
7376                "dynamic_shapes": [["dx"]],
7377                "dims": {
7378                    "dy": {
7379                        "min": 2,
7380                        "max": 4,
7381                        "derived": [],
7382                    },
7383                },
7384            }
7385            _load_dynamic_shapes(spec, from_dict=True)
7386
7387        with self.assertRaisesRegex(
7388            torch._dynamo.exc.UserError,
7389            re.escape(
7390                "Expected derived expressions to be linear expressions, got dx**2 + 4"
7391            ),
7392        ):
7393            spec = {
7394                "dynamic_shapes": [["dx"]],
7395                "dims": {
7396                    "dx": {
7397                        "min": 2,
7398                        "max": 4,
7399                        "derived": ["dx**2 + 4"],
7400                    },
7401                },
7402            }
7403            _load_dynamic_shapes(spec, from_dict=True)
7404
7405    @testing.expectedFailureNonStrict
7406    @testing.expectedFailureTrainingIRToRunDecompNonStrict  # unbacked symint not tracked?
7407    @testing.expectedFailureSerDer  # T195866111
7408    def test_hints_wrapper(self):
7409        class M(torch.nn.Module):
7410            def __init__(self) -> None:
7411                super().__init__()
7412
7413            def forward(self, x, y):
7414                x = x + y
7415
7416                def inner_body_fn(x, y):
7417                    x = torch.relu(x)
7418                    x = x + y
7419                    return x
7420
7421                def outer_body_fn(x, y):
7422                    x = hints_wrapper(
7423                        inner_body_fn, (x, y), {}, hints={"inner_body": True}
7424                    )
7425                    x = torch.abs(x)
7426                    return x
7427
7428                res = hints_wrapper(
7429                    outer_body_fn, (x, y), {}, hints={"outer_body": True}
7430                )
7431                return res
7432
7433        x = torch.randn(2, 4)
7434        y = torch.ones(4)
7435
7436        ep = export(M(), (x, y))
7437        export_res = ep.module()(x, y)
7438        ref_res = M()(x, y)
7439        self.assertEqual(export_res, ref_res)
7440        self.assertExpectedInline(
7441            normalize_gm(ep.graph_module.print_readable(print_output=False)),
7442            """\
7443class GraphModule(torch.nn.Module):
7444    def forward(self, x: "f32[2, 4]", y: "f32[4]"):
7445        add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y);  x = None
7446
7447        hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
7448        hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True});  hints_wrapper_body_graph_0 = add = y = None
7449        getitem: "f32[2, 4]" = hints_wrapper[0];  hints_wrapper = None
7450        return (getitem,)
7451
7452    class hints_wrapper_body_graph_0(torch.nn.Module):
7453        def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
7454            hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
7455            hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True});  hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
7456            getitem: "f32[2, 4]" = hints_wrapper[0];  hints_wrapper = None
7457            abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem);  getitem = None
7458            return (abs_1,)
7459
7460        class hints_wrapper_body_graph_0(torch.nn.Module):
7461            def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
7462                relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1);  arg0_1 = None
7463                add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1);  relu = arg1_1 = None
7464                return (add,)
7465""",
7466        )
7467
7468
7469@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
7470class TestOneOffModelExportResult(TestCase):
7471    def test_scaled_dot_product_attention_cpu(self):
7472        """
7473        This test makes sure we are always getting the same decomposition result for SDPA.
7474        As of now _scaled_dot_product_flash_attention_for_cpu is expected to show up in
7475        export() result. Some downstream backend then further decompose it into core ATen
7476        ops in torch/_decomp/decompositions.py (search for
7477        _scaled_dot_product_flash_attention_for_cpu).
7478
7479        Export is decomposing based on the CompositeImplicitAutograd kernel implementation
7480        of SDPA. If this test fails, it means the kernel is being modified. In this case
7481        we strongly encourage you to change the decomposition rule under
7482        torch/_decomp/decompositions.py along with the kernel changes, so all of the
7483        downstream backends are not being affected.
7484        """
7485
7486        class ScaledDotProductAttention(torch.nn.Module):
7487            def __init__(self) -> None:
7488                super().__init__()
7489
7490            def forward(self, q, k, v):
7491                attn_output = F.scaled_dot_product_attention(
7492                    q, k, v, None, dropout_p=0.0, is_causal=True
7493                )
7494                return attn_output
7495
7496        q = torch.randn(1, 1, 8, 8, device="cpu")
7497        k = torch.randn(1, 1, 8, 8, device="cpu")
7498        v = torch.randn(1, 1, 8, 8, device="cpu")
7499
7500        from torch.nn.attention import SDPBackend
7501
7502        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]):
7503            ep = torch.export.export(ScaledDotProductAttention(), (q, k, v))
7504            print(ep.graph)
7505            ep.run_decompositions()
7506            print(ep.graph)
7507
7508    #         self.assertExpectedInline(ep.graph_module.code.strip(), """\
7509    # def forward(self, arg0_1, arg1_1, arg2_1):
7510    #     _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(arg0_1, arg1_1, arg2_1, 0.0, True);  arg0_1 = arg1_1 = arg2_1 = None
7511    #     getitem = _scaled_dot_product_flash_attention_for_cpu[0];  _scaled_dot_product_flash_attention_for_cpu = None
7512    #     return (getitem,)""")
7513
7514    @unittest.skipIf(
7515        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
7516        "Can't run fused SDPA on this platform",
7517    )
7518    def test_scaled_dot_product_attention_cuda(self):
7519        """
7520        This test makes sure we are always getting the same decomposition result for SDPA.
7521        As of now _scaled_dot_product_flash_attention is expected to show up in
7522        export() result (GPU tensors are given). Currently there's no downstream
7523        backend relies on this export result so if this test fails, feel free to
7524        change it to the latest export() result.
7525        """
7526
7527        class ScaledDotProductAttention(torch.nn.Module):
7528            def __init__(self) -> None:
7529                super().__init__()
7530
7531            def forward(self, q, k, v):
7532                attn_output = F.scaled_dot_product_attention(
7533                    q, k, v, None, dropout_p=0.0, is_causal=True
7534                )
7535                return attn_output
7536
7537        q = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda")
7538        k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda")
7539        v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda")
7540
7541        ep = torch.export.export(
7542            ScaledDotProductAttention(), (q, k, v)
7543        ).run_decompositions()
7544        code_str = """\
7545def forward(self, q, k, v):
7546    _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(q, k, v, 0.0, True, scale = 0.125);  q = k = v = None
7547    getitem = _scaled_dot_product_flash_attention[0];  _scaled_dot_product_flash_attention = None
7548    return (getitem,)"""
7549        if SM90OrLater and not torch.version.hip:
7550            code_str = """\
7551def forward(self, q, k, v):
7552    _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(q, k, v, None, False, 0.0, True);  q = k = v = None
7553    getitem = _scaled_dot_product_cudnn_attention[0];  _scaled_dot_product_cudnn_attention = None
7554    return (getitem,)"""
7555        self.assertExpectedInline(
7556            ep.graph_module.code.strip(),
7557            code_str,
7558        )
7559
7560    def test_int_list_output(self):
7561        class M(torch.nn.Module):
7562            def forward(self, x):
7563                return [((1, 3), [x + x, x * x])]
7564
7565        ep = torch.export.export(M(), (torch.ones(2, 3),))
7566        res = ep.module()(torch.ones(2, 3))
7567        self.assertEqual(res[0][0], (1, 3))
7568
7569    def test_primitive_constant_output(self):
7570        class Z(torch.nn.Module):
7571            def forward(self, x, y):
7572                with torch.no_grad():
7573                    return y * x, "moo"
7574
7575        ep = torch.export.export(Z(), (torch.tensor(3), 5))
7576        res = ep.module()(torch.tensor(4), 5)
7577        self.assertEqual(res[0], torch.tensor(20))
7578        self.assertEqual(res[1], "moo")
7579
7580        class B(torch.nn.Module):
7581            def forward(self, x, y):
7582                return y * x, y
7583
7584        ep = torch.export.export(B(), (torch.tensor(3), 5))
7585        res = ep.module()(torch.tensor(4), 5)
7586        self.assertEqual(res[0], torch.tensor(20))
7587        self.assertEqual(res[1], 5)
7588
7589        with self.assertRaisesRegex(
7590            RuntimeError,
7591            escape("Expected input at *args[1] to be equal to 5, but got 20"),
7592        ):
7593            res = ep.module()(torch.tensor(4), 20)
7594
7595        class F(torch.nn.Module):
7596            def forward(self, x):
7597                # return a constant of primitive type
7598                y = 5
7599                return y * x, y
7600
7601        ep = torch.export.export(F(), (torch.tensor(3),))
7602        res = ep.module()(torch.tensor(4))
7603        self.assertEqual(res[0], torch.tensor(20))
7604        self.assertEqual(res[1], 5)
7605
7606        class Q(torch.nn.Module):
7607            def forward(self, x, y):
7608                return y * x, y - 1
7609
7610        ep = torch.export.export(Q(), (torch.tensor(3), 5))
7611        res = ep.module()(torch.tensor(4), 5)
7612        self.assertEqual(res[0], torch.tensor(20))
7613        self.assertEqual(res[1], 4)
7614
7615    def test_unbacked_sdpa(self):
7616        import torch
7617        from torch.nn.attention import sdpa_kernel, SDPBackend
7618        from torch.nn.functional import scaled_dot_product_attention
7619
7620        class Module(torch.nn.Module):
7621            def forward(
7622                self, query: torch.Tensor, cache: torch.Tensor, start_pos: torch.Tensor
7623            ) -> torch.Tensor:
7624                # x.sizes(): 1, 128, 16, 128
7625                sp = start_pos.item()
7626                torch._check_is_size(sp)
7627                torch._check(sp >= 0)
7628                torch._check(sp <= 126)
7629                key = cache[:, : sp + 1, :, :]  # 1, sp+1, 16, 128
7630                value = cache[:, : sp + 1, :, :]  # 1, sp+1, 16, 128
7631                query = query.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
7632                key = key.transpose(1, 2)
7633                value = value.transpose(1, 2)
7634                # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L732
7635                return scaled_dot_product_attention(query, key, value)
7636
7637        cache = torch.randn(1, 128, 16, 128, dtype=torch.float16)
7638        query = torch.randn(1, 1, 16, 128, dtype=torch.float16)
7639        start_pos = torch.tensor([0])
7640        with sdpa_kernel(SDPBackend.MATH), torch.no_grad():
7641            ep = torch.export.export(Module(), (query, cache, start_pos))
7642            args = (query, cache, start_pos)
7643            self.assertEqual(ep.module()(*args), Module()(*args))
7644            args = (query, cache, torch.tensor([3]))
7645            self.assertEqual(ep.module()(*args), Module()(*args))
7646            args = (query, cache, torch.tensor([126]))
7647            self.assertEqual(ep.module()(*args), Module()(*args))
7648
7649    def test_none_input_output(self):
7650        class Z(torch.nn.Module):
7651            def forward(self, x, y):
7652                return x * x
7653
7654        ep = torch.export.export(Z(), (torch.tensor(3), None))
7655        res = ep.module()(torch.tensor(4), None)
7656        self.assertEqual(res, torch.tensor(16))
7657
7658        class B(torch.nn.Module):
7659            def forward(self, x, y):
7660                return x * x, y
7661
7662        ep = torch.export.export(B(), (torch.tensor(3), None))
7663        res = ep.module()(torch.tensor(4), None)
7664        self.assertEqual(res[0], torch.tensor(16))
7665        self.assertEqual(res[1], None)
7666
7667        decomp = ep.run_decompositions()
7668        gm = decomp.module()
7669        res = gm(torch.tensor(4), None)
7670        self.assertEqual(res[0], torch.tensor(16))
7671        self.assertEqual(res[1], None)
7672
7673    def test_print(self):
7674        class M(torch.nn.Module):
7675            def forward(self, x):
7676                print("start")
7677                x1 = x + x
7678                print(x1)
7679                x2 = x1 * x1
7680                print(1, 2, 3)
7681                x3 = x2 + x2
7682                return (x1, x3)
7683
7684        gm = export(M(), (torch.randn(3, 3),)).graph_module
7685        self.assertExpectedInline(
7686            gm.code.strip(),
7687            """\
7688def forward(self, x):
7689    add = torch.ops.aten.add.Tensor(x, x);  x = None
7690    mul = torch.ops.aten.mul.Tensor(add, add)
7691    add_1 = torch.ops.aten.add.Tensor(mul, mul);  mul = None
7692    return (add, add_1)""",
7693        )
7694
7695    def test_logging_logger(self):
7696        logger = logging.getLogger(__name__)
7697
7698        class M(torch.nn.Module):
7699            def forward(self, x):
7700                logger.log("start")
7701                x1 = x + x
7702                logger.debug(x1)
7703                x2 = x1 * x1
7704                logger.info(1, 2, 3)
7705                x3 = x2 + x2
7706                return (x1, x3)
7707
7708        gm = export(M(), (torch.randn(3, 3),)).graph_module
7709        self.assertExpectedInline(
7710            gm.code.strip(),
7711            """\
7712def forward(self, x):
7713    add = torch.ops.aten.add.Tensor(x, x);  x = None
7714    mul = torch.ops.aten.mul.Tensor(add, add)
7715    add_1 = torch.ops.aten.add.Tensor(mul, mul);  mul = None
7716    return (add, add_1)""",
7717        )
7718
7719    @unittest.skipIf(not TEST_TRANSFORMERS, "No transformers")
7720    def test_hf_logging_logger(self):
7721        import transformers
7722
7723        logger = transformers.utils.logging.get_logger(__name__)
7724
7725        class M(torch.nn.Module):
7726            def forward(self, x):
7727                logger.warning_once("start")
7728                x1 = x + x
7729                x2 = x1 * x1
7730                x3 = x2 + x2
7731                return (x1, x3)
7732
7733        gm = export(M(), (torch.randn(3, 3),)).graph_module
7734        self.assertExpectedInline(
7735            gm.code.strip(),
7736            """\
7737def forward(self, x):
7738    add = torch.ops.aten.add.Tensor(x, x);  x = None
7739    mul = torch.ops.aten.mul.Tensor(add, add)
7740    add_1 = torch.ops.aten.add.Tensor(mul, mul);  mul = None
7741    return (add, add_1)""",
7742        )
7743
7744    def test_warning(self):
7745        class M(torch.nn.Module):
7746            def forward(self, x):
7747                warnings.warn("moo")
7748                res = x + x
7749                warnings.warn(f"{res}")
7750                return res
7751
7752        gm = export(M(), (torch.randn(3, 3),)).graph_module
7753        self.assertExpectedInline(
7754            gm.code.strip(),
7755            """\
7756def forward(self, x):
7757    add = torch.ops.aten.add.Tensor(x, x);  x = None
7758    return (add,)""",
7759        )
7760
7761    def test_constant_fqn(self):
7762        class Nested(torch.nn.Module):
7763            def __init__(self) -> None:
7764                super().__init__()
7765                self.constant = torch.rand(2, 3)
7766                self.parameter = torch.nn.Parameter(torch.rand(2, 3))
7767
7768            def forward(self, x):
7769                return x + self.constant
7770
7771        class Mod(torch.nn.Module):
7772            def __init__(self) -> None:
7773                super().__init__()
7774                self.nested = Nested()
7775
7776            def forward(self, x):
7777                return self.nested(x) + self.nested.constant + self.nested.parameter
7778
7779        m = Mod()
7780        ep = export(m, (torch.rand(2, 3),), strict=True)
7781        self.assertEqual(ep.constants["nested.constant"], m.nested.constant)
7782        self.assertEqual(ep.module()(torch.ones(2, 3)), m(torch.ones(2, 3)))
7783
7784    def test_constant_name(self):
7785        class Nested(torch.nn.Module):
7786            def __init__(self) -> None:
7787                super().__init__()
7788                self.constant = torch.rand(2, 3)
7789                self.parameter = torch.nn.Parameter(torch.rand(2, 3))
7790
7791            def forward(self, x):
7792                return x + self.constant
7793
7794        class Mod(torch.nn.Module):
7795            def __init__(self) -> None:
7796                super().__init__()
7797                self.nested_1 = Nested()
7798                self.nested_2 = Nested()
7799
7800            def forward(self, x):
7801                return (
7802                    self.nested_1(x)
7803                    + self.nested_2(x)
7804                    + self.nested_1.constant
7805                    + self.nested_2.constant
7806                    + self.nested_1.parameter
7807                    + self.nested_2.parameter
7808                )
7809
7810        m = Mod()
7811        ep = export(m, (torch.rand(2, 3),), strict=False)
7812        self.assertEqual(ep.module()(torch.ones(2, 3)), m(torch.ones(2, 3)))
7813
7814        # check constant fqn when there are multiple instances of the same class
7815        self.assertEqual(ep.constants["nested_1.constant"], m.nested_1.constant)
7816        self.assertEqual(ep.constants["nested_2.constant"], m.nested_2.constant)
7817
7818        # check constant_name in the graph
7819        placeholders = [
7820            node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
7821        ]
7822        self.assertEqual(len(placeholders), 5)
7823        self.assertTrue(all(ph.name == ph.target for ph in placeholders))
7824        # suffix should be added to duplicated constant_name
7825        self.assertEqual(placeholders[2].name, "c_nested_1_constant")
7826        self.assertEqual(placeholders[3].name, "c_nested_2_constant")
7827
7828    def test_nested_retrace(self):
7829        class Nested(torch.nn.Module):
7830            def __init__(self) -> None:
7831                super().__init__()
7832                self.param = torch.nn.Parameter(torch.randn(3))
7833
7834            def forward(self, x):
7835                return x + self.param
7836
7837        class Foo(torch.nn.Module):
7838            def __init__(self) -> None:
7839                super().__init__()
7840                self.nested = Nested()
7841
7842            def forward(self, x):
7843                return x + self.nested(x)
7844
7845        # first export
7846        foo = Foo().to("meta")
7847        inputs = (torch.ones(3, device="meta"),)
7848        foo(*inputs)
7849        ep = torch.export.export(foo, inputs, strict=False)
7850
7851        # second export
7852        foo_1 = ep.module()
7853        ep_1 = torch.export.export(foo_1, inputs, strict=False)
7854
7855        for node1, node2 in zip(ep.graph.nodes, ep_1.graph.nodes):
7856            nn_module_stack_1 = node1.meta.get("nn_module_stack", None)
7857            nn_module_stack_2 = node2.meta.get("nn_module_stack", None)
7858
7859            if nn_module_stack_1 is None:
7860                self.assertTrue(nn_module_stack_2 is None)
7861            else:
7862                for v1, v2 in zip(
7863                    nn_module_stack_1.values(), nn_module_stack_2.values()
7864                ):
7865                    self.assertEqual(v1, v2)
7866
7867    def test_duplicated_getitem(self):
7868        class Foo(torch.nn.Module):
7869            def forward(self, x):
7870                return torch.topk(x, 2)
7871
7872        foo = Foo()
7873        inputs = (torch.randn(3),)
7874        ep = torch.export.export(foo, inputs, strict=False)
7875
7876        graph_module = copy.deepcopy(ep.graph_module)
7877
7878        call_function_node = None
7879        num_getitems = 0
7880        for node in graph_module.graph.nodes:
7881            if (
7882                node.op == "call_function"
7883                and node.target == torch.ops.aten.topk.default
7884            ):
7885                call_function_node = node
7886            elif node.op == "call_function" and node.target == operator.getitem:
7887                self.assertIs(node.args[0], call_function_node)
7888                num_getitems += 1
7889
7890        self.assertIsNotNone(call_function_node)
7891        self.assertEqual(num_getitems, 2)
7892
7893        output_node = list(graph_module.graph.nodes)[-1]
7894
7895        nodes = []
7896        with graph_module.graph.inserting_before(output_node):
7897            nodes.append(
7898                graph_module.graph.call_function(
7899                    operator.getitem, (call_function_node, 1)
7900                )
7901            )
7902            nodes.append(
7903                graph_module.graph.call_function(
7904                    operator.getitem, (call_function_node, 0)
7905                )
7906            )
7907            nodes.append(
7908                graph_module.graph.call_function(
7909                    operator.getitem, (call_function_node, 0)
7910                )
7911            )
7912            nodes.append(
7913                graph_module.graph.call_function(
7914                    operator.getitem, (call_function_node, 1)
7915                )
7916            )
7917        signature = ExportGraphSignature(
7918            input_specs=ep.graph_signature.input_specs,
7919            output_specs=ep.graph_signature.output_specs
7920            + [
7921                OutputSpec(
7922                    kind=OutputKind.USER_OUTPUT,
7923                    arg=TensorArgument(name=node.name),
7924                    target=None,
7925                )
7926                for node in nodes
7927            ],
7928        )
7929        output_node.args = (output_node.args[0] + tuple(nodes),)
7930        graph_module.recompile()
7931        new_ep = ep._update(graph_module, signature)
7932
7933        new_num_getitems = 0
7934        for node in new_ep.graph.nodes:
7935            if (
7936                node.op == "call_function"
7937                and node.target == torch.ops.aten.topk.default
7938            ):
7939                call_function_node = node
7940            elif node.op == "call_function" and node.target == operator.getitem:
7941                self.assertIs(node.args[0], call_function_node)
7942                new_num_getitems += 1
7943        self.assertEqual(num_getitems, new_num_getitems)
7944        self.assertEqual(
7945            len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
7946        )
7947
7948
7949@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
7950class TestExportCustomClass(TorchTestCase):
7951    def setUp(self):
7952        if IS_FBCODE:
7953            lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations"
7954        elif IS_SANDCASTLE or IS_MACOS:
7955            raise unittest.SkipTest("non-portable load_library call used in test")
7956        elif IS_WINDOWS:
7957            lib_file_path = find_library_location("torchbind_test.dll")
7958        else:
7959            lib_file_path = find_library_location("libtorchbind_test.so")
7960        torch.ops.load_library(str(lib_file_path))
7961
7962    def test_lift_custom_obj(self):
7963        # TODO: fix this test once custom class tracing is implemented
7964
7965        custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4])
7966
7967        class Foo(torch.nn.Module):
7968            def forward(self, x):
7969                return x + x
7970
7971        f = Foo()
7972
7973        inputs = (torch.zeros(4, 4),)
7974        ep = export(f, inputs)
7975
7976        # Replace one of the values with an instance of our custom class
7977        for node in ep.graph.nodes:
7978            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
7979                with ep.graph.inserting_before(node):
7980                    setattr(ep.graph_module, "custom_obj", custom_obj)
7981                    getattr_node = ep.graph.get_attr("custom_obj")
7982                    # Copy over an nn_module_stack as they are required.
7983                    getattr_node.meta["nn_module_stack"] = node.meta["nn_module_stack"]
7984                    custom_node = ep.graph.call_function(
7985                        torch.ops._TorchScriptTesting.take_an_instance.default,
7986                        (getattr_node,),
7987                    )
7988                    custom_node.meta["val"] = torch.ones(4, 4)
7989                    # Copy over an nn_module_stack as they are required.
7990                    custom_node.meta["nn_module_stack"] = node.meta["nn_module_stack"]
7991                    custom_node.meta["torch_fn"] = (
7992                        "custom_op",
7993                        "torch.ops._TorchScriptTesting.take_an_instance.default",
7994                    )
7995                    arg0, _ = node.args
7996                    node.args = (arg0, custom_node)
7997
7998        from torch._export.passes.lift_constants_pass import lift_constants_pass
7999        from torch._export.serde.serialize import deserialize, serialize
8000
8001        constants = lift_constants_pass(ep.graph_module, ep.graph_signature, {})
8002        for k, v in constants.items():
8003            assert k not in ep.constants
8004            ep._constants[k] = v
8005        serialized_vals = serialize(ep)
8006        deserialized_ep = deserialize(serialized_vals)
8007
8008        for node in deserialized_ep.graph.nodes:
8009            if (
8010                node.op == "call_function"
8011                and node.target
8012                == torch.ops._TorchScriptTesting.take_an_instance.default
8013            ):
8014                arg = node.args[0]
8015                self.assertTrue(arg.op == "placeholder")
8016
8017    def test_preserve_non_cia_op(self):
8018        class M(torch.nn.Module):
8019            def forward(self, x):
8020                return torch.nn.functional.elu(x)
8021
8022        ep = export(M(), (torch.randn(2, 3, 4, 5),))
8023        FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run(
8024            ep.graph_module.code
8025        )
8026
8027        ep = ep.run_decompositions(
8028            decomp_table=get_decompositions([torch.ops.aten.elu.default]),
8029            _preserve_ops=[torch.ops.aten.elu.default],
8030        )
8031        FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run(
8032            ep.graph_module.code
8033        )
8034
8035    def test_preserve_cia_op(self):
8036        class StaticResizeBilinear2dModule(torch.nn.Module):
8037            def forward(self, x):
8038                a = torch.nn.functional.interpolate(
8039                    x,
8040                    size=(x.shape[2] * 2, x.shape[3] * 3),
8041                    mode="bilinear",
8042                    align_corners=False,
8043                    antialias=False,
8044                )
8045                return a
8046
8047        ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),))
8048        FileCheck().check_count(
8049            "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
8050        ).run(ep.graph_module.code)
8051
8052        decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec])
8053        ep = ep.run_decompositions(
8054            decomp_table=decomp_table,
8055            _preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec],
8056        )
8057        assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table
8058        FileCheck().check_count(
8059            "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
8060        ).run(ep.graph_module.code)
8061
8062
8063if __name__ == "__main__":
8064    run_tests()
8065