xref: /aosp_15_r20/external/pytorch/test/package/test_package_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4
5import torch
6from torch.fx import Graph, GraphModule, symbolic_trace
7from torch.package import (
8    ObjMismatchError,
9    PackageExporter,
10    PackageImporter,
11    sys_importer,
12)
13from torch.testing._internal.common_utils import run_tests
14
15
16try:
17    from .common import PackageTestCase
18except ImportError:
19    # Support the case where we run this file directly.
20    from common import PackageTestCase
21
22torch.fx.wrap("len")
23# Do it twice to make sure it doesn't affect anything
24torch.fx.wrap("len")
25
26
27class TestPackageFX(PackageTestCase):
28    """Tests for compatibility with FX."""
29
30    def test_package_fx_simple(self):
31        class SimpleTest(torch.nn.Module):
32            def forward(self, x):
33                return torch.relu(x + 3.0)
34
35        st = SimpleTest()
36        traced = symbolic_trace(st)
37
38        f = BytesIO()
39        with PackageExporter(f) as pe:
40            pe.save_pickle("model", "model.pkl", traced)
41
42        f.seek(0)
43        pi = PackageImporter(f)
44        loaded_traced = pi.load_pickle("model", "model.pkl")
45        input = torch.rand(2, 3)
46        self.assertEqual(loaded_traced(input), traced(input))
47
48    def test_package_then_fx(self):
49        from package_a.test_module import SimpleTest
50
51        model = SimpleTest()
52        f = BytesIO()
53        with PackageExporter(f) as pe:
54            pe.intern("**")
55            pe.save_pickle("model", "model.pkl", model)
56
57        f.seek(0)
58        pi = PackageImporter(f)
59        loaded = pi.load_pickle("model", "model.pkl")
60        traced = symbolic_trace(loaded)
61        input = torch.rand(2, 3)
62        self.assertEqual(loaded(input), traced(input))
63
64    def test_package_fx_package(self):
65        from package_a.test_module import SimpleTest
66
67        model = SimpleTest()
68        f = BytesIO()
69        with PackageExporter(f) as pe:
70            pe.intern("**")
71            pe.save_pickle("model", "model.pkl", model)
72
73        f.seek(0)
74        pi = PackageImporter(f)
75        loaded = pi.load_pickle("model", "model.pkl")
76        traced = symbolic_trace(loaded)
77
78        # re-save the package exporter
79        f2 = BytesIO()
80        # This should fail, because we are referencing some globals that are
81        # only in the package.
82        with self.assertRaises(ObjMismatchError):
83            with PackageExporter(f2) as pe:
84                pe.intern("**")
85                pe.save_pickle("model", "model.pkl", traced)
86
87        f2.seek(0)
88        with PackageExporter(f2, importer=(pi, sys_importer)) as pe:
89            # Make the package available to the exporter's environment.
90            pe.intern("**")
91            pe.save_pickle("model", "model.pkl", traced)
92        f2.seek(0)
93        pi2 = PackageImporter(f2)
94        loaded2 = pi2.load_pickle("model", "model.pkl")
95
96        input = torch.rand(2, 3)
97        self.assertEqual(loaded(input), loaded2(input))
98
99    def test_package_fx_with_imports(self):
100        import package_a.subpackage
101
102        # Manually construct a graph that invokes a leaf function
103        graph = Graph()
104        a = graph.placeholder("x")
105        b = graph.placeholder("y")
106        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
107        d = graph.call_function(torch.sin, (c,))
108        graph.output(d)
109        gm = GraphModule(torch.nn.Module(), graph)
110
111        f = BytesIO()
112        with PackageExporter(f) as pe:
113            pe.intern("**")
114            pe.save_pickle("model", "model.pkl", gm)
115        f.seek(0)
116
117        pi = PackageImporter(f)
118        loaded_gm = pi.load_pickle("model", "model.pkl")
119        input_x = torch.rand(2, 3)
120        input_y = torch.rand(2, 3)
121
122        self.assertTrue(
123            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))
124        )
125
126        # Check that the packaged version of the leaf_function dependency is
127        # not the same as in the outer env.
128        packaged_dependency = pi.import_module("package_a.subpackage")
129        self.assertTrue(packaged_dependency is not package_a.subpackage)
130
131    def test_package_fx_custom_tracer(self):
132        from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
133        from package_a.test_module import ModWithTwoSubmodsAndTensor, SimpleTest
134
135        class SpecialGraphModule(torch.fx.GraphModule):
136            def __init__(self, root, graph, info):
137                super().__init__(root, graph)
138                self.info = info
139
140        sub_module = SimpleTest()
141        module = ModWithTwoSubmodsAndTensor(
142            torch.ones(3),
143            sub_module,
144            sub_module,
145        )
146        tracer = TestAllLeafModulesTracer()
147        graph = tracer.trace(module)
148
149        self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
150
151        gm = SpecialGraphModule(module, graph, "secret")
152        self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
153
154        f = BytesIO()
155        with PackageExporter(f) as pe:
156            pe.intern("**")
157            pe.save_pickle("model", "model.pkl", gm)
158        f.seek(0)
159
160        pi = PackageImporter(f)
161        loaded_gm = pi.load_pickle("model", "model.pkl")
162        self.assertEqual(
163            type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
164        )
165        self.assertEqual(loaded_gm.info, "secret")
166
167        input_x = torch.randn(3)
168        self.assertEqual(loaded_gm(input_x), gm(input_x))
169
170    def test_package_fx_wrap(self):
171        class TestModule(torch.nn.Module):
172            def __init__(self) -> None:
173                super().__init__()
174
175            def forward(self, a):
176                return len(a)
177
178        traced = torch.fx.symbolic_trace(TestModule())
179
180        f = BytesIO()
181        with torch.package.PackageExporter(f) as pe:
182            pe.save_pickle("model", "model.pkl", traced)
183        f.seek(0)
184
185        pi = PackageImporter(f)
186        loaded_traced = pi.load_pickle("model", "model.pkl")
187        input = torch.rand(2, 3)
188        self.assertEqual(loaded_traced(input), traced(input))
189
190
191if __name__ == "__main__":
192    run_tests()
193