xref: /aosp_15_r20/external/executorch/devtools/inspector/tests/inspector_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import random
10import statistics
11import tempfile
12import unittest
13from contextlib import redirect_stdout
14
15from typing import Callable, List
16
17from unittest.mock import patch
18
19from executorch.devtools import generate_etrecord, parse_etrecord
20from executorch.devtools.debug_format.et_schema import OperatorNode
21from executorch.devtools.etdump.schema_flatcc import ProfileEvent
22from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
23
24from executorch.devtools.inspector import (
25    _inspector,
26    Event,
27    EventBlock,
28    Inspector,
29    PerfData,
30)
31from executorch.devtools.inspector._inspector import (
32    DebugEventSignature,
33    flatcc,
34    InstructionEvent,
35    InstructionEventSignature,
36    ProfileEventSignature,
37    TimeScale,
38)
39
40from executorch.exir import ExportedProgram
41
42
43OP_TYPE = "aten::add"
44EVENT_BLOCK_NAME = "block_0"
45EVENTS_SIZE = 5
46RAW_DATA_SIZE = 10
47ETDUMP_PATH = "unittest_etdump_path"
48ETRECORD_PATH = "unittest_etrecord_path"
49
50
51# TODO: write an E2E test: create an inspector instance, mock just the file reads, and then verify the external correctness
52class TestInspector(unittest.TestCase):
53    def test_perf_data(self) -> None:
54        random_floats = self._gen_random_float_list()
55        perfData = PerfData(random_floats)
56
57        # Intentionally use a different way to calculate p50 from the implementation
58        self.assertEqual(perfData.p50, statistics.median(random_floats))
59
60    def test_event_block_to_dataframe(self) -> None:
61        eventBlock = EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
62
63        df = eventBlock.to_dataframe()
64        # Check some fields of the returned dataframe
65        self.assertEqual(len(df), EVENTS_SIZE)
66        self.assertTrue("op_0" in df["event_name"].values)
67        self.assertEqual(len(df["raw"].values[0]), RAW_DATA_SIZE)
68        self.assertEqual(df["op_types"].values[0][0], OP_TYPE)
69
70    def test_inspector_constructor(self):
71        # Create a context manager to patch functions called by Inspector.__init__
72        with patch.object(
73            _inspector, "parse_etrecord", return_value=None
74        ) as mock_parse_etrecord, patch.object(
75            _inspector, "gen_etdump_object", return_value=None
76        ) as mock_gen_etdump, patch.object(
77            EventBlock, "_gen_from_etdump"
78        ) as mock_gen_from_etdump, patch.object(
79            _inspector, "gen_graphs_from_etrecord"
80        ) as mock_gen_graphs_from_etrecord:
81            # Call the constructor of Inspector
82            Inspector(
83                etdump_path=ETDUMP_PATH,
84                etrecord=ETRECORD_PATH,
85            )
86
87            # Assert that expected functions are called
88            mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
89            mock_gen_etdump.assert_called_once_with(
90                etdump_path=ETDUMP_PATH, etdump_data=None
91            )
92            mock_gen_from_etdump.assert_called_once()
93            # Because we mocked parse_etrecord() to return None, this method shouldn't be called
94            mock_gen_graphs_from_etrecord.assert_not_called()
95
96    def test_default_delegate_time_scale_converter(self):
97        # Create a context manager to patch functions called by Inspector.__init__
98        with patch.object(
99            _inspector, "parse_etrecord", return_value=None
100        ), patch.object(
101            _inspector, "gen_etdump_object", return_value=None
102        ), patch.object(
103            EventBlock, "_gen_from_etdump"
104        ) as mock_gen_from_etdump, patch.object(
105            _inspector, "gen_graphs_from_etrecord"
106        ), patch.object(
107            _inspector, "create_debug_handle_to_op_node_mapping"
108        ):
109            # Call the constructor of Inspector
110            Inspector(
111                etdump_path=ETDUMP_PATH,
112                etrecord=ETRECORD_PATH,
113                source_time_scale=TimeScale.US,
114                target_time_scale=TimeScale.S,
115            )
116
117            # Verify delegate_time_scale_converter is set to be a callable
118            self.assertIsInstance(
119                mock_gen_from_etdump.call_args.get("delegate_time_scale_converter"),
120                Callable,
121            )
122
123    def test_inspector_print_data_tabular(self):
124        # Create a context manager to patch functions called by Inspector.__init__
125        with patch.object(
126            _inspector, "parse_etrecord", return_value=None
127        ), patch.object(
128            _inspector, "gen_etdump_object", return_value=None
129        ), patch.object(
130            EventBlock, "_gen_from_etdump"
131        ), patch.object(
132            _inspector, "gen_graphs_from_etrecord"
133        ):
134            # Call the constructor of Inspector
135            inspector_instance = Inspector(
136                etdump_path=ETDUMP_PATH,
137                etrecord=ETRECORD_PATH,
138            )
139
140            # The mock inspector instance starts with having an empty event blocks list.
141            # Add non-empty event blocks to test print_data_tabular().
142            inspector_instance.event_blocks = [
143                EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events())
144            ]
145            # Call print_data_tabular(), make sure it doesn't crash
146            with redirect_stdout(None):
147                inspector_instance.print_data_tabular()
148
149    def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):
150        # Test on an event with a single debug handle
151        debug_handle = 111
152        event_with_single_debug_handle = Event(
153            name="event_with_single_debug_handle",
154            perf_data=PerfData(raw=[]),
155            debug_handles=debug_handle,
156        )
157        node_0 = OperatorNode(
158            name="node_0",
159            metadata={
160                "debug_handle": debug_handle,
161                "stack_trace": "stack_trace_relu",
162                "nn_module_stack": "module_hierarchy_relu",
163            },
164            op="op",
165        )
166
167        # Call the method that's under testing and verify
168        event_with_single_debug_handle._associate_with_op_graph_nodes(
169            {debug_handle: node_0}
170        )
171
172        expected_stack_traces = {"node_0": "stack_trace_relu"}
173        self.assertEqual(
174            event_with_single_debug_handle.stack_traces, expected_stack_traces
175        )
176        expected_module_hierarchy = {"node_0": "module_hierarchy_relu"}
177        self.assertEqual(
178            event_with_single_debug_handle.module_hierarchy, expected_module_hierarchy
179        )
180        expected_ops = ["op"]
181        self.assertEqual(event_with_single_debug_handle.op_types, expected_ops)
182
183    def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):
184        # Test on an event with a sequence of debug handles
185        debug_handles = [222, 333]
186        event_with_multiple_debug_handles = Event(
187            name="event_with_multiple_debug_handles",
188            perf_data=PerfData(raw=[]),
189            debug_handles=debug_handles,
190        )
191        node_0 = OperatorNode(
192            name="node_0",
193            metadata={
194                "debug_handle": debug_handles[0],
195                "stack_trace": "stack_trace_relu",
196                "nn_module_stack": "module_hierarchy_relu",
197            },
198            op="op_0",
199        )
200        node_1 = OperatorNode(
201            name="node_1",
202            metadata={
203                "debug_handle": debug_handles[1],
204                "stack_trace": "stack_trace_conv",
205                "nn_module_stack": "module_hierarchy_conv",
206            },
207            op="op_1",
208        )
209
210        # Call the method that's under testing and verify
211        event_with_multiple_debug_handles._associate_with_op_graph_nodes(
212            {debug_handles[0]: node_0, debug_handles[1]: node_1}
213        )
214
215        expected_stack_traces = {
216            "node_0": "stack_trace_relu",
217            "node_1": "stack_trace_conv",
218        }
219        self.assertEqual(
220            event_with_multiple_debug_handles.stack_traces, expected_stack_traces
221        )
222        expected_module_hierarchy = {
223            "node_0": "module_hierarchy_relu",
224            "node_1": "module_hierarchy_conv",
225        }
226        self.assertEqual(
227            event_with_multiple_debug_handles.module_hierarchy,
228            expected_module_hierarchy,
229        )
230        expected_ops = ["op_0", "op_1"]
231        self.assertEqual(event_with_multiple_debug_handles.op_types, expected_ops)
232
233    def test_inspector_delegate_time_scale_converter(self):
234        def time_scale_converter(event_name, time):
235            return time / 10
236
237        event = Event(
238            name="",
239            _delegate_metadata_parser=None,
240            _delegate_time_scale_converter=None,
241        )
242        event_signature = ProfileEventSignature(
243            name="",
244            instruction_id=0,
245            delegate_id_str="test_event",
246        )
247        instruction_events = [
248            InstructionEvent(
249                signature=InstructionEventSignature(0, 0),
250                profile_events=[
251                    ProfileEvent(
252                        name="test_event",
253                        chain_index=0,
254                        instruction_id=0,
255                        delegate_debug_id_int=None,
256                        delegate_debug_id_str="test_event_delegated",
257                        start_time=100,
258                        end_time=200,
259                        delegate_debug_metadata=None,
260                    )
261                ],
262            )
263        ]
264        Event._populate_profiling_related_fields(
265            event, event_signature, instruction_events, 1
266        )
267        # Value of the perf data before scaling is done.
268        self.assertEqual(event.perf_data.raw[0], 100)
269        event._delegate_time_scale_converter = time_scale_converter
270        Event._populate_profiling_related_fields(
271            event, event_signature, instruction_events, 1
272        )
273        # Value of the perf data after scaling is done. 200/10 - 100/10.
274        self.assertEqual(event.perf_data.raw[0], 10)
275
276    def test_inspector_get_exported_program(self):
277        # Create a context manager to patch functions called by Inspector.__init__
278        with patch.object(
279            _inspector, "parse_etrecord", return_value=None
280        ), patch.object(
281            _inspector, "gen_etdump_object", return_value=None
282        ), patch.object(
283            EventBlock, "_gen_from_etdump"
284        ), patch.object(
285            _inspector, "gen_graphs_from_etrecord"
286        ), patch.object(
287            _inspector, "create_debug_handle_to_op_node_mapping"
288        ):
289            # Call the constructor of Inspector
290            inspector_instance = Inspector(
291                etdump_path=ETDUMP_PATH,
292                etrecord=ETRECORD_PATH,
293            )
294
295            # Gen a mock etrecord
296            captured_output, edge_output, et_output = TestETRecord().get_test_model()
297            with tempfile.TemporaryDirectory() as tmpdirname:
298                generate_etrecord(
299                    tmpdirname + "/etrecord.bin",
300                    edge_output,
301                    et_output,
302                    {
303                        "aten_dialect_output": captured_output,
304                    },
305                )
306
307                inspector_instance._etrecord = parse_etrecord(
308                    tmpdirname + "/etrecord.bin"
309                )
310
311                self.assertTrue(
312                    isinstance(
313                        inspector_instance.get_exported_program(), ExportedProgram
314                    )
315                )
316
317    def test_populate_debugging_related_fields_raises_for_inconsistent_events(self):
318        ret_event: Event = Event(
319            name="event",
320        )
321
322        debug_event_0 = flatcc.DebugEvent(
323            name="event",
324            chain_index=1,
325            instruction_id=0,
326            delegate_debug_id_int=1,
327            delegate_debug_id_str=None,
328            debug_entry=flatcc.Value(
329                val=flatcc.ValueType.TENSOR.value,
330                tensor=flatcc.Tensor(
331                    scalar_type=flatcc.ScalarType.INT,
332                    sizes=[2],
333                    strides=[1],
334                    offset=12345,
335                ),
336                tensor_list=None,
337                int_value=None,
338                float_value=None,
339                double_value=None,
340                bool_value=None,
341                output=None,
342            ),
343        )
344
345        # Note the sizes of this tensor are different from the previous one
346        debug_event_1 = flatcc.DebugEvent(
347            name="event",
348            chain_index=1,
349            instruction_id=0,
350            delegate_debug_id_int=1,
351            delegate_debug_id_str=None,
352            debug_entry=flatcc.Value(
353                val=flatcc.ValueType.TENSOR.value,
354                tensor=flatcc.Tensor(
355                    scalar_type=flatcc.ScalarType.INT,
356                    sizes=[1],
357                    strides=[1],
358                    offset=23456,
359                ),
360                tensor_list=None,
361                int_value=None,
362                float_value=None,
363                double_value=None,
364                bool_value=None,
365                output=None,
366            ),
367        )
368
369        instruction_event_0 = InstructionEvent(
370            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_0]
371        )
372        instruction_event_1 = InstructionEvent(
373            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_1]
374        )
375
376        events = [instruction_event_0, instruction_event_1]
377
378        # Expect AssertionError because 2 tensors have different sizes
379        with self.assertRaises(AssertionError):
380            Event._populate_debugging_related_fields(
381                ret_event=ret_event,
382                debug_event_signature=DebugEventSignature(instruction_id=1),
383                events=events,
384            )
385
386    def test_populate_debugging_related_fields_passes_for_consistent_events(self):
387        ret_event: Event = Event(
388            name="event",
389        )
390
391        debug_event_0 = flatcc.DebugEvent(
392            name="event",
393            chain_index=1,
394            instruction_id=0,
395            delegate_debug_id_int=1,
396            delegate_debug_id_str=None,
397            debug_entry=flatcc.Value(
398                val=flatcc.ValueType.TENSOR.value,
399                tensor=flatcc.Tensor(
400                    scalar_type=flatcc.ScalarType.INT,
401                    sizes=[1],
402                    strides=[1],
403                    offset=12345,
404                ),
405                tensor_list=None,
406                int_value=None,
407                float_value=None,
408                double_value=None,
409                bool_value=None,
410                output=None,
411            ),
412        )
413
414        # Same as the event above except for offset
415        debug_event_1 = flatcc.DebugEvent(
416            name="event",
417            chain_index=1,
418            instruction_id=0,
419            delegate_debug_id_int=1,
420            delegate_debug_id_str=None,
421            debug_entry=flatcc.Value(
422                val=flatcc.ValueType.TENSOR.value,
423                tensor=flatcc.Tensor(
424                    scalar_type=flatcc.ScalarType.INT,
425                    sizes=[1],
426                    strides=[1],
427                    offset=23456,
428                ),
429                tensor_list=None,
430                int_value=None,
431                float_value=None,
432                double_value=None,
433                bool_value=None,
434                output=None,
435            ),
436        )
437
438        instruction_event_0 = InstructionEvent(
439            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_0]
440        )
441        instruction_event_1 = InstructionEvent(
442            signature=InstructionEventSignature(1, 1), debug_events=[debug_event_1]
443        )
444
445        events = [instruction_event_0, instruction_event_1]
446
447        with patch.object(_inspector, "is_inference_output_equal", return_value=True):
448            # Expect it runs with no error because is_inference_output_equal() is mocked to return True
449            Event._populate_debugging_related_fields(
450                ret_event=ret_event,
451                debug_event_signature=DebugEventSignature(instruction_id=1),
452                events=events,
453            )
454
455    def _gen_random_float_list(self) -> List[float]:
456        return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
457
458    def _gen_random_events(self) -> List[Event]:
459        events = []
460        for i in range(EVENTS_SIZE):
461            events.append(
462                Event(
463                    name=f"op_{i}",
464                    op_types=[OP_TYPE],
465                    perf_data=PerfData(self._gen_random_float_list()),
466                )
467            )
468        return events
469