xref: /aosp_15_r20/external/executorch/test/end2end/exported_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9"""Test helper for exporting an nn.Module to an ExecuTorch program."""
10
11import functools
12import inspect
13from typing import Callable, Sequence, Type
14
15import executorch.exir as exir
16import torch
17from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
18from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
19from executorch.exir.passes import (
20    DebugPass,
21    MemoryPlanningPass,
22    to_scratch_op_pass,
23    ToOutVarPass,
24)
25from torch import nn
26from torch.export import export
27from torch.export._trace import _export
28from torch.export.experimental import _export_forward_backward
29
30
31class ExportedModule:
32    """The result of exporting an nn.Module.
33
34    Attributes:
35        eager_module: The original nn.Module that was exported.
36        methods: The names of the eager_module methods that were traced.
37        executorch_program: The resulting ExecutorchProgram.
38        exported_program: The resulting ExportedProgram.
39        trace_inputs: The inputs that were used when tracing eager_module.
40    """
41
42    def __init__(
43        self,
44        eager_module: nn.Module,
45        methods: Sequence[str],
46        executorch_program: ExecutorchProgramManager,
47        exported_program: torch.export.ExportedProgram,
48        trace_inputs: Sequence,
49        get_random_inputs_fn: Callable[[], Sequence],
50    ):
51        """INTERNAL ONLY: Use ExportedModule.export() instead."""
52        self.eager_module: nn.Module = eager_module
53        self.methods: Sequence[str] = methods
54        self.executorch_program: ExecutorchProgramManager = executorch_program
55        self.exported_program: torch.export.ExportedProgram = exported_program
56        self.trace_inputs: Sequence = trace_inputs
57        self.__get_random_inputs_fn = get_random_inputs_fn
58
59    def get_random_inputs(self) -> Sequence:
60        """Returns random inputs appropriate for model inference."""
61        return self.__get_random_inputs_fn()
62
63    @staticmethod
64    def export(
65        module_class: Type[nn.Module],
66        methods: Sequence[str] = ("forward",),
67        ignore_to_out_var_failure: bool = False,
68        dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
69        capture_config=None,
70        skip_type_promotion: bool = False,
71        export_joint_graph: bool = False,
72    ) -> "ExportedModule":
73        """
74        Creates a new ExportedModule for the specified module class.
75
76        Args:
77            module_class: The subclass of nn.Module to export.
78            methods: The names of the module_class methods to trace.
79            ignore_to_out_var_failure: Whether to ignore the failue when an
80                functional op does not have an out variant.
81            dynamic_memory_planning_mode: The dynamic memory planning mode to
82                use.
83        """
84
85        def get_inputs_adapter(
86            worker_fn: Callable, method: str
87        ) -> Callable[[], Sequence]:
88            """Returns a function that may bind `method` as a parameter of
89            `worker_fn`, and ensures that `worker_fn` always returns a list or
90            tuple.
91
92            Args:
93                worker_fn: The function to wrap. Must take zero or one
94                    arguments. If it takes one argument, that argument must be
95                    called "method" and expect a string.
96                method: The name of the method to possibly pass to `worker_fn`.
97
98            Returns:
99                A function that takes zero arguments and returns a Sequence.
100            """
101            # Names of the parameters of worker_fn.
102            params = inspect.signature(worker_fn).parameters.keys()
103            if len(params) == 1:
104                assert "method" in params, f"Expected 'method' param in {params}"
105                # Bind our `method` parameter to `worker_fn`, which has the
106                # signature `func(method: str)`.
107                worker_fn = functools.partial(worker_fn, method)
108            else:
109                assert len(params) == 0, f"Unexpected params in {params}"
110                # worker_fn takes no parameters.
111
112            def return_wrapper():
113                inputs = worker_fn()
114                # Wrap the return value in a tuple if it's not already a tuple
115                # or list.
116                if not isinstance(inputs, (tuple, list)):
117                    inputs = (inputs,)
118                return inputs
119
120            return return_wrapper
121
122        # Create the eager module.
123        eager_module = module_class().eval()
124
125        # Generate inputs to use while tracing.
126        trace_inputs_method = "get_upper_bound_inputs"
127        get_trace_inputs = get_inputs_adapter(
128            (
129                # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
130                #  `Union[Module, Tensor]`.
131                getattr(eager_module, trace_inputs_method)
132                if hasattr(eager_module, trace_inputs_method)
133                else eager_module.get_random_inputs
134            ),
135            # all exported methods must have the same signature so just pick the first one.
136            methods[0],
137        )
138        trace_inputs: Sequence = get_trace_inputs()
139        method_name_to_args = {}
140        for method in methods:
141            method_name_to_args[method] = trace_inputs
142
143        method_name_to_dynamic_shapes = None
144        if hasattr(eager_module, "get_dynamic_shapes"):
145            assert capture_config is not None
146            assert capture_config.enable_aot is True
147            # pyre-fixme[29]: `Union[nn.modules.module.Module,
148            #  torch._tensor.Tensor]` is not a function.
149            trace_dynamic_shapes = eager_module.get_dynamic_shapes()
150            method_name_to_dynamic_shapes = {}
151            for method in methods:
152                method_name_to_dynamic_shapes[method] = trace_dynamic_shapes
153
154        memory_planning_pass = MemoryPlanningPass()
155        if hasattr(eager_module, "get_memory_planning_pass"):
156            # pyre-fixme[29]: `Union[nn.modules.module.Module,
157            #  torch._tensor.Tensor]` is not a function.
158            memory_planning_pass = eager_module.get_memory_planning_pass()
159
160        class WrapperModule(nn.Module):
161            def __init__(self, method):
162                super().__init__()
163                self.forward = method
164
165        exported_methods = {}
166        # These cleanup passes are required to convert the `add` op to its out
167        # variant, along with some other transformations.
168        for method_name, method_input in method_name_to_args.items():
169            # if not isinstance(eager_module, torch.nn.Module):
170            if export_joint_graph:
171                # _export was having issues with WrapperModule.
172                assert method_name == "forward"
173                ep = _export(
174                    eager_module,
175                    method_input,
176                    dynamic_shapes=(
177                        method_name_to_dynamic_shapes[method_name]
178                        if method_name_to_dynamic_shapes
179                        else None
180                    ),
181                    pre_dispatch=True,
182                )
183                exported_methods[method_name] = _export_forward_backward(ep)
184            else:
185                exported_methods[method_name] = export(
186                    eager_module,
187                    method_input,
188                    dynamic_shapes=(
189                        method_name_to_dynamic_shapes[method_name]
190                        if method_name_to_dynamic_shapes
191                        else None
192                    ),
193                )
194
195        exec_prog = to_edge(
196            exported_methods,
197            compile_config=exir.EdgeCompileConfig(
198                _check_ir_validity=False, _skip_type_promotion=skip_type_promotion
199            ),
200        ).to_executorch(
201            ExecutorchBackendConfig(
202                passes=[
203                    DebugPass(
204                        show_src=True,
205                        show_spec=False,
206                        show_full_path=True,
207                        show_all_frames=True,
208                    ),
209                    to_scratch_op_pass,
210                ],
211                dynamic_memory_planning_mode=dynamic_memory_planning_mode,
212                memory_planning_pass=memory_planning_pass,
213                to_out_var_pass=ToOutVarPass(ignore_to_out_var_failure),
214            )
215        )
216
217        # Generate the graph module created during capture.
218        exported_program = exec_prog.exported_program()
219
220        # Get a function that creates random inputs appropriate for testing.
221        get_random_inputs_fn = get_inputs_adapter(
222            # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
223            #  `Union[Module, Tensor]`.
224            eager_module.get_random_inputs,
225            # all exported methods must have the same signature so just pick the first one.
226            methods[0],
227        )
228
229        # Create the ExportedModule.
230        return ExportedModule(
231            eager_module=eager_module,
232            methods=methods,
233            executorch_program=exec_prog,
234            exported_program=exported_program,
235            trace_inputs=trace_inputs,
236            get_random_inputs_fn=get_random_inputs_fn,
237        )
238