1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.examples.models import MODEL_NAME_TO_MODEL 11from executorch.examples.models.model_factory import EagerModelFactory 12 13from executorch.extension.export_util.utils import export_to_edge 14 15from executorch.extension.pybindings.portable_lib import ( # @manual 16 _load_for_executorch_from_buffer, 17) 18 19 20class ExportTest(unittest.TestCase): 21 def collect_executorch_and_eager_outputs( 22 self, 23 eager_model: torch.nn.Module, 24 example_inputs, 25 ): 26 """ 27 Compares the output of the given eager mode PyTorch model with the output 28 of the equivalent executorch model, both provided with example inputs. 29 Returns a tuple containing the outputs of the eager mode model and the executorch mode model. 30 """ 31 eager_model = eager_model.eval() 32 model = torch.export.export_for_training(eager_model, example_inputs).module() 33 edge_model = export_to_edge(model, example_inputs) 34 35 executorch_prog = edge_model.to_executorch() 36 37 pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer) 38 39 with torch.no_grad(): 40 eager_output = eager_model(*example_inputs) 41 with torch.no_grad(): 42 executorch_output = pte_model.run_method("forward", example_inputs) 43 44 return (eager_output, executorch_output) 45 46 def validate_tensor_allclose( 47 self, eager_output, executorch_output, rtol=1e-5, atol=1e-5 48 ): 49 self.assertTrue( 50 isinstance(eager_output, type(executorch_output)), 51 f"Outputs are not of the same type: eager type: {type(eager_output)}, executorch type: {type(executorch_output)}", 52 ) 53 self.assertTrue( 54 len(eager_output) == len(executorch_output), 55 f"len(eager_output)={len(eager_output)}, len(executorch_output)={len(executorch_output)}", 56 ) 57 result = True 58 for i in range(len(eager_output)): 59 result = torch.allclose( 60 eager_output[i], 61 executorch_output[i], 62 rtol=rtol, 63 atol=atol, 64 ) 65 if not result: 66 print(f"eager output[{i}]: {eager_output[i]}") 67 print(f"executorch output[{i}]: {executorch_output[i]}") 68 break 69 return self.assertTrue(result) 70 71 def test_mv3_export_to_executorch(self): 72 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 73 *MODEL_NAME_TO_MODEL["mv3"] 74 ) 75 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 76 eager_model, example_inputs 77 ) 78 # TODO(T166083470): Fix accuracy issue 79 self.validate_tensor_allclose( 80 eager_output, executorch_output[0], rtol=1e-3, atol=1e-5 81 ) 82 83 def test_mv2_export_to_executorch(self): 84 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 85 *MODEL_NAME_TO_MODEL["mv2"] 86 ) 87 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 88 eager_model, example_inputs 89 ) 90 self.validate_tensor_allclose(eager_output, executorch_output[0]) 91 92 def test_vit_export_to_executorch(self): 93 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 94 *MODEL_NAME_TO_MODEL["vit"] 95 ) 96 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 97 eager_model, example_inputs 98 ) 99 # TODO(T166083470): Fix accuracy, detected on Arm64 100 self.validate_tensor_allclose( 101 eager_output, executorch_output[0], rtol=1e-2, atol=1e-2 102 ) 103 104 def test_w2l_export_to_executorch(self): 105 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 106 *MODEL_NAME_TO_MODEL["w2l"] 107 ) 108 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 109 eager_model, example_inputs 110 ) 111 self.validate_tensor_allclose(eager_output, executorch_output[0]) 112 113 def test_ic3_export_to_executorch(self): 114 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 115 *MODEL_NAME_TO_MODEL["ic3"] 116 ) 117 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 118 eager_model, example_inputs 119 ) 120 # TODO(T166083470): Fix accuracy issue 121 self.validate_tensor_allclose( 122 eager_output, executorch_output[0], rtol=1e-3, atol=1e-5 123 ) 124 125 def test_resnet18_export_to_executorch(self): 126 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 127 *MODEL_NAME_TO_MODEL["resnet18"] 128 ) 129 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 130 eager_model, example_inputs 131 ) 132 self.validate_tensor_allclose(eager_output, executorch_output[0]) 133 134 def test_resnet50_export_to_executorch(self): 135 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 136 *MODEL_NAME_TO_MODEL["resnet50"] 137 ) 138 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 139 eager_model, example_inputs 140 ) 141 self.validate_tensor_allclose(eager_output, executorch_output[0]) 142 143 def test_dl3_export_to_executorch(self): 144 eager_model, example_inputs, _, _ = EagerModelFactory.create_model( 145 *MODEL_NAME_TO_MODEL["dl3"] 146 ) 147 eager_output, executorch_output = self.collect_executorch_and_eager_outputs( 148 eager_model, example_inputs 149 ) 150 self.validate_tensor_allclose(list(eager_output.values()), executorch_output) 151