xref: /aosp_15_r20/external/pytorch/test/export/test_serialize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
3with test_sym_bool)
4"""
5
6
7# Owner(s): ["oncall: export"]
8import copy
9import io
10import tempfile
11import unittest
12import zipfile
13from pathlib import Path
14
15import torch
16import torch._dynamo as torchdynamo
17import torch.export._trace
18import torch.utils._pytree as pytree
19from torch._export.db.case import ExportCase, SupportLevel
20from torch._export.db.examples import all_examples
21from torch._export.serde.serialize import (
22    canonicalize,
23    deserialize,
24    ExportedProgramDeserializer,
25    ExportedProgramSerializer,
26    serialize,
27    SerializeError,
28)
29from torch._higher_order_ops.torchbind import enable_torchbind_tracing
30from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
31from torch.export import Dim, export, load, save
32from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges
33from torch.testing._internal.common_utils import (
34    instantiate_parametrized_tests,
35    IS_WINDOWS,
36    parametrize,
37    run_tests,
38    TemporaryFileName,
39    TestCase,
40)
41from torch.testing._internal.torchbind_impls import init_torchbind_implementations
42
43
44def get_filtered_export_db_tests():
45    return [
46        (name, case)
47        for name, case in all_examples().items()
48        if case.support_level == SupportLevel.SUPPORTED
49    ]
50
51
52@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
53class TestSerialize(TestCase):
54    def test_export_with_extension_op_serialization(self):
55        class TestModule(torch.nn.Module):
56            def forward(self, x):
57                return x + x
58
59        class FooExtensionOp:
60            def __hash__(self):
61                return 0
62
63            def __eq__(self, other):
64                return type(other) == type(self)
65
66            def __call__(self, *args, **kwargs):
67                return torch.ops.aten.add.Tensor(*args, **kwargs)
68
69            @property
70            def __name__(self):
71                return "foo.my_op"
72
73        class ExtensionVerifier(torch._export.verifier.Verifier):
74            dialect = "FOO"
75
76            def allowed_op_types(self):
77                return super().allowed_op_types() + (FooExtensionOp,)
78
79        class FooExtensionHandler(torch._export.serde.serialize.ExtensionHandler):
80            @classmethod
81            def namespace(cls):
82                return "foo"
83
84            @classmethod
85            def to_op_name(cls, op):
86                return "my_op"
87
88            @classmethod
89            def from_op_name(cls, name: str):
90                self.assertEqual(name, "my_op")
91                return FooExtensionOp()
92
93            @classmethod
94            def op_schema(cls, op):
95                return torch.ops.aten.add.Tensor._schema
96
97        inp = (torch.ones(10),)
98        ep = export(TestModule(), inp)
99
100        # Register the custom op handler.
101        foo_custom_op = FooExtensionOp()
102        torch._export.serde.serialize.register_extension(
103            FooExtensionOp, FooExtensionHandler
104        )
105
106        new_gm = copy.deepcopy(ep.graph_module)
107        # Inject the custom operator.
108        for node in new_gm.graph.nodes:
109            if node.name == "add":
110                node.target = foo_custom_op
111
112        new_ep = ep._update(new_gm, ep.graph_signature, verifiers=[ExtensionVerifier])
113        serialized = serialize(new_ep)
114        deserialized = deserialize(serialized)
115        self.assertEqual(
116            len(
117                deserialized.graph.find_nodes(op="call_function", target=foo_custom_op)
118            ),
119            1,
120        )
121
122    def test_predispatch_export_with_autograd_op(self):
123        class Foo(torch.nn.Module):
124            def __init__(self) -> None:
125                super().__init__()
126
127            def forward(self, x):
128                with torch.enable_grad():
129                    return x + x
130
131        inp = (torch.ones(10),)
132        with torch.no_grad():
133            from torch.export._trace import _export
134
135            ep = _export(Foo(), inp, pre_dispatch=True)
136
137        buffer = io.BytesIO()
138        torch.export.save(ep, buffer)
139        buffer.seek(0)
140        loaded_ep = torch.export.load(buffer)
141
142        exp_out = ep.module()(*inp)
143        actual_out = loaded_ep.module()(*inp)
144        self.assertEqual(exp_out, actual_out)
145        self.assertEqual(exp_out.requires_grad, actual_out.requires_grad)
146
147    def test_export_example_inputs_preserved(self):
148        class MyModule(torch.nn.Module):
149            """A test module with that has multiple args and uses kwargs"""
150
151            def __init__(self) -> None:
152                super().__init__()
153                self.p = torch.nn.Parameter(torch.ones(2, 3))
154
155            def forward(self, x, y, use_p=False):
156                out = x + y
157                if use_p:
158                    out += self.p
159                return out
160
161        model = MyModule().eval()
162        random_inputs = (torch.rand([2, 3]), torch.rand([2, 3]))
163        exp_program = torch.export.export(model, random_inputs, {"use_p": True})
164
165        output_buffer = io.BytesIO()
166        # Tests that example inputs are preserved when saving and loading module.
167        torch.export.save(exp_program, output_buffer)
168        loaded_model = torch.export.load(output_buffer)
169        # Extract the example inputs from before and after saving.
170        orig_args, orig_kwargs = exp_program.example_inputs
171        loaded_args, loaded_kwargs = loaded_model.example_inputs
172        # Run both modules and confirm that outputs match.
173        orig_out = exp_program.module()(*orig_args, **orig_kwargs)
174        loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
175        self.assertEqual(orig_out, loaded_out)
176
177    def test_metadata_parsing_with_layer_split(self):
178        # Tests that modules with more complicated layer patterns can be serialized
179        # and deserialized correctly.
180        class MyModule(torch.nn.Module):
181            def __init__(self) -> None:
182                super().__init__()
183                self.layers = torch.nn.Sequential(
184                    torch.nn.SiLU(),
185                    torch.nn.SiLU(),
186                    torch.nn.SiLU(),
187                )
188
189            def forward(self, x):
190                # Splitting layers of a sequential stack introduces commas and parens
191                # into metadata trace.
192                out_start, out_rest = self.layers[0], self.layers[1:]
193                h = out_start(x)
194                h = out_rest(h)
195                return h
196
197        inp = (torch.ones(10),)
198        # Module will only be able to roundtrip if metadata
199        # can be correctly parsed.
200        ep = export(MyModule(), inp)
201        buffer = io.BytesIO()
202        save(ep, buffer)
203        loaded_ep = load(buffer)
204
205        # Check that both modules run to confirm load was successful.
206        exp_out = ep.module()(*inp)
207        actual_out = loaded_ep.module()(*inp)
208        self.assertEqual(exp_out, actual_out)
209
210    def test_serialize_constant_outputs(self):
211        class MyModule(torch.nn.Module):
212            def __init__(self) -> None:
213                super().__init__()
214
215            def forward(self, x):
216                # Along with tensor output, return Nonetype
217                # and constant. Although these outputs aren't
218                # very useful, they do show up in graphs.
219                return x + 1, None, 1024
220
221        # Check that module can be roundtripped, thereby confirming proper deserialization.
222        inp = (torch.ones(10),)
223        ep = export(MyModule(), inp)
224        buffer = io.BytesIO()
225        save(ep, buffer)
226        loaded_ep = load(buffer)
227
228        exp_out = ep.module()(*inp)
229        actual_out = loaded_ep.module()(*inp)
230        self.assertEqual(exp_out, actual_out)
231
232    def test_serialize_multiple_returns_from_node(self) -> None:
233        class MyModule(torch.nn.Module):
234            def __init__(self) -> None:
235                super().__init__()
236
237            def forward(self, x, w, b):
238                return torch.nn.functional.layer_norm(
239                    x,
240                    x.size()[1:],
241                    weight=w,
242                    bias=b,
243                    eps=1e-5,
244                )
245
246        exported_module = export(
247            MyModule(),
248            (
249                torch.ones([512, 512], requires_grad=True),
250                torch.ones([512]),
251                torch.ones([512]),
252            ),
253        ).run_decompositions()
254
255        serialized = ExportedProgramSerializer().serialize(exported_module)
256        node = serialized.exported_program.graph_module.graph.nodes[-1]
257        self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default")
258        # aten::native_layer_norm returns 3 tensors
259        self.assertEqual(len(node.outputs), 3)
260
261        # check the names are unique
262        seen = set()
263        for output in node.outputs:
264            name = output.as_tensor.name
265            self.assertNotIn(name, seen)
266            seen.add(name)
267
268    def test_serialize_sym_int(self) -> None:
269        class DynamicShapeSimpleModel(torch.nn.Module):
270            def __init__(self):
271                super().__init__()
272
273            def forward(self, a, b, c) -> torch.Tensor:
274                d = (torch.matmul(a, b) + c) / 2
275                d_s0 = d.shape[0]
276                d_s1 = d.shape[1]
277                d_s3 = d_s0 * d_s1
278                e = d.view(d_s3)
279                return torch.cat([e, e])
280
281        inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
282        dim0_ac = torch.export.Dim("dim0_ac")
283        dim1_bc = torch.export.Dim("dim1_b")
284        dynamic_shapes = {
285            "a": {0: dim0_ac},
286            "b": {1: dim1_bc},
287            "c": {0: dim0_ac, 1: dim1_bc},
288        }
289        exported_module = export(
290            DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
291        ).run_decompositions()
292        serialized = ExportedProgramSerializer().serialize(exported_module)
293        sym_size_nodes = [
294            node
295            for node in serialized.exported_program.graph_module.graph.nodes
296            if node.target == "torch.ops.aten.sym_size.int"
297        ]
298        for node in sym_size_nodes:
299            self.assertEqual(node.inputs[0].name, "self")
300            self.assertEqual(node.inputs[1].name, "dim")
301
302    def test_serialize_list_returns(self) -> None:
303        class MyModule(torch.nn.Module):
304            def __init__(self) -> None:
305                super().__init__()
306
307            def forward(self, x):
308                return torch.split(x, 2)
309
310        input = torch.arange(10.0).reshape(5, 2)
311        exported_module = export(MyModule(), (input,)).run_decompositions()
312
313        serialized = ExportedProgramSerializer().serialize(exported_module)
314        node = serialized.exported_program.graph_module.graph.nodes[-1]
315        # split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table
316        self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default")
317        self.assertEqual(len(node.outputs), 1)
318        # Input looks like:
319        # tensor([[0, 1],
320        #         [2, 3],
321        #         [4, 5],
322        #         [6, 7],
323        #         [8, 9]])
324        # Output looks like:
325        # (tensor([[0, 1],
326        #          [2, 3]]),
327        #  tensor([[4, 5],
328        #          [6, 7]]),
329        #  tensor([[8, 9]]))
330        self.assertEqual(len(node.outputs[0].as_tensors), 3)
331
332        # check the names are unique
333        seen = set()
334        for output in node.outputs[0].as_tensors:
335            name = output.name
336            self.assertNotIn(name, seen)
337            seen.add(name)
338
339    def test_multi_return_some_unused(self) -> None:
340        """
341        Make sure the serialized output matches the op schema, even if some of
342        the arguments are never used in the graph.
343        """
344
345        class MyModule(torch.nn.Module):
346            def __init__(self) -> None:
347                super().__init__()
348
349            def forward(self, x):
350                return torch.ops.aten.var_mean.correction(x, [1])[0]
351
352        exported_module = export(
353            MyModule(),
354            (torch.ones([512, 512], requires_grad=True),),
355        ).run_decompositions()
356
357        serialized = ExportedProgramSerializer().serialize(exported_module)
358        node = serialized.exported_program.graph_module.graph.nodes[-1]
359        self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
360        self.assertEqual(len(node.outputs), 2)
361
362        # check the names are unique
363        seen = set()
364        for output in node.outputs:
365            name = output.as_tensor.name
366            self.assertNotIn(name, seen)
367            seen.add(name)
368
369    def test_rational_ranges(self) -> None:
370        class M(torch.nn.Module):
371            def forward(self, x):
372                return x + x
373
374        ep = torch.export.export(
375            M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},)
376        )
377
378        range_constraints = list(ep.range_constraints.keys())
379        assert len(range_constraints) == 1
380        symint = range_constraints[0]
381
382        import sympy
383
384        upper_range = sympy.Rational(10, 3)
385        lower_range = sympy.Rational(10, 6)
386        ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
387
388        serialized = ExportedProgramSerializer().serialize(ep)
389        self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2)
390        self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3)
391
392    def test_kwargs_default(self) -> None:
393        """
394        Tests that the kwargs default values are serialized even if they are not
395        specified
396        """
397
398        class Foo(torch.nn.Module):
399            def forward(self, x: torch.Tensor) -> torch.Tensor:
400                values = torch.randn(3, 2)
401                return torch.searchsorted(x, values, side="right", right=True)
402
403        f = Foo()
404
405        x, _ = torch.sort(torch.randn(3, 4))
406        exported_module = export(f, (x,)).run_decompositions()
407        serialized = ExportedProgramSerializer().serialize(exported_module)
408
409        node = serialized.exported_program.graph_module.graph.nodes[-1]
410        self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor")
411        self.assertEqual(len(node.inputs), 4)
412        self.assertEqual(node.inputs[2].name, "right")
413        self.assertEqual(node.inputs[2].arg.as_bool, True)
414        self.assertEqual(node.inputs[3].name, "side")
415        self.assertEqual(node.inputs[3].arg.as_string, "right")
416
417    def test_canonicalize(self) -> None:
418        class Module(torch.nn.Module):
419            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
420                a = y + x
421                b = x + y
422                return b + a
423
424        ep = torch.export.export(Module(), (torch.randn(3, 2), torch.randn(3, 2)))
425        s = ExportedProgramSerializer().serialize(ep)
426        c = canonicalize(s.exported_program)
427        g = c.graph_module.graph
428        self.assertLess(
429            g.nodes[0].inputs[0].arg.as_tensor.name,
430            g.nodes[1].inputs[0].arg.as_tensor.name,
431        )
432
433    def test_int_list(self) -> None:
434        class M(torch.nn.Module):
435            def forward(self, x):
436                return torch.ops.aten.sum.dim_IntList(x, [])
437
438        ep = torch.export.export(M(), (torch.randn(3, 2),))
439        serialized = ExportedProgramSerializer().serialize(ep)
440        for node in serialized.exported_program.graph_module.graph.nodes:
441            if "aten.sum.dim_IntList" in node.target:
442                self.assertEqual(node.inputs[1].arg.type, "as_ints")
443
444
445@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
446@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
447class TestDeserialize(TestCase):
448    def setUp(self):
449        super().setUp()
450        init_torchbind_implementations()
451
452    def _check_graph_nodes(self, gm1, gm2, _check_meta=True):
453        # TODO: The _check_meta flag bypasses checking for
454        # source_fn/nn_module_stack as there is an issue with
455        # roundtripping the source_fn value on torch.ops.map nodes
456        # original source_fn: <functorch.experimental._map.MapWrapper object at 0x7f80a0549930>
457        # deserialized source_fn: 'functorch.experimental._map.map'
458
459        self.assertEqual(len(gm1.graph.nodes), len(gm2.graph.nodes))
460
461        for node1, node2 in zip(gm1.graph.nodes, gm2.graph.nodes):
462            self.assertEqual(node1.op, node2.op)
463            if node1.op == "call_function":
464                # Check "val" metadata
465                val1 = node1.meta.get("val", None)
466                val2 = node2.meta.get("val", None)
467                if val1 is None or val2 is None:
468                    # Either both are None
469                    self.assertEqual(val1, val2)
470                elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
471                    # Or both are fake tensors with the same shape/dtype
472                    self.assertEqual(len(val1.shape), len(val2.shape))
473                    for s1, s2 in zip(val1.shape, val2.shape):
474                        if is_concrete_int(s1) and is_concrete_int(s2):
475                            self.assertEqual(s1, s2)
476                        else:
477                            self.assertEqual(str(s1), str(s2))
478                    self.assertEqual(val1.dtype, val2.dtype)
479                elif isinstance(val1, (list, tuple)) and isinstance(
480                    val2, (list, tuple)
481                ):
482                    # Or both are fake tensors lists with one element and with the
483                    # same shape/dtype
484                    for v1, v2 in zip(
485                        pytree.tree_leaves(val1), pytree.tree_leaves(val2)
486                    ):
487                        if isinstance(v1, FakeTensor):
488                            self.assertEqual(v1.shape, v2.shape)
489                            self.assertEqual(v1.dtype, v2.dtype)
490                else:
491                    # For expressions like 's0 < 10' can only compare through string
492                    self.assertEqual(str(val1), str(val2))
493
494                # Check "stack_trace" metadata
495                self.assertEqual(
496                    node1.meta.get("stack_trace", None),
497                    node2.meta.get("stack_trace", None),
498                )
499
500                if node1.target == torch.ops.higher_order.cond:
501                    true_graph1 = getattr(gm1, node1.args[1].target)
502                    true_graph2 = getattr(gm2, node2.args[1].target)
503                    self._check_graph_nodes(true_graph1, true_graph2)
504
505                    false_graph1 = getattr(gm1, node1.args[2].target)
506                    false_graph2 = getattr(gm2, node2.args[2].target)
507                    self._check_graph_nodes(false_graph1, false_graph2)
508                elif node1.target == torch.ops.higher_order.map_impl:
509                    map_graph1 = getattr(gm1, node1.args[0].target)
510                    map_graph2 = getattr(gm2, node2.args[0].target)
511                    self._check_graph_nodes(map_graph1, map_graph2, False)
512
513            if _check_meta and node1.op not in ("get_attr", "placeholder", "output"):
514                # Check "nn_module_stack" metadata
515                self.assertEqual(
516                    node1.meta.get("nn_module_stack", None),
517                    node2.meta.get("nn_module_stack", None),
518                )
519                # Check "source_fn_stack" metadata
520                self.assertEqual(
521                    node1.meta.get("source_fn_stack", None),
522                    node2.meta.get("source_fn_stack", None),
523                )
524
525    def check_graph(
526        self,
527        fn,
528        inputs,
529        dynamic_shapes=None,
530        _check_meta=True,
531        use_pre_dispatch=True,
532        strict=True,
533    ) -> None:
534        """Export a graph, serialize it, deserialize it, and compare the results."""
535
536        def _deepcopy_inputs(inputs):
537            # copy.deepcopy(deepcopy) can fail if tensor inputs have attribute (i.e. __dict__).
538            # we remove __dict__ when deepcopying.
539            dict_mapping = dict()
540            inputs_clone = ()
541            for idx, i in enumerate(inputs):
542                if isinstance(i, torch.Tensor) and hasattr(inputs[0], "__dict__"):
543                    dict_mapping[idx] = i.__dict__
544                    i.__dict__ = {}
545                inputs_clone += (copy.deepcopy(i),)
546
547            # Add __dict__ back.
548            for k, v in dict_mapping.items():
549                inputs[k].__dict__ = v
550                inputs_clone[k].__dict__ = v
551            return inputs_clone
552
553        def _check_graph(pre_dispatch):
554            if pre_dispatch:
555                ep = torch.export._trace._export(
556                    fn,
557                    _deepcopy_inputs(inputs),
558                    {},
559                    dynamic_shapes=dynamic_shapes,
560                    pre_dispatch=True,
561                    strict=strict,
562                )
563            else:
564                ep = torch.export.export(
565                    fn,
566                    _deepcopy_inputs(inputs),
567                    {},
568                    dynamic_shapes=dynamic_shapes,
569                    strict=strict,
570                )
571            ep.graph.eliminate_dead_code()
572
573            serialized_artifact = serialize(ep, opset_version={"aten": 0})
574            deserialized_ep = deserialize(
575                serialized_artifact, expected_opset_version={"aten": 0}
576            )
577            deserialized_ep.graph.eliminate_dead_code()
578
579            orig_outputs = ep.module()(*_deepcopy_inputs(inputs))
580            loaded_outputs = deserialized_ep.module()(*_deepcopy_inputs(inputs))
581
582            flat_orig_outputs = pytree.tree_leaves(orig_outputs)
583            flat_loaded_outputs = pytree.tree_leaves(loaded_outputs)
584
585            for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
586                self.assertEqual(type(orig), type(loaded))
587                if isinstance(orig, torch.Tensor):
588                    if orig.is_meta:
589                        self.assertEqual(orig, loaded)
590                    else:
591                        self.assertTrue(torch.allclose(orig, loaded))
592                else:
593                    self.assertEqual(orig, loaded)
594            self._check_graph_nodes(
595                ep.graph_module, deserialized_ep.graph_module, _check_meta
596            )
597
598        if use_pre_dispatch:
599            _check_graph(pre_dispatch=True)
600            _check_graph(pre_dispatch=False)
601        else:
602            _check_graph(pre_dispatch=False)
603
604    def test_optional_tuple(self):
605        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
606            torch.library.define(
607                "mylib::foo",
608                "(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)",
609                tags=torch.Tag.pt2_compliant_tag,
610                lib=lib,
611            )
612
613            @torch.library.impl("mylib::foo", "cpu", lib=lib)
614            @torch.library.impl_abstract("mylib::foo")
615            def foo_impl(a, b, c):
616                res2 = None
617                if c is not None:
618                    res2 = c + a + b
619                return a + b, res2
620
621            class M(torch.nn.Module):
622                def forward(self, a, b, c):
623                    return torch.ops.mylib.foo(a, b, c)
624
625            self.check_graph(M(), (torch.randn(3), torch.randn(3), torch.randn(3)))
626
627    def test_auto_functionalize(self):
628        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
629            torch.library.define(
630                "mylib::foo1",
631                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor",
632                tags=torch.Tag.pt2_compliant_tag,
633                lib=lib,
634            )
635            torch.library.define(
636                "mylib::foo2",
637                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
638                tags=torch.Tag.pt2_compliant_tag,
639                lib=lib,
640            )
641            torch.library.define(
642                "mylib::foo3",
643                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
644                tags=torch.Tag.pt2_compliant_tag,
645                lib=lib,
646            )
647
648            @torch.library.impl("mylib::foo1", "cpu", lib=lib)
649            @torch.library.impl_abstract("mylib::foo1")
650            def foo1_impl(x, y, z, w, n):
651                x.add_(y[0] + w)
652                z.add_(y[1] + n)
653                return n + n
654
655            @torch.library.impl("mylib::foo2", "cpu", lib=lib)
656            @torch.library.impl_abstract("mylib::foo2")
657            def foo2_impl(x, y, z, w, n):
658                x.add_(y[0] + w)
659                z.add_(y[1] + n)
660                return (n + n, n * n)
661
662            @torch.library.impl("mylib::foo3", "cpu", lib=lib)
663            @torch.library.impl_abstract("mylib::foo3")
664            def foo3_impl(x, y, z, w, n):
665                x.add_(y[0] + w)
666                z.add_(y[1] + n)
667                return
668
669            class M(torch.nn.Module):
670                def forward(self, x, y, z, n):
671                    n = torch.ops.mylib.foo1(x, y, z, 2, n)
672                    torch.ops.mylib.foo3(x, y, z, 2, n)
673                    return torch.ops.mylib.foo2(x, y, z, 2, n)
674
675            x = torch.randn(3)
676            y = (torch.randn(3), torch.randn(3))
677            z = torch.randn(3)
678            n = torch.randn(3)
679            orig_args = (x, y, z, n)
680
681            # TODO Auto_functionalize is not supported on pre_dispatch IR
682            self.check_graph(M(), orig_args, use_pre_dispatch=False)
683
684    def test_multi_return(self) -> None:
685        """
686        Test multiple return from a single node (ex. layer_norm has 2 outputs)
687        """
688
689        class MyModule(torch.nn.Module):
690            def __init__(self) -> None:
691                super().__init__()
692
693            def forward(self, x, w, b):
694                return torch.nn.functional.layer_norm(
695                    x,
696                    x.size()[1:],
697                    weight=w,
698                    bias=b,
699                    eps=1e-5,
700                )
701
702        inputs = (
703            torch.ones([512, 512], requires_grad=True),
704            torch.ones([512]),
705            torch.ones([512]),
706        )
707        self.check_graph(MyModule(), inputs)
708
709    def test_basic(self) -> None:
710        class MyModule(torch.nn.Module):
711            def __init__(self) -> None:
712                super().__init__()
713
714            def forward(self, x):
715                x = x + x
716                x = x * x
717                x = x / x
718                return x, x.clone()
719
720        inputs = (torch.ones([512], requires_grad=True),)
721        self.check_graph(MyModule(), inputs)
722
723    def test_dynamic(self) -> None:
724        class DynamicShapeSimpleModel(torch.nn.Module):
725            def __init__(self) -> None:
726                super().__init__()
727
728            def forward(self, a, b, c) -> torch.Tensor:
729                d = (torch.matmul(a, b) + c) / 2
730                d_s0 = d.shape[0]
731                d_s1 = d.shape[1]
732                d_s3 = d_s0 * d_s1
733                e = d.view(d_s3)
734                return torch.cat([e, e])
735
736        inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
737        dim0_ac = torch.export.Dim("dim0_ac")
738        dynamic_shapes = {"a": {0: dim0_ac}, "b": None, "c": {0: dim0_ac}}
739        self.check_graph(DynamicShapeSimpleModel(), inputs, dynamic_shapes)
740
741    def test_sym_bool(self):
742        class Module(torch.nn.Module):
743            def forward(self, x, y):
744                assert x.size(0) in y
745                return x + y
746
747        f = Module()
748        self.check_graph(f, (torch.ones(1), torch.ones(3)))
749
750    def test_shape(self):
751        class Foo(torch.nn.Module):
752            def forward(self, x):
753                z, y = x.size()
754                return z + y + x[0], z
755
756        inputs = (torch.ones(2, 3),)
757        dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x")
758        dynamic_shapes = {"x": (dim0_x, dim1_x)}
759        self.check_graph(Foo(), inputs, dynamic_shapes)
760
761    def test_module(self):
762        class M(torch.nn.Module):
763            def __init__(self) -> None:
764                super().__init__()
765                self.linear1 = torch.nn.Linear(3, 3)
766                self.relu = torch.nn.ReLU()
767                self.linear2 = torch.nn.Linear(3, 5)
768
769            def forward(self, x):
770                x = self.linear1(x)
771                x = self.linear1(x)
772                x = torch.nn.functional.relu(x)
773                x = self.linear2(x)
774                return x
775
776        inputs = (torch.randn(3, 3),)
777        self.check_graph(M(), inputs)
778
779    def test_module_meta(self):
780        class M(torch.nn.Module):
781            def __init__(self) -> None:
782                super().__init__()
783                self.p = torch.nn.Parameter(torch.ones(3, 3))
784
785            def forward(self, x):
786                return self.p + x
787
788        with torch.device("meta"):
789            mod = M()
790
791        inputs = (torch.randn(3, 3, device="meta"),)
792        self.check_graph(mod, inputs)
793
794    def test_cond(self):
795        from functorch.experimental.control_flow import cond
796
797        inputs = torch.ones(4, 3), torch.zeros(4, 3)
798
799        class M(torch.nn.Module):
800            def forward(self, x, y):
801                def t(x, y):
802                    return x + y
803
804                def f(x, y):
805                    return x - y
806
807                return cond(x[0][0] > 4, t, f, [x, y])
808
809        self.check_graph(M(), inputs)
810
811    def test_map(self):
812        from functorch.experimental import control_flow
813
814        def f(x, y):
815            return x + y
816
817        class Module(torch.nn.Module):
818            def forward(self, xs, y):
819                return control_flow.map(f, xs, y)
820
821        g = Module()
822        inputs = (torch.ones(3, 2, 2), torch.ones(2))
823        self.check_graph(g, inputs, _check_meta=False)
824
825    def test_tensor_tensor_list(self):
826        with torch.library._scoped_library("_export", "FRAGMENT") as lib:
827            lib.define(
828                "_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])",
829                tags=torch.Tag.pt2_compliant_tag,
830            )
831
832            def _test_tensor_tensor_list_output(x, y):
833                return y, [x]
834
835            lib.impl(
836                "_test_tensor_tensor_list_output",
837                _test_tensor_tensor_list_output,
838                "CPU",
839            )
840            lib.impl(
841                "_test_tensor_tensor_list_output",
842                _test_tensor_tensor_list_output,
843                "Meta",
844            )
845
846            class M(torch.nn.Module):
847                def forward(self, x, y):
848                    a, b = torch.ops._export._test_tensor_tensor_list_output.default(
849                        x, y
850                    )
851                    return a + b[0]
852
853            self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2)))
854
855    def test_list_of_optional_tensors(self) -> None:
856        class MyModule(torch.nn.Module):
857            def __init__(self) -> None:
858                super().__init__()
859
860            def forward(self, x, y, z):
861                indices = [None, None, torch.tensor([1, 3, 5, 7])]
862                indexed = torch.ops.aten.index.Tensor(x + y, indices)
863                return indexed + z
864
865        inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4))
866        self.check_graph(MyModule(), inputs)
867
868    def test_sym_ite(self):
869        class Foo(torch.nn.Module):
870            def forward(self, x):
871                b = x.shape[0] == 5
872                ret = torch.sym_ite(b, x.shape[0], x.shape[1])
873                return ret
874
875        dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}}
876        self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes)
877
878    def test_multiple_getitem(self):
879        class M(torch.nn.Module):
880            def forward(self, x):
881                a, b = torch.topk(x, 2)
882                a = a * 2
883                return a, b
884
885        ep = torch.export.export(M(), (torch.ones(3),))
886
887        # insert another getitem node
888        for node in ep.graph.nodes:
889            if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor:
890                getitem_0 = node.args[0]
891                with ep.graph.inserting_before(getitem_0):
892                    getitem_copy = ep.graph.node_copy(getitem_0)
893                    mul_node = ep.graph.call_function(
894                        torch.ops.aten.mul.Tensor, (getitem_copy, 2)
895                    )
896                    mul_node.meta = copy.copy(getitem_copy.meta)
897                    node.args = (getitem_0, mul_node)
898
899        deserialized_ep = deserialize(serialize(ep))
900
901        inp = (torch.randn(3),)
902        orig_res = ep.module()(*inp)
903        res = deserialized_ep.module()(*inp)
904        self.assertTrue(torch.allclose(orig_res[0], res[0]))
905        self.assertTrue(torch.allclose(orig_res[1], res[1]))
906
907        # The deserialized graph should have deduped getitem calls
908        self.assertExpectedInline(
909            deserialized_ep.graph_module.code.strip("\n"),
910            """\
911def forward(self, x):
912    topk_default = torch.ops.aten.topk.default(x, 2);  x = None
913    getitem = topk_default[0]
914    getitem_1 = topk_default[1];  topk_default = None
915    mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2)
916    mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor);  getitem = mul_tensor = None
917    return (mul, getitem_1)
918    """,
919        )
920
921    @parametrize(
922        "name,case",
923        get_filtered_export_db_tests(),
924        name_fn=lambda name, case: f"case_{name}",
925    )
926    def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
927        model = case.model
928        _check_meta = "map" not in name
929        self.check_graph(model, case.example_args, _check_meta=_check_meta)
930
931    def test_constraints(self):
932        class Module(torch.nn.Module):
933            def forward(self, x, y):
934                n = x.item()
935                torch._check_is_size(n)
936                return y.sum() + torch.ones(n, 5).sum()
937
938        f = Module()
939        self.check_graph(f, (torch.tensor(3), torch.randn(4, 5)))
940
941    def test_get_attr(self) -> None:
942        class Module(torch.nn.Module):
943            def forward(self, x):
944                return x + torch.tensor(3)
945
946        f = Module()
947        self.check_graph(f, (torch.tensor(3),))
948
949    def test_get_attr_list(self) -> None:
950        class Module(torch.nn.Module):
951            def forward(self, x):
952                return torch.cat([x, torch.tensor([1, 1])])
953
954        f = Module()
955        self.check_graph(f, (torch.tensor([1, 1]),))
956
957    @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
958    def test_device(self) -> None:
959        class MyModule(torch.nn.Module):
960            def __init__(self) -> None:
961                super().__init__()
962                self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
963                self.relu = torch.nn.ReLU()
964
965            def forward(self, x):
966                conv = self.conv(x)
967                relu = self.relu(conv)
968                mul = relu * 0.5
969                return mul
970
971        inp = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
972        model = MyModule().eval().cuda()
973        self.check_graph(model, (inp,))
974
975    def test_custom_obj_tuple_out(self):
976        class MyModule(torch.nn.Module):
977            def __init__(self) -> None:
978                super().__init__()
979                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
980
981            def forward(self, x):
982                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
983                y = a[0] + a[1]
984                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
985                return x + b
986
987        m = MyModule()
988        inputs = (torch.ones(2, 3),)
989        self.check_graph(m, inputs, strict=False)
990
991    def test_custom_obj(self):
992        class MyModule(torch.nn.Module):
993            def __init__(self) -> None:
994                super().__init__()
995                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
996
997            def forward(self, x):
998                a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
999                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
1000                return x + b
1001
1002        m = MyModule()
1003        inputs = (torch.ones(2, 3),)
1004        self.check_graph(m, inputs, strict=False)
1005
1006    def test_custom_obj_list_out(self):
1007        class MyModule(torch.nn.Module):
1008            def __init__(self) -> None:
1009                super().__init__()
1010                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
1011
1012            def forward(self, x):
1013                a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
1014                y = a[0] + a[1] + a[2]
1015                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
1016                return x + b
1017
1018        m = MyModule()
1019        inputs = (torch.ones(2, 3),)
1020        self.check_graph(m, inputs, strict=False)
1021
1022    def test_export_no_inputs(self):
1023        class M(torch.nn.Module):
1024            def __init__(self) -> None:
1025                super().__init__()
1026                self.p = torch.ones(3, 3)
1027
1028            def forward(self):
1029                return self.p * self.p
1030
1031        ep = torch.export.export(M(), ())
1032        ep._example_inputs = None
1033        roundtrip_ep = deserialize(serialize(ep))
1034        self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
1035
1036
1037instantiate_parametrized_tests(TestDeserialize)
1038
1039
1040@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
1041class TestSchemaVersioning(TestCase):
1042    def test_error(self):
1043        class Module(torch.nn.Module):
1044            def forward(self, x):
1045                return x + x
1046
1047        f = Module()
1048        ep = export(f, (torch.randn(1, 3),))
1049
1050        serialized_program = ExportedProgramSerializer().serialize(ep)
1051        serialized_program.exported_program.schema_version.major = -1
1052        with self.assertRaisesRegex(
1053            SerializeError, r"Serialized schema version .* does not match our current"
1054        ):
1055            ExportedProgramDeserializer().deserialize(
1056                serialized_program.exported_program,
1057                serialized_program.state_dict,
1058                serialized_program.constants,
1059                serialized_program.example_inputs,
1060            )
1061
1062
1063# We didn't set up kwargs input yet
1064unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs)
1065
1066
1067@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
1068class TestSaveLoad(TestCase):
1069    def test_save_buffer(self):
1070        inp = (torch.tensor([0.1, 0.1]),)
1071
1072        class Module(torch.nn.Module):
1073            def __init__(self) -> None:
1074                super().__init__()
1075                self.linear = torch.nn.Linear(2, 2)
1076
1077            def forward(self, x):
1078                x = x + 1
1079                y = x.t()
1080                y = y.relu()
1081                y = self.linear(y)
1082                return y
1083
1084        ep = export(Module(), inp)
1085
1086        buffer = io.BytesIO()
1087        save(ep, buffer)
1088        buffer.seek(0)
1089        loaded_ep = load(buffer)
1090
1091        self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
1092
1093    def test_save_file(self):
1094        class Foo(torch.nn.Module):
1095            def forward(self, x):
1096                return x * x
1097
1098        f = Foo()
1099
1100        inp = (torch.randn(2, 2),)
1101        ep = export(f, inp)
1102
1103        with tempfile.NamedTemporaryFile() as f:
1104            save(ep, f)
1105            f.seek(0)
1106            loaded_ep = load(f)
1107
1108        self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
1109
1110    def test_save_path(self):
1111        class Foo(torch.nn.Module):
1112            def forward(self, x, y):
1113                return x + y
1114
1115        f = Foo()
1116
1117        inp = (torch.tensor([6]), torch.tensor([7]))
1118        ep = export(f, inp)
1119
1120        with TemporaryFileName() as fname:
1121            path = Path(fname)
1122            save(ep, path)
1123            loaded_ep = load(path)
1124
1125        self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
1126
1127    def test_save_extra(self):
1128        inp = (torch.tensor([0.1, 0.1]),)
1129
1130        class Foo(torch.nn.Module):
1131            def forward(self, x):
1132                return x * x + x
1133
1134        f = Foo()
1135
1136        ep = export(f, inp)
1137
1138        buffer = io.BytesIO()
1139        save(ep, buffer, extra_files={"extra.txt": "moo"})
1140        buffer.seek(0)
1141        extra_files = {"extra.txt": ""}
1142        loaded_ep = load(buffer, extra_files=extra_files)
1143
1144        self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
1145        self.assertEqual(extra_files["extra.txt"], "moo")
1146
1147    def test_version_error(self):
1148        class Foo(torch.nn.Module):
1149            def forward(self, x):
1150                return x + x
1151
1152        f = Foo()
1153
1154        ep = export(f, (torch.randn(1, 3),))
1155
1156        with tempfile.NamedTemporaryFile() as f:
1157            save(ep, f)
1158            f.seek(0)
1159
1160            # Modify the version
1161            with zipfile.ZipFile(f, "a") as zipf:
1162                zipf.writestr("version", "-1.1")
1163
1164            with self.assertRaisesRegex(
1165                RuntimeError, r"Serialized version .* does not match our current"
1166            ):
1167                f.seek(0)
1168                load(f)
1169
1170    def test_save_constants(self):
1171        class Foo(torch.nn.Module):
1172            def __init__(self) -> None:
1173                super().__init__()
1174                self.a = torch.tensor(3)
1175
1176            def forward(self, x):
1177                list_tensor = [torch.tensor(3), torch.tensor(4)]
1178                return x + self.a + list_tensor[0] + list_tensor[1]
1179
1180        ep = export(Foo(), (torch.tensor(1),))
1181        buffer = io.BytesIO()
1182        save(ep, buffer)
1183        buffer.seek(0)
1184        loaded_ep = load(buffer)
1185
1186        inp = (torch.tensor(1),)
1187        self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
1188
1189
1190@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
1191class TestSerializeCustomClass(TestCase):
1192    def setUp(self):
1193        super().setUp()
1194        init_torchbind_implementations()
1195
1196    def test_custom_class(self):
1197        custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4])
1198
1199        class Foo(torch.nn.Module):
1200            def forward(self, x):
1201                return x + x
1202
1203        f = Foo()
1204
1205        inputs = (torch.zeros(4, 4),)
1206        ep = export(f, inputs)
1207
1208        # Replace one of the values with an instance of our custom class
1209        for node in ep.graph.nodes:
1210            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1211                with ep.graph.inserting_before(node):
1212                    custom_node = ep.graph.call_function(
1213                        torch.ops._TorchScriptTesting.take_an_instance.default,
1214                        (custom_obj,),
1215                    )
1216                    custom_node.meta["val"] = torch.ones(4, 4)
1217                    custom_node.meta["torch_fn"] = (
1218                        "take_an_instance",
1219                        "take_an_instance",
1220                    )
1221                    arg0, _ = node.args
1222                    node.args = (arg0, custom_node)
1223
1224        serialized_vals = serialize(ep)
1225
1226        ep_str = serialized_vals.exported_program.decode("utf-8")
1227        assert "class_fqn" in ep_str
1228        assert custom_obj._type().qualified_name() in ep_str
1229
1230        deserialized_ep = deserialize(serialized_vals)
1231
1232        for node in deserialized_ep.graph.nodes:
1233            if (
1234                node.op == "call_function"
1235                and node.target
1236                == torch.ops._TorchScriptTesting.take_an_instance.default
1237            ):
1238                arg = node.args[0]
1239                self.assertTrue(isinstance(arg, torch._C.ScriptObject))
1240                self.assertEqual(arg._type(), custom_obj._type())
1241                self.assertEqual(arg.__getstate__(), custom_obj.__getstate__())
1242                self.assertEqual(arg.top(), 7)
1243
1244    def test_custom_class_containing_fake_tensor(self):
1245        class Foo(torch.nn.Module):
1246            def __init__(self) -> None:
1247                super().__init__()
1248                self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor(
1249                    torch.rand(2, 3)
1250                )
1251
1252            def forward(self, x):
1253                return x + self.custom_obj.get()
1254
1255        with FakeTensorMode():
1256            f = Foo()
1257
1258        inputs = (torch.zeros(2, 3),)
1259        with enable_torchbind_tracing():
1260            ep = export(f, inputs, strict=False)
1261
1262        serialized_vals = serialize(ep)
1263        ep = deserialize(serialized_vals)
1264        self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor))
1265
1266    def test_custom_tag_metadata_serialization(self):
1267        class Foo(torch.nn.Module):
1268            def forward(self, x):
1269                return x + x
1270
1271        f = Foo()
1272
1273        inputs = (torch.zeros(4, 4),)
1274        ep = export(f, inputs)
1275
1276        new_gm = copy.deepcopy(ep.graph_module)
1277        new_gm.meta["custom"] = {}
1278        new_gm.meta["custom"]["f"] = "bar"
1279
1280        for node in new_gm.graph.nodes:
1281            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1282                node.meta["custom"] = {}
1283                node.meta["custom"]["quantization_tag"] = "foo"
1284
1285        new_ep = ep._update(new_gm, ep.graph_signature)
1286        serialized_vals = serialize(new_ep)
1287        new_ep = deserialize(serialized_vals)
1288
1289        self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar")
1290        counter = 0
1291        for node in new_ep.graph.nodes:
1292            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1293                counter += 1
1294                self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
1295        self.assertEqual(counter, 1)
1296
1297    def test_custom_tag_metadata_decomp(self):
1298        class Foo(torch.nn.Module):
1299            def __init__(self):
1300                super().__init__()
1301                self.linear = torch.nn.Linear(2, 2)
1302
1303            def forward(self, x):
1304                return self.linear(x)
1305
1306        f = Foo()
1307
1308        inputs = (torch.ones(2, 2),)
1309        ep = export(f, inputs)
1310
1311        new_gm = copy.deepcopy(ep.graph_module)
1312        new_gm.meta["custom"] = {}
1313        new_gm.meta["custom"]["f"] = "bar"
1314
1315        counter = 0
1316        for node in new_gm.graph.nodes:
1317            if (
1318                node.op == "call_function"
1319                and node.target == torch.ops.aten.linear.default
1320            ):
1321                counter += 1
1322                node.meta["custom"] = {}
1323                node.meta["custom"]["quantization_tag"] = "foo"
1324        self.assertEqual(counter, 1)
1325
1326        new_ep = ep._update(new_gm, ep.graph_signature)
1327        new_ep = new_ep.run_decompositions()
1328
1329        self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar")
1330        counter = 0
1331        for node in new_ep.graph.nodes:
1332            if node.op == "call_function":
1333                counter += 1
1334                self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
1335        self.assertTrue(counter > 1)
1336
1337    # TODO For some reason, this doesn't work on Windows ONLY.
1338    # def test_custom_tag_metadata_reexport(self):
1339    #     class Foo(torch.nn.Module):
1340    #         def forward(self, x):
1341    #             return x + x
1342    #
1343    #     f = Foo()
1344    #
1345    #     inputs = (torch.zeros(4, 4),)
1346    #     ep = export(f, inputs)
1347    #
1348    #     new_gm = copy.deepcopy(ep.graph_module)
1349    #     new_gm.meta["custom"] = {}
1350    #     new_gm.meta["custom"]["f"] = "bar"
1351    #
1352    #     for node in new_gm.graph.nodes:
1353    #         if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1354    #             node.meta["custom"] = {}
1355    #             node.meta["custom"]["quantization_tag"] = "foo"
1356    #
1357    #     new_ep = ep._update(new_gm, ep.graph_signature)
1358    #     new_ep = torch.export.export(new_ep.module(), inputs)
1359    #
1360    #     self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar")
1361    #     counter = 0
1362    #     for node in new_ep.graph.nodes:
1363    #         if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1364    #             counter += 1
1365    #             self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
1366    #     self.assertEqual(counter, 1)
1367
1368    def test_custom_tag_metadata_copy(self):
1369        class Foo(torch.nn.Module):
1370            def forward(self, x):
1371                return x + x
1372
1373        f = Foo()
1374
1375        inputs = (torch.zeros(4, 4),)
1376        ep = export(f, inputs)
1377
1378        new_gm = copy.deepcopy(ep.graph_module)
1379        new_gm.meta["custom"] = {}
1380        new_gm.meta["custom"]["f"] = "bar"
1381
1382        for node in new_gm.graph.nodes:
1383            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1384                node.meta["custom"] = {}
1385                node.meta["custom"]["quantization_tag"] = "foo"
1386
1387        new_gm = copy.deepcopy(new_gm)
1388
1389        self.assertEqual(new_gm.meta["custom"]["f"], "bar")
1390        counter = 0
1391        for node in new_gm.graph.nodes:
1392            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1393                counter += 1
1394                self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
1395        self.assertEqual(counter, 1)
1396
1397
1398if __name__ == "__main__":
1399    run_tests()
1400