xref: /aosp_15_r20/external/emboss/compiler/util/traverse_ir_test.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for util.traverse_ir."""
16
17import collections
18
19import unittest
20
21from compiler.util import ir_data
22from compiler.util import ir_data_utils
23from compiler.util import traverse_ir
24
25_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{
26"module": [
27  {
28    "type": [
29      {
30        "structure": {
31          "field": [
32            {
33              "location": {
34                "start": { "constant": { "value": "0" } },
35                "size": { "constant": { "value": "8" } }
36              },
37              "type": {
38                "atomic_type": {
39                  "reference": {
40                    "canonical_name": {
41                      "module_file": "",
42                      "object_path": ["UInt"]
43                    }
44                  }
45                }
46              },
47              "name": { "name": { "text": "field1" } }
48            },
49            {
50              "location": {
51                "start": { "constant": { "value": "8" } },
52                "size": { "constant": { "value": "16" } }
53              },
54              "type": {
55                "array_type": {
56                  "base_type": {
57                    "atomic_type": {
58                      "reference": {
59                        "canonical_name": {
60                          "module_file": "",
61                          "object_path": ["UInt"]
62                        }
63                      }
64                    }
65                  },
66                  "element_count": { "constant": { "value": "8" } }
67                }
68              },
69              "name": { "name": { "text": "field2" } }
70            }
71          ]
72        },
73        "name": { "name": { "text": "Foo" } },
74        "subtype": [
75          {
76            "structure": {
77              "field": [
78                {
79                  "location": {
80                    "start": { "constant": { "value": "24" } },
81                    "size": { "constant": { "value": "32" } }
82                  },
83                  "type": {
84                    "atomic_type": {
85                      "reference": {
86                        "canonical_name": {
87                          "module_file": "",
88                          "object_path": ["UInt"]
89                        }
90                      }
91                    }
92                  },
93                  "name": { "name": { "text": "bar_field1" } }
94                },
95                {
96                  "location": {
97                    "start": { "constant": { "value": "32" } },
98                    "size": { "constant": { "value": "320" } }
99                  },
100                  "type": {
101                    "array_type": {
102                      "base_type": {
103                        "array_type": {
104                          "base_type": {
105                            "atomic_type": {
106                              "reference": {
107                                "canonical_name": {
108                                  "module_file": "",
109                                  "object_path": ["UInt"]
110                                }
111                              }
112                            }
113                          },
114                          "element_count": { "constant": { "value": "16" } }
115                        }
116                      },
117                      "automatic": { }
118                    }
119                  },
120                  "name": { "name": { "text": "bar_field2" } }
121                }
122              ]
123            },
124            "name": { "name": { "text": "Bar" } }
125          }
126        ]
127      },
128      {
129        "enumeration": {
130          "value": [
131            {
132              "name": { "name": { "text": "ONE" } },
133              "value": { "constant": { "value": "1" } }
134            },
135            {
136              "name": { "name": { "text": "TWO" } },
137              "value": {
138                "function": {
139                  "function": "ADDITION",
140                  "args": [
141                    { "constant": { "value": "1" } },
142                    { "constant": { "value": "1" } }
143                  ],
144                  "function_name": { "text": "+" }
145                }
146              }
147            }
148          ]
149        },
150        "name": { "name": { "text": "Bar" } }
151      }
152    ],
153    "source_file_name": "t.emb"
154  },
155  {
156    "type": [
157      {
158        "external": { },
159        "name": {
160          "name": { "text": "UInt" },
161          "canonical_name": { "module_file": "", "object_path": ["UInt"] }
162        },
163        "attribute": [
164          {
165            "name": { "text": "statically_sized" },
166            "value": { "expression": { "boolean_constant": { "value": true } } }
167          },
168          {
169            "name": { "text": "size_in_bits" },
170            "value": { "expression": { "constant": { "value": "64" } } }
171          }
172        ]
173      }
174    ],
175    "source_file_name": ""
176  }
177]
178}""")
179
180
181def _count_entries(sequence):
182  counts = collections.Counter()
183  for entry in sequence:
184    counts[entry] += 1
185  return counts
186
187
188def _record_constant(constant, constant_list):
189  constant_list.append(int(constant.value))
190
191
192def _record_field_name_and_constant(constant, constant_list, field):
193  constant_list.append((field.name.name.text, int(constant.value)))
194
195
196def _record_file_name_and_constant(constant, constant_list, source_file_name):
197  constant_list.append((source_file_name, int(constant.value)))
198
199
200def _record_location_parameter_and_constant(constant, constant_list,
201                                            location=None):
202  constant_list.append((location, int(constant.value)))
203
204
205def _record_kind_and_constant(constant, constant_list, type_definition):
206  if type_definition.HasField("enumeration"):
207    constant_list.append(("enumeration", int(constant.value)))
208  elif type_definition.HasField("structure"):
209    constant_list.append(("structure", int(constant.value)))
210  elif type_definition.HasField("external"):
211    constant_list.append(("external", int(constant.value)))
212  else:
213    assert False, "Shouldn't be here."
214
215
216class TraverseIrTest(unittest.TestCase):
217
218  def test_filter_on_type(self):
219    constants = []
220    traverse_ir.fast_traverse_ir_top_down(
221        _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant,
222        parameters={"constant_list": constants})
223    self.assertEqual(
224        _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]),
225        _count_entries(constants))
226
227  def test_filter_on_type_in_type(self):
228    constants = []
229    traverse_ir.fast_traverse_ir_top_down(
230        _EXAMPLE_IR,
231        [ir_data.Function, ir_data.Expression, ir_data.NumericConstant],
232        _record_constant,
233        parameters={"constant_list": constants})
234    self.assertEqual([1, 1], constants)
235
236  def test_filter_on_type_star_type(self):
237    struct_constants = []
238    traverse_ir.fast_traverse_ir_top_down(
239        _EXAMPLE_IR, [ir_data.Structure, ir_data.NumericConstant],
240        _record_constant,
241        parameters={"constant_list": struct_constants})
242    self.assertEqual(_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]),
243                     _count_entries(struct_constants))
244    enum_constants = []
245    traverse_ir.fast_traverse_ir_top_down(
246        _EXAMPLE_IR, [ir_data.Enum, ir_data.NumericConstant], _record_constant,
247        parameters={"constant_list": enum_constants})
248    self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants))
249
250  def test_filter_on_not_type(self):
251    notstruct_constants = []
252    traverse_ir.fast_traverse_ir_top_down(
253        _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant,
254        skip_descendants_of=(ir_data.Structure,),
255        parameters={"constant_list": notstruct_constants})
256    self.assertEqual(_count_entries([1, 1, 1, 64]),
257                     _count_entries(notstruct_constants))
258
259  def test_field_is_populated(self):
260    constants = []
261    traverse_ir.fast_traverse_ir_top_down(
262        _EXAMPLE_IR, [ir_data.Field, ir_data.NumericConstant],
263        _record_field_name_and_constant,
264        parameters={"constant_list": constants})
265    self.assertEqual(_count_entries([
266        ("field1", 0), ("field1", 8), ("field2", 8), ("field2", 8),
267        ("field2", 16), ("bar_field1", 24), ("bar_field1", 32),
268        ("bar_field2", 16), ("bar_field2", 32), ("bar_field2", 320)
269    ]), _count_entries(constants))
270
271  def test_file_name_is_populated(self):
272    constants = []
273    traverse_ir.fast_traverse_ir_top_down(
274        _EXAMPLE_IR, [ir_data.NumericConstant], _record_file_name_and_constant,
275        parameters={"constant_list": constants})
276    self.assertEqual(_count_entries([
277        ("t.emb", 0), ("t.emb", 8), ("t.emb", 8), ("t.emb", 8), ("t.emb", 16),
278        ("t.emb", 24), ("t.emb", 32), ("t.emb", 16), ("t.emb", 32),
279        ("t.emb", 320), ("t.emb", 1), ("t.emb", 1), ("t.emb", 1), ("", 64)
280    ]), _count_entries(constants))
281
282  def test_type_definition_is_populated(self):
283    constants = []
284    traverse_ir.fast_traverse_ir_top_down(
285        _EXAMPLE_IR, [ir_data.NumericConstant], _record_kind_and_constant,
286        parameters={"constant_list": constants})
287    self.assertEqual(_count_entries([
288        ("structure", 0), ("structure", 8), ("structure", 8), ("structure", 8),
289        ("structure", 16), ("structure", 24), ("structure", 32),
290        ("structure", 16), ("structure", 32), ("structure", 320),
291        ("enumeration", 1), ("enumeration", 1), ("enumeration", 1),
292        ("external", 64)
293    ]), _count_entries(constants))
294
295  def test_keyword_args_dict_in_action(self):
296    call_counts = {"populated": 0, "not": 0}
297
298    def check_field_is_populated(node, **kwargs):
299      del node  # Unused.
300      self.assertTrue(kwargs["field"])
301      call_counts["populated"] += 1
302
303    def check_field_is_not_populated(node, **kwargs):
304      del node  # Unused.
305      self.assertFalse("field" in kwargs)
306      call_counts["not"] += 1
307
308    traverse_ir.fast_traverse_ir_top_down(
309        _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated)
310    self.assertEqual(7, call_counts["populated"])
311
312    traverse_ir.fast_traverse_ir_top_down(
313        _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue],
314        check_field_is_not_populated)
315    self.assertEqual(2, call_counts["not"])
316
317  def test_pass_only_to_sub_nodes(self):
318    constants = []
319
320    def pass_location_down(field):
321      return {
322          "location": (int(field.location.start.constant.value),
323                       int(field.location.size.constant.value))
324      }
325
326    traverse_ir.fast_traverse_ir_top_down(
327        _EXAMPLE_IR, [ir_data.NumericConstant],
328        _record_location_parameter_and_constant,
329        incidental_actions={ir_data.Field: pass_location_down},
330        parameters={"constant_list": constants, "location": None})
331    self.assertEqual(_count_entries([
332        ((0, 8), 0), ((0, 8), 8), ((8, 16), 8), ((8, 16), 8), ((8, 16), 16),
333        ((24, 32), 24), ((24, 32), 32), ((32, 320), 16), ((32, 320), 32),
334        ((32, 320), 320), (None, 1), (None, 1), (None, 1), (None, 64)
335    ]), _count_entries(constants))
336
337
338if __name__ == "__main__":
339  unittest.main()
340