1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for util.traverse_ir.""" 16 17import collections 18 19import unittest 20 21from compiler.util import ir_data 22from compiler.util import ir_data_utils 23from compiler.util import traverse_ir 24 25_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{ 26"module": [ 27 { 28 "type": [ 29 { 30 "structure": { 31 "field": [ 32 { 33 "location": { 34 "start": { "constant": { "value": "0" } }, 35 "size": { "constant": { "value": "8" } } 36 }, 37 "type": { 38 "atomic_type": { 39 "reference": { 40 "canonical_name": { 41 "module_file": "", 42 "object_path": ["UInt"] 43 } 44 } 45 } 46 }, 47 "name": { "name": { "text": "field1" } } 48 }, 49 { 50 "location": { 51 "start": { "constant": { "value": "8" } }, 52 "size": { "constant": { "value": "16" } } 53 }, 54 "type": { 55 "array_type": { 56 "base_type": { 57 "atomic_type": { 58 "reference": { 59 "canonical_name": { 60 "module_file": "", 61 "object_path": ["UInt"] 62 } 63 } 64 } 65 }, 66 "element_count": { "constant": { "value": "8" } } 67 } 68 }, 69 "name": { "name": { "text": "field2" } } 70 } 71 ] 72 }, 73 "name": { "name": { "text": "Foo" } }, 74 "subtype": [ 75 { 76 "structure": { 77 "field": [ 78 { 79 "location": { 80 "start": { "constant": { "value": "24" } }, 81 "size": { "constant": { "value": "32" } } 82 }, 83 "type": { 84 "atomic_type": { 85 "reference": { 86 "canonical_name": { 87 "module_file": "", 88 "object_path": ["UInt"] 89 } 90 } 91 } 92 }, 93 "name": { "name": { "text": "bar_field1" } } 94 }, 95 { 96 "location": { 97 "start": { "constant": { "value": "32" } }, 98 "size": { "constant": { "value": "320" } } 99 }, 100 "type": { 101 "array_type": { 102 "base_type": { 103 "array_type": { 104 "base_type": { 105 "atomic_type": { 106 "reference": { 107 "canonical_name": { 108 "module_file": "", 109 "object_path": ["UInt"] 110 } 111 } 112 } 113 }, 114 "element_count": { "constant": { "value": "16" } } 115 } 116 }, 117 "automatic": { } 118 } 119 }, 120 "name": { "name": { "text": "bar_field2" } } 121 } 122 ] 123 }, 124 "name": { "name": { "text": "Bar" } } 125 } 126 ] 127 }, 128 { 129 "enumeration": { 130 "value": [ 131 { 132 "name": { "name": { "text": "ONE" } }, 133 "value": { "constant": { "value": "1" } } 134 }, 135 { 136 "name": { "name": { "text": "TWO" } }, 137 "value": { 138 "function": { 139 "function": "ADDITION", 140 "args": [ 141 { "constant": { "value": "1" } }, 142 { "constant": { "value": "1" } } 143 ], 144 "function_name": { "text": "+" } 145 } 146 } 147 } 148 ] 149 }, 150 "name": { "name": { "text": "Bar" } } 151 } 152 ], 153 "source_file_name": "t.emb" 154 }, 155 { 156 "type": [ 157 { 158 "external": { }, 159 "name": { 160 "name": { "text": "UInt" }, 161 "canonical_name": { "module_file": "", "object_path": ["UInt"] } 162 }, 163 "attribute": [ 164 { 165 "name": { "text": "statically_sized" }, 166 "value": { "expression": { "boolean_constant": { "value": true } } } 167 }, 168 { 169 "name": { "text": "size_in_bits" }, 170 "value": { "expression": { "constant": { "value": "64" } } } 171 } 172 ] 173 } 174 ], 175 "source_file_name": "" 176 } 177] 178}""") 179 180 181def _count_entries(sequence): 182 counts = collections.Counter() 183 for entry in sequence: 184 counts[entry] += 1 185 return counts 186 187 188def _record_constant(constant, constant_list): 189 constant_list.append(int(constant.value)) 190 191 192def _record_field_name_and_constant(constant, constant_list, field): 193 constant_list.append((field.name.name.text, int(constant.value))) 194 195 196def _record_file_name_and_constant(constant, constant_list, source_file_name): 197 constant_list.append((source_file_name, int(constant.value))) 198 199 200def _record_location_parameter_and_constant(constant, constant_list, 201 location=None): 202 constant_list.append((location, int(constant.value))) 203 204 205def _record_kind_and_constant(constant, constant_list, type_definition): 206 if type_definition.HasField("enumeration"): 207 constant_list.append(("enumeration", int(constant.value))) 208 elif type_definition.HasField("structure"): 209 constant_list.append(("structure", int(constant.value))) 210 elif type_definition.HasField("external"): 211 constant_list.append(("external", int(constant.value))) 212 else: 213 assert False, "Shouldn't be here." 214 215 216class TraverseIrTest(unittest.TestCase): 217 218 def test_filter_on_type(self): 219 constants = [] 220 traverse_ir.fast_traverse_ir_top_down( 221 _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, 222 parameters={"constant_list": constants}) 223 self.assertEqual( 224 _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]), 225 _count_entries(constants)) 226 227 def test_filter_on_type_in_type(self): 228 constants = [] 229 traverse_ir.fast_traverse_ir_top_down( 230 _EXAMPLE_IR, 231 [ir_data.Function, ir_data.Expression, ir_data.NumericConstant], 232 _record_constant, 233 parameters={"constant_list": constants}) 234 self.assertEqual([1, 1], constants) 235 236 def test_filter_on_type_star_type(self): 237 struct_constants = [] 238 traverse_ir.fast_traverse_ir_top_down( 239 _EXAMPLE_IR, [ir_data.Structure, ir_data.NumericConstant], 240 _record_constant, 241 parameters={"constant_list": struct_constants}) 242 self.assertEqual(_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]), 243 _count_entries(struct_constants)) 244 enum_constants = [] 245 traverse_ir.fast_traverse_ir_top_down( 246 _EXAMPLE_IR, [ir_data.Enum, ir_data.NumericConstant], _record_constant, 247 parameters={"constant_list": enum_constants}) 248 self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants)) 249 250 def test_filter_on_not_type(self): 251 notstruct_constants = [] 252 traverse_ir.fast_traverse_ir_top_down( 253 _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, 254 skip_descendants_of=(ir_data.Structure,), 255 parameters={"constant_list": notstruct_constants}) 256 self.assertEqual(_count_entries([1, 1, 1, 64]), 257 _count_entries(notstruct_constants)) 258 259 def test_field_is_populated(self): 260 constants = [] 261 traverse_ir.fast_traverse_ir_top_down( 262 _EXAMPLE_IR, [ir_data.Field, ir_data.NumericConstant], 263 _record_field_name_and_constant, 264 parameters={"constant_list": constants}) 265 self.assertEqual(_count_entries([ 266 ("field1", 0), ("field1", 8), ("field2", 8), ("field2", 8), 267 ("field2", 16), ("bar_field1", 24), ("bar_field1", 32), 268 ("bar_field2", 16), ("bar_field2", 32), ("bar_field2", 320) 269 ]), _count_entries(constants)) 270 271 def test_file_name_is_populated(self): 272 constants = [] 273 traverse_ir.fast_traverse_ir_top_down( 274 _EXAMPLE_IR, [ir_data.NumericConstant], _record_file_name_and_constant, 275 parameters={"constant_list": constants}) 276 self.assertEqual(_count_entries([ 277 ("t.emb", 0), ("t.emb", 8), ("t.emb", 8), ("t.emb", 8), ("t.emb", 16), 278 ("t.emb", 24), ("t.emb", 32), ("t.emb", 16), ("t.emb", 32), 279 ("t.emb", 320), ("t.emb", 1), ("t.emb", 1), ("t.emb", 1), ("", 64) 280 ]), _count_entries(constants)) 281 282 def test_type_definition_is_populated(self): 283 constants = [] 284 traverse_ir.fast_traverse_ir_top_down( 285 _EXAMPLE_IR, [ir_data.NumericConstant], _record_kind_and_constant, 286 parameters={"constant_list": constants}) 287 self.assertEqual(_count_entries([ 288 ("structure", 0), ("structure", 8), ("structure", 8), ("structure", 8), 289 ("structure", 16), ("structure", 24), ("structure", 32), 290 ("structure", 16), ("structure", 32), ("structure", 320), 291 ("enumeration", 1), ("enumeration", 1), ("enumeration", 1), 292 ("external", 64) 293 ]), _count_entries(constants)) 294 295 def test_keyword_args_dict_in_action(self): 296 call_counts = {"populated": 0, "not": 0} 297 298 def check_field_is_populated(node, **kwargs): 299 del node # Unused. 300 self.assertTrue(kwargs["field"]) 301 call_counts["populated"] += 1 302 303 def check_field_is_not_populated(node, **kwargs): 304 del node # Unused. 305 self.assertFalse("field" in kwargs) 306 call_counts["not"] += 1 307 308 traverse_ir.fast_traverse_ir_top_down( 309 _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated) 310 self.assertEqual(7, call_counts["populated"]) 311 312 traverse_ir.fast_traverse_ir_top_down( 313 _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], 314 check_field_is_not_populated) 315 self.assertEqual(2, call_counts["not"]) 316 317 def test_pass_only_to_sub_nodes(self): 318 constants = [] 319 320 def pass_location_down(field): 321 return { 322 "location": (int(field.location.start.constant.value), 323 int(field.location.size.constant.value)) 324 } 325 326 traverse_ir.fast_traverse_ir_top_down( 327 _EXAMPLE_IR, [ir_data.NumericConstant], 328 _record_location_parameter_and_constant, 329 incidental_actions={ir_data.Field: pass_location_down}, 330 parameters={"constant_list": constants, "location": None}) 331 self.assertEqual(_count_entries([ 332 ((0, 8), 0), ((0, 8), 8), ((8, 16), 8), ((8, 16), 8), ((8, 16), 16), 333 ((24, 32), 24), ((24, 32), 32), ((32, 320), 16), ((32, 320), 32), 334 ((32, 320), 320), (None, 1), (None, 1), (None, 1), (None, 64) 335 ]), _count_entries(constants)) 336 337 338if __name__ == "__main__": 339 unittest.main() 340