xref: /aosp_15_r20/external/executorch/extension/pybindings/test/make_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10from types import ModuleType
11from typing import Any, Callable, Optional, Tuple
12
13import torch
14from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
15from executorch.exir.passes import MemoryPlanningPass
16from torch.export import export
17
18
19class ModuleAdd(torch.nn.Module):
20    """The module to serialize and execute."""
21
22    def __init__(self):
23        super(ModuleAdd, self).__init__()
24
25    def forward(self, x, y):
26        return x + y
27
28    def get_methods_to_export(self):
29        return ("forward",)
30
31    def get_inputs(self):
32        return (torch.ones(2, 2), torch.ones(2, 2))
33
34
35class ModuleMulti(torch.nn.Module):
36    """The module to serialize and execute."""
37
38    def __init__(self):
39        super(ModuleMulti, self).__init__()
40
41    def forward(self, x, y):
42        return x + y
43
44    def forward2(self, x, y):
45        return x + y + 1
46
47    def get_methods_to_export(self):
48        return ("forward", "forward2")
49
50    def get_inputs(self):
51        return (torch.ones(2, 2), torch.ones(2, 2))
52
53
54class ModuleAddSingleInput(torch.nn.Module):
55    """The module to serialize and execute."""
56
57    def __init__(self):
58        super(ModuleAddSingleInput, self).__init__()
59
60    def forward(self, x):
61        return x + x
62
63    def get_methods_to_export(self):
64        return ("forward",)
65
66    def get_inputs(self):
67        return (torch.ones(2, 2),)
68
69
70class ModuleAddConstReturn(torch.nn.Module):
71    """The module to serialize and execute."""
72
73    def __init__(self):
74        super(ModuleAddConstReturn, self).__init__()
75        self.state = torch.ones(2, 2)
76
77    def forward(self, x):
78        return x + self.state, self.state
79
80    def get_methods_to_export(self):
81        return ("forward",)
82
83    def get_inputs(self):
84        return (torch.ones(2, 2),)
85
86
87def create_program(
88    eager_module: torch.nn.Module,
89    et_config: Optional[ExecutorchBackendConfig] = None,
90) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
91    """Returns an executorch program based on ModuleAdd, along with inputs."""
92
93    # Trace the test module and create a serialized ExecuTorch program.
94    # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
95    #  is not a function.
96    inputs = eager_module.get_inputs()
97    input_map = {}
98    # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
99    #  is not a function.
100    for method in eager_module.get_methods_to_export():
101        input_map[method] = inputs
102
103    class WrapperModule(torch.nn.Module):
104        def __init__(self, fn):
105            super().__init__()
106            self.fn = fn
107
108        def forward(self, *args, **kwargs):
109            return self.fn(*args, **kwargs)
110
111    exported_methods = {}
112    # These cleanup passes are required to convert the `add` op to its out
113    # variant, along with some other transformations.
114    for method_name, method_input in input_map.items():
115        wrapped_mod = WrapperModule(getattr(eager_module, method_name))
116        exported_methods[method_name] = export(wrapped_mod, method_input)
117
118    exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
119
120    # Create the ExecuTorch program from the graph.
121    exec_prog.dump_executorch_program(verbose=True)
122    return (exec_prog, inputs)
123
124
125def make_test(  # noqa: C901
126    tester: unittest.TestCase,
127    runtime: ModuleType,
128) -> Callable[[unittest.TestCase], None]:
129    """
130    Returns a function that operates as a test case within a unittest.TestCase class.
131
132    Used to allow the test code for pybindings to be shared across different pybinding libs
133    which will all have different load functions. In this case each individual test case is a
134    subfunction of wrapper.
135    """
136    load_fn: Callable = runtime._load_for_executorch_from_buffer
137
138    def wrapper(tester: unittest.TestCase) -> None:
139
140        ######### TEST CASES #########
141
142        def test_e2e(tester):
143            # Create an ExecuTorch program from ModuleAdd.
144            exported_program, inputs = create_program(ModuleAdd())
145
146            # Use pybindings to load and execute the program.
147            executorch_module = load_fn(exported_program.buffer)
148            executorch_output = executorch_module.forward(inputs)[0]
149
150            # The test module adds the two inputs, so its output should be the same
151            # as adding them directly.
152            expected = inputs[0] + inputs[1]
153
154            tester.assertEqual(str(expected), str(executorch_output))
155
156        def test_multiple_entry(tester):
157
158            program, inputs = create_program(ModuleMulti())
159            executorch_module = load_fn(program.buffer)
160
161            executorch_output = executorch_module.forward(inputs)[0]
162            tester.assertTrue(torch.allclose(executorch_output, torch.ones(2, 2) * 2))
163
164            executorch_output2 = executorch_module.run_method("forward2", inputs)[0]
165            tester.assertTrue(torch.allclose(executorch_output2, torch.ones(2, 2) * 3))
166
167        def test_output_lifespan(tester):
168            def lower_function_call():
169                program, inputs = create_program(ModuleMulti())
170                executorch_module = load_fn(program.buffer)
171
172                return executorch_module.forward(inputs)
173                # executorch_module is destructed here and all of its memory is freed
174
175            outputs = lower_function_call()
176            tester.assertTrue(torch.allclose(outputs[0], torch.ones(2, 2) * 2))
177
178        def test_module_callable(tester):
179            # Create an ExecuTorch program from ModuleAdd.
180            exported_program, inputs = create_program(ModuleAdd())
181
182            # Use pybindings to load and execute the program.
183            executorch_module = load_fn(exported_program.buffer)
184            # Invoke the callable on executorch_module instead of calling module.forward.
185            executorch_output = executorch_module(inputs)[0]
186
187            # The test module adds the two inputs, so its output should be the same
188            # as adding them directly.
189            expected = inputs[0] + inputs[1]
190            tester.assertEqual(str(expected), str(executorch_output))
191
192        def test_module_single_input(tester):
193            # Create an ExecuTorch program from ModuleAdd.
194            exported_program, inputs = create_program(ModuleAddSingleInput())
195
196            # Use pybindings to load and execute the program.
197            executorch_module = load_fn(exported_program.buffer)
198            # Inovke the callable on executorch_module instead of calling module.forward.
199            # Use only one input to test this case.
200            executorch_output = executorch_module(inputs[0])[0]
201
202            # The test module adds the two inputs, so its output should be the same
203            # as adding them directly.
204            expected = inputs[0] + inputs[0]
205            tester.assertEqual(str(expected), str(executorch_output))
206
207        def test_stderr_redirect(tester):
208            import sys
209            from io import StringIO
210
211            class RedirectedStderr:
212                def __init__(self):
213                    self._stderr = None
214                    self._string_io = None
215
216                def __enter__(self):
217                    self._stderr = sys.stderr
218                    sys.stderr = self._string_io = StringIO()
219                    return self
220
221                def __exit__(self, type, value, traceback):
222                    sys.stderr = self._stderr
223
224                def __str__(self):
225                    return self._string_io.getvalue()
226
227            with RedirectedStderr() as out:
228                try:
229                    # Create an ExecuTorch program from ModuleAdd.
230                    exported_program, inputs = create_program(ModuleAdd())
231
232                    # Use pybindings to load and execute the program.
233                    executorch_module = load_fn(exported_program.buffer)
234
235                    # add an extra input to trigger error
236                    inputs = (*inputs, 1)
237
238                    # Invoke the callable on executorch_module instead of calling module.forward.
239                    executorch_output = executorch_module(inputs)[0]  # noqa
240                    tester.assertFalse(True)  # should be unreachable
241                except Exception:
242                    tester.assertTrue(str(out).find("The length of given input array"))
243
244        def test_quantized_ops(tester):
245            eager_module = ModuleAdd()
246
247            from executorch.exir import EdgeCompileConfig
248            from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
249            from executorch.kernels import quantized  # noqa: F401
250            from torch.ao.quantization import get_default_qconfig_mapping
251            from torch.ao.quantization.backend_config.executorch import (
252                get_executorch_backend_config,
253            )
254            from torch.ao.quantization.quantize_fx import (
255                _convert_to_reference_decomposed_fx,
256                prepare_fx,
257            )
258
259            qconfig_mapping = get_default_qconfig_mapping("qnnpack")
260            example_inputs = (
261                torch.ones(1, 5, dtype=torch.float32),
262                torch.ones(1, 5, dtype=torch.float32),
263            )
264            m = prepare_fx(
265                eager_module,
266                qconfig_mapping,
267                example_inputs,
268                backend_config=get_executorch_backend_config(),
269            )
270            m = _convert_to_reference_decomposed_fx(m)
271            config = EdgeCompileConfig(_check_ir_validity=False)
272            m = to_edge(export(m, example_inputs), compile_config=config)
273            m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
274
275            exec_prog = m.to_executorch()
276
277            executorch_module = load_fn(exec_prog.buffer)
278            executorch_output = executorch_module.forward(example_inputs)[0]
279
280            expected = example_inputs[0] + example_inputs[1]
281            tester.assertEqual(str(expected), str(executorch_output))
282
283        def test_constant_output_not_memory_planned(tester):
284            # Create an ExecuTorch program from ModuleAdd.
285            exported_program, inputs = create_program(
286                ModuleAddConstReturn(),
287                et_config=ExecutorchBackendConfig(
288                    memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
289                ),
290            )
291
292            exported_program.dump_executorch_program(verbose=True)
293
294            # Use pybindings to load and execute the program.
295            executorch_module = load_fn(exported_program.buffer)
296            # Invoke the callable on executorch_module instead of calling module.forward.
297            # Use only one input to test this case.
298            executorch_output = executorch_module((torch.ones(2, 2),))
299            print(executorch_output)
300
301            # The test module adds the input to torch.ones(2,2), so its output should be the same
302            # as adding them directly.
303            expected = torch.ones(2, 2) + torch.ones(2, 2)
304            tester.assertEqual(str(expected), str(executorch_output[0]))
305
306            # The test module returns the state. Check that its value is correct.
307            tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
308
309        def test_method_meta(tester) -> None:
310            exported_program, inputs = create_program(ModuleAdd())
311
312            # Use pybindings to load the program and query its metadata.
313            executorch_module = load_fn(exported_program.buffer)
314            meta = executorch_module.method_meta("forward")
315
316            # Ensure that all these APIs work even if the module object is destroyed.
317            del executorch_module
318            tester.assertEqual(meta.name(), "forward")
319            tester.assertEqual(meta.num_inputs(), 2)
320            tester.assertEqual(meta.num_outputs(), 1)
321            # Common string for all these tensors.
322            tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
323            float_dtype = 6
324            tester.assertEqual(
325                str(meta),
326                "MethodMeta(name='forward', num_inputs=2, "
327                f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], "
328                f"num_outputs=1, output_tensor_meta=['{tensor_info}'])",
329            )
330
331            input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
332            output_tensor = meta.output_tensor_meta(0)
333            # Check that accessing out of bounds raises IndexError.
334            with tester.assertRaises(IndexError):
335                meta.input_tensor_meta(2)
336            # Test that tensor metadata can outlive method metadata.
337            del meta
338            tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
339            tester.assertEqual(
340                [t.dtype() for t in input_tensors], [float_dtype, float_dtype]
341            )
342            tester.assertEqual(
343                [t.is_memory_planned() for t in input_tensors], [True, True]
344            )
345            tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
346            tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]")
347
348            tester.assertEqual(output_tensor.sizes(), (2, 2))
349            tester.assertEqual(output_tensor.dtype(), float_dtype)
350            tester.assertEqual(output_tensor.is_memory_planned(), True)
351            tester.assertEqual(output_tensor.nbytes(), 16)
352            tester.assertEqual(str(output_tensor), tensor_info)
353
354        def test_bad_name(tester) -> None:
355            # Create an ExecuTorch program from ModuleAdd.
356            exported_program, inputs = create_program(ModuleAdd())
357
358            # Use pybindings to load and execute the program.
359            executorch_module = load_fn(exported_program.buffer)
360            # Invoke the callable on executorch_module instead of calling module.forward.
361            with tester.assertRaises(RuntimeError):
362                executorch_module.run_method("not_a_real_method", inputs)
363
364        def test_verification_config(tester) -> None:
365            # Create an ExecuTorch program from ModuleAdd.
366            exported_program, inputs = create_program(ModuleAdd())
367            Verification = runtime.Verification
368
369            # Use pybindings to load and execute the program.
370            for config in [Verification.Minimal, Verification.InternalConsistency]:
371                executorch_module = load_fn(
372                    exported_program.buffer,
373                    enable_etdump=False,
374                    debug_buffer_size=0,
375                    program_verification=config,
376                )
377
378                executorch_output = executorch_module.forward(inputs)[0]
379
380                # The test module adds the two inputs, so its output should be the same
381                # as adding them directly.
382                expected = inputs[0] + inputs[1]
383
384                tester.assertEqual(str(expected), str(executorch_output))
385
386        ######### RUN TEST CASES #########
387        test_e2e(tester)
388        test_multiple_entry(tester)
389        test_output_lifespan(tester)
390        test_module_callable(tester)
391        test_module_single_input(tester)
392        test_stderr_redirect(tester)
393        test_quantized_ops(tester)
394        test_constant_output_not_memory_planned(tester)
395        test_method_meta(tester)
396        test_bad_name(tester)
397        test_verification_config(tester)
398
399    return wrapper
400