1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9"""Test helper for exporting an nn.Module to an ExecuTorch program.""" 10 11import functools 12import inspect 13from typing import Callable, Sequence, Type 14 15import executorch.exir as exir 16import torch 17from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge 18from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode 19from executorch.exir.passes import ( 20 DebugPass, 21 MemoryPlanningPass, 22 to_scratch_op_pass, 23 ToOutVarPass, 24) 25from torch import nn 26from torch.export import export 27from torch.export._trace import _export 28from torch.export.experimental import _export_forward_backward 29 30 31class ExportedModule: 32 """The result of exporting an nn.Module. 33 34 Attributes: 35 eager_module: The original nn.Module that was exported. 36 methods: The names of the eager_module methods that were traced. 37 executorch_program: The resulting ExecutorchProgram. 38 exported_program: The resulting ExportedProgram. 39 trace_inputs: The inputs that were used when tracing eager_module. 40 """ 41 42 def __init__( 43 self, 44 eager_module: nn.Module, 45 methods: Sequence[str], 46 executorch_program: ExecutorchProgramManager, 47 exported_program: torch.export.ExportedProgram, 48 trace_inputs: Sequence, 49 get_random_inputs_fn: Callable[[], Sequence], 50 ): 51 """INTERNAL ONLY: Use ExportedModule.export() instead.""" 52 self.eager_module: nn.Module = eager_module 53 self.methods: Sequence[str] = methods 54 self.executorch_program: ExecutorchProgramManager = executorch_program 55 self.exported_program: torch.export.ExportedProgram = exported_program 56 self.trace_inputs: Sequence = trace_inputs 57 self.__get_random_inputs_fn = get_random_inputs_fn 58 59 def get_random_inputs(self) -> Sequence: 60 """Returns random inputs appropriate for model inference.""" 61 return self.__get_random_inputs_fn() 62 63 @staticmethod 64 def export( 65 module_class: Type[nn.Module], 66 methods: Sequence[str] = ("forward",), 67 ignore_to_out_var_failure: bool = False, 68 dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, 69 capture_config=None, 70 skip_type_promotion: bool = False, 71 export_joint_graph: bool = False, 72 ) -> "ExportedModule": 73 """ 74 Creates a new ExportedModule for the specified module class. 75 76 Args: 77 module_class: The subclass of nn.Module to export. 78 methods: The names of the module_class methods to trace. 79 ignore_to_out_var_failure: Whether to ignore the failue when an 80 functional op does not have an out variant. 81 dynamic_memory_planning_mode: The dynamic memory planning mode to 82 use. 83 """ 84 85 def get_inputs_adapter( 86 worker_fn: Callable, method: str 87 ) -> Callable[[], Sequence]: 88 """Returns a function that may bind `method` as a parameter of 89 `worker_fn`, and ensures that `worker_fn` always returns a list or 90 tuple. 91 92 Args: 93 worker_fn: The function to wrap. Must take zero or one 94 arguments. If it takes one argument, that argument must be 95 called "method" and expect a string. 96 method: The name of the method to possibly pass to `worker_fn`. 97 98 Returns: 99 A function that takes zero arguments and returns a Sequence. 100 """ 101 # Names of the parameters of worker_fn. 102 params = inspect.signature(worker_fn).parameters.keys() 103 if len(params) == 1: 104 assert "method" in params, f"Expected 'method' param in {params}" 105 # Bind our `method` parameter to `worker_fn`, which has the 106 # signature `func(method: str)`. 107 worker_fn = functools.partial(worker_fn, method) 108 else: 109 assert len(params) == 0, f"Unexpected params in {params}" 110 # worker_fn takes no parameters. 111 112 def return_wrapper(): 113 inputs = worker_fn() 114 # Wrap the return value in a tuple if it's not already a tuple 115 # or list. 116 if not isinstance(inputs, (tuple, list)): 117 inputs = (inputs,) 118 return inputs 119 120 return return_wrapper 121 122 # Create the eager module. 123 eager_module = module_class().eval() 124 125 # Generate inputs to use while tracing. 126 trace_inputs_method = "get_upper_bound_inputs" 127 get_trace_inputs = get_inputs_adapter( 128 ( 129 # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got 130 # `Union[Module, Tensor]`. 131 getattr(eager_module, trace_inputs_method) 132 if hasattr(eager_module, trace_inputs_method) 133 else eager_module.get_random_inputs 134 ), 135 # all exported methods must have the same signature so just pick the first one. 136 methods[0], 137 ) 138 trace_inputs: Sequence = get_trace_inputs() 139 method_name_to_args = {} 140 for method in methods: 141 method_name_to_args[method] = trace_inputs 142 143 method_name_to_dynamic_shapes = None 144 if hasattr(eager_module, "get_dynamic_shapes"): 145 assert capture_config is not None 146 assert capture_config.enable_aot is True 147 # pyre-fixme[29]: `Union[nn.modules.module.Module, 148 # torch._tensor.Tensor]` is not a function. 149 trace_dynamic_shapes = eager_module.get_dynamic_shapes() 150 method_name_to_dynamic_shapes = {} 151 for method in methods: 152 method_name_to_dynamic_shapes[method] = trace_dynamic_shapes 153 154 memory_planning_pass = MemoryPlanningPass() 155 if hasattr(eager_module, "get_memory_planning_pass"): 156 # pyre-fixme[29]: `Union[nn.modules.module.Module, 157 # torch._tensor.Tensor]` is not a function. 158 memory_planning_pass = eager_module.get_memory_planning_pass() 159 160 class WrapperModule(nn.Module): 161 def __init__(self, method): 162 super().__init__() 163 self.forward = method 164 165 exported_methods = {} 166 # These cleanup passes are required to convert the `add` op to its out 167 # variant, along with some other transformations. 168 for method_name, method_input in method_name_to_args.items(): 169 # if not isinstance(eager_module, torch.nn.Module): 170 if export_joint_graph: 171 # _export was having issues with WrapperModule. 172 assert method_name == "forward" 173 ep = _export( 174 eager_module, 175 method_input, 176 dynamic_shapes=( 177 method_name_to_dynamic_shapes[method_name] 178 if method_name_to_dynamic_shapes 179 else None 180 ), 181 pre_dispatch=True, 182 ) 183 exported_methods[method_name] = _export_forward_backward(ep) 184 else: 185 exported_methods[method_name] = export( 186 eager_module, 187 method_input, 188 dynamic_shapes=( 189 method_name_to_dynamic_shapes[method_name] 190 if method_name_to_dynamic_shapes 191 else None 192 ), 193 ) 194 195 exec_prog = to_edge( 196 exported_methods, 197 compile_config=exir.EdgeCompileConfig( 198 _check_ir_validity=False, _skip_type_promotion=skip_type_promotion 199 ), 200 ).to_executorch( 201 ExecutorchBackendConfig( 202 passes=[ 203 DebugPass( 204 show_src=True, 205 show_spec=False, 206 show_full_path=True, 207 show_all_frames=True, 208 ), 209 to_scratch_op_pass, 210 ], 211 dynamic_memory_planning_mode=dynamic_memory_planning_mode, 212 memory_planning_pass=memory_planning_pass, 213 to_out_var_pass=ToOutVarPass(ignore_to_out_var_failure), 214 ) 215 ) 216 217 # Generate the graph module created during capture. 218 exported_program = exec_prog.exported_program() 219 220 # Get a function that creates random inputs appropriate for testing. 221 get_random_inputs_fn = get_inputs_adapter( 222 # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got 223 # `Union[Module, Tensor]`. 224 eager_module.get_random_inputs, 225 # all exported methods must have the same signature so just pick the first one. 226 methods[0], 227 ) 228 229 # Create the ExportedModule. 230 return ExportedModule( 231 eager_module=eager_module, 232 methods=methods, 233 executorch_program=exec_prog, 234 exported_program=exported_program, 235 trace_inputs=trace_inputs, 236 get_random_inputs_fn=get_random_inputs_fn, 237 ) 238