xref: /aosp_15_r20/external/pytorch/test/functorch/test_parsing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3"""Adapted from https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/tests/test_parsing.py.
4
5MIT License
6
7Copyright (c) 2018 Alex Rogozhnikov
8
9Permission is hereby granted, free of charge, to any person obtaining a copy
10of this software and associated documentation files (the "Software"), to deal
11in the Software without restriction, including without limitation the rights
12to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13copies of the Software, and to permit persons to whom the Software is
14furnished to do so, subject to the following conditions:
15
16The above copyright notice and this permission notice shall be included in all
17copies or substantial portions of the Software.
18
19THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25SOFTWARE.
26"""
27from typing import Any, Callable, Dict
28from unittest import mock
29
30from functorch.einops._parsing import (
31    _ellipsis,
32    AnonymousAxis,
33    parse_pattern,
34    ParsedExpression,
35    validate_rearrange_expressions,
36)
37from torch.testing._internal.common_utils import run_tests, TestCase
38
39
40mock_anonymous_axis_eq: Callable[[AnonymousAxis, object], bool] = (
41    lambda self, other: isinstance(other, AnonymousAxis) and self.value == other.value
42)
43
44
45class TestAnonymousAxis(TestCase):
46    def test_anonymous_axes(self) -> None:
47        a, b = AnonymousAxis("2"), AnonymousAxis("2")
48        self.assertNotEqual(a, b)
49
50        with mock.patch.object(AnonymousAxis, "__eq__", mock_anonymous_axis_eq):
51            c, d = AnonymousAxis("2"), AnonymousAxis("3")
52            self.assertEqual(a, c)
53            self.assertEqual(b, c)
54            self.assertNotEqual(a, d)
55            self.assertNotEqual(b, d)
56            self.assertListEqual([a, 2, b], [c, 2, c])
57
58
59class TestParsedExpression(TestCase):
60    def test_elementary_axis_name(self) -> None:
61        for name in [
62            "a",
63            "b",
64            "h",
65            "dx",
66            "h1",
67            "zz",
68            "i9123",
69            "somelongname",
70            "Alex",
71            "camelCase",
72            "u_n_d_e_r_score",
73            "unreasonablyLongAxisName",
74        ]:
75            self.assertTrue(ParsedExpression.check_axis_name(name))
76
77        for name in [
78            "",
79            "2b",
80            "12",
81            "_startWithUnderscore",
82            "endWithUnderscore_",
83            "_",
84            "...",
85            _ellipsis,
86        ]:
87            self.assertFalse(ParsedExpression.check_axis_name(name))
88
89    def test_invalid_expressions(self) -> None:
90        # double ellipsis should raise an error
91        ParsedExpression("... a b c d")
92        with self.assertRaises(ValueError):
93            ParsedExpression("... a b c d ...")
94        with self.assertRaises(ValueError):
95            ParsedExpression("... a b c (d ...)")
96        with self.assertRaises(ValueError):
97            ParsedExpression("(... a) b c (d ...)")
98
99        # double/missing/enclosed parenthesis
100        ParsedExpression("(a) b c (d ...)")
101        with self.assertRaises(ValueError):
102            ParsedExpression("(a)) b c (d ...)")
103        with self.assertRaises(ValueError):
104            ParsedExpression("(a b c (d ...)")
105        with self.assertRaises(ValueError):
106            ParsedExpression("(a) (()) b c (d ...)")
107        with self.assertRaises(ValueError):
108            ParsedExpression("(a) ((b c) (d ...))")
109
110        # invalid identifiers
111        ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...")
112        with self.assertRaises(ValueError):
113            ParsedExpression("1a")
114        with self.assertRaises(ValueError):
115            ParsedExpression("_pre")
116        with self.assertRaises(ValueError):
117            ParsedExpression("...pre")
118        with self.assertRaises(ValueError):
119            ParsedExpression("pre...")
120
121    @mock.patch.object(AnonymousAxis, "__eq__", mock_anonymous_axis_eq)
122    def test_parse_expression(self, *mocks: mock.MagicMock) -> None:
123        parsed = ParsedExpression("a1  b1   c1    d1")
124        self.assertSetEqual(parsed.identifiers, {"a1", "b1", "c1", "d1"})
125        self.assertListEqual(parsed.composition, [["a1"], ["b1"], ["c1"], ["d1"]])
126        self.assertFalse(parsed.has_non_unitary_anonymous_axes)
127        self.assertFalse(parsed.has_ellipsis)
128
129        parsed = ParsedExpression("() () () ()")
130        self.assertSetEqual(parsed.identifiers, set())
131        self.assertListEqual(parsed.composition, [[], [], [], []])
132        self.assertFalse(parsed.has_non_unitary_anonymous_axes)
133        self.assertFalse(parsed.has_ellipsis)
134
135        parsed = ParsedExpression("1 1 1 ()")
136        self.assertSetEqual(parsed.identifiers, set())
137        self.assertListEqual(parsed.composition, [[], [], [], []])
138        self.assertFalse(parsed.has_non_unitary_anonymous_axes)
139        self.assertFalse(parsed.has_ellipsis)
140
141        parsed = ParsedExpression("5 (3 4)")
142        self.assertEqual(len(parsed.identifiers), 3)
143        self.assertSetEqual(
144            {
145                i.value if isinstance(i, AnonymousAxis) else i
146                for i in parsed.identifiers
147            },
148            {3, 4, 5},
149        )
150        self.assertListEqual(
151            parsed.composition,
152            [[AnonymousAxis("5")], [AnonymousAxis("3"), AnonymousAxis("4")]],
153        )
154        self.assertTrue(parsed.has_non_unitary_anonymous_axes)
155        self.assertFalse(parsed.has_ellipsis)
156
157        parsed = ParsedExpression("5 1 (1 4) 1")
158        self.assertEqual(len(parsed.identifiers), 2)
159        self.assertSetEqual(
160            {
161                i.value if isinstance(i, AnonymousAxis) else i
162                for i in parsed.identifiers
163            },
164            {4, 5},
165        )
166        self.assertListEqual(
167            parsed.composition, [[AnonymousAxis("5")], [], [AnonymousAxis("4")], []]
168        )
169
170        parsed = ParsedExpression("name1 ... a1 12 (name2 14)")
171        self.assertEqual(len(parsed.identifiers), 6)
172        self.assertEqual(
173            len(parsed.identifiers - {"name1", _ellipsis, "a1", "name2"}), 2
174        )
175        self.assertListEqual(
176            parsed.composition,
177            [
178                ["name1"],
179                _ellipsis,
180                ["a1"],
181                [AnonymousAxis("12")],
182                ["name2", AnonymousAxis("14")],
183            ],
184        )
185        self.assertTrue(parsed.has_non_unitary_anonymous_axes)
186        self.assertTrue(parsed.has_ellipsis)
187        self.assertFalse(parsed.has_ellipsis_parenthesized)
188
189        parsed = ParsedExpression("(name1 ... a1 12) name2 14")
190        self.assertEqual(len(parsed.identifiers), 6)
191        self.assertEqual(
192            len(parsed.identifiers - {"name1", _ellipsis, "a1", "name2"}), 2
193        )
194        self.assertListEqual(
195            parsed.composition,
196            [
197                ["name1", _ellipsis, "a1", AnonymousAxis("12")],
198                ["name2"],
199                [AnonymousAxis("14")],
200            ],
201        )
202        self.assertTrue(parsed.has_non_unitary_anonymous_axes)
203        self.assertTrue(parsed.has_ellipsis)
204        self.assertTrue(parsed.has_ellipsis_parenthesized)
205
206
207class TestParsingUtils(TestCase):
208    def test_parse_pattern_number_of_arrows(self) -> None:
209        axes_lengths: Dict[str, int] = {}
210
211        too_many_arrows_pattern = "a -> b -> c -> d"
212        with self.assertRaises(ValueError):
213            parse_pattern(too_many_arrows_pattern, axes_lengths)
214
215        too_few_arrows_pattern = "a"
216        with self.assertRaises(ValueError):
217            parse_pattern(too_few_arrows_pattern, axes_lengths)
218
219        just_right_arrows = "a -> a"
220        parse_pattern(just_right_arrows, axes_lengths)
221
222    def test_ellipsis_invalid_identifier(self) -> None:
223        axes_lengths: Dict[str, int] = {"a": 1, _ellipsis: 2}
224        pattern = f"a {_ellipsis} -> {_ellipsis} a"
225        with self.assertRaises(ValueError):
226            parse_pattern(pattern, axes_lengths)
227
228    def test_ellipsis_matching(self) -> None:
229        axes_lengths: Dict[str, int] = {}
230
231        pattern = "a -> a ..."
232        with self.assertRaises(ValueError):
233            parse_pattern(pattern, axes_lengths)
234
235        # raising an error on this pattern is handled by the rearrange expression validation
236        pattern = "a ... -> a"
237        parse_pattern(pattern, axes_lengths)
238
239        pattern = "a ... -> ... a"
240        parse_pattern(pattern, axes_lengths)
241
242    def test_left_parenthesized_ellipsis(self) -> None:
243        axes_lengths: Dict[str, int] = {}
244
245        pattern = "(...) -> ..."
246        with self.assertRaises(ValueError):
247            parse_pattern(pattern, axes_lengths)
248
249
250class MaliciousRepr:
251    def __repr__(self) -> str:
252        return "print('hello world!')"
253
254
255class TestValidateRearrangeExpressions(TestCase):
256    def test_validate_axes_lengths_are_integers(self) -> None:
257        axes_lengths: Dict[str, Any] = {"a": 1, "b": 2, "c": 3}
258        pattern = "a b c -> c b a"
259        left, right = parse_pattern(pattern, axes_lengths)
260        validate_rearrange_expressions(left, right, axes_lengths)
261
262        axes_lengths = {"a": 1, "b": 2, "c": MaliciousRepr()}
263        left, right = parse_pattern(pattern, axes_lengths)
264        with self.assertRaises(TypeError):
265            validate_rearrange_expressions(left, right, axes_lengths)
266
267    def test_non_unitary_anonymous_axes_raises_error(self) -> None:
268        axes_lengths: Dict[str, int] = {}
269
270        left_non_unitary_axis = "a 2 -> 1 1 a"
271        left, right = parse_pattern(left_non_unitary_axis, axes_lengths)
272        with self.assertRaises(ValueError):
273            validate_rearrange_expressions(left, right, axes_lengths)
274
275        right_non_unitary_axis = "1 1 a -> a 2"
276        left, right = parse_pattern(right_non_unitary_axis, axes_lengths)
277        with self.assertRaises(ValueError):
278            validate_rearrange_expressions(left, right, axes_lengths)
279
280    def test_identifier_mismatch(self) -> None:
281        axes_lengths: Dict[str, int] = {}
282
283        mismatched_identifiers = "a -> a b"
284        left, right = parse_pattern(mismatched_identifiers, axes_lengths)
285        with self.assertRaises(ValueError):
286            validate_rearrange_expressions(left, right, axes_lengths)
287
288        mismatched_identifiers = "a b -> a"
289        left, right = parse_pattern(mismatched_identifiers, axes_lengths)
290        with self.assertRaises(ValueError):
291            validate_rearrange_expressions(left, right, axes_lengths)
292
293    def test_unexpected_axes_lengths(self) -> None:
294        axes_lengths: Dict[str, int] = {"c": 2}
295
296        pattern = "a b -> b a"
297        left, right = parse_pattern(pattern, axes_lengths)
298        with self.assertRaises(ValueError):
299            validate_rearrange_expressions(left, right, axes_lengths)
300
301
302if __name__ == "__main__":
303    run_tests()
304