xref: /aosp_15_r20/external/cronet/third_party/re2/src/python/re2_test.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Copyright 2019 The RE2 Authors.  All Rights Reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file.
4"""Tests for google3.third_party.re2.python.re2."""
5
6import collections
7import pickle
8import re
9
10from absl.testing import absltest
11from absl.testing import parameterized
12import re2
13
14
15class OptionsTest(parameterized.TestCase):
16
17  @parameterized.parameters(*re2.Options.NAMES)
18  def test_option(self, name):
19    options = re2.Options()
20    value = getattr(options, name)
21    if isinstance(value, re2.Options.Encoding):
22      value = next(v for v in type(value).__members__.values() if v != value)
23    elif isinstance(value, bool):
24      value = not value
25    elif isinstance(value, int):
26      value = value + 1
27    else:
28      raise TypeError('option {!r}: {!r} {!r}'.format(name, type(value), value))
29    setattr(options, name, value)
30    self.assertEqual(value, getattr(options, name))
31
32
33class Re2CompileTest(parameterized.TestCase):
34  """Contains tests that apply to the re2 module only.
35
36  We disagree with Python on the string types of group names,
37  so there is no point attempting to verify consistency.
38  """
39
40  @parameterized.parameters(
41      (u'(foo*)(?P<bar>qux+)', 2, [(u'bar', 2)]),
42      (b'(foo*)(?P<bar>qux+)', 2, [(b'bar', 2)]),
43      (u'(foo*)(?P<中文>qux+)', 2, [(u'中文', 2)]),
44  )
45  def test_compile(self, pattern, expected_groups, expected_groupindex):
46    regexp = re2.compile(pattern)
47    self.assertIs(regexp, re2.compile(pattern))  # cached
48    self.assertIs(regexp, re2.compile(regexp))  # cached
49    with self.assertRaisesRegex(re2.error,
50                                ('pattern is already compiled, so '
51                                 'options may not be specified')):
52      options = re2.Options()
53      options.log_errors = not options.log_errors
54      re2.compile(regexp, options=options)
55    self.assertIsNotNone(regexp.options)
56    self.assertEqual(expected_groups, regexp.groups)
57    self.assertDictEqual(dict(expected_groupindex), regexp.groupindex)
58
59  def test_compile_with_options(self):
60    options = re2.Options()
61    options.max_mem = 100
62    with self.assertRaisesRegex(re2.error, 'pattern too large'):
63      re2.compile('.{1000}', options=options)
64
65  def test_programsize_reverseprogramsize(self):
66    regexp = re2.compile('a+b')
67    self.assertEqual(7, regexp.programsize)
68    self.assertEqual(7, regexp.reverseprogramsize)
69
70  def test_programfanout_reverseprogramfanout(self):
71    regexp = re2.compile('a+b')
72    self.assertListEqual([1, 1], regexp.programfanout)
73    self.assertListEqual([3], regexp.reverseprogramfanout)
74
75  @parameterized.parameters(
76      (u'abc', 0, None),
77      (b'abc', 0, None),
78      (u'abc', 10, (b'abc', b'abc')),
79      (b'abc', 10, (b'abc', b'abc')),
80      (u'ab*c', 10, (b'ab', b'ac')),
81      (b'ab*c', 10, (b'ab', b'ac')),
82      (u'ab+c', 10, (b'abb', b'abc')),
83      (b'ab+c', 10, (b'abb', b'abc')),
84      (u'ab?c', 10, (b'abc', b'ac')),
85      (b'ab?c', 10, (b'abc', b'ac')),
86      (u'.*', 10, (b'', b'\xf4\xbf\xbf\xc0')),
87      (b'.*', 10, None),
88      (u'\\C*', 10, None),
89      (b'\\C*', 10, None),
90  )
91  def test_possiblematchrange(self, pattern, maxlen, expected_min_max):
92    # For brevity, the string type of pattern determines the encoding.
93    # It would otherwise be possible to have bytes with UTF8, but as per
94    # the module docstring, it isn't permitted to have str with LATIN1.
95    options = re2.Options()
96    if isinstance(pattern, str):
97      options.encoding = re2.Options.Encoding.UTF8
98    else:
99      options.encoding = re2.Options.Encoding.LATIN1
100    regexp = re2.compile(pattern, options=options)
101    if expected_min_max:
102      self.assertEqual(expected_min_max, regexp.possiblematchrange(maxlen))
103    else:
104      with self.assertRaisesRegex(re2.error, 'failed to compute match range'):
105        regexp.possiblematchrange(maxlen)
106
107
108Params = collections.namedtuple(
109    'Params', ('pattern', 'text', 'spans', 'search', 'match', 'fullmatch'))
110
111PARAMS = [
112    Params(u'\\d+', u'Hello, world.', None, False, False, False),
113    Params(b'\\d+', b'Hello, world.', None, False, False, False),
114    Params(u'\\s+', u'Hello, world.', [(6, 7)], True, False, False),
115    Params(b'\\s+', b'Hello, world.', [(6, 7)], True, False, False),
116    Params(u'\\w+', u'Hello, world.', [(0, 5)], True, True, False),
117    Params(b'\\w+', b'Hello, world.', [(0, 5)], True, True, False),
118    Params(u'(\\d+)?', u'Hello, world.', [(0, 0), (-1, -1)], True, True, False),
119    Params(b'(\\d+)?', b'Hello, world.', [(0, 0), (-1, -1)], True, True, False),
120    Params(u'youtube(_device|_md|_gaia|_multiday|_multiday_gaia)?',
121           u'youtube_ads', [(0, 7), (-1, -1)], True, True, False),
122    Params(b'youtube(_device|_md|_gaia|_multiday|_multiday_gaia)?',
123           b'youtube_ads', [(0, 7), (-1, -1)], True, True, False),
124]
125
126
127def upper(match):
128  return match.group().upper()
129
130
131class ReRegexpTest(parameterized.TestCase):
132  """Contains tests that apply to the re and re2 modules."""
133
134  MODULE = re
135
136  @parameterized.parameters((p.pattern,) for p in PARAMS)
137  def test_pickle(self, pattern):
138    regexp = self.MODULE.compile(pattern)
139    rick = pickle.loads(pickle.dumps(regexp))
140    self.assertEqual(regexp.pattern, rick.pattern)
141
142  @parameterized.parameters(
143      (p.pattern, p.text, (p.spans if p.search else None)) for p in PARAMS)
144  def test_search(self, pattern, text, expected_spans):
145    match = self.MODULE.search(pattern, text)
146    if expected_spans is None:
147      self.assertIsNone(match)
148    else:
149      spans = [match.span(group) for group in range(match.re.groups + 1)]
150      self.assertListEqual(expected_spans, spans)
151
152  def test_search_with_pos_and_endpos(self):
153    regexp = self.MODULE.compile(u'.+')  # empty string NOT allowed
154    text = u'I \u2665 RE2!'
155    # Note that len(text) is the position of the empty string at the end of
156    # text, so range() stops at len(text) + 1 in order to include len(text).
157    for pos in range(len(text) + 1):
158      for endpos in range(pos, len(text) + 1):
159        match = regexp.search(text, pos=pos, endpos=endpos)
160        if pos == endpos:
161          self.assertIsNone(match)
162        else:
163          self.assertEqual(pos, match.pos)
164          self.assertEqual(endpos, match.endpos)
165          self.assertEqual(pos, match.start())
166          self.assertEqual(endpos, match.end())
167          self.assertTupleEqual((pos, endpos), match.span())
168
169  def test_search_with_bogus_pos_and_endpos(self):
170    regexp = self.MODULE.compile(u'.*')  # empty string allowed
171    text = u'I \u2665 RE2!'
172
173    match = regexp.search(text, pos=-100)
174    self.assertEqual(0, match.pos)
175    match = regexp.search(text, pos=100)
176    self.assertEqual(8, match.pos)
177
178    match = regexp.search(text, endpos=-100)
179    self.assertEqual(0, match.endpos)
180    match = regexp.search(text, endpos=100)
181    self.assertEqual(8, match.endpos)
182
183    match = regexp.search(text, pos=100, endpos=-100)
184    self.assertIsNone(match)
185
186  @parameterized.parameters(
187      (p.pattern, p.text, (p.spans if p.match else None)) for p in PARAMS)
188  def test_match(self, pattern, text, expected_spans):
189    match = self.MODULE.match(pattern, text)
190    if expected_spans is None:
191      self.assertIsNone(match)
192    else:
193      spans = [match.span(group) for group in range(match.re.groups + 1)]
194      self.assertListEqual(expected_spans, spans)
195
196  @parameterized.parameters(
197      (p.pattern, p.text, (p.spans if p.fullmatch else None)) for p in PARAMS)
198  def test_fullmatch(self, pattern, text, expected_spans):
199    match = self.MODULE.fullmatch(pattern, text)
200    if expected_spans is None:
201      self.assertIsNone(match)
202    else:
203      spans = [match.span(group) for group in range(match.re.groups + 1)]
204      self.assertListEqual(expected_spans, spans)
205
206  @parameterized.parameters(
207      (u'', u'', [(0, 0)]),
208      (b'', b'', [(0, 0)]),
209      (u'', u'x', [(0, 0), (1, 1)]),
210      (b'', b'x', [(0, 0), (1, 1)]),
211      (u'', u'xy', [(0, 0), (1, 1), (2, 2)]),
212      (b'', b'xy', [(0, 0), (1, 1), (2, 2)]),
213      (u'.', u'xy', [(0, 1), (1, 2)]),
214      (b'.', b'xy', [(0, 1), (1, 2)]),
215      (u'x', u'xy', [(0, 1)]),
216      (b'x', b'xy', [(0, 1)]),
217      (u'y', u'xy', [(1, 2)]),
218      (b'y', b'xy', [(1, 2)]),
219      (u'z', u'xy', []),
220      (b'z', b'xy', []),
221      (u'\\w*', u'Hello, world.', [(0, 5), (5, 5), (6, 6), (7, 12), (12, 12),
222                                   (13, 13)]),
223      (b'\\w*', b'Hello, world.', [(0, 5), (5, 5), (6, 6), (7, 12), (12, 12),
224                                   (13, 13)]),
225  )
226  def test_finditer(self, pattern, text, expected_matches):
227    matches = [match.span() for match in self.MODULE.finditer(pattern, text)]
228    self.assertListEqual(expected_matches, matches)
229
230  @parameterized.parameters(
231      (u'\\w\\w+', u'Hello, world.', [u'Hello', u'world']),
232      (b'\\w\\w+', b'Hello, world.', [b'Hello', b'world']),
233      (u'(\\w)\\w+', u'Hello, world.', [u'H', u'w']),
234      (b'(\\w)\\w+', b'Hello, world.', [b'H', b'w']),
235      (u'(\\w)(\\w+)', u'Hello, world.', [(u'H', u'ello'), (u'w', u'orld')]),
236      (b'(\\w)(\\w+)', b'Hello, world.', [(b'H', b'ello'), (b'w', b'orld')]),
237      (u'(\\w)(\\w+)?', u'Hello, w.', [(u'H', u'ello'), (u'w', u'')]),
238      (b'(\\w)(\\w+)?', b'Hello, w.', [(b'H', b'ello'), (b'w', b'')]),
239  )
240  def test_findall(self, pattern, text, expected_matches):
241    matches = self.MODULE.findall(pattern, text)
242    self.assertListEqual(expected_matches, matches)
243
244  @parameterized.parameters(
245      (u'\\W+', u'Hello, world.', -1, [u'Hello, world.']),
246      (b'\\W+', b'Hello, world.', -1, [b'Hello, world.']),
247      (u'\\W+', u'Hello, world.', 0, [u'Hello', u'world', u'']),
248      (b'\\W+', b'Hello, world.', 0, [b'Hello', b'world', b'']),
249      (u'\\W+', u'Hello, world.', 1, [u'Hello', u'world.']),
250      (b'\\W+', b'Hello, world.', 1, [b'Hello', b'world.']),
251      (u'(\\W+)', u'Hello, world.', -1, [u'Hello, world.']),
252      (b'(\\W+)', b'Hello, world.', -1, [b'Hello, world.']),
253      (u'(\\W+)', u'Hello, world.', 0, [u'Hello', u', ', u'world', u'.', u'']),
254      (b'(\\W+)', b'Hello, world.', 0, [b'Hello', b', ', b'world', b'.', b'']),
255      (u'(\\W+)', u'Hello, world.', 1, [u'Hello', u', ', u'world.']),
256      (b'(\\W+)', b'Hello, world.', 1, [b'Hello', b', ', b'world.']),
257  )
258  def test_split(self, pattern, text, maxsplit, expected_pieces):
259    pieces = self.MODULE.split(pattern, text, maxsplit)
260    self.assertListEqual(expected_pieces, pieces)
261
262  @parameterized.parameters(
263      (u'\\w+', upper, u'Hello, world.', -1, u'Hello, world.', 0),
264      (b'\\w+', upper, b'Hello, world.', -1, b'Hello, world.', 0),
265      (u'\\w+', upper, u'Hello, world.', 0, u'HELLO, WORLD.', 2),
266      (b'\\w+', upper, b'Hello, world.', 0, b'HELLO, WORLD.', 2),
267      (u'\\w+', upper, u'Hello, world.', 1, u'HELLO, world.', 1),
268      (b'\\w+', upper, b'Hello, world.', 1, b'HELLO, world.', 1),
269      (u'\\w+', u'MEEP', u'Hello, world.', -1, u'Hello, world.', 0),
270      (b'\\w+', b'MEEP', b'Hello, world.', -1, b'Hello, world.', 0),
271      (u'\\w+', u'MEEP', u'Hello, world.', 0, u'MEEP, MEEP.', 2),
272      (b'\\w+', b'MEEP', b'Hello, world.', 0, b'MEEP, MEEP.', 2),
273      (u'\\w+', u'MEEP', u'Hello, world.', 1, u'MEEP, world.', 1),
274      (b'\\w+', b'MEEP', b'Hello, world.', 1, b'MEEP, world.', 1),
275      (u'\\\\', u'\\\\\\\\', u'Hello,\\world.', 0, u'Hello,\\\\world.', 1),
276      (b'\\\\', b'\\\\\\\\', b'Hello,\\world.', 0, b'Hello,\\\\world.', 1),
277  )
278  def test_subn_sub(self, pattern, repl, text, count, expected_joined_pieces,
279                    expected_numsplit):
280    joined_pieces, numsplit = self.MODULE.subn(pattern, repl, text, count)
281    self.assertEqual(expected_joined_pieces, joined_pieces)
282    self.assertEqual(expected_numsplit, numsplit)
283
284    joined_pieces = self.MODULE.sub(pattern, repl, text, count)
285    self.assertEqual(expected_joined_pieces, joined_pieces)
286
287
288class Re2RegexpTest(ReRegexpTest):
289  """Contains tests that apply to the re2 module only."""
290
291  MODULE = re2
292
293  def test_compile_with_latin1_encoding(self):
294    options = re2.Options()
295    options.encoding = re2.Options.Encoding.LATIN1
296    with self.assertRaisesRegex(re2.error,
297                                ('string type of pattern is str, but '
298                                 'encoding specified in options is LATIN1')):
299      re2.compile(u'.?', options=options)
300
301    # ... whereas this is fine, of course.
302    re2.compile(b'.?', options=options)
303
304  @parameterized.parameters(
305      (u'\\p{Lo}', u'\u0ca0_\u0ca0', [(0, 1), (2, 3)]),
306      (b'\\p{Lo}', b'\xe0\xb2\xa0_\xe0\xb2\xa0', [(0, 3), (4, 7)]),
307  )
308  def test_finditer_with_utf8(self, pattern, text, expected_matches):
309    matches = [match.span() for match in self.MODULE.finditer(pattern, text)]
310    self.assertListEqual(expected_matches, matches)
311
312  def test_purge(self):
313    re2.compile('Goodbye, world.')
314    self.assertGreater(re2._Regexp._make.cache_info().currsize, 0)
315    re2.purge()
316    self.assertEqual(re2._Regexp._make.cache_info().currsize, 0)
317
318
319class Re2EscapeTest(parameterized.TestCase):
320  """Contains tests that apply to the re2 module only.
321
322  We disagree with Python on the escaping of some characters,
323  so there is no point attempting to verify consistency.
324  """
325
326  @parameterized.parameters(
327      (u'a*b+c?', u'a\\*b\\+c\\?'),
328      (b'a*b+c?', b'a\\*b\\+c\\?'),
329  )
330  def test_escape(self, pattern, expected_escaped):
331    escaped = re2.escape(pattern)
332    self.assertEqual(expected_escaped, escaped)
333
334
335class ReMatchTest(parameterized.TestCase):
336  """Contains tests that apply to the re and re2 modules."""
337
338  MODULE = re
339
340  def test_expand(self):
341    pattern = u'(?P<S>[\u2600-\u26ff]+).*?(?P<P>[^\\s\\w]+)'
342    text = u'I \u2665 RE2!\n'
343    match = self.MODULE.search(pattern, text)
344
345    self.assertEqual(u'\u2665\n!', match.expand(u'\\1\\n\\2'))
346    self.assertEqual(u'\u2665\n!', match.expand(u'\\g<1>\\n\\g<2>'))
347    self.assertEqual(u'\u2665\n!', match.expand(u'\\g<S>\\n\\g<P>'))
348    self.assertEqual(u'\\1\\2\n\u2665!', match.expand(u'\\\\1\\\\2\\n\\1\\2'))
349
350  def test_expand_with_octal(self):
351    pattern = u'()()()()()()()()()(\\w+)'
352    text = u'Hello, world.'
353    match = self.MODULE.search(pattern, text)
354
355    self.assertEqual(u'Hello\n', match.expand(u'\\g<0>\\n'))
356    self.assertEqual(u'Hello\n', match.expand(u'\\g<10>\\n'))
357
358    self.assertEqual(u'\x00\n', match.expand(u'\\0\\n'))
359    self.assertEqual(u'\x00\n', match.expand(u'\\00\\n'))
360    self.assertEqual(u'\x00\n', match.expand(u'\\000\\n'))
361    self.assertEqual(u'\x000\n', match.expand(u'\\0000\\n'))
362
363    self.assertEqual(u'\n', match.expand(u'\\1\\n'))
364    self.assertEqual(u'Hello\n', match.expand(u'\\10\\n'))
365    self.assertEqual(u'@\n', match.expand(u'\\100\\n'))
366    self.assertEqual(u'@0\n', match.expand(u'\\1000\\n'))
367
368  def test_getitem_group_groups_groupdict(self):
369    pattern = u'(?P<S>[\u2600-\u26ff]+).*?(?P<P>[^\\s\\w]+)'
370    text = u'Hello, world.\nI \u2665 RE2!\nGoodbye, world.\n'
371    match = self.MODULE.search(pattern, text)
372
373    self.assertEqual(u'\u2665 RE2!', match[0])
374    self.assertEqual(u'\u2665', match[1])
375    self.assertEqual(u'!', match[2])
376    self.assertEqual(u'\u2665', match[u'S'])
377    self.assertEqual(u'!', match[u'P'])
378
379    self.assertEqual(u'\u2665 RE2!', match.group())
380    self.assertEqual(u'\u2665 RE2!', match.group(0))
381    self.assertEqual(u'\u2665', match.group(1))
382    self.assertEqual(u'!', match.group(2))
383    self.assertEqual(u'\u2665', match.group(u'S'))
384    self.assertEqual(u'!', match.group(u'P'))
385
386    self.assertTupleEqual((u'\u2665', u'!'), match.group(1, 2))
387    self.assertTupleEqual((u'\u2665', u'!'), match.group(u'S', u'P'))
388    self.assertTupleEqual((u'\u2665', u'!'), match.groups())
389    self.assertDictEqual({u'S': u'\u2665', u'P': u'!'}, match.groupdict())
390
391  def test_bogus_group_start_end_and_span(self):
392    pattern = u'(?P<S>[\u2600-\u26ff]+).*?(?P<P>[^\\s\\w]+)'
393    text = u'I \u2665 RE2!\n'
394    match = self.MODULE.search(pattern, text)
395
396    self.assertRaises(IndexError, match.group, -1)
397    self.assertRaises(IndexError, match.group, 3)
398    self.assertRaises(IndexError, match.group, 'X')
399
400    self.assertRaises(IndexError, match.start, -1)
401    self.assertRaises(IndexError, match.start, 3)
402
403    self.assertRaises(IndexError, match.end, -1)
404    self.assertRaises(IndexError, match.end, 3)
405
406    self.assertRaises(IndexError, match.span, -1)
407    self.assertRaises(IndexError, match.span, 3)
408
409  @parameterized.parameters(
410      (u'((a)(b))((c)(d))', u'foo bar qux', None, None),
411      (u'(?P<one>(a)(b))((c)(d))', u'foo abcd qux', 4, None),
412      (u'(?P<one>(a)(b))(?P<four>(c)(d))', u'foo abcd qux', 4, 'four'),
413  )
414  def test_lastindex_lastgroup(self, pattern, text, expected_lastindex,
415                               expected_lastgroup):
416    match = self.MODULE.search(pattern, text)
417    if expected_lastindex is None:
418      self.assertIsNone(match)
419    else:
420      self.assertEqual(expected_lastindex, match.lastindex)
421      self.assertEqual(expected_lastgroup, match.lastgroup)
422
423
424class Re2MatchTest(ReMatchTest):
425  """Contains tests that apply to the re2 module only."""
426
427  MODULE = re2
428
429
430class SetTest(absltest.TestCase):
431
432  def test_search(self):
433    s = re2.Set.SearchSet()
434    self.assertEqual(0, s.Add('\\d+'))
435    self.assertEqual(1, s.Add('\\s+'))
436    self.assertEqual(2, s.Add('\\w+'))
437    self.assertRaises(re2.error, s.Add, '(MEEP')
438    s.Compile()
439    self.assertItemsEqual([1, 2], s.Match('Hello, world.'))
440
441  def test_match(self):
442    s = re2.Set.MatchSet()
443    self.assertEqual(0, s.Add('\\d+'))
444    self.assertEqual(1, s.Add('\\s+'))
445    self.assertEqual(2, s.Add('\\w+'))
446    self.assertRaises(re2.error, s.Add, '(MEEP')
447    s.Compile()
448    self.assertItemsEqual([2], s.Match('Hello, world.'))
449
450  def test_fullmatch(self):
451    s = re2.Set.FullMatchSet()
452    self.assertEqual(0, s.Add('\\d+'))
453    self.assertEqual(1, s.Add('\\s+'))
454    self.assertEqual(2, s.Add('\\w+'))
455    self.assertRaises(re2.error, s.Add, '(MEEP')
456    s.Compile()
457    self.assertIsNone(s.Match('Hello, world.'))
458
459
460class FilterTest(absltest.TestCase):
461
462  def test_match(self):
463    f = re2.Filter()
464    self.assertEqual(0, f.Add('Hello, \\w+\\.'))
465    self.assertEqual(1, f.Add('\\w+, world\\.'))
466    self.assertEqual(2, f.Add('Goodbye, \\w+\\.'))
467    self.assertRaises(re2.error, f.Add, '(MEEP')
468    f.Compile()
469    self.assertItemsEqual([0, 1], f.Match('Hello, world.', potential=True))
470    self.assertItemsEqual([0, 1], f.Match('HELLO, WORLD.', potential=True))
471    self.assertItemsEqual([0, 1], f.Match('Hello, world.'))
472    self.assertIsNone(f.Match('HELLO, WORLD.'))
473
474    self.assertRaises(IndexError, f.re, -1)
475    self.assertRaises(IndexError, f.re, 3)
476    self.assertEqual('Goodbye, \\w+\\.', f.re(2).pattern)
477    # Verify whether the underlying RE2 object is usable.
478    self.assertEqual(0, f.re(2).groups)
479
480  def test_issue_484(self):
481    # Previously, the shim would dereference a null pointer and crash.
482    f = re2.Filter()
483    with self.assertRaisesRegex(re2.error,
484                                r'Match\(\) called before compiling'):
485      f.Match('')
486
487
488if __name__ == '__main__':
489  absltest.main()
490