xref: /aosp_15_r20/external/executorch/devtools/etdump/tests/serialize_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import difflib
10import json
11import unittest
12from pprint import pformat
13from typing import List
14
15import executorch.devtools.etdump.schema_flatcc as flatcc
16
17from executorch.devtools.etdump.serialize import (
18    deserialize_from_etdump_flatcc,
19    serialize_to_etdump_flatcc,
20)
21from executorch.exir._serialize._dataclass import _DataclassEncoder
22
23
24def diff_jsons(a: str, b: str) -> List[str]:
25    data_a = json.loads(a)
26    data_b = json.loads(b)
27
28    return list(
29        difflib.unified_diff(pformat(data_a).splitlines(), pformat(data_b).splitlines())
30    )
31
32
33def get_sample_etdump_flatcc() -> flatcc.ETDumpFlatCC:
34    return flatcc.ETDumpFlatCC(
35        version=0,
36        run_data=[
37            flatcc.RunData(
38                name="test_block",
39                bundled_input_index=-1,
40                allocators=[
41                    flatcc.Allocator(
42                        name="test_allocator",
43                    )
44                ],
45                events=[
46                    flatcc.Event(
47                        profile_event=flatcc.ProfileEvent(
48                            name="test_profile_event",
49                            chain_index=1,
50                            instruction_id=1,
51                            delegate_debug_id_str="",
52                            delegate_debug_id_int=-1,
53                            delegate_debug_metadata=bytes(),
54                            start_time=1001,
55                            end_time=2002,
56                        ),
57                        allocation_event=None,
58                        debug_event=None,
59                    ),
60                    flatcc.Event(
61                        profile_event=flatcc.ProfileEvent(
62                            name="test_profile_event_delegated",
63                            chain_index=1,
64                            instruction_id=1,
65                            delegate_debug_id_str="",
66                            delegate_debug_id_int=13,
67                            delegate_debug_metadata=bytes(),
68                            start_time=1001,
69                            end_time=2002,
70                        ),
71                        allocation_event=None,
72                        debug_event=None,
73                    ),
74                    flatcc.Event(
75                        profile_event=None,
76                        allocation_event=flatcc.AllocationEvent(
77                            allocator_id=1,
78                            allocation_size=8,
79                        ),
80                        debug_event=None,
81                    ),
82                    flatcc.Event(
83                        profile_event=None,
84                        allocation_event=None,
85                        debug_event=flatcc.DebugEvent(
86                            name="test_debug_event",
87                            chain_index=1,
88                            instruction_id=0,
89                            delegate_debug_id_str="56",
90                            delegate_debug_id_int=-1,
91                            debug_entry=flatcc.Value(
92                                val=flatcc.ValueType.TENSOR.value,
93                                tensor=flatcc.Tensor(
94                                    scalar_type=flatcc.ScalarType.INT,
95                                    sizes=[1],
96                                    strides=[1],
97                                    offset=12345,
98                                ),
99                                tensor_list=flatcc.TensorList(
100                                    [
101                                        flatcc.Tensor(
102                                            scalar_type=flatcc.ScalarType.INT,
103                                            sizes=[1],
104                                            strides=[1],
105                                            offset=12345,
106                                        )
107                                    ]
108                                ),
109                                int_value=flatcc.Int(1),
110                                float_value=flatcc.Float(1.0),
111                                double_value=flatcc.Double(1.0),
112                                bool_value=flatcc.Bool(False),
113                                output=flatcc.Bool(True),
114                            ),
115                        ),
116                    ),
117                ],
118            )
119        ],
120    )
121
122
123class TestSerializeFlatCC(unittest.TestCase):
124    def test_serialize(self) -> None:
125        import json
126
127        program = get_sample_etdump_flatcc()
128
129        flatcc_from_py = serialize_to_etdump_flatcc(program)
130        deserialized_obj = deserialize_from_etdump_flatcc(
131            flatcc_from_py, size_prefixed=False
132        )
133        self.assertEqual(
134            program,
135            deserialized_obj,
136            msg="\n".join(
137                diff_jsons(
138                    json.dumps(program, cls=_DataclassEncoder, indent=4),
139                    json.dumps(deserialized_obj, cls=_DataclassEncoder, indent=4),
140                )
141            ),
142        )
143