xref: /aosp_15_r20/build/soong/scripts/hiddenapi/signature_trie_test.py (revision 333d2b3687b3a337dbcca9d65000bca186795e39)
1#!/usr/bin/env python
2#
3# Copyright (C) 2022 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the 'License');
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an 'AS IS' BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Unit tests for verify_overlaps_test.py."""
17import io
18import unittest
19
20from signature_trie import InteriorNode
21from signature_trie import signature_trie
22
23
24class TestSignatureToElements(unittest.TestCase):
25
26    @staticmethod
27    def signature_to_elements(signature):
28        return InteriorNode.signature_to_elements(signature)
29
30    @staticmethod
31    def elements_to_signature(elements):
32        return InteriorNode.elements_to_selector(elements)
33
34    def test_nested_inner_classes(self):
35        elements = [
36            ("package", "java"),
37            ("package", "lang"),
38            ("class", "ProcessBuilder"),
39            ("class", "Redirect"),
40            ("class", "1"),
41            ("member", "<init>()V"),
42        ]
43        signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
44        self.assertEqual(elements, self.signature_to_elements(signature))
45        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
46
47    def test_basic_member(self):
48        elements = [
49            ("package", "java"),
50            ("package", "lang"),
51            ("class", "Object"),
52            ("member", "hashCode()I"),
53        ]
54        signature = "Ljava/lang/Object;->hashCode()I"
55        self.assertEqual(elements, self.signature_to_elements(signature))
56        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
57
58    def test_double_dollar_class(self):
59        elements = [
60            ("package", "java"),
61            ("package", "lang"),
62            ("class", "CharSequence"),
63            ("class", ""),
64            ("class", "ExternalSyntheticLambda0"),
65            ("member", "<init>(Ljava/lang/CharSequence;)V"),
66        ]
67        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
68                    "-><init>(Ljava/lang/CharSequence;)V"
69        self.assertEqual(elements, self.signature_to_elements(signature))
70        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
71
72    def test_no_member(self):
73        elements = [
74            ("package", "java"),
75            ("package", "lang"),
76            ("class", "CharSequence"),
77            ("class", ""),
78            ("class", "ExternalSyntheticLambda0"),
79        ]
80        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
81        self.assertEqual(elements, self.signature_to_elements(signature))
82        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
83
84    def test_wildcard(self):
85        elements = [
86            ("package", "java"),
87            ("package", "lang"),
88            ("wildcard", "*"),
89        ]
90        signature = "java/lang/*"
91        self.assertEqual(elements, self.signature_to_elements(signature))
92        self.assertEqual(signature, self.elements_to_signature(elements))
93
94    def test_recursive_wildcard(self):
95        elements = [
96            ("package", "java"),
97            ("package", "lang"),
98            ("wildcard", "**"),
99        ]
100        signature = "java/lang/**"
101        self.assertEqual(elements, self.signature_to_elements(signature))
102        self.assertEqual(signature, self.elements_to_signature(elements))
103
104    def test_no_packages_wildcard(self):
105        elements = [
106            ("wildcard", "*"),
107        ]
108        signature = "*"
109        self.assertEqual(elements, self.signature_to_elements(signature))
110        self.assertEqual(signature, self.elements_to_signature(elements))
111
112    def test_no_packages_recursive_wildcard(self):
113        elements = [
114            ("wildcard", "**"),
115        ]
116        signature = "**"
117        self.assertEqual(elements, self.signature_to_elements(signature))
118        self.assertEqual(signature, self.elements_to_signature(elements))
119
120    def test_non_standard_class_name(self):
121        elements = [
122            ("package", "javax"),
123            ("package", "crypto"),
124            ("class", "extObjectInputStream"),
125        ]
126        signature = "Ljavax/crypto/extObjectInputStream"
127        self.assertEqual(elements, self.signature_to_elements(signature))
128        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
129
130    def test_invalid_pattern_wildcard(self):
131        pattern = "Ljava/lang/Class*"
132        with self.assertRaises(Exception) as context:
133            self.signature_to_elements(pattern)
134        self.assertIn("invalid wildcard 'Class*'", str(context.exception))
135
136    def test_invalid_pattern_wildcard_and_member(self):
137        pattern = "Ljava/lang/*;->hashCode()I"
138        with self.assertRaises(Exception) as context:
139            self.signature_to_elements(pattern)
140        self.assertIn(
141            "contains wildcard '*' and member signature 'hashCode()I'",
142            str(context.exception))
143
144
145class TestValues(unittest.TestCase):
146    def test_add_then_get(self):
147        trie = signature_trie()
148        trie.add("La/b/C;->l()", 1)
149        trie.add("La/b/C$D;->m()", "A")
150        trie.add("La/b/C$D;->n()", {})
151
152        package_a_node = next(iter(trie.child_nodes()))
153        self.assertEqual("package", package_a_node.type)
154        self.assertEqual("a", package_a_node.selector)
155
156        package_b_node = next(iter(package_a_node.child_nodes()))
157        self.assertEqual("package", package_b_node.type)
158        self.assertEqual("a/b", package_b_node.selector)
159
160        class_c_node = next(iter(package_b_node.child_nodes()))
161        self.assertEqual("class", class_c_node.type)
162        self.assertEqual("a/b/C", class_c_node.selector)
163
164        self.assertEqual([1, "A", {}], class_c_node.values(lambda _: True))
165
166class TestGetMatchingRows(unittest.TestCase):
167    extractInput = """
168Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
169Ljava/lang/Character;->serialVersionUID:J
170Ljava/lang/Object;->hashCode()I
171Ljava/lang/Object;->toString()Ljava/lang/String;
172Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V
173Ljava/util/zip/ZipFile;-><clinit>()V
174"""
175
176    def read_trie(self):
177        trie = signature_trie()
178        with io.StringIO(self.extractInput.strip()) as f:
179            for line in iter(f.readline, ""):
180                line = line.rstrip()
181                trie.add(line, line)
182        return trie
183
184    def check_patterns(self, pattern, expected):
185        trie = self.read_trie()
186        self.check_node_patterns(trie, pattern, expected)
187
188    def check_node_patterns(self, node, pattern, expected):
189        actual = list(node.get_matching_rows(pattern))
190        actual.sort()
191        self.assertEqual(expected, actual)
192
193    def test_member_pattern(self):
194        self.check_patterns("java/util/zip/ZipFile;-><clinit>()V",
195                            ["Ljava/util/zip/ZipFile;-><clinit>()V"])
196
197    def test_class_pattern(self):
198        self.check_patterns("java/lang/Object", [
199            "Ljava/lang/Object;->hashCode()I",
200            "Ljava/lang/Object;->toString()Ljava/lang/String;",
201        ])
202
203    # pylint: disable=line-too-long
204    def test_nested_class_pattern(self):
205        self.check_patterns("java/lang/Character", [
206            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
207            "Ljava/lang/Character;->serialVersionUID:J",
208        ])
209
210    def test_wildcard(self):
211        self.check_patterns("java/lang/*", [
212            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
213            "Ljava/lang/Character;->serialVersionUID:J",
214            "Ljava/lang/Object;->hashCode()I",
215            "Ljava/lang/Object;->toString()Ljava/lang/String;",
216            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
217        ])
218
219    def test_recursive_wildcard(self):
220        self.check_patterns("java/**", [
221            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
222            "Ljava/lang/Character;->serialVersionUID:J",
223            "Ljava/lang/Object;->hashCode()I",
224            "Ljava/lang/Object;->toString()Ljava/lang/String;",
225            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
226            "Ljava/util/zip/ZipFile;-><clinit>()V",
227        ])
228
229    def test_node_wildcard(self):
230        trie = self.read_trie()
231        node = list(trie.child_nodes())[0]
232        self.check_node_patterns(node, "**", [
233            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
234            "Ljava/lang/Character;->serialVersionUID:J",
235            "Ljava/lang/Object;->hashCode()I",
236            "Ljava/lang/Object;->toString()Ljava/lang/String;",
237            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
238            "Ljava/util/zip/ZipFile;-><clinit>()V",
239        ])
240
241    # pylint: enable=line-too-long
242
243
244if __name__ == "__main__":
245    unittest.main(verbosity=2)
246