1# Owner(s): ["module: functorch"] 2 3"""Adapted from https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/tests/test_parsing.py. 4 5MIT License 6 7Copyright (c) 2018 Alex Rogozhnikov 8 9Permission is hereby granted, free of charge, to any person obtaining a copy 10of this software and associated documentation files (the "Software"), to deal 11in the Software without restriction, including without limitation the rights 12to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13copies of the Software, and to permit persons to whom the Software is 14furnished to do so, subject to the following conditions: 15 16The above copyright notice and this permission notice shall be included in all 17copies or substantial portions of the Software. 18 19THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25SOFTWARE. 26""" 27from typing import Any, Callable, Dict 28from unittest import mock 29 30from functorch.einops._parsing import ( 31 _ellipsis, 32 AnonymousAxis, 33 parse_pattern, 34 ParsedExpression, 35 validate_rearrange_expressions, 36) 37from torch.testing._internal.common_utils import run_tests, TestCase 38 39 40mock_anonymous_axis_eq: Callable[[AnonymousAxis, object], bool] = ( 41 lambda self, other: isinstance(other, AnonymousAxis) and self.value == other.value 42) 43 44 45class TestAnonymousAxis(TestCase): 46 def test_anonymous_axes(self) -> None: 47 a, b = AnonymousAxis("2"), AnonymousAxis("2") 48 self.assertNotEqual(a, b) 49 50 with mock.patch.object(AnonymousAxis, "__eq__", mock_anonymous_axis_eq): 51 c, d = AnonymousAxis("2"), AnonymousAxis("3") 52 self.assertEqual(a, c) 53 self.assertEqual(b, c) 54 self.assertNotEqual(a, d) 55 self.assertNotEqual(b, d) 56 self.assertListEqual([a, 2, b], [c, 2, c]) 57 58 59class TestParsedExpression(TestCase): 60 def test_elementary_axis_name(self) -> None: 61 for name in [ 62 "a", 63 "b", 64 "h", 65 "dx", 66 "h1", 67 "zz", 68 "i9123", 69 "somelongname", 70 "Alex", 71 "camelCase", 72 "u_n_d_e_r_score", 73 "unreasonablyLongAxisName", 74 ]: 75 self.assertTrue(ParsedExpression.check_axis_name(name)) 76 77 for name in [ 78 "", 79 "2b", 80 "12", 81 "_startWithUnderscore", 82 "endWithUnderscore_", 83 "_", 84 "...", 85 _ellipsis, 86 ]: 87 self.assertFalse(ParsedExpression.check_axis_name(name)) 88 89 def test_invalid_expressions(self) -> None: 90 # double ellipsis should raise an error 91 ParsedExpression("... a b c d") 92 with self.assertRaises(ValueError): 93 ParsedExpression("... a b c d ...") 94 with self.assertRaises(ValueError): 95 ParsedExpression("... a b c (d ...)") 96 with self.assertRaises(ValueError): 97 ParsedExpression("(... a) b c (d ...)") 98 99 # double/missing/enclosed parenthesis 100 ParsedExpression("(a) b c (d ...)") 101 with self.assertRaises(ValueError): 102 ParsedExpression("(a)) b c (d ...)") 103 with self.assertRaises(ValueError): 104 ParsedExpression("(a b c (d ...)") 105 with self.assertRaises(ValueError): 106 ParsedExpression("(a) (()) b c (d ...)") 107 with self.assertRaises(ValueError): 108 ParsedExpression("(a) ((b c) (d ...))") 109 110 # invalid identifiers 111 ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...") 112 with self.assertRaises(ValueError): 113 ParsedExpression("1a") 114 with self.assertRaises(ValueError): 115 ParsedExpression("_pre") 116 with self.assertRaises(ValueError): 117 ParsedExpression("...pre") 118 with self.assertRaises(ValueError): 119 ParsedExpression("pre...") 120 121 @mock.patch.object(AnonymousAxis, "__eq__", mock_anonymous_axis_eq) 122 def test_parse_expression(self, *mocks: mock.MagicMock) -> None: 123 parsed = ParsedExpression("a1 b1 c1 d1") 124 self.assertSetEqual(parsed.identifiers, {"a1", "b1", "c1", "d1"}) 125 self.assertListEqual(parsed.composition, [["a1"], ["b1"], ["c1"], ["d1"]]) 126 self.assertFalse(parsed.has_non_unitary_anonymous_axes) 127 self.assertFalse(parsed.has_ellipsis) 128 129 parsed = ParsedExpression("() () () ()") 130 self.assertSetEqual(parsed.identifiers, set()) 131 self.assertListEqual(parsed.composition, [[], [], [], []]) 132 self.assertFalse(parsed.has_non_unitary_anonymous_axes) 133 self.assertFalse(parsed.has_ellipsis) 134 135 parsed = ParsedExpression("1 1 1 ()") 136 self.assertSetEqual(parsed.identifiers, set()) 137 self.assertListEqual(parsed.composition, [[], [], [], []]) 138 self.assertFalse(parsed.has_non_unitary_anonymous_axes) 139 self.assertFalse(parsed.has_ellipsis) 140 141 parsed = ParsedExpression("5 (3 4)") 142 self.assertEqual(len(parsed.identifiers), 3) 143 self.assertSetEqual( 144 { 145 i.value if isinstance(i, AnonymousAxis) else i 146 for i in parsed.identifiers 147 }, 148 {3, 4, 5}, 149 ) 150 self.assertListEqual( 151 parsed.composition, 152 [[AnonymousAxis("5")], [AnonymousAxis("3"), AnonymousAxis("4")]], 153 ) 154 self.assertTrue(parsed.has_non_unitary_anonymous_axes) 155 self.assertFalse(parsed.has_ellipsis) 156 157 parsed = ParsedExpression("5 1 (1 4) 1") 158 self.assertEqual(len(parsed.identifiers), 2) 159 self.assertSetEqual( 160 { 161 i.value if isinstance(i, AnonymousAxis) else i 162 for i in parsed.identifiers 163 }, 164 {4, 5}, 165 ) 166 self.assertListEqual( 167 parsed.composition, [[AnonymousAxis("5")], [], [AnonymousAxis("4")], []] 168 ) 169 170 parsed = ParsedExpression("name1 ... a1 12 (name2 14)") 171 self.assertEqual(len(parsed.identifiers), 6) 172 self.assertEqual( 173 len(parsed.identifiers - {"name1", _ellipsis, "a1", "name2"}), 2 174 ) 175 self.assertListEqual( 176 parsed.composition, 177 [ 178 ["name1"], 179 _ellipsis, 180 ["a1"], 181 [AnonymousAxis("12")], 182 ["name2", AnonymousAxis("14")], 183 ], 184 ) 185 self.assertTrue(parsed.has_non_unitary_anonymous_axes) 186 self.assertTrue(parsed.has_ellipsis) 187 self.assertFalse(parsed.has_ellipsis_parenthesized) 188 189 parsed = ParsedExpression("(name1 ... a1 12) name2 14") 190 self.assertEqual(len(parsed.identifiers), 6) 191 self.assertEqual( 192 len(parsed.identifiers - {"name1", _ellipsis, "a1", "name2"}), 2 193 ) 194 self.assertListEqual( 195 parsed.composition, 196 [ 197 ["name1", _ellipsis, "a1", AnonymousAxis("12")], 198 ["name2"], 199 [AnonymousAxis("14")], 200 ], 201 ) 202 self.assertTrue(parsed.has_non_unitary_anonymous_axes) 203 self.assertTrue(parsed.has_ellipsis) 204 self.assertTrue(parsed.has_ellipsis_parenthesized) 205 206 207class TestParsingUtils(TestCase): 208 def test_parse_pattern_number_of_arrows(self) -> None: 209 axes_lengths: Dict[str, int] = {} 210 211 too_many_arrows_pattern = "a -> b -> c -> d" 212 with self.assertRaises(ValueError): 213 parse_pattern(too_many_arrows_pattern, axes_lengths) 214 215 too_few_arrows_pattern = "a" 216 with self.assertRaises(ValueError): 217 parse_pattern(too_few_arrows_pattern, axes_lengths) 218 219 just_right_arrows = "a -> a" 220 parse_pattern(just_right_arrows, axes_lengths) 221 222 def test_ellipsis_invalid_identifier(self) -> None: 223 axes_lengths: Dict[str, int] = {"a": 1, _ellipsis: 2} 224 pattern = f"a {_ellipsis} -> {_ellipsis} a" 225 with self.assertRaises(ValueError): 226 parse_pattern(pattern, axes_lengths) 227 228 def test_ellipsis_matching(self) -> None: 229 axes_lengths: Dict[str, int] = {} 230 231 pattern = "a -> a ..." 232 with self.assertRaises(ValueError): 233 parse_pattern(pattern, axes_lengths) 234 235 # raising an error on this pattern is handled by the rearrange expression validation 236 pattern = "a ... -> a" 237 parse_pattern(pattern, axes_lengths) 238 239 pattern = "a ... -> ... a" 240 parse_pattern(pattern, axes_lengths) 241 242 def test_left_parenthesized_ellipsis(self) -> None: 243 axes_lengths: Dict[str, int] = {} 244 245 pattern = "(...) -> ..." 246 with self.assertRaises(ValueError): 247 parse_pattern(pattern, axes_lengths) 248 249 250class MaliciousRepr: 251 def __repr__(self) -> str: 252 return "print('hello world!')" 253 254 255class TestValidateRearrangeExpressions(TestCase): 256 def test_validate_axes_lengths_are_integers(self) -> None: 257 axes_lengths: Dict[str, Any] = {"a": 1, "b": 2, "c": 3} 258 pattern = "a b c -> c b a" 259 left, right = parse_pattern(pattern, axes_lengths) 260 validate_rearrange_expressions(left, right, axes_lengths) 261 262 axes_lengths = {"a": 1, "b": 2, "c": MaliciousRepr()} 263 left, right = parse_pattern(pattern, axes_lengths) 264 with self.assertRaises(TypeError): 265 validate_rearrange_expressions(left, right, axes_lengths) 266 267 def test_non_unitary_anonymous_axes_raises_error(self) -> None: 268 axes_lengths: Dict[str, int] = {} 269 270 left_non_unitary_axis = "a 2 -> 1 1 a" 271 left, right = parse_pattern(left_non_unitary_axis, axes_lengths) 272 with self.assertRaises(ValueError): 273 validate_rearrange_expressions(left, right, axes_lengths) 274 275 right_non_unitary_axis = "1 1 a -> a 2" 276 left, right = parse_pattern(right_non_unitary_axis, axes_lengths) 277 with self.assertRaises(ValueError): 278 validate_rearrange_expressions(left, right, axes_lengths) 279 280 def test_identifier_mismatch(self) -> None: 281 axes_lengths: Dict[str, int] = {} 282 283 mismatched_identifiers = "a -> a b" 284 left, right = parse_pattern(mismatched_identifiers, axes_lengths) 285 with self.assertRaises(ValueError): 286 validate_rearrange_expressions(left, right, axes_lengths) 287 288 mismatched_identifiers = "a b -> a" 289 left, right = parse_pattern(mismatched_identifiers, axes_lengths) 290 with self.assertRaises(ValueError): 291 validate_rearrange_expressions(left, right, axes_lengths) 292 293 def test_unexpected_axes_lengths(self) -> None: 294 axes_lengths: Dict[str, int] = {"c": 2} 295 296 pattern = "a b -> b a" 297 left, right = parse_pattern(pattern, axes_lengths) 298 with self.assertRaises(ValueError): 299 validate_rearrange_expressions(left, right, axes_lengths) 300 301 302if __name__ == "__main__": 303 run_tests() 304