1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10from types import ModuleType 11from typing import Any, Callable, Optional, Tuple 12 13import torch 14from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge 15from executorch.exir.passes import MemoryPlanningPass 16from torch.export import export 17 18 19class ModuleAdd(torch.nn.Module): 20 """The module to serialize and execute.""" 21 22 def __init__(self): 23 super(ModuleAdd, self).__init__() 24 25 def forward(self, x, y): 26 return x + y 27 28 def get_methods_to_export(self): 29 return ("forward",) 30 31 def get_inputs(self): 32 return (torch.ones(2, 2), torch.ones(2, 2)) 33 34 35class ModuleMulti(torch.nn.Module): 36 """The module to serialize and execute.""" 37 38 def __init__(self): 39 super(ModuleMulti, self).__init__() 40 41 def forward(self, x, y): 42 return x + y 43 44 def forward2(self, x, y): 45 return x + y + 1 46 47 def get_methods_to_export(self): 48 return ("forward", "forward2") 49 50 def get_inputs(self): 51 return (torch.ones(2, 2), torch.ones(2, 2)) 52 53 54class ModuleAddSingleInput(torch.nn.Module): 55 """The module to serialize and execute.""" 56 57 def __init__(self): 58 super(ModuleAddSingleInput, self).__init__() 59 60 def forward(self, x): 61 return x + x 62 63 def get_methods_to_export(self): 64 return ("forward",) 65 66 def get_inputs(self): 67 return (torch.ones(2, 2),) 68 69 70class ModuleAddConstReturn(torch.nn.Module): 71 """The module to serialize and execute.""" 72 73 def __init__(self): 74 super(ModuleAddConstReturn, self).__init__() 75 self.state = torch.ones(2, 2) 76 77 def forward(self, x): 78 return x + self.state, self.state 79 80 def get_methods_to_export(self): 81 return ("forward",) 82 83 def get_inputs(self): 84 return (torch.ones(2, 2),) 85 86 87def create_program( 88 eager_module: torch.nn.Module, 89 et_config: Optional[ExecutorchBackendConfig] = None, 90) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]: 91 """Returns an executorch program based on ModuleAdd, along with inputs.""" 92 93 # Trace the test module and create a serialized ExecuTorch program. 94 # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` 95 # is not a function. 96 inputs = eager_module.get_inputs() 97 input_map = {} 98 # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` 99 # is not a function. 100 for method in eager_module.get_methods_to_export(): 101 input_map[method] = inputs 102 103 class WrapperModule(torch.nn.Module): 104 def __init__(self, fn): 105 super().__init__() 106 self.fn = fn 107 108 def forward(self, *args, **kwargs): 109 return self.fn(*args, **kwargs) 110 111 exported_methods = {} 112 # These cleanup passes are required to convert the `add` op to its out 113 # variant, along with some other transformations. 114 for method_name, method_input in input_map.items(): 115 wrapped_mod = WrapperModule(getattr(eager_module, method_name)) 116 exported_methods[method_name] = export(wrapped_mod, method_input) 117 118 exec_prog = to_edge(exported_methods).to_executorch(config=et_config) 119 120 # Create the ExecuTorch program from the graph. 121 exec_prog.dump_executorch_program(verbose=True) 122 return (exec_prog, inputs) 123 124 125def make_test( # noqa: C901 126 tester: unittest.TestCase, 127 runtime: ModuleType, 128) -> Callable[[unittest.TestCase], None]: 129 """ 130 Returns a function that operates as a test case within a unittest.TestCase class. 131 132 Used to allow the test code for pybindings to be shared across different pybinding libs 133 which will all have different load functions. In this case each individual test case is a 134 subfunction of wrapper. 135 """ 136 load_fn: Callable = runtime._load_for_executorch_from_buffer 137 138 def wrapper(tester: unittest.TestCase) -> None: 139 140 ######### TEST CASES ######### 141 142 def test_e2e(tester): 143 # Create an ExecuTorch program from ModuleAdd. 144 exported_program, inputs = create_program(ModuleAdd()) 145 146 # Use pybindings to load and execute the program. 147 executorch_module = load_fn(exported_program.buffer) 148 executorch_output = executorch_module.forward(inputs)[0] 149 150 # The test module adds the two inputs, so its output should be the same 151 # as adding them directly. 152 expected = inputs[0] + inputs[1] 153 154 tester.assertEqual(str(expected), str(executorch_output)) 155 156 def test_multiple_entry(tester): 157 158 program, inputs = create_program(ModuleMulti()) 159 executorch_module = load_fn(program.buffer) 160 161 executorch_output = executorch_module.forward(inputs)[0] 162 tester.assertTrue(torch.allclose(executorch_output, torch.ones(2, 2) * 2)) 163 164 executorch_output2 = executorch_module.run_method("forward2", inputs)[0] 165 tester.assertTrue(torch.allclose(executorch_output2, torch.ones(2, 2) * 3)) 166 167 def test_output_lifespan(tester): 168 def lower_function_call(): 169 program, inputs = create_program(ModuleMulti()) 170 executorch_module = load_fn(program.buffer) 171 172 return executorch_module.forward(inputs) 173 # executorch_module is destructed here and all of its memory is freed 174 175 outputs = lower_function_call() 176 tester.assertTrue(torch.allclose(outputs[0], torch.ones(2, 2) * 2)) 177 178 def test_module_callable(tester): 179 # Create an ExecuTorch program from ModuleAdd. 180 exported_program, inputs = create_program(ModuleAdd()) 181 182 # Use pybindings to load and execute the program. 183 executorch_module = load_fn(exported_program.buffer) 184 # Invoke the callable on executorch_module instead of calling module.forward. 185 executorch_output = executorch_module(inputs)[0] 186 187 # The test module adds the two inputs, so its output should be the same 188 # as adding them directly. 189 expected = inputs[0] + inputs[1] 190 tester.assertEqual(str(expected), str(executorch_output)) 191 192 def test_module_single_input(tester): 193 # Create an ExecuTorch program from ModuleAdd. 194 exported_program, inputs = create_program(ModuleAddSingleInput()) 195 196 # Use pybindings to load and execute the program. 197 executorch_module = load_fn(exported_program.buffer) 198 # Inovke the callable on executorch_module instead of calling module.forward. 199 # Use only one input to test this case. 200 executorch_output = executorch_module(inputs[0])[0] 201 202 # The test module adds the two inputs, so its output should be the same 203 # as adding them directly. 204 expected = inputs[0] + inputs[0] 205 tester.assertEqual(str(expected), str(executorch_output)) 206 207 def test_stderr_redirect(tester): 208 import sys 209 from io import StringIO 210 211 class RedirectedStderr: 212 def __init__(self): 213 self._stderr = None 214 self._string_io = None 215 216 def __enter__(self): 217 self._stderr = sys.stderr 218 sys.stderr = self._string_io = StringIO() 219 return self 220 221 def __exit__(self, type, value, traceback): 222 sys.stderr = self._stderr 223 224 def __str__(self): 225 return self._string_io.getvalue() 226 227 with RedirectedStderr() as out: 228 try: 229 # Create an ExecuTorch program from ModuleAdd. 230 exported_program, inputs = create_program(ModuleAdd()) 231 232 # Use pybindings to load and execute the program. 233 executorch_module = load_fn(exported_program.buffer) 234 235 # add an extra input to trigger error 236 inputs = (*inputs, 1) 237 238 # Invoke the callable on executorch_module instead of calling module.forward. 239 executorch_output = executorch_module(inputs)[0] # noqa 240 tester.assertFalse(True) # should be unreachable 241 except Exception: 242 tester.assertTrue(str(out).find("The length of given input array")) 243 244 def test_quantized_ops(tester): 245 eager_module = ModuleAdd() 246 247 from executorch.exir import EdgeCompileConfig 248 from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 249 from executorch.kernels import quantized # noqa: F401 250 from torch.ao.quantization import get_default_qconfig_mapping 251 from torch.ao.quantization.backend_config.executorch import ( 252 get_executorch_backend_config, 253 ) 254 from torch.ao.quantization.quantize_fx import ( 255 _convert_to_reference_decomposed_fx, 256 prepare_fx, 257 ) 258 259 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 260 example_inputs = ( 261 torch.ones(1, 5, dtype=torch.float32), 262 torch.ones(1, 5, dtype=torch.float32), 263 ) 264 m = prepare_fx( 265 eager_module, 266 qconfig_mapping, 267 example_inputs, 268 backend_config=get_executorch_backend_config(), 269 ) 270 m = _convert_to_reference_decomposed_fx(m) 271 config = EdgeCompileConfig(_check_ir_validity=False) 272 m = to_edge(export(m, example_inputs), compile_config=config) 273 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 274 275 exec_prog = m.to_executorch() 276 277 executorch_module = load_fn(exec_prog.buffer) 278 executorch_output = executorch_module.forward(example_inputs)[0] 279 280 expected = example_inputs[0] + example_inputs[1] 281 tester.assertEqual(str(expected), str(executorch_output)) 282 283 def test_constant_output_not_memory_planned(tester): 284 # Create an ExecuTorch program from ModuleAdd. 285 exported_program, inputs = create_program( 286 ModuleAddConstReturn(), 287 et_config=ExecutorchBackendConfig( 288 memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False) 289 ), 290 ) 291 292 exported_program.dump_executorch_program(verbose=True) 293 294 # Use pybindings to load and execute the program. 295 executorch_module = load_fn(exported_program.buffer) 296 # Invoke the callable on executorch_module instead of calling module.forward. 297 # Use only one input to test this case. 298 executorch_output = executorch_module((torch.ones(2, 2),)) 299 print(executorch_output) 300 301 # The test module adds the input to torch.ones(2,2), so its output should be the same 302 # as adding them directly. 303 expected = torch.ones(2, 2) + torch.ones(2, 2) 304 tester.assertEqual(str(expected), str(executorch_output[0])) 305 306 # The test module returns the state. Check that its value is correct. 307 tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1])) 308 309 def test_method_meta(tester) -> None: 310 exported_program, inputs = create_program(ModuleAdd()) 311 312 # Use pybindings to load the program and query its metadata. 313 executorch_module = load_fn(exported_program.buffer) 314 meta = executorch_module.method_meta("forward") 315 316 # Ensure that all these APIs work even if the module object is destroyed. 317 del executorch_module 318 tester.assertEqual(meta.name(), "forward") 319 tester.assertEqual(meta.num_inputs(), 2) 320 tester.assertEqual(meta.num_outputs(), 1) 321 # Common string for all these tensors. 322 tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)" 323 float_dtype = 6 324 tester.assertEqual( 325 str(meta), 326 "MethodMeta(name='forward', num_inputs=2, " 327 f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], " 328 f"num_outputs=1, output_tensor_meta=['{tensor_info}'])", 329 ) 330 331 input_tensors = [meta.input_tensor_meta(i) for i in range(2)] 332 output_tensor = meta.output_tensor_meta(0) 333 # Check that accessing out of bounds raises IndexError. 334 with tester.assertRaises(IndexError): 335 meta.input_tensor_meta(2) 336 # Test that tensor metadata can outlive method metadata. 337 del meta 338 tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)]) 339 tester.assertEqual( 340 [t.dtype() for t in input_tensors], [float_dtype, float_dtype] 341 ) 342 tester.assertEqual( 343 [t.is_memory_planned() for t in input_tensors], [True, True] 344 ) 345 tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16]) 346 tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]") 347 348 tester.assertEqual(output_tensor.sizes(), (2, 2)) 349 tester.assertEqual(output_tensor.dtype(), float_dtype) 350 tester.assertEqual(output_tensor.is_memory_planned(), True) 351 tester.assertEqual(output_tensor.nbytes(), 16) 352 tester.assertEqual(str(output_tensor), tensor_info) 353 354 def test_bad_name(tester) -> None: 355 # Create an ExecuTorch program from ModuleAdd. 356 exported_program, inputs = create_program(ModuleAdd()) 357 358 # Use pybindings to load and execute the program. 359 executorch_module = load_fn(exported_program.buffer) 360 # Invoke the callable on executorch_module instead of calling module.forward. 361 with tester.assertRaises(RuntimeError): 362 executorch_module.run_method("not_a_real_method", inputs) 363 364 def test_verification_config(tester) -> None: 365 # Create an ExecuTorch program from ModuleAdd. 366 exported_program, inputs = create_program(ModuleAdd()) 367 Verification = runtime.Verification 368 369 # Use pybindings to load and execute the program. 370 for config in [Verification.Minimal, Verification.InternalConsistency]: 371 executorch_module = load_fn( 372 exported_program.buffer, 373 enable_etdump=False, 374 debug_buffer_size=0, 375 program_verification=config, 376 ) 377 378 executorch_output = executorch_module.forward(inputs)[0] 379 380 # The test module adds the two inputs, so its output should be the same 381 # as adding them directly. 382 expected = inputs[0] + inputs[1] 383 384 tester.assertEqual(str(expected), str(executorch_output)) 385 386 ######### RUN TEST CASES ######### 387 test_e2e(tester) 388 test_multiple_entry(tester) 389 test_output_lifespan(tester) 390 test_module_callable(tester) 391 test_module_single_input(tester) 392 test_stderr_redirect(tester) 393 test_quantized_ops(tester) 394 test_constant_output_not_memory_planned(tester) 395 test_method_meta(tester) 396 test_bad_name(tester) 397 test_verification_config(tester) 398 399 return wrapper 400