xref: /aosp_15_r20/external/executorch/devtools/etrecord/tests/etrecord_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import copy
10import json
11import tempfile
12import unittest
13
14import executorch.exir.tests.models as models
15import torch
16from executorch import exir
17from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
18from executorch.devtools.bundled_program.core import BundledProgram
19from executorch.devtools.etrecord import generate_etrecord, parse_etrecord
20from executorch.devtools.etrecord._etrecord import (
21    _get_reference_outputs,
22    ETRecordReservedFileNames,
23)
24from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
25from torch.export import export
26
27
28# TODO : T154728484  Add test cases to cover multiple entry points
29class TestETRecord(unittest.TestCase):
30    def get_test_model(self):
31        f = models.BasicSinMax()
32        captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
33        captured_output_copy = copy.deepcopy(captured_output)
34        edge_output = captured_output.to_edge(
35            # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops
36            exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False)
37        )
38        edge_output_copy = copy.deepcopy(edge_output)
39        et_output = edge_output.to_executorch()
40        return (captured_output_copy, edge_output_copy, et_output)
41
42    def get_test_model_with_bundled_program(self):
43        f = models.BasicSinMax()
44        inputs = [f.get_random_inputs() for _ in range(2)]
45        m_name = "forward"
46
47        method_test_suites = [
48            MethodTestSuite(
49                method_name=m_name,
50                test_cases=[
51                    MethodTestCase(
52                        inputs=inp, expected_outputs=getattr(f, m_name)(*inp)
53                    )
54                    for inp in inputs
55                ],
56            )
57        ]
58        captured_output = exir.capture(f, inputs[0], exir.CaptureConfig())
59        captured_output_copy = copy.deepcopy(captured_output)
60        edge_output = captured_output.to_edge(
61            # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops
62            exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False)
63        )
64        edge_output_copy = copy.deepcopy(edge_output)
65        et_output = edge_output.to_executorch()
66
67        bundled_program = BundledProgram(et_output, method_test_suites)
68        return (captured_output_copy, edge_output_copy, bundled_program)
69
70    def get_test_model_with_manager(self):
71        f = models.BasicSinMax()
72        aten_dialect = export(f, f.get_random_inputs())
73        edge_program: EdgeProgramManager = to_edge(
74            aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
75        )
76        edge_program_copy = copy.deepcopy(edge_program)
77        return (aten_dialect, edge_program_copy, edge_program.to_executorch())
78
79    # Serialized and deserialized graph modules are not completely the same, so we check
80    # that they are close enough and match especially on the parameters we care about in the Developer Tools.
81    def check_graph_closeness(self, graph_a, graph_b):
82        self.assertEqual(len(graph_a.graph.nodes), len(graph_b.graph.nodes))
83        for node_a, node_b in zip(graph_a.graph.nodes, graph_b.graph.nodes):
84            self.assertEqual(node_a.target, node_b.target)
85            self.assertEqual(len(node_a.args), len(node_b.args))
86            self.assertEqual(len(node_a.kwargs), len(node_b.kwargs))
87            self.assertEqual(node_a.name, node_b.name)
88            self.assertEqual(node_a.type, node_b.type)
89            self.assertEqual(node_a.op, node_b.op)
90            if node_a.op not in {"placeholder", "output"}:
91                self.assertEqual(
92                    node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
93                )
94
95    def test_etrecord_generation(self):
96        captured_output, edge_output, et_output = self.get_test_model()
97        with tempfile.TemporaryDirectory() as tmpdirname:
98            generate_etrecord(
99                tmpdirname + "/etrecord.bin",
100                edge_output,
101                et_output,
102                {
103                    "aten_dialect_output": captured_output,
104                },
105            )
106
107            etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
108            self.check_graph_closeness(
109                etrecord.graph_map["aten_dialect_output/forward"],
110                captured_output.exported_program.graph_module,
111            )
112            self.check_graph_closeness(
113                etrecord.edge_dialect_program,
114                edge_output.exported_program.graph_module,
115            )
116            self.assertEqual(
117                etrecord._debug_handle_map,
118                json.loads(json.dumps(et_output.debug_handle_map)),
119            )
120
121    def test_etrecord_generation_with_bundled_program(self):
122        (
123            captured_output,
124            edge_output,
125            bundled_program,
126        ) = self.get_test_model_with_bundled_program()
127        with tempfile.TemporaryDirectory() as tmpdirname:
128            generate_etrecord(
129                tmpdirname + "/etrecord.bin",
130                edge_output,
131                bundled_program,
132                {
133                    "aten_dialect_output": captured_output,
134                },
135            )
136            etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
137
138            expected = etrecord._reference_outputs
139            actual = _get_reference_outputs(bundled_program)
140            # assertEqual() gives "RuntimeError: Boolean value of Tensor with more than one value is ambiguous" when comparing tensors,
141            # so we use torch.equal() to compare the tensors one by one.
142            self.assertTrue(
143                torch.equal(expected["forward"][0][0], actual["forward"][0][0])
144            )
145            self.assertTrue(
146                torch.equal(expected["forward"][1][0], actual["forward"][1][0])
147            )
148
149    def test_etrecord_generation_with_manager(self):
150        captured_output, edge_output, et_output = self.get_test_model_with_manager()
151        with tempfile.TemporaryDirectory() as tmpdirname:
152            generate_etrecord(
153                tmpdirname + "/etrecord.bin",
154                edge_output,
155                et_output,
156            )
157
158            etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
159            self.check_graph_closeness(
160                etrecord.edge_dialect_program,
161                edge_output.exported_program().graph_module,
162            )
163            self.assertEqual(
164                etrecord._debug_handle_map,
165                json.loads(json.dumps(et_output.debug_handle_map)),
166            )
167
168    def test_etrecord_invalid_input(self):
169        captured_output, edge_output, et_output = self.get_test_model()
170        with tempfile.TemporaryDirectory() as tmpdirname:
171            with self.assertRaises(RuntimeError):
172                generate_etrecord(
173                    tmpdirname + "/etrecord.bin",
174                    edge_output,
175                    et_output,
176                    {"fail_test_case": et_output},
177                )
178
179    def test_etrecord_reserved_name(self):
180        captured_output, edge_output, et_output = self.get_test_model()
181        with tempfile.TemporaryDirectory() as tmpdirname:
182            for reserved_name in ETRecordReservedFileNames:
183                with self.assertRaises(RuntimeError):
184                    generate_etrecord(
185                        tmpdirname + "/etrecord.bin",
186                        edge_output,
187                        et_output,
188                        {reserved_name: captured_output.exported_program.graph_module},
189                    )
190