xref: /aosp_15_r20/external/pytorch/test/export/test_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
3with test_functionalization_with_native_python_assertion)
4"""
5
6# Owner(s): ["oncall: export"]
7import math
8import operator
9import unittest
10from re import escape
11from typing import List, Set
12
13import torch
14from functorch.experimental.control_flow import cond
15from torch._dynamo.eval_frame import is_dynamo_supported
16from torch._export.non_strict_utils import (
17    _fakify_script_objects,
18    _gather_constant_attrs,
19)
20from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
21from torch._export.passes.replace_set_grad_with_hop_pass import (
22    _is_set_grad_enabled_node,
23    _is_set_grad_enabled_sub_mod,
24)
25from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
26    get_view_copy_of_view_op,
27    is_view_op,
28    ReplaceViewOpsWithViewCopyOpsPass,
29)
30from torch._export.utils import (
31    node_inline_,
32    nodes_count,
33    nodes_filter,
34    nodes_map,
35    sequential_split,
36)
37from torch._higher_order_ops.auto_functionalize import auto_functionalized
38from torch._subclasses.fake_tensor import FakeTensorMode
39from torch.export import export
40from torch.export._remove_auto_functionalized_pass import (
41    unsafe_remove_auto_functionalized_pass,
42)
43from torch.export._remove_effect_tokens_pass import _remove_effect_tokens
44from torch.export.passes import move_to_device_pass
45from torch.fx.experimental.symbolic_shapes import ShapeEnv
46from torch.fx.passes.infra.partitioner import Partition
47from torch.fx.passes.operator_support import OperatorSupport
48from torch.library import _scoped_library, impl
49from torch.testing._internal.common_cuda import TEST_CUDA
50from torch.testing._internal.common_utils import (
51    IS_WINDOWS,
52    run_tests,
53    skipIfTorchDynamo,
54    TestCase,
55)
56from torch.testing._internal.torchbind_impls import init_torchbind_implementations
57from torch.utils import _pytree as pytree
58
59
60def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
61    count = 0
62    for node in graph.nodes:
63        if node.op == "call_function" and node.target == target:
64            count += 1
65    return count
66
67
68class _AddOperatorSupport(OperatorSupport):
69    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
70        return node.op == "call_function" and node.target in {operator.add}
71
72
73class _AtenAddOperatorSupport(OperatorSupport):
74    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
75        return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor}
76
77
78def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]:
79    return [{n.name for n in p.nodes} for p in partitions]
80
81
82def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
83    output_node = next(n for n in gm.graph.nodes if n.op == "output")
84    args = pytree.tree_leaves(output_node.args)
85    # if isinstance(args, tuple) and len(args) == 1:
86    #     args = args[0]
87    return [str(arg) for arg in args]
88
89
90class ModelsWithScriptObjectAttr:
91    class Simple(torch.nn.Module):
92        def __init__(self) -> None:
93            super().__init__()
94            self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
95
96    class SimpleWithAttrInContainer(torch.nn.Module):
97        def __init__(self) -> None:
98            super().__init__()
99            self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
100            self.pytree_attr2 = [
101                torch.classes._TorchScriptTesting._Foo(1, 2),
102                {
103                    torch.classes._TorchScriptTesting._Foo(3, 4),
104                },
105                {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
106            ]
107
108    class NestedWithAttrInContainer(torch.nn.Module):
109        def __init__(self) -> None:
110            super().__init__()
111            self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
112            self.pytree_attr2 = [
113                torch.classes._TorchScriptTesting._Foo(1, 2),
114                {
115                    torch.classes._TorchScriptTesting._Foo(3, 4),
116                },
117                {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
118            ]
119            self.sub_mod = ModelsWithScriptObjectAttr.Simple()
120            self.sub_mod2 = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer()
121
122    class MoreNestedWithAttrInContainer(torch.nn.Module):
123        def __init__(self) -> None:
124            super().__init__()
125            self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
126            self.pytree_attr2 = [
127                torch.classes._TorchScriptTesting._Foo(1, 2),
128                {
129                    torch.classes._TorchScriptTesting._Foo(3, 4),
130                },
131                {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
132            ]
133            self.sub_mod = ModelsWithScriptObjectAttr.Simple()
134            self.sub_mod2 = ModelsWithScriptObjectAttr.NestedWithAttrInContainer()
135
136
137def _set_grad_enabled_tests():
138    from torch.export._trace import _export
139
140    class SetGradOp(torch.nn.Module):
141        def forward(self, x):
142            x = x + 1
143            torch._C._set_grad_enabled(True)
144            c = x.sin().sum()
145            torch._C._set_grad_enabled(False)
146            d = c + 1
147            torch._C._set_grad_enabled(True)
148            e = d - 1
149            return d, e
150
151    class SetGradCtxManager(torch.nn.Module):
152        def forward(self, x):
153            x = x + 1
154            with torch.enable_grad():
155                c = x.sin().sum()
156            with torch.no_grad():
157                d = c + 1
158            with torch.enable_grad():
159                e = d - 1
160            return d, e
161
162    class SetGradCtxManagerMultiDep(torch.nn.Module):
163        def forward(self, x):
164            x = x + 1
165            with torch.enable_grad():
166                c1 = x.sin().sum()
167                c2 = x.cos().sum()
168            with torch.no_grad():
169                d1 = c1 + 1
170                d2 = c2 + 1
171            with torch.enable_grad():
172                e1 = d1 - 1
173                e2 = d2 - 1
174            return d1, d2, e1, e2
175
176    x = torch.randn(2, 2)
177
178    def _get_predispatch_module(mod, args, ambient_grad_enabled=True):
179        with torch.set_grad_enabled(ambient_grad_enabled):
180            return _export(mod, args, pre_dispatch=True).module()
181
182    return {
183        "ctx_manager": (
184            SetGradCtxManager(),
185            _get_predispatch_module(SetGradCtxManager(), (x,)),
186            (x,),
187        ),
188        "ctx_manager_under_no_grad": (
189            SetGradCtxManager(),
190            _get_predispatch_module(SetGradCtxManager(), (x,), False),
191            (x,),
192        ),
193        "ctx_manager_multi_dep": (
194            SetGradCtxManagerMultiDep(),
195            _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)),
196            (x,),
197        ),
198        "ctx_manager_multi_dep_no_grad": (
199            SetGradCtxManagerMultiDep(),
200            _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False),
201            (x,),
202        ),
203        "op": (SetGradOp(), _get_predispatch_module(SetGradOp(), (x,)), (x,)),
204        "op_under_no_grad": (
205            SetGradOp(),
206            _get_predispatch_module(SetGradOp(), (x,), False),
207            (x,),
208        ),
209    }
210
211
212def _with_autocast_tests():
213    from torch.export._trace import _export
214
215    class WithAutocastOp(torch.nn.Module):
216        def forward(self, x):
217            x = x + 1
218            with torch.autocast(device_type="cpu", enabled=True):
219                c = x.sin().sum()
220            with torch.autocast(device_type="cpu", enabled=False):
221                d = c + 1
222            with torch.autocast(device_type="cpu", enabled=True):
223                e = d - 1
224            return d, e
225
226    class WithAutocastOpMultiDep(torch.nn.Module):
227        def forward(self, x):
228            x = x + 1
229            with torch.autocast(device_type="cpu", enabled=True):
230                c1 = x.sin().sum()
231                c2 = x.cos().sum()
232            with torch.autocast(device_type="cpu", enabled=False):
233                d1 = c1 + 1
234                d2 = c2 + 1
235            with torch.autocast(device_type="cpu", enabled=True):
236                e1 = d1 - 1
237                e2 = d2 - 1
238            return d1, d2, e1, e2
239
240    class SplitAutocastOp(torch.nn.Module):
241        def forward(self, x):
242            x = x + 1
243            with torch.autocast(device_type="cpu", enabled=True):
244                c = x.sin().sum()
245            d = c + 1
246            with torch.autocast(device_type="cpu", enabled=True):
247                e = d - 1
248            return d, e
249
250    x = torch.randn(2, 2)
251
252    def _get_predispatch_module(mod, args):
253        return _export(mod, args, pre_dispatch=True).module()
254
255    return {
256        "ctx_manager": (
257            WithAutocastOp(),
258            _get_predispatch_module(WithAutocastOp(), (x,)),
259            (x,),
260        ),
261        "ctx_manager_multi_dep": (
262            WithAutocastOpMultiDep(),
263            _get_predispatch_module(WithAutocastOpMultiDep(), (x,)),
264            (x,),
265        ),
266        "ctx_manager_split": (
267            SplitAutocastOp(),
268            _get_predispatch_module(SplitAutocastOp(), (x,)),
269            (x,),
270        ),
271    }
272
273
274def _sequential_split_inline_tests():
275    from torch.export._trace import _export
276
277    class Simple(torch.nn.Module):
278        def forward(self, x):
279            x = x + 1
280            c = x.sin().sum()
281            d = c + 1
282            e = d - 1
283            return d, e
284
285    class MultiDep(torch.nn.Module):
286        def forward(self, x1, x2):
287            x1 = x1 + 1
288            x2 = x2 + 1
289            c1 = x1.sin()
290            c2 = x2.cos()
291            d1 = c1 + 1
292            d2 = c2 + 1
293            e1 = d1 - 1
294            e2 = d2 - 1
295            return d1, d2, e1, e2
296
297    def _get_predispatch_module(mod, args):
298        return _export(mod, args, pre_dispatch=True).module()
299
300    def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1):
301        insert_locs = []
302        for i, node in enumerate(
303            nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function")
304        ):
305            if i % step == 0:
306                insert_locs.append(node)
307
308        for i, node in enumerate(insert_locs):
309            with gm.graph.inserting_before(node):
310                gm.graph.call_function(
311                    torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
312                )
313        return gm
314
315    x = torch.randn(2, 2)
316    simple = _get_predispatch_module(Simple(), (x,))
317    simple1 = _get_predispatch_module(Simple(), (x,))
318    multi_dep = _get_predispatch_module(MultiDep(), (x, x.sin()))
319    multi_dep1 = _get_predispatch_module(MultiDep(), (x, x.sin()))
320    return {
321        "simple_step1": (_insert_dilimiter_nodes(simple1, 1), (x,)),
322        "simple_step2": (_insert_dilimiter_nodes(simple, 2), (x,)),
323        "multi_dep_step2": (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())),
324        "multi_dep_step3": (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())),
325    }
326
327
328@skipIfTorchDynamo("recursively running dynamo on export is unlikely")
329@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
330class TestPasses(TestCase):
331    def setUp(self):
332        super().setUp()
333        self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests()
334        self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests()
335        self.WITH_AUTOCAST_TESTS = _with_autocast_tests()
336
337        init_torchbind_implementations()
338
339    def tearDown(self):
340        self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear()
341        self.SET_GRAD_ENABLED_TESTS.clear()
342        self.WITH_AUTOCAST_TESTS.clear()
343        super().tearDown()
344
345    def test_runtime_assert_one_dim(self) -> None:
346        class M(torch.nn.Module):
347            def __init__(self) -> None:
348                super().__init__()
349
350            def forward(self, x):
351                return x.cos()
352
353        x = torch.zeros(2, 2, 3)
354
355        dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
356        ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
357
358        with self.assertRaisesRegex(
359            RuntimeError,
360            escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
361        ):
362            ep.module()(torch.zeros(2, 7, 3))
363
364        self.assertEqual(
365            ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))
366        )
367
368    def test_runtime_assert_multiple_dims(self) -> None:
369        class M(torch.nn.Module):
370            def __init__(self) -> None:
371                super().__init__()
372
373            def forward(self, x, y):
374                return x.cos().sum() + y.sin().sum()
375
376        x = torch.zeros(4, 2, 3)
377        y = torch.zeros(5, 5, 5)
378
379        dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
380        dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3)
381
382        ep = torch.export.export(
383            M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
384        )
385
386        with self.assertRaisesRegex(
387            RuntimeError,
388            escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
389        ):
390            ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
391
392        with self.assertRaisesRegex(
393            RuntimeError,
394            escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"),
395        ):
396            ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
397
398    def test_runtime_assert_some_dims_not_specified(self) -> None:
399        class M(torch.nn.Module):
400            def __init__(self) -> None:
401                super().__init__()
402
403            def forward(self, x, y):
404                return x.cos().sum() + y.sin().sum()
405
406        x = torch.zeros(4, 2, 3)
407        y = torch.zeros(5, 5, 5)
408
409        dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
410        dim0_x = torch.export.Dim("dim0_x", min=3)
411
412        ep = torch.export.export(
413            M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
414        )
415
416        with self.assertRaisesRegex(
417            RuntimeError,
418            escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
419        ):
420            ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
421
422        # y is specialized to 5
423        with self.assertRaisesRegex(
424            RuntimeError,
425            escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
426        ):
427            ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
428
429        # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
430        gm_result_for_1_size = ep.module()(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
431        eager_result_for_1_size = M().forward(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
432
433        self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
434
435    def test_runtime_assert_some_inps_not_used(self) -> None:
436        class M(torch.nn.Module):
437            def __init__(self) -> None:
438                super().__init__()
439
440            def forward(self, x, y):
441                return y.cos().sum()
442
443        x = torch.zeros(4, 2, 3)
444        y = torch.zeros(5, 5, 5)
445
446        dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
447        ep = torch.export.export(
448            M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}
449        )
450
451        with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
452            ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
453
454        # y is specialized to 5
455        with self.assertRaisesRegex(
456            RuntimeError,
457            escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
458        ):
459            ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
460
461        # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
462        gm_result_for_1_size = ep.module()(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
463        eager_result_for_1_size = M().forward(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
464
465        self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
466
467    def test_view_to_view_copy(self) -> None:
468        class M(torch.nn.Module):
469            def __init__(self) -> None:
470                super().__init__()
471
472            def forward(self, x):
473                z = x.view(x.shape)
474                return z.cos().sum()
475
476        x = torch.zeros(4, 2, 3)
477
478        ep = export(M(), (x,))
479        self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)
480
481        ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass())
482        self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0)
483
484    def test_functionalization_with_view_copy(self) -> None:
485        class Module(torch.nn.Module):
486            def forward(self, x):
487                y = x + 4
488                y.add_(4)
489                z = y.view(y.shape)
490                return x.cos() + z.cos()
491
492        x = torch.zeros(4, 2, 3)
493        foo = Module()
494        ep = export(foo, (x,))._transform_do_not_use(
495            ReplaceViewOpsWithViewCopyOpsPass()
496        )
497        # After this pass, there shouldn't be any view nodes in the graph
498        self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0)
499        self.assertTrue(
500            count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0
501        )
502
503    def test_views_op_having_view_copy(self) -> None:
504        schemas = torch._C._dispatch_get_registrations_for_dispatch_key("")
505        aten_schemas = [s[6:] for s in schemas if s.startswith("aten::")]
506
507        for aten_schema in aten_schemas:
508            val = aten_schema.split(".")
509            assert len(val) <= 2
510            name = ""
511            overload = ""
512            if len(val) == 1:
513                name = val[0]
514                overload = "default"
515            else:
516                name, overload = val[0], val[1]
517
518            op_overload = getattr(getattr(torch.ops.aten, name), overload)
519            if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
520                self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
521
522    def test_custom_obj_tuple_out(self):
523        class MyModule(torch.nn.Module):
524            def __init__(self) -> None:
525                super().__init__()
526                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
527
528            def forward(self, x):
529                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
530                y = a[0] + a[1]
531                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
532                return b
533
534        m = MyModule()
535        inputs = (torch.ones(2, 3),)
536        ep = torch.export.export(m, inputs, strict=False)
537
538        inp = torch.randn(2, 3)
539        orig_res = m(inp)
540        ep_res = ep.module()(inp)
541
542        without_token_ep = _remove_effect_tokens(ep)
543        without_token_ep.verifier().check(without_token_ep)
544        without_token_res = without_token_ep.module()(inp)
545
546        self.assertTrue(torch.allclose(orig_res, ep_res))
547        self.assertTrue(torch.allclose(orig_res, without_token_res))
548
549    def test_remove_effect_token_kwargs(self):
550        class MyModule(torch.nn.Module):
551            def __init__(self) -> None:
552                super().__init__()
553                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
554
555            def forward(self, x):
556                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(
557                    foo=self.attr, x=x
558                )
559                y = a[0] + a[1]
560                b = torch.ops._TorchScriptTesting.takes_foo(foo=self.attr, x=y)
561                return b
562
563        m = MyModule()
564        inputs = (torch.ones(2, 3),)
565        ep = torch.export.export(m, inputs, strict=False)
566        without_token_ep = _remove_effect_tokens(ep)
567        self.assertExpectedInline(
568            without_token_ep.graph_module.code.strip(),
569            """\
570def forward(self, token, obj_attr, x):
571    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x);  token = x = None
572    getitem = with_effects[0]
573    getitem_1 = with_effects[1]
574    getitem_2 = with_effects[2];  with_effects = None
575    add = torch.ops.aten.add.Tensor(getitem_1, getitem_2);  getitem_1 = getitem_2 = None
576    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add);  getitem = obj_attr = add = None
577    getitem_3 = with_effects_1[0]
578    getitem_4 = with_effects_1[1];  with_effects_1 = None
579    return (getitem_3, getitem_4)""",  # noqa: B950
580        )
581
582    def test_fakify_script_objects(self):
583        for m in [
584            ModelsWithScriptObjectAttr.Simple(),
585            ModelsWithScriptObjectAttr.SimpleWithAttrInContainer(),
586            ModelsWithScriptObjectAttr.NestedWithAttrInContainer(),
587            ModelsWithScriptObjectAttr.MoreNestedWithAttrInContainer(),
588        ]:
589            constant_attrs = _gather_constant_attrs(m)
590            fake_mode = FakeTensorMode(
591                shape_env=ShapeEnv(tracked_fakes=[]),
592                allow_non_fake_inputs=True,
593            )
594            with _fakify_script_objects(m, (), {}, fake_mode) as (
595                patched_mod,
596                _,
597                _,
598                fake_constant_attrs,
599                fake_to_real,
600            ):
601                self.assertEqual(len(fake_constant_attrs), len(constant_attrs))
602                for fake_obj, fqn in fake_constant_attrs.items():
603                    self.assertEqual(constant_attrs[fake_to_real[fake_obj]], fqn)
604
605    # TODO: _gather_constants doesn't recursively look into the pytree containers.
606    @unittest.expectedFailure
607    def test_fakify_script_objects_properly_handle_containers(self):
608        m = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer()
609        constant_attrs = _gather_constant_attrs(m)
610        fake_mode = FakeTensorMode(
611            shape_env=ShapeEnv(tracked_fakes=[]),
612            allow_non_fake_inputs=True,
613        )
614        with _fakify_script_objects(m, (), {}, fake_mode) as (
615            patched_mod,
616            _,
617            _,
618            fake_constant_attrs,
619            fake_to_real,
620        ):
621            self.assertTrue("attr" in fake_constant_attrs.values())
622            self.assertTrue("pytree_attr2" in fake_constant_attrs.values())
623
624    def test_runtime_assert_inline_constraints_for_item(self) -> None:
625        class M(torch.nn.Module):
626            def __init__(self) -> None:
627                super().__init__()
628
629            def forward(self, x):
630                b = x.item()
631                torch._check(b >= 2)
632                torch._check(b <= 5)
633                return b
634
635        x = torch.tensor([2])
636        mod = M()
637        ep = export(mod, (x,))
638
639        with self.assertRaisesRegex(
640            RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5"
641        ):
642            ep.module()(torch.tensor([6]))
643
644        new_inp = torch.tensor([5])
645        self.assertEqual(mod(new_inp), ep.module()(new_inp))
646
647    def test_runtime_assert_inline_constraints_for_nonzero(self) -> None:
648        class M(torch.nn.Module):
649            def __init__(self) -> None:
650                super().__init__()
651
652            def forward(self, x):
653                b = x.nonzero()
654                torch._check(b.shape[0] >= 3)
655                torch._check(b.shape[0] <= 5)
656                return b
657
658        x = torch.tensor([2, 1, 2, 3, 5, 0])
659
660        mod = M()
661        dim0_x = torch.export.Dim("dim0_x")
662        ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}})
663
664        num_assert = count_call_function(
665            ep.graph, torch.ops.aten._assert_scalar.default
666        )
667        self.assertEqual(num_assert, 2)
668        num_constrain_range = count_call_function(
669            ep.graph, torch.ops.aten.sym_constrain_range.default
670        )
671        self.assertEqual(num_constrain_range, 0)
672
673        with self.assertRaisesRegex(
674            RuntimeError,
675            r"Runtime assertion failed for expression u[\d+] \>\= 3",
676        ):
677            ep.module()(torch.tensor([1, 1, 0, 0, 0]))
678
679        with self.assertRaisesRegex(
680            RuntimeError,
681            r"Runtime assertion failed for expression u[\d+] \<\= 5",
682        ):
683            ep.module()(torch.ones(6))
684
685        new_inp = torch.tensor([1, 1, 1, 1])
686        self.assertEqual(mod(new_inp), ep.module()(new_inp))
687
688    @unittest.skipIf(IS_WINDOWS, "Windows not supported")
689    @unittest.expectedFailure
690    # TODO(pianpwk): add back runtime asserts to subgraphs
691    def test_runtime_assert_inline_constraints_for_cond(self) -> None:
692        class M(torch.nn.Module):
693            def __init__(self) -> None:
694                super().__init__()
695
696            def forward(self, pred, x, y):
697                def true_fn(x, y):
698                    b = x.item()
699                    torch._check(b >= 2)
700                    torch._check(b <= 5)
701                    return x - b
702
703                def false_fn(x, y):
704                    c = y.item()
705                    torch._check(c >= 2)
706                    torch._check(c <= 5)
707                    return y - c
708
709                ret = cond(pred, true_fn, false_fn, [x, y])
710                return ret
711
712        x = torch.tensor([2])
713        y = torch.tensor([5])
714        mod = M()
715        ep = export(mod, (torch.tensor(True), x, y))
716
717        with self.assertRaisesRegex(
718            RuntimeError, "is outside of inline constraint \\[2, 5\\]."
719        ):
720            ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
721
722    def test_math_ops(self):
723        class Module(torch.nn.Module):
724            def forward(self, x):
725                return (
726                    torch.tensor([math.ceil(x.item())]),
727                    torch.tensor([math.floor(x.item())]),
728                )
729
730        func = Module()
731        x = torch.randn(1, dtype=torch.float32)
732        ep = torch.export.export(func, args=(x,))
733        _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
734
735    def test_predispatch_set_grad(self):
736        def _check_node_users_in_the_same_graph(gm):
737            for node in gm.graph.nodes:
738                for user in node.users:
739                    self.assertTrue(user.graph is gm.graph)
740
741        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op"]
742        _check_node_users_in_the_same_graph(mod)
743        self.assertEqual(mod_orig(*args), mod(*args))
744        self.assertExpectedInline(
745            mod.code.strip("\n"),
746            """\
747def forward(self, x):
748    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
749    add = torch.ops.aten.add.Tensor(x, 1);  x = None
750    sin = torch.ops.aten.sin.default(add);  add = None
751    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
752    submod_4 = self.submod_2
753    add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1);  submod_4 = sum_1 = None
754    getitem = add_1[0];  add_1 = None
755    sub = torch.ops.aten.sub.Tensor(getitem, 1)
756    return pytree.tree_unflatten((getitem, sub), self._out_spec)
757    """,
758        )
759
760        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op_under_no_grad"]
761        _check_node_users_in_the_same_graph(mod)
762        self.assertEqual(mod_orig(*args), mod(*args))
763        self.assertExpectedInline(
764            mod.code.strip("\n"),
765            """\
766def forward(self, x):
767    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
768    add = torch.ops.aten.add.Tensor(x, 1);  x = None
769    sin = torch.ops.aten.sin.default(add);  add = None
770    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
771    submod_4 = self.submod_2
772    add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1);  submod_4 = sum_1 = None
773    getitem = add_1[0];  add_1 = None
774    sub = torch.ops.aten.sub.Tensor(getitem, 1)
775    return pytree.tree_unflatten((getitem, sub), self._out_spec)
776    """,
777        )
778
779        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager"]
780        _check_node_users_in_the_same_graph(mod)
781        self.assertEqual(mod_orig(*args), mod(*args))
782        self.assertExpectedInline(
783            mod.code.strip("\n"),
784            """\
785def forward(self, x):
786    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
787    add = torch.ops.aten.add.Tensor(x, 1);  x = None
788    sin = torch.ops.aten.sin.default(add);  add = None
789    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
790    submod_3 = self.submod_1
791    add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1);  submod_3 = sum_1 = None
792    getitem = add_1[0];  add_1 = None
793    sub = torch.ops.aten.sub.Tensor(getitem, 1)
794    return pytree.tree_unflatten((getitem, sub), self._out_spec)
795    """,
796        )
797
798        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_under_no_grad"]
799        _check_node_users_in_the_same_graph(mod)
800        self.assertEqual(mod_orig(*args), mod(*args))
801        self.assertExpectedInline(
802            mod.code.strip("\n"),
803            """\
804def forward(self, x):
805    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
806    add = torch.ops.aten.add.Tensor(x, 1);  x = None
807    submod_5 = self.submod_1
808    sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add);  submod_5 = add = None
809    getitem = sum_1[0];  sum_1 = None
810    add_1 = torch.ops.aten.add.Tensor(getitem, 1);  getitem = None
811    submod_6 = self.submod_3
812    sub = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1);  submod_6 = None
813    getitem_1 = sub[0];  sub = None
814    return pytree.tree_unflatten((add_1, getitem_1), self._out_spec)
815    """,
816        )
817
818        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep"]
819        _check_node_users_in_the_same_graph(mod)
820        self.assertEqual(mod_orig(*args), mod(*args))
821        self.assertExpectedInline(
822            mod.code.strip("\n"),
823            """\
824def forward(self, x):
825    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
826    add = torch.ops.aten.add.Tensor(x, 1);  x = None
827    sin = torch.ops.aten.sin.default(add)
828    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
829    cos = torch.ops.aten.cos.default(add);  add = None
830    sum_2 = torch.ops.aten.sum.default(cos);  cos = None
831    submod_3 = self.submod_1
832    wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1, sum_2);  submod_3 = sum_1 = sum_2 = None
833    add_1 = wrap_with_set_grad_enabled[0]
834    add_2 = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None
835    sub = torch.ops.aten.sub.Tensor(add_1, 1)
836    sub_1 = torch.ops.aten.sub.Tensor(add_2, 1)
837    return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec)
838    """,  # noqa: B950
839        )
840
841        mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS[
842            "ctx_manager_multi_dep_no_grad"
843        ]
844        _check_node_users_in_the_same_graph(mod)
845        self.assertEqual(mod_orig(*args), mod(*args))
846        self.assertExpectedInline(
847            mod.code.strip("\n"),
848            """\
849def forward(self, x):
850    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
851    add = torch.ops.aten.add.Tensor(x, 1);  x = None
852    submod_5 = self.submod_1
853    wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add);  submod_5 = add = None
854    sum_1 = wrap_with_set_grad_enabled[0]
855    sum_2 = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None
856    add_1 = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
857    add_2 = torch.ops.aten.add.Tensor(sum_2, 1);  sum_2 = None
858    submod_6 = self.submod_3
859    wrap_with_set_grad_enabled_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1, add_2);  submod_6 = None
860    sub = wrap_with_set_grad_enabled_1[0]
861    sub_1 = wrap_with_set_grad_enabled_1[1];  wrap_with_set_grad_enabled_1 = None
862    return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec)
863    """,  # noqa: B950
864        )
865
866    def test_sequential_split(self):
867        for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values():
868            set_grad_counts = nodes_count(gm.graph.nodes, _is_set_grad_enabled_node)
869            new_gm = sequential_split(gm, _is_set_grad_enabled_node)
870            new_set_grad_counts = nodes_count(
871                new_gm.graph.nodes, _is_set_grad_enabled_sub_mod
872            )
873            self.assertEqual(set_grad_counts, new_set_grad_counts)
874            self.assertEqual(gm(*args), new_gm(*args))
875
876    def test_sequential_split_graph(self):
877        gm, args = self.SEQUENTIAL_SPLIT_INLINE_TESTS["multi_dep_step2"]
878
879        new_gm = sequential_split(gm, _is_set_grad_enabled_node)
880        self.assertEqual(gm(*args), new_gm(*args))
881        self.assertExpectedInline(
882            new_gm.code.strip("\n"),
883            """\
884def forward(self, x1, x2):
885    x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
886    submod_1 = self.submod_1(x1, x2);  x1 = x2 = None
887    getitem = submod_1[0]
888    getitem_1 = submod_1[1];  submod_1 = None
889    submod_2 = self.submod_2(getitem, getitem_1);  getitem = getitem_1 = None
890    getitem_2 = submod_2[0]
891    getitem_3 = submod_2[1];  submod_2 = None
892    submod_3 = self.submod_3(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
893    getitem_4 = submod_3[0]
894    getitem_5 = submod_3[1];  submod_3 = None
895    submod_4 = self.submod_4(getitem_4, getitem_5)
896    getitem_6 = submod_4[0]
897    getitem_7 = submod_4[1];  submod_4 = None
898    return pytree.tree_unflatten((getitem_4, getitem_5, getitem_6, getitem_7), self._out_spec)
899    """,
900        )
901        self.assertExpectedInline(
902            new_gm.submod_1.code.strip("\n"),
903            """\
904def forward(self, x1, x2):
905    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
906    add = torch.ops.aten.add.Tensor(x1, 1);  x1 = None
907    add_1 = torch.ops.aten.add.Tensor(x2, 1);  x2 = None
908    return (add, add_1)
909    """,
910        )
911        self.assertExpectedInline(
912            new_gm.submod_2.code.strip("\n"),
913            """\
914def forward(self, add, add_1):
915    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
916    sin = torch.ops.aten.sin.default(add);  add = None
917    cos = torch.ops.aten.cos.default(add_1);  add_1 = None
918    return (sin, cos)
919    """,
920        )
921        self.assertExpectedInline(
922            new_gm.submod_3.code.strip("\n"),
923            """\
924def forward(self, sin, cos):
925    _set_grad_enabled_2 = torch._C._set_grad_enabled(True);  _set_grad_enabled_2 = None
926    add_2 = torch.ops.aten.add.Tensor(sin, 1);  sin = None
927    add_3 = torch.ops.aten.add.Tensor(cos, 1);  cos = None
928    return (add_2, add_3)
929    """,
930        )
931
932    def test_predispatch_autocast(self):
933        def _check_node_users_in_the_same_graph(gm):
934            for node in gm.graph.nodes:
935                for user in node.users:
936                    self.assertTrue(user.graph is gm.graph)
937
938        mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager"]
939        _check_node_users_in_the_same_graph(mod)
940        self.assertEqual(mod_orig(*args), mod(*args))
941        self.assertExpectedInline(
942            mod.code.strip("\n"),
943            """\
944def forward(self, x):
945    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
946    add = torch.ops.aten.add.Tensor(x, 1);  x = None
947    submod_4 = self.submod_1
948    sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add);  submod_4 = add = None
949    getitem = sum_1[0];  sum_1 = None
950    submod_5 = self.submod_2
951    add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, getitem);  submod_5 = getitem = None
952    getitem_1 = add_1[0];  add_1 = None
953    submod_6 = self.submod_3
954    sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, getitem_1);  submod_6 = None
955    getitem_2 = sub[0];  sub = None
956    return pytree.tree_unflatten((getitem_1, getitem_2), self._out_spec)
957    """,
958        )
959
960        self.assertExpectedInline(
961            mod.submod_1.code.strip("\n"),
962            """\
963def forward(self, add):
964    sin = torch.ops.aten.sin.default(add);  add = None
965    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
966    return (sum_1,)
967    """,
968        )
969
970        self.assertExpectedInline(
971            mod.submod_2.code.strip("\n"),
972            """\
973def forward(self, sum_1):
974    add_1 = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
975    return (add_1,)
976    """,
977        )
978
979        self.assertExpectedInline(
980            mod.submod_3.code.strip("\n"),
981            """\
982def forward(self, add_1):
983    sub = torch.ops.aten.sub.Tensor(add_1, 1);  add_1 = None
984    return (sub,)
985    """,
986        )
987
988        mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_multi_dep"]
989        _check_node_users_in_the_same_graph(mod)
990        self.assertEqual(mod_orig(*args), mod(*args))
991        self.assertExpectedInline(
992            mod.code.strip("\n"),
993            """\
994def forward(self, x):
995    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
996    add = torch.ops.aten.add.Tensor(x, 1);  x = None
997    submod_4 = self.submod_1
998    wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add);  submod_4 = add = None
999    sum_1 = wrap_with_autocast[0]
1000    sum_2 = wrap_with_autocast[1];  wrap_with_autocast = None
1001    submod_5 = self.submod_2
1002    wrap_with_autocast_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, sum_1, sum_2);  submod_5 = sum_1 = sum_2 = None
1003    add_1 = wrap_with_autocast_1[0]
1004    add_2 = wrap_with_autocast_1[1];  wrap_with_autocast_1 = None
1005    submod_6 = self.submod_3
1006    wrap_with_autocast_2 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, add_1, add_2);  submod_6 = None
1007    sub = wrap_with_autocast_2[0]
1008    sub_1 = wrap_with_autocast_2[1];  wrap_with_autocast_2 = None
1009    return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec)
1010    """,  # noqa: B950
1011        )
1012
1013        self.assertExpectedInline(
1014            mod.submod_1.code.strip("\n"),
1015            """\
1016def forward(self, add):
1017    sin = torch.ops.aten.sin.default(add)
1018    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
1019    cos = torch.ops.aten.cos.default(add);  add = None
1020    sum_2 = torch.ops.aten.sum.default(cos);  cos = None
1021    return (sum_1, sum_2)
1022    """,
1023        )
1024
1025        self.assertExpectedInline(
1026            mod.submod_2.code.strip("\n"),
1027            """\
1028def forward(self, sum_1, sum_2):
1029    add_1 = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
1030    add_2 = torch.ops.aten.add.Tensor(sum_2, 1);  sum_2 = None
1031    return (add_1, add_2)
1032    """,
1033        )
1034
1035        self.assertExpectedInline(
1036            mod.submod_3.code.strip("\n"),
1037            """\
1038def forward(self, add_1, add_2):
1039    sub = torch.ops.aten.sub.Tensor(add_1, 1);  add_1 = None
1040    sub_1 = torch.ops.aten.sub.Tensor(add_2, 1);  add_2 = None
1041    return (sub, sub_1)
1042    """,
1043        )
1044
1045        mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_split"]
1046        _check_node_users_in_the_same_graph(mod)
1047        self.assertEqual(mod_orig(*args), mod(*args))
1048        self.assertExpectedInline(
1049            mod.code.strip("\n"),
1050            """\
1051def forward(self, x):
1052    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
1053    add = torch.ops.aten.add.Tensor(x, 1);  x = None
1054    submod_4 = self.submod_1
1055    sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add);  submod_4 = add = None
1056    getitem = sum_1[0];  sum_1 = None
1057    add_1 = torch.ops.aten.add.Tensor(getitem, 1);  getitem = None
1058    submod_5 = self.submod_3
1059    sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_5, add_1);  submod_5 = None
1060    getitem_1 = sub[0];  sub = None
1061    return pytree.tree_unflatten((add_1, getitem_1), self._out_spec)
1062    """,
1063        )
1064
1065        self.assertExpectedInline(
1066            mod.submod_1.code.strip("\n"),
1067            """\
1068def forward(self, add):
1069    sin = torch.ops.aten.sin.default(add);  add = None
1070    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
1071    return (sum_1,)
1072    """,
1073        )
1074
1075        self.assertExpectedInline(
1076            mod.submod_3.code.strip("\n"),
1077            """\
1078def forward(self, add_1):
1079    sub = torch.ops.aten.sub.Tensor(add_1, 1);  add_1 = None
1080    return (sub,)
1081    """,
1082        )
1083
1084    def test_inline_(self):
1085        for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values():
1086            before_str = gm.print_readable(print_output=False)
1087            new_gm = sequential_split(gm, _is_set_grad_enabled_node)
1088            nodes_map(
1089                new_gm.graph.nodes,
1090                lambda node: node_inline_(node) if node.op == "call_module" else node,
1091            )
1092            after_inline_str = new_gm.print_readable(print_output=False)
1093            self.assertEqual(before_str, after_inline_str)
1094            self.assertEqual(gm(*args), new_gm(*args))
1095
1096    def test_remove_auto_functionalized_pass(self) -> None:
1097        with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
1098            lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
1099
1100            @impl(lib, "custom_mutator", "Meta")
1101            def custom_mutator_meta(
1102                x: torch.Tensor,
1103                y: torch.Tensor,
1104            ) -> torch.Tensor:
1105                return torch.empty_like(x)
1106
1107            @impl(lib, "custom_mutator", "CompositeExplicitAutograd")
1108            def custom_mutator(
1109                x: torch.Tensor,
1110                y: torch.Tensor,
1111            ) -> torch.Tensor:
1112                return x + y.add_(1)
1113
1114            class M(torch.nn.Module):
1115                def __init__(self) -> None:
1116                    super().__init__()
1117                    self.state = torch.nn.Buffer(torch.zeros(1))
1118
1119                def forward(self, x):
1120                    return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state)
1121
1122            mod = M()
1123            x = torch.randn([3, 3])
1124            ep = export(mod, (x,))
1125            inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
1126            nodes = inplace_ep.graph.nodes
1127            for node in nodes:
1128                if node.op == "call_function":
1129                    self.assertFalse(node.target is auto_functionalized)
1130                    self.assertFalse(node.target is operator.getitem)
1131
1132            for spec in inplace_ep.graph_signature.output_specs:
1133                self.assertFalse("getitem" in spec.arg.name)
1134
1135    def test_remove_auto_functionalized_pass_tuple(self) -> None:
1136        with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
1137            lib.define(
1138                "custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)"
1139            )
1140
1141            @impl(lib, "custom_mutator_tuple", "Meta")
1142            def custom_mutator_tuple_meta(
1143                x: torch.Tensor,
1144                y: torch.Tensor,
1145            ):
1146                return (torch.empty_like(x), torch.empty_like(x))
1147
1148            @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd")
1149            def custom_mutator_tuple(
1150                x: torch.Tensor,
1151                y: torch.Tensor,
1152            ):
1153                return (x, x + y.add_(1))
1154
1155            class M(torch.nn.Module):
1156                def __init__(self) -> None:
1157                    super().__init__()
1158                    self.state = torch.nn.Buffer(torch.zeros(1))
1159
1160                def forward(self, x):
1161                    return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple(
1162                        x, self.state
1163                    )
1164
1165            mod = M()
1166            x = torch.randn([3, 3])
1167            ep = export(mod, (x,))
1168            inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
1169            graph_text = str(inplace_ep.graph)
1170            self.assertExpectedInline(
1171                graph_text,
1172                """\
1173graph():
1174    %b_state : [num_users=2] = placeholder[target=b_state]
1175    %x : [num_users=1] = placeholder[target=x]
1176    %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\
1177default](args = (%x, %b_state), kwargs = {})
1178    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {})
1179    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {})
1180    return (b_state, getitem_3, getitem_4)""",
1181            )
1182
1183    @unittest.skipIf(not TEST_CUDA, "requires cuda")
1184    def test_move_to_device_pass(self):
1185        class Model(torch.nn.Module):
1186            def __init__(self, size=4, h_dim=10):
1187                super().__init__()
1188                self.rnn = torch.nn.GRU(size, h_dim, batch_first=True)
1189
1190            def forward(self, x):
1191                _, states = self.rnn(x)
1192                return states
1193
1194        # move the exported program from cpu to cuda:0
1195        mod = Model()
1196        example_inputs = (torch.rand(1, 10, 4),)
1197        ep = export(mod, example_inputs)
1198        location = torch.device("cuda:0")
1199        ep = move_to_device_pass(ep, location=location)
1200        gm = ep.module()
1201        test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),)
1202        outputs = gm(*test_inputs)
1203        self.assertEqual(outputs.device, torch.device("cuda:0"))
1204        # move it back to cpu
1205        location = "cpu"
1206        ep = move_to_device_pass(ep, location=location)
1207        gm = ep.module()
1208        test_inputs = (torch.rand(1, 10, 4).to("cpu"),)
1209        outputs = gm(*test_inputs)
1210        self.assertEqual(outputs.device, torch.device("cpu"))
1211        # move it to cuda:0 again
1212        location = {"cpu": "cuda:0"}
1213        ep = move_to_device_pass(ep, location=location)
1214        gm = ep.module()
1215        test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),)
1216        outputs = gm(*test_inputs)
1217        self.assertEqual(outputs.device, torch.device("cuda:0"))
1218
1219
1220if __name__ == "__main__":
1221    run_tests()
1222