xref: /aosp_15_r20/external/executorch/examples/models/test/test_export.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.examples.models import MODEL_NAME_TO_MODEL
11from executorch.examples.models.model_factory import EagerModelFactory
12
13from executorch.extension.export_util.utils import export_to_edge
14
15from executorch.extension.pybindings.portable_lib import (  # @manual
16    _load_for_executorch_from_buffer,
17)
18
19
20class ExportTest(unittest.TestCase):
21    def collect_executorch_and_eager_outputs(
22        self,
23        eager_model: torch.nn.Module,
24        example_inputs,
25    ):
26        """
27        Compares the output of the given eager mode PyTorch model with the output
28        of the equivalent executorch model, both provided with example inputs.
29        Returns a tuple containing the outputs of the eager mode model and the executorch mode model.
30        """
31        eager_model = eager_model.eval()
32        model = torch.export.export_for_training(eager_model, example_inputs).module()
33        edge_model = export_to_edge(model, example_inputs)
34
35        executorch_prog = edge_model.to_executorch()
36
37        pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer)
38
39        with torch.no_grad():
40            eager_output = eager_model(*example_inputs)
41        with torch.no_grad():
42            executorch_output = pte_model.run_method("forward", example_inputs)
43
44        return (eager_output, executorch_output)
45
46    def validate_tensor_allclose(
47        self, eager_output, executorch_output, rtol=1e-5, atol=1e-5
48    ):
49        self.assertTrue(
50            isinstance(eager_output, type(executorch_output)),
51            f"Outputs are not of the same type: eager type: {type(eager_output)}, executorch type: {type(executorch_output)}",
52        )
53        self.assertTrue(
54            len(eager_output) == len(executorch_output),
55            f"len(eager_output)={len(eager_output)}, len(executorch_output)={len(executorch_output)}",
56        )
57        result = True
58        for i in range(len(eager_output)):
59            result = torch.allclose(
60                eager_output[i],
61                executorch_output[i],
62                rtol=rtol,
63                atol=atol,
64            )
65            if not result:
66                print(f"eager output[{i}]: {eager_output[i]}")
67                print(f"executorch output[{i}]: {executorch_output[i]}")
68                break
69        return self.assertTrue(result)
70
71    def test_mv3_export_to_executorch(self):
72        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
73            *MODEL_NAME_TO_MODEL["mv3"]
74        )
75        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
76            eager_model, example_inputs
77        )
78        # TODO(T166083470): Fix accuracy issue
79        self.validate_tensor_allclose(
80            eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
81        )
82
83    def test_mv2_export_to_executorch(self):
84        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
85            *MODEL_NAME_TO_MODEL["mv2"]
86        )
87        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
88            eager_model, example_inputs
89        )
90        self.validate_tensor_allclose(eager_output, executorch_output[0])
91
92    def test_vit_export_to_executorch(self):
93        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
94            *MODEL_NAME_TO_MODEL["vit"]
95        )
96        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
97            eager_model, example_inputs
98        )
99        # TODO(T166083470): Fix accuracy, detected on Arm64
100        self.validate_tensor_allclose(
101            eager_output, executorch_output[0], rtol=1e-2, atol=1e-2
102        )
103
104    def test_w2l_export_to_executorch(self):
105        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
106            *MODEL_NAME_TO_MODEL["w2l"]
107        )
108        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
109            eager_model, example_inputs
110        )
111        self.validate_tensor_allclose(eager_output, executorch_output[0])
112
113    def test_ic3_export_to_executorch(self):
114        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
115            *MODEL_NAME_TO_MODEL["ic3"]
116        )
117        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
118            eager_model, example_inputs
119        )
120        # TODO(T166083470): Fix accuracy issue
121        self.validate_tensor_allclose(
122            eager_output, executorch_output[0], rtol=1e-3, atol=1e-5
123        )
124
125    def test_resnet18_export_to_executorch(self):
126        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
127            *MODEL_NAME_TO_MODEL["resnet18"]
128        )
129        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
130            eager_model, example_inputs
131        )
132        self.validate_tensor_allclose(eager_output, executorch_output[0])
133
134    def test_resnet50_export_to_executorch(self):
135        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
136            *MODEL_NAME_TO_MODEL["resnet50"]
137        )
138        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
139            eager_model, example_inputs
140        )
141        self.validate_tensor_allclose(eager_output, executorch_output[0])
142
143    def test_dl3_export_to_executorch(self):
144        eager_model, example_inputs, _, _ = EagerModelFactory.create_model(
145            *MODEL_NAME_TO_MODEL["dl3"]
146        )
147        eager_output, executorch_output = self.collect_executorch_and_eager_outputs(
148            eager_model, example_inputs
149        )
150        self.validate_tensor_allclose(list(eager_output.values()), executorch_output)
151