xref: /aosp_15_r20/external/pytorch/test/jit/test_backends.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch._C
10*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.mobile import _load_for_lite_interpreter
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
13*da0073e9SAndroid Build Coastguard Worker    find_library_location,
14*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
15*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
16*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
17*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
18*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
19*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
20*da0073e9SAndroid Build Coastguard Worker)
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
25*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
26*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
29*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
30*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
31*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
32*da0073e9SAndroid Build Coastguard Worker        "instead."
33*da0073e9SAndroid Build Coastguard Worker    )
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef to_test_backend(module, method_compile_spec):
37*da0073e9SAndroid Build Coastguard Worker    return torch._C._jit_to_backend(
38*da0073e9SAndroid Build Coastguard Worker        "test_backend", module, {"forward": method_compile_spec}
39*da0073e9SAndroid Build Coastguard Worker    )
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdef to_test_backend_multi(module, method_compile_spec):
43*da0073e9SAndroid Build Coastguard Worker    return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef to_test_backend_selective(module, method_compile_spec, submodules):
47*da0073e9SAndroid Build Coastguard Worker    def _to_test_backend(module):
48*da0073e9SAndroid Build Coastguard Worker        return to_test_backend(module, method_compile_spec)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerclass BasicModule(torch.nn.Module):
54*da0073e9SAndroid Build Coastguard Worker    """
55*da0073e9SAndroid Build Coastguard Worker    A simple Module used to test to_backend lowering machinery.
56*da0073e9SAndroid Build Coastguard Worker    """
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, h):
59*da0073e9SAndroid Build Coastguard Worker        return self.accum(x, h), self.sub_accum(x, h)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def accum(self, x, h):
62*da0073e9SAndroid Build Coastguard Worker        return x + h
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def sub_accum(self, x, h):
65*da0073e9SAndroid Build Coastguard Worker        return x - h
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
69*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
70*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
71*da0073e9SAndroid Build Coastguard Worker    "Non-portable load_library call used in test",
72*da0073e9SAndroid Build Coastguard Worker)
73*da0073e9SAndroid Build Coastguard Workerclass JitBackendTestCase(JitTestCase):
74*da0073e9SAndroid Build Coastguard Worker    """
75*da0073e9SAndroid Build Coastguard Worker    A common base class for JIT backend tests that contains common utility
76*da0073e9SAndroid Build Coastguard Worker    functions for output comparison and serialization/deserialization.
77*da0073e9SAndroid Build Coastguard Worker    """
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
80*da0073e9SAndroid Build Coastguard Worker        super().setUp()
81*da0073e9SAndroid Build Coastguard Worker        lib_file_path = find_library_location("libjitbackend_test.so")
82*da0073e9SAndroid Build Coastguard Worker        torch.ops.load_library(str(lib_file_path))
83*da0073e9SAndroid Build Coastguard Worker        # Subclasses are expected to set up three variables in their setUp methods:
84*da0073e9SAndroid Build Coastguard Worker        # module - a regular, Python version of the module being tested
85*da0073e9SAndroid Build Coastguard Worker        # scripted_module - a scripted version of module
86*da0073e9SAndroid Build Coastguard Worker        # lowered_module - a version of module lowered to a backend
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def check_function(self, function_name, input):
89*da0073e9SAndroid Build Coastguard Worker        """
90*da0073e9SAndroid Build Coastguard Worker        Check that the function named 'function_name' produces the same output using
91*da0073e9SAndroid Build Coastguard Worker        Python, regular JIT and the backend for the given 'input'.
92*da0073e9SAndroid Build Coastguard Worker        """
93*da0073e9SAndroid Build Coastguard Worker        # Get handles for Python, JIT and backend methods.
94*da0073e9SAndroid Build Coastguard Worker        python_method = self.module.__getattribute__(function_name)
95*da0073e9SAndroid Build Coastguard Worker        jit_method = self.scripted_module.__getattr__(function_name)
96*da0073e9SAndroid Build Coastguard Worker        backend_method = self.lowered_module.__getattr__(function_name)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        # Run methods.
99*da0073e9SAndroid Build Coastguard Worker        python_output = python_method(*input)
100*da0073e9SAndroid Build Coastguard Worker        jit_output = jit_method(*input)
101*da0073e9SAndroid Build Coastguard Worker        backend_output = backend_method(*input)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        # The answers returned by Python, JIT and to_backend should all match.
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(python_output, backend_output)
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jit_output, backend_output)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def save_load(self):
108*da0073e9SAndroid Build Coastguard Worker        """
109*da0073e9SAndroid Build Coastguard Worker        Save and load the lowered module.
110*da0073e9SAndroid Build Coastguard Worker        """
111*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = self.getExportImportCopy(self.lowered_module)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
114*da0073e9SAndroid Build Coastguard Worker        """
115*da0073e9SAndroid Build Coastguard Worker        Stub for correctness tests.
116*da0073e9SAndroid Build Coastguard Worker        """
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
119*da0073e9SAndroid Build Coastguard Worker        """
120*da0073e9SAndroid Build Coastguard Worker        Stub for serialization tests.
121*da0073e9SAndroid Build Coastguard Worker        """
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
124*da0073e9SAndroid Build Coastguard Worker        """
125*da0073e9SAndroid Build Coastguard Worker        Stub for testing error checking.
126*da0073e9SAndroid Build Coastguard Worker        """
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Workerclass BasicModuleTest(JitBackendTestCase):
130*da0073e9SAndroid Build Coastguard Worker    """
131*da0073e9SAndroid Build Coastguard Worker    Tests for BasicModule.
132*da0073e9SAndroid Build Coastguard Worker    """
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
135*da0073e9SAndroid Build Coastguard Worker        super().setUp()
136*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of BasicModule.
137*da0073e9SAndroid Build Coastguard Worker        self.module = BasicModule()
138*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(BasicModule())
139*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = to_test_backend_multi(
140*da0073e9SAndroid Build Coastguard Worker            self.scripted_module,
141*da0073e9SAndroid Build Coastguard Worker            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
142*da0073e9SAndroid Build Coastguard Worker        )
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
145*da0073e9SAndroid Build Coastguard Worker        # Test execution with backend against Python and JIT.
146*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        # Test all three module methods.
149*da0073e9SAndroid Build Coastguard Worker        self.check_function("accum", (input, input))
150*da0073e9SAndroid Build Coastguard Worker        self.check_function("sub_accum", (input, input))
151*da0073e9SAndroid Build Coastguard Worker        self.check_function("forward", (input, input))
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
154*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
155*da0073e9SAndroid Build Coastguard Worker        # Lowered module should produce the same outputs.
156*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        # Save the compile spec to compare against the version retrieved after loading.
159*da0073e9SAndroid Build Coastguard Worker        pre_compile_spec = self.lowered_module.__getattr__(
160*da0073e9SAndroid Build Coastguard Worker            "__loweredModule__"
161*da0073e9SAndroid Build Coastguard Worker        ).__getattr__("__method_compile_spec")
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        # Save and load the lowered module.
164*da0073e9SAndroid Build Coastguard Worker        self.save_load()
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        # Get the compile spec after loading.
167*da0073e9SAndroid Build Coastguard Worker        post_compile_spec = self.lowered_module.__getattr__(
168*da0073e9SAndroid Build Coastguard Worker            "__loweredModule__"
169*da0073e9SAndroid Build Coastguard Worker        ).__getattr__("__method_compile_spec")
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        # Compile specs should match.
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pre_compile_spec, post_compile_spec)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        # Loaded module should produce the same outputs.
175*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Workerclass BasicModuleUnavailableTest(JitBackendTestCase):
179*da0073e9SAndroid Build Coastguard Worker    """
180*da0073e9SAndroid Build Coastguard Worker    Tests for BasicModule with a backend that is not available.
181*da0073e9SAndroid Build Coastguard Worker    Fundamentally:
182*da0073e9SAndroid Build Coastguard Worker      * _jit_to_backend is successful.
183*da0073e9SAndroid Build Coastguard Worker      * Execution fails with an exception.
184*da0073e9SAndroid Build Coastguard Worker      * Saving is successful.
185*da0073e9SAndroid Build Coastguard Worker      * Loading fails with an exception.
186*da0073e9SAndroid Build Coastguard Worker    """
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
189*da0073e9SAndroid Build Coastguard Worker        super().setUp()
190*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of BasicModule.
191*da0073e9SAndroid Build Coastguard Worker        self.module = BasicModule()
192*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(BasicModule())
193*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = torch._C._jit_to_backend(
194*da0073e9SAndroid Build Coastguard Worker            "test_backend_unavailable",
195*da0073e9SAndroid Build Coastguard Worker            self.scripted_module,
196*da0073e9SAndroid Build Coastguard Worker            {"forward": {"": ""}},
197*da0073e9SAndroid Build Coastguard Worker        )
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
200*da0073e9SAndroid Build Coastguard Worker        # Test execution with backend fails because the backend that is not available.
201*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        # Test exception is thrown.
204*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
205*da0073e9SAndroid Build Coastguard Worker            Exception,
206*da0073e9SAndroid Build Coastguard Worker            r"Backend is not available.",
207*da0073e9SAndroid Build Coastguard Worker            'raise Exception("Backend is not available."',
208*da0073e9SAndroid Build Coastguard Worker        ):
209*da0073e9SAndroid Build Coastguard Worker            backend_method = self.lowered_module.__getattr__("forward")
210*da0073e9SAndroid Build Coastguard Worker            backend_output = backend_method(*(input, input))
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
213*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
214*da0073e9SAndroid Build Coastguard Worker        # Test that saving the lowered module is OK but loading fails because the backend is not available.
215*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
216*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(self.lowered_module, buffer)
217*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
218*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
219*da0073e9SAndroid Build Coastguard Worker            Exception,
220*da0073e9SAndroid Build Coastguard Worker            r"Backend is not available.",
221*da0073e9SAndroid Build Coastguard Worker            'raise Exception("Backend is not available."',
222*da0073e9SAndroid Build Coastguard Worker        ):
223*da0073e9SAndroid Build Coastguard Worker            imported = torch.jit.load(buffer)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workerclass NestedModuleTest(JitBackendTestCase):
227*da0073e9SAndroid Build Coastguard Worker    """
228*da0073e9SAndroid Build Coastguard Worker    Tests for NestedModule that check that a module lowered to a backend can be used
229*da0073e9SAndroid Build Coastguard Worker    as a submodule.
230*da0073e9SAndroid Build Coastguard Worker    """
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    class NestedModule(torch.nn.Module):
233*da0073e9SAndroid Build Coastguard Worker        """
234*da0073e9SAndroid Build Coastguard Worker        A Module with one submodule that is used to test that lowered Modules
235*da0073e9SAndroid Build Coastguard Worker        can be used as submodules.
236*da0073e9SAndroid Build Coastguard Worker        """
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        def __init__(self, submodule):
239*da0073e9SAndroid Build Coastguard Worker            super().__init__()
240*da0073e9SAndroid Build Coastguard Worker            self.submodule = submodule
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, h):
243*da0073e9SAndroid Build Coastguard Worker            return self.submodule.forward(x, h)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
246*da0073e9SAndroid Build Coastguard Worker        super().setUp()
247*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of NestedModule.
248*da0073e9SAndroid Build Coastguard Worker        # Both modules in self.module are regular Python modules.
249*da0073e9SAndroid Build Coastguard Worker        self.module = NestedModuleTest.NestedModule(BasicModule())
250*da0073e9SAndroid Build Coastguard Worker        # Both modules in self.scripted_module are ScriptModules.
251*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(
252*da0073e9SAndroid Build Coastguard Worker            NestedModuleTest.NestedModule(BasicModule())
253*da0073e9SAndroid Build Coastguard Worker        )
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        # First, script another instance of NestedModule with share_types=False so that it can be
256*da0073e9SAndroid Build Coastguard Worker        # selectively lowered without modifying the type of self.scripted_module.
257*da0073e9SAndroid Build Coastguard Worker        lowered_module = to_test_backend_multi(
258*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(BasicModule()),
259*da0073e9SAndroid Build Coastguard Worker            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
260*da0073e9SAndroid Build Coastguard Worker        )
261*da0073e9SAndroid Build Coastguard Worker        # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
262*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = torch.jit.script(
263*da0073e9SAndroid Build Coastguard Worker            NestedModuleTest.NestedModule(lowered_module)
264*da0073e9SAndroid Build Coastguard Worker        )
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
267*da0073e9SAndroid Build Coastguard Worker        # Test execution with backend against Python and JIT.
268*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        # Test forward.
271*da0073e9SAndroid Build Coastguard Worker        self.check_function("forward", (input, input))
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
274*da0073e9SAndroid Build Coastguard Worker        # Lowered module should produce the same outputs.
275*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        # Save and load the lowered module.
278*da0073e9SAndroid Build Coastguard Worker        self.save_load()
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        # Loaded module should produce the same outputs.
281*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Workerclass SelectiveLoweringTest(JitBackendTestCase):
285*da0073e9SAndroid Build Coastguard Worker    """
286*da0073e9SAndroid Build Coastguard Worker    Tests for the selective lowering API.
287*da0073e9SAndroid Build Coastguard Worker    """
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    class OuterModule(torch.nn.Module):
290*da0073e9SAndroid Build Coastguard Worker        def __init__(self, sub1, sub2, other):
291*da0073e9SAndroid Build Coastguard Worker            super().__init__()
292*da0073e9SAndroid Build Coastguard Worker            self.sub1 = sub1
293*da0073e9SAndroid Build Coastguard Worker            self.sub2 = sub2
294*da0073e9SAndroid Build Coastguard Worker            self.other = other
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, y):
297*da0073e9SAndroid Build Coastguard Worker            # Call the module that will be lowered directly to test
298*da0073e9SAndroid Build Coastguard Worker            # type remapping in modules that are not its parent.
299*da0073e9SAndroid Build Coastguard Worker            a, b = self.sub1.submodule.forward(x, y)
300*da0073e9SAndroid Build Coastguard Worker            c, d = self.sub2.forward(x, y)
301*da0073e9SAndroid Build Coastguard Worker            e, f = self.other.forward(x, y)
302*da0073e9SAndroid Build Coastguard Worker            return a + c + e, b + d + f
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    class MiddleModule(torch.nn.Module):
305*da0073e9SAndroid Build Coastguard Worker        def __init__(self, submodule):
306*da0073e9SAndroid Build Coastguard Worker            super().__init__()
307*da0073e9SAndroid Build Coastguard Worker            self.submodule = submodule
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, y):
310*da0073e9SAndroid Build Coastguard Worker            return self.submodule.forward(x, y)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
313*da0073e9SAndroid Build Coastguard Worker        super().setUp()
314*da0073e9SAndroid Build Coastguard Worker        OuterModule = SelectiveLoweringTest.OuterModule
315*da0073e9SAndroid Build Coastguard Worker        MiddleModule = SelectiveLoweringTest.MiddleModule
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        def script_without_type_sharing(mod):
318*da0073e9SAndroid Build Coastguard Worker            return torch.jit._recursive.create_script_module(
319*da0073e9SAndroid Build Coastguard Worker                mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
320*da0073e9SAndroid Build Coastguard Worker            )
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of a hierarchy that looks like this:
323*da0073e9SAndroid Build Coastguard Worker        #                 --------- OuterModule --------
324*da0073e9SAndroid Build Coastguard Worker        #                 |              |              |
325*da0073e9SAndroid Build Coastguard Worker        #           MiddleModule    MiddleModule   MiddleModule
326*da0073e9SAndroid Build Coastguard Worker        #                |               |              |
327*da0073e9SAndroid Build Coastguard Worker        #           BasicModule     BasicModule    BasicModule
328*da0073e9SAndroid Build Coastguard Worker        #
329*da0073e9SAndroid Build Coastguard Worker        # Two BasicModules will be lowered and the third will not.
330*da0073e9SAndroid Build Coastguard Worker        self.module = OuterModule(
331*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
332*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
333*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
334*da0073e9SAndroid Build Coastguard Worker        )
335*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = script_without_type_sharing(
336*da0073e9SAndroid Build Coastguard Worker            OuterModule(
337*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
338*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
339*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
340*da0073e9SAndroid Build Coastguard Worker            )
341*da0073e9SAndroid Build Coastguard Worker        )
342*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = script_without_type_sharing(
343*da0073e9SAndroid Build Coastguard Worker            OuterModule(
344*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
345*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
346*da0073e9SAndroid Build Coastguard Worker                MiddleModule(BasicModule()),
347*da0073e9SAndroid Build Coastguard Worker            )
348*da0073e9SAndroid Build Coastguard Worker        )
349*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = to_test_backend_selective(
350*da0073e9SAndroid Build Coastguard Worker            self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
351*da0073e9SAndroid Build Coastguard Worker        )
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
354*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5)
355*da0073e9SAndroid Build Coastguard Worker        self.check_function("forward", (input, input))
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        self.test_selective_lowering_type_remap()
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
360*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
361*da0073e9SAndroid Build Coastguard Worker        self.save_load()
362*da0073e9SAndroid Build Coastguard Worker        self.test_execution()
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        self.test_selective_lowering_type_remap()
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker    def test_selective_lowering_type_remap(self):
367*da0073e9SAndroid Build Coastguard Worker        """
368*da0073e9SAndroid Build Coastguard Worker        Check that type remapping and replacement occurred during selective lowering.
369*da0073e9SAndroid Build Coastguard Worker        """
370*da0073e9SAndroid Build Coastguard Worker        # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
371*da0073e9SAndroid Build Coastguard Worker        # calling the lowered module directly.
372*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("OuterModule").check("BasicModule").run(
373*da0073e9SAndroid Build Coastguard Worker            self.scripted_module.graph
374*da0073e9SAndroid Build Coastguard Worker        )
375*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("OuterModule").check_not(
376*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
377*da0073e9SAndroid Build Coastguard Worker        ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
380*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("MiddleModule").check("BasicModule").check_not(
381*da0073e9SAndroid Build Coastguard Worker            "LoweredWrapper.test_backend"
382*da0073e9SAndroid Build Coastguard Worker        ).run(self.scripted_module.sub1.graph)
383*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("MiddleModule").check_not(
384*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
385*da0073e9SAndroid Build Coastguard Worker        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("MiddleModule").check("BasicModule").check_not(
388*da0073e9SAndroid Build Coastguard Worker            "LoweredWrapper.test_backend"
389*da0073e9SAndroid Build Coastguard Worker        ).run(self.scripted_module.sub2.graph)
390*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("MiddleModule").check_not(
391*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
392*da0073e9SAndroid Build Coastguard Worker        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
395*da0073e9SAndroid Build Coastguard Worker        # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
396*da0073e9SAndroid Build Coastguard Worker        # the TorchBind class for executing functions on the test JIT backend.
397*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("LoweredModule.test_backend").check(
398*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
399*da0073e9SAndroid Build Coastguard Worker        ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("LoweredModule.test_backend").check(
402*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
403*da0073e9SAndroid Build Coastguard Worker        ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker        # Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
406*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("MiddleModule").check("BasicModule").check_not(
407*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
408*da0073e9SAndroid Build Coastguard Worker        ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
409*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("BasicModule").check_not(
410*da0073e9SAndroid Build Coastguard Worker            "__torch__.torch.classes.__backends__.test_backend"
411*da0073e9SAndroid Build Coastguard Worker        ).check_not("LoweredModule.test_backend").run(
412*da0073e9SAndroid Build Coastguard Worker            self.scripted_module.other.submodule.graph
413*da0073e9SAndroid Build Coastguard Worker        )
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
416*da0073e9SAndroid Build Coastguard Worker        """
417*da0073e9SAndroid Build Coastguard Worker        Check errors associated with selective lowering.
418*da0073e9SAndroid Build Coastguard Worker        """
419*da0073e9SAndroid Build Coastguard Worker        # Check error messages thrown when attempting to lower something that is not a ScriptModule.
420*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
421*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Object .* is not a ScriptModule", ""
422*da0073e9SAndroid Build Coastguard Worker        ):
423*da0073e9SAndroid Build Coastguard Worker            to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        MiddleModule = SelectiveLoweringTest.MiddleModule
426*da0073e9SAndroid Build Coastguard Worker        mod = MiddleModule(BasicModule())
427*da0073e9SAndroid Build Coastguard Worker        mod.new_attr = 3
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
430*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Attribute named new_attr is not a Module", ""
431*da0073e9SAndroid Build Coastguard Worker        ):
432*da0073e9SAndroid Build Coastguard Worker            to_test_backend_selective(
433*da0073e9SAndroid Build Coastguard Worker                torch.jit.script(mod), {"forward": ""}, ["new_attr"]
434*da0073e9SAndroid Build Coastguard Worker            )
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        # Check error message thrown when module hierarchy doesn't have unique types.
437*da0073e9SAndroid Build Coastguard Worker        OuterModule = SelectiveLoweringTest.OuterModule
438*da0073e9SAndroid Build Coastguard Worker        mod = OuterModule(
439*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
440*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
441*da0073e9SAndroid Build Coastguard Worker            MiddleModule(BasicModule()),
442*da0073e9SAndroid Build Coastguard Worker        )
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
445*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
446*da0073e9SAndroid Build Coastguard Worker            r"Selective lowering is only supported for module hierarchies with unique types",
447*da0073e9SAndroid Build Coastguard Worker            "",
448*da0073e9SAndroid Build Coastguard Worker        ):
449*da0073e9SAndroid Build Coastguard Worker            to_test_backend_selective(
450*da0073e9SAndroid Build Coastguard Worker                torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
451*da0073e9SAndroid Build Coastguard Worker            )
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
455*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
456*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
457*da0073e9SAndroid Build Coastguard Worker    "Non-portable load_library call used in test",
458*da0073e9SAndroid Build Coastguard Worker)
459*da0073e9SAndroid Build Coastguard Workerclass TestBackends(JitTestCase):
460*da0073e9SAndroid Build Coastguard Worker    """
461*da0073e9SAndroid Build Coastguard Worker    This class wraps and invokes all subclasses of JitBackendTestCase so that each one
462*da0073e9SAndroid Build Coastguard Worker    does not have to be individually imported in test_jit.py.
463*da0073e9SAndroid Build Coastguard Worker    """
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
466*da0073e9SAndroid Build Coastguard Worker        super().__init__(name)
467*da0073e9SAndroid Build Coastguard Worker        self.basic_module_test = BasicModuleTest(name)
468*da0073e9SAndroid Build Coastguard Worker        self.basic_module_unavailable_test = BasicModuleUnavailableTest(name)
469*da0073e9SAndroid Build Coastguard Worker        self.nested_module_test = NestedModuleTest(name)
470*da0073e9SAndroid Build Coastguard Worker        self.selective_lowering_test = SelectiveLoweringTest(name)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
473*da0073e9SAndroid Build Coastguard Worker        super().setUp()
474*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_ROCM:
475*da0073e9SAndroid Build Coastguard Worker            self.basic_module_test.setUp()
476*da0073e9SAndroid Build Coastguard Worker            self.basic_module_unavailable_test.setUp()
477*da0073e9SAndroid Build Coastguard Worker            self.nested_module_test.setUp()
478*da0073e9SAndroid Build Coastguard Worker            self.selective_lowering_test.setUp()
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
481*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
482*da0073e9SAndroid Build Coastguard Worker        self.basic_module_test.test_execution()
483*da0073e9SAndroid Build Coastguard Worker        self.basic_module_unavailable_test.test_execution()
484*da0073e9SAndroid Build Coastguard Worker        self.nested_module_test.test_execution()
485*da0073e9SAndroid Build Coastguard Worker        self.selective_lowering_test.test_execution()
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
488*da0073e9SAndroid Build Coastguard Worker    def test_save_load(self):
489*da0073e9SAndroid Build Coastguard Worker        self.basic_module_test.test_save_load()
490*da0073e9SAndroid Build Coastguard Worker        self.basic_module_unavailable_test.test_save_load()
491*da0073e9SAndroid Build Coastguard Worker        self.nested_module_test.test_save_load()
492*da0073e9SAndroid Build Coastguard Worker        self.selective_lowering_test.test_save_load()
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
495*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
496*da0073e9SAndroid Build Coastguard Worker        self.selective_lowering_test.test_errors()
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker"""
500*da0073e9SAndroid Build Coastguard WorkerUnit Tests for backend with compiler
501*da0073e9SAndroid Build Coastguard WorkerThis test case and the existing TestBackends are separate because they cover different aspects.
502*da0073e9SAndroid Build Coastguard WorkerThe actual backend implementation in this test is different.
503*da0073e9SAndroid Build Coastguard WorkerIt has a simple demo compiler to test the end-to-end flow in mobile.
504*da0073e9SAndroid Build Coastguard WorkerHowever, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
505*da0073e9SAndroid Build Coastguard Worker"""
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Workerclass BasicModuleAdd(torch.nn.Module):
509*da0073e9SAndroid Build Coastguard Worker    """
510*da0073e9SAndroid Build Coastguard Worker    A simple add Module used to test to_backend lowering machinery.
511*da0073e9SAndroid Build Coastguard Worker    """
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, h):
514*da0073e9SAndroid Build Coastguard Worker        return x + h
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
518*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
519*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
520*da0073e9SAndroid Build Coastguard Worker    "Non-portable load_library call used in test",
521*da0073e9SAndroid Build Coastguard Worker)
522*da0073e9SAndroid Build Coastguard Workerclass JitBackendTestCaseWithCompiler(JitTestCase):
523*da0073e9SAndroid Build Coastguard Worker    """
524*da0073e9SAndroid Build Coastguard Worker    A common base class for JIT backend tests with compilers that contains common utility
525*da0073e9SAndroid Build Coastguard Worker    functions for output comparison.
526*da0073e9SAndroid Build Coastguard Worker    """
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
529*da0073e9SAndroid Build Coastguard Worker        super().setUp()
530*da0073e9SAndroid Build Coastguard Worker        lib_file_path = find_library_location("libbackend_with_compiler.so")
531*da0073e9SAndroid Build Coastguard Worker        torch.ops.load_library(str(lib_file_path))
532*da0073e9SAndroid Build Coastguard Worker        # Subclasses are expected to set up four variables in their setUp methods:
533*da0073e9SAndroid Build Coastguard Worker        # module - a regular, Python version of the module being tested
534*da0073e9SAndroid Build Coastguard Worker        # scripted_module - a scripted version of module
535*da0073e9SAndroid Build Coastguard Worker        # lowered_module - a version of module lowered to a backend
536*da0073e9SAndroid Build Coastguard Worker        # mobile_module - a module with a format that Pytorch Mobile can execute
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    def check_forward(self, input):
539*da0073e9SAndroid Build Coastguard Worker        """
540*da0073e9SAndroid Build Coastguard Worker        Check that the forward function produces the same output using
541*da0073e9SAndroid Build Coastguard Worker        Python, regular JIT, the backend, and mobile for the given 'input'.
542*da0073e9SAndroid Build Coastguard Worker        """
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        # Get outputs from forward.
545*da0073e9SAndroid Build Coastguard Worker        python_output = self.module.forward(*input)
546*da0073e9SAndroid Build Coastguard Worker        jit_output = self.scripted_module.forward(*input)
547*da0073e9SAndroid Build Coastguard Worker        backend_output = self.lowered_module(*input)
548*da0073e9SAndroid Build Coastguard Worker        mobile_output = self.mobile_module(*input)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        # The answers returned by Python, JIT, to_backend, and mobile should all match.
551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(python_output, backend_output)
552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jit_output, backend_output)
553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mobile_output, backend_output)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
556*da0073e9SAndroid Build Coastguard Worker        """
557*da0073e9SAndroid Build Coastguard Worker        Stub for correctness tests.
558*da0073e9SAndroid Build Coastguard Worker        """
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
561*da0073e9SAndroid Build Coastguard Worker        """
562*da0073e9SAndroid Build Coastguard Worker        Stub for testing error checking.
563*da0073e9SAndroid Build Coastguard Worker        """
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Workerclass BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
567*da0073e9SAndroid Build Coastguard Worker    """
568*da0073e9SAndroid Build Coastguard Worker    Tests for BasicModuleAdd.
569*da0073e9SAndroid Build Coastguard Worker    """
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
572*da0073e9SAndroid Build Coastguard Worker        super().setUp()
573*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of BasicModuleAdd.
574*da0073e9SAndroid Build Coastguard Worker        self.module = BasicModuleAdd()
575*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(BasicModuleAdd())
576*da0073e9SAndroid Build Coastguard Worker        compile_spec = {
577*da0073e9SAndroid Build Coastguard Worker            "forward": {
578*da0073e9SAndroid Build Coastguard Worker                "input_shapes": "((1, 1, 320, 240), (1, 3))",
579*da0073e9SAndroid Build Coastguard Worker                "some_other_option": "True",
580*da0073e9SAndroid Build Coastguard Worker            },
581*da0073e9SAndroid Build Coastguard Worker        }
582*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = torch._C._jit_to_backend(
583*da0073e9SAndroid Build Coastguard Worker            "backend_with_compiler_demo", self.scripted_module, compile_spec
584*da0073e9SAndroid Build Coastguard Worker        )
585*da0073e9SAndroid Build Coastguard Worker        # Create mobile version of BasicModuleAdd
586*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
587*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
588*da0073e9SAndroid Build Coastguard Worker        self.mobile_module = _load_for_lite_interpreter(buffer)
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
591*da0073e9SAndroid Build Coastguard Worker        # Test execution with backend against Python and JIT.
592*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(1, dtype=torch.float)
593*da0073e9SAndroid Build Coastguard Worker        self.check_forward((input, input))
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Workerclass ErrorMessagesWithCompiler(JitBackendTestCase):
597*da0073e9SAndroid Build Coastguard Worker    """
598*da0073e9SAndroid Build Coastguard Worker    Tests for errors that occur with compiler, specifically:
599*da0073e9SAndroid Build Coastguard Worker        * an operator is not supported by the backend
600*da0073e9SAndroid Build Coastguard Worker    """
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    class ModuleNotSupported(torch.nn.Module):
603*da0073e9SAndroid Build Coastguard Worker        """
604*da0073e9SAndroid Build Coastguard Worker        A module with an operator that is not supported.
605*da0073e9SAndroid Build Coastguard Worker        """
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, h):
608*da0073e9SAndroid Build Coastguard Worker            return x * h
609*da0073e9SAndroid Build Coastguard Worker            self._loweredmodule.forward()
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
612*da0073e9SAndroid Build Coastguard Worker        scripted_module_n = torch.jit.script(
613*da0073e9SAndroid Build Coastguard Worker            ErrorMessagesWithCompiler.ModuleNotSupported()
614*da0073e9SAndroid Build Coastguard Worker        )
615*da0073e9SAndroid Build Coastguard Worker        # Test exception is thrown when lowering a module with an unsupported operator
616*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
617*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
618*da0073e9SAndroid Build Coastguard Worker            # Special escape characters are replaced with '.'
619*da0073e9SAndroid Build Coastguard Worker            r"""The node of aten::mul is not supported in this compiler. .*
620*da0073e9SAndroid Build Coastguard Worker        def forward.self, x, h.:
621*da0073e9SAndroid Build Coastguard Worker            return x . h
622*da0073e9SAndroid Build Coastguard Worker                   ~~~~~ <--- HERE
623*da0073e9SAndroid Build Coastguard Worker            self._loweredmodule.forward..
624*da0073e9SAndroid Build Coastguard Worker""",
625*da0073e9SAndroid Build Coastguard Worker            "",
626*da0073e9SAndroid Build Coastguard Worker        ):
627*da0073e9SAndroid Build Coastguard Worker            lowered_module_n = torch._C._jit_to_backend(
628*da0073e9SAndroid Build Coastguard Worker                "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
629*da0073e9SAndroid Build Coastguard Worker            )
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Workerclass CompModuleTestWithCompiler(JitBackendTestCase):
633*da0073e9SAndroid Build Coastguard Worker    """
634*da0073e9SAndroid Build Coastguard Worker    Tests for CompModule, which is a module with two lowered submodules
635*da0073e9SAndroid Build Coastguard Worker    """
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    class BasicModuleSub(torch.nn.Module):
638*da0073e9SAndroid Build Coastguard Worker        """
639*da0073e9SAndroid Build Coastguard Worker        A simple subtraction Module to be used in CompModule.
640*da0073e9SAndroid Build Coastguard Worker        """
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, h):
643*da0073e9SAndroid Build Coastguard Worker            return x - h
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker    class CompModule(torch.nn.Module):
646*da0073e9SAndroid Build Coastguard Worker        """
647*da0073e9SAndroid Build Coastguard Worker        A module with two lowered submodules.
648*da0073e9SAndroid Build Coastguard Worker        """
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        def __init__(self, addmodule, submodule):
651*da0073e9SAndroid Build Coastguard Worker            super().__init__()
652*da0073e9SAndroid Build Coastguard Worker            self.lowered_add = addmodule
653*da0073e9SAndroid Build Coastguard Worker            self.lowered_sub = submodule
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        def forward(self, a, b, s):
656*da0073e9SAndroid Build Coastguard Worker            c = self.lowered_add.forward(a, b)
657*da0073e9SAndroid Build Coastguard Worker            d = self.lowered_sub.forward(a, b)
658*da0073e9SAndroid Build Coastguard Worker            y = s * (c * d)
659*da0073e9SAndroid Build Coastguard Worker            return y
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
662*da0073e9SAndroid Build Coastguard Worker        super().setUp()
663*da0073e9SAndroid Build Coastguard Worker        # Create Python and JIT versions of CompModule with lowered submodules.
664*da0073e9SAndroid Build Coastguard Worker        compile_spec = {
665*da0073e9SAndroid Build Coastguard Worker            "forward": {
666*da0073e9SAndroid Build Coastguard Worker                "input_shapes": "((1, 1, 320, 240), (1, 3))",
667*da0073e9SAndroid Build Coastguard Worker                "some_other_option": "True",
668*da0073e9SAndroid Build Coastguard Worker            },
669*da0073e9SAndroid Build Coastguard Worker        }
670*da0073e9SAndroid Build Coastguard Worker        lowered_add = torch._C._jit_to_backend(
671*da0073e9SAndroid Build Coastguard Worker            "backend_with_compiler_demo",
672*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(BasicModuleAdd()),
673*da0073e9SAndroid Build Coastguard Worker            compile_spec,
674*da0073e9SAndroid Build Coastguard Worker        )
675*da0073e9SAndroid Build Coastguard Worker        lowered_sub = torch._C._jit_to_backend(
676*da0073e9SAndroid Build Coastguard Worker            "backend_with_compiler_demo",
677*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
678*da0073e9SAndroid Build Coastguard Worker            {"forward": {"": ""}},
679*da0073e9SAndroid Build Coastguard Worker        )
680*da0073e9SAndroid Build Coastguard Worker        self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
681*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(
682*da0073e9SAndroid Build Coastguard Worker            CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
683*da0073e9SAndroid Build Coastguard Worker        )
684*da0073e9SAndroid Build Coastguard Worker        # No backend version of CompModule currently, so this is filler.
685*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = self.scripted_module
686*da0073e9SAndroid Build Coastguard Worker        # Create a mobile version of CompModule from JIT version
687*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
688*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
689*da0073e9SAndroid Build Coastguard Worker        self.mobile_module = _load_for_lite_interpreter(buffer)
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
692*da0073e9SAndroid Build Coastguard Worker        # Test execution with backend against Python and JIT.
693*da0073e9SAndroid Build Coastguard Worker        input1 = torch.ones(1, dtype=torch.float)
694*da0073e9SAndroid Build Coastguard Worker        input2 = torch.ones(1, dtype=torch.float)
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker        # Test forward.
697*da0073e9SAndroid Build Coastguard Worker        self.check_function("forward", (input1, input2, input2))
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
701*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
702*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
703*da0073e9SAndroid Build Coastguard Worker    "Non-portable load_library call used in test",
704*da0073e9SAndroid Build Coastguard Worker)
705*da0073e9SAndroid Build Coastguard Workerclass TestBackendsWithCompiler(JitTestCase):
706*da0073e9SAndroid Build Coastguard Worker    """
707*da0073e9SAndroid Build Coastguard Worker    This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
708*da0073e9SAndroid Build Coastguard Worker    so that each one does not have to be individually imported in test_jit.py.
709*da0073e9SAndroid Build Coastguard Worker    """
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
712*da0073e9SAndroid Build Coastguard Worker        super().__init__(name)
713*da0073e9SAndroid Build Coastguard Worker        self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
714*da0073e9SAndroid Build Coastguard Worker        self.error_module_compiler_test = ErrorMessagesWithCompiler(name)
715*da0073e9SAndroid Build Coastguard Worker        self.comp_module_compiler_test = CompModuleTestWithCompiler(name)
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
718*da0073e9SAndroid Build Coastguard Worker        super().setUp()
719*da0073e9SAndroid Build Coastguard Worker        self.basic_module_compiler_test.setUp()
720*da0073e9SAndroid Build Coastguard Worker        self.error_module_compiler_test.setUp()
721*da0073e9SAndroid Build Coastguard Worker        self.comp_module_compiler_test.setUp()
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
724*da0073e9SAndroid Build Coastguard Worker        self.basic_module_compiler_test.test_execution()
725*da0073e9SAndroid Build Coastguard Worker        self.comp_module_compiler_test.test_execution()
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker    def test_errors(self):
728*da0073e9SAndroid Build Coastguard Worker        self.error_module_compiler_test.test_errors()
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Workerclass CompModuleTestSameNameWithCompiler(JitBackendTestCase):
732*da0073e9SAndroid Build Coastguard Worker    """
733*da0073e9SAndroid Build Coastguard Worker    Tests for CompModule, which is a module with two lowered submodules with same module name
734*da0073e9SAndroid Build Coastguard Worker    """
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    class ModuleAdd(torch.nn.Module):
737*da0073e9SAndroid Build Coastguard Worker        """
738*da0073e9SAndroid Build Coastguard Worker        A simple Module used to test to_backend lowering machinery.
739*da0073e9SAndroid Build Coastguard Worker        """
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        def forward(self, x, h):
742*da0073e9SAndroid Build Coastguard Worker            return x + h
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker    class CompModule(torch.nn.Module):
745*da0073e9SAndroid Build Coastguard Worker        """
746*da0073e9SAndroid Build Coastguard Worker        A module with two lowered submodules.
747*da0073e9SAndroid Build Coastguard Worker        """
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
750*da0073e9SAndroid Build Coastguard Worker            super().__init__()
751*da0073e9SAndroid Build Coastguard Worker            compile_spec = {
752*da0073e9SAndroid Build Coastguard Worker                "forward": {
753*da0073e9SAndroid Build Coastguard Worker                    "some_other_option": "True",
754*da0073e9SAndroid Build Coastguard Worker                },
755*da0073e9SAndroid Build Coastguard Worker            }
756*da0073e9SAndroid Build Coastguard Worker            self.add = torch._C._jit_to_backend(
757*da0073e9SAndroid Build Coastguard Worker                "backend_with_compiler_demo",
758*da0073e9SAndroid Build Coastguard Worker                torch.jit.script(ModuleAdd()),  # noqa: F821
759*da0073e9SAndroid Build Coastguard Worker                compile_spec,
760*da0073e9SAndroid Build Coastguard Worker            )
761*da0073e9SAndroid Build Coastguard Worker            self.sub = torch._C._jit_to_backend(
762*da0073e9SAndroid Build Coastguard Worker                "backend_with_compiler_demo",
763*da0073e9SAndroid Build Coastguard Worker                torch.jit.script(ModuleAdd()),  # noqa: F821
764*da0073e9SAndroid Build Coastguard Worker                compile_spec,
765*da0073e9SAndroid Build Coastguard Worker            )
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        def forward(self, a, b, s: int):
768*da0073e9SAndroid Build Coastguard Worker            c = self.add.forward(a, b)
769*da0073e9SAndroid Build Coastguard Worker            d = self.sub.forward(a, b)
770*da0073e9SAndroid Build Coastguard Worker            y = s * (c * d)
771*da0073e9SAndroid Build Coastguard Worker            return y
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
774*da0073e9SAndroid Build Coastguard Worker        super().setUp()
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        self.module = CompModule()  # noqa: F821
777*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(self.module)
778*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
779*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
780*da0073e9SAndroid Build Coastguard Worker        self.mobile_module = _load_for_lite_interpreter(buffer)
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker    def test_execution(self):
783*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1)
784*da0073e9SAndroid Build Coastguard Worker        b = 3 * torch.ones(1)
785*da0073e9SAndroid Build Coastguard Worker        s = 3
786*da0073e9SAndroid Build Coastguard Worker        # Test forward.
787*da0073e9SAndroid Build Coastguard Worker        self.check_function("forward", (a, b, s))
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Workerclass AddedAttributesTest(JitBackendTestCase):
791*da0073e9SAndroid Build Coastguard Worker    """
792*da0073e9SAndroid Build Coastguard Worker    Tests for adding attributes to a model after lowering.
793*da0073e9SAndroid Build Coastguard Worker    """
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
796*da0073e9SAndroid Build Coastguard Worker        super().setUp()
797*da0073e9SAndroid Build Coastguard Worker        # Create Python, JIT and backend versions of BasicModule.
798*da0073e9SAndroid Build Coastguard Worker        self.module = BasicModule()
799*da0073e9SAndroid Build Coastguard Worker        self.scripted_module = torch.jit.script(BasicModule())
800*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = to_test_backend_multi(
801*da0073e9SAndroid Build Coastguard Worker            self.scripted_module,
802*da0073e9SAndroid Build Coastguard Worker            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
803*da0073e9SAndroid Build Coastguard Worker        )
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    def test_attribute(self):
806*da0073e9SAndroid Build Coastguard Worker        input = [(torch.ones(5),)]
807*da0073e9SAndroid Build Coastguard Worker        pre_bundled = self.lowered_module(*input[0])
808*da0073e9SAndroid Build Coastguard Worker        # Attach bundled inputs which adds several attributes and functions to the model
809*da0073e9SAndroid Build Coastguard Worker        self.lowered_module = (
810*da0073e9SAndroid Build Coastguard Worker            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
811*da0073e9SAndroid Build Coastguard Worker                lowered_module, input  # noqa: F821
812*da0073e9SAndroid Build Coastguard Worker            )
813*da0073e9SAndroid Build Coastguard Worker        )
814*da0073e9SAndroid Build Coastguard Worker        post_bundled = self.lowered_module(
815*da0073e9SAndroid Build Coastguard Worker            *self.lowered_module.get_all_bundled_inputs()[0]
816*da0073e9SAndroid Build Coastguard Worker        )
817*da0073e9SAndroid Build Coastguard Worker        # Save and load the lowered module.
818*da0073e9SAndroid Build Coastguard Worker        self.save_load()
819*da0073e9SAndroid Build Coastguard Worker        # Use bundled after save and load to prove its preserved
820*da0073e9SAndroid Build Coastguard Worker        post_load = self.lowered_module(
821*da0073e9SAndroid Build Coastguard Worker            *self.lowered_module.get_all_bundled_inputs()[0]
822*da0073e9SAndroid Build Coastguard Worker        )
823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pre_bundled, post_bundled)
824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(post_bundled, post_load)
825