1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import difflib 10import json 11import unittest 12from pprint import pformat 13from typing import List 14 15import executorch.devtools.etdump.schema_flatcc as flatcc 16 17from executorch.devtools.etdump.serialize import ( 18 deserialize_from_etdump_flatcc, 19 serialize_to_etdump_flatcc, 20) 21from executorch.exir._serialize._dataclass import _DataclassEncoder 22 23 24def diff_jsons(a: str, b: str) -> List[str]: 25 data_a = json.loads(a) 26 data_b = json.loads(b) 27 28 return list( 29 difflib.unified_diff(pformat(data_a).splitlines(), pformat(data_b).splitlines()) 30 ) 31 32 33def get_sample_etdump_flatcc() -> flatcc.ETDumpFlatCC: 34 return flatcc.ETDumpFlatCC( 35 version=0, 36 run_data=[ 37 flatcc.RunData( 38 name="test_block", 39 bundled_input_index=-1, 40 allocators=[ 41 flatcc.Allocator( 42 name="test_allocator", 43 ) 44 ], 45 events=[ 46 flatcc.Event( 47 profile_event=flatcc.ProfileEvent( 48 name="test_profile_event", 49 chain_index=1, 50 instruction_id=1, 51 delegate_debug_id_str="", 52 delegate_debug_id_int=-1, 53 delegate_debug_metadata=bytes(), 54 start_time=1001, 55 end_time=2002, 56 ), 57 allocation_event=None, 58 debug_event=None, 59 ), 60 flatcc.Event( 61 profile_event=flatcc.ProfileEvent( 62 name="test_profile_event_delegated", 63 chain_index=1, 64 instruction_id=1, 65 delegate_debug_id_str="", 66 delegate_debug_id_int=13, 67 delegate_debug_metadata=bytes(), 68 start_time=1001, 69 end_time=2002, 70 ), 71 allocation_event=None, 72 debug_event=None, 73 ), 74 flatcc.Event( 75 profile_event=None, 76 allocation_event=flatcc.AllocationEvent( 77 allocator_id=1, 78 allocation_size=8, 79 ), 80 debug_event=None, 81 ), 82 flatcc.Event( 83 profile_event=None, 84 allocation_event=None, 85 debug_event=flatcc.DebugEvent( 86 name="test_debug_event", 87 chain_index=1, 88 instruction_id=0, 89 delegate_debug_id_str="56", 90 delegate_debug_id_int=-1, 91 debug_entry=flatcc.Value( 92 val=flatcc.ValueType.TENSOR.value, 93 tensor=flatcc.Tensor( 94 scalar_type=flatcc.ScalarType.INT, 95 sizes=[1], 96 strides=[1], 97 offset=12345, 98 ), 99 tensor_list=flatcc.TensorList( 100 [ 101 flatcc.Tensor( 102 scalar_type=flatcc.ScalarType.INT, 103 sizes=[1], 104 strides=[1], 105 offset=12345, 106 ) 107 ] 108 ), 109 int_value=flatcc.Int(1), 110 float_value=flatcc.Float(1.0), 111 double_value=flatcc.Double(1.0), 112 bool_value=flatcc.Bool(False), 113 output=flatcc.Bool(True), 114 ), 115 ), 116 ), 117 ], 118 ) 119 ], 120 ) 121 122 123class TestSerializeFlatCC(unittest.TestCase): 124 def test_serialize(self) -> None: 125 import json 126 127 program = get_sample_etdump_flatcc() 128 129 flatcc_from_py = serialize_to_etdump_flatcc(program) 130 deserialized_obj = deserialize_from_etdump_flatcc( 131 flatcc_from_py, size_prefixed=False 132 ) 133 self.assertEqual( 134 program, 135 deserialized_obj, 136 msg="\n".join( 137 diff_jsons( 138 json.dumps(program, cls=_DataclassEncoder, indent=4), 139 json.dumps(deserialized_obj, cls=_DataclassEncoder, indent=4), 140 ) 141 ), 142 ) 143