xref: /aosp_15_r20/external/cronet/third_party/re2/src/python/re2.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Copyright 2019 The RE2 Authors.  All Rights Reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file.
4r"""A drop-in replacement for the re module.
5
6It uses RE2 under the hood, of course, so various PCRE features
7(e.g. backreferences, look-around assertions) are not supported.
8See https://github.com/google/re2/wiki/Syntax for the canonical
9reference, but known syntactic "gotchas" relative to Python are:
10
11  * PCRE supports \Z and \z; RE2 supports \z; Python supports \z,
12    but calls it \Z. You must rewrite \Z to \z in pattern strings.
13
14Known differences between this module's API and the re module's API:
15
16  * The error class does not provide any error information as attributes.
17  * The Options class replaces the re module's flags with RE2's options as
18    gettable/settable properties. Please see re2.h for their documentation.
19  * The pattern string and the input string do not have to be the same type.
20    Any str will be encoded to UTF-8.
21  * The pattern string cannot be str if the options specify Latin-1 encoding.
22
23This module's LRU cache contains a maximum of 128 regular expression objects.
24Each regular expression object's underlying RE2 object uses a maximum of 8MiB
25of memory (by default). Hence, this module's LRU cache uses a maximum of 1GiB
26of memory (by default), but in most cases, it should use much less than that.
27"""
28
29import codecs
30import functools
31import itertools
32
33import _re2
34
35
36# pybind11 translates C++ exceptions to Python exceptions.
37# We use that same Python exception class for consistency.
38error = _re2.Error
39
40
41class Options(_re2.RE2.Options):
42
43  __slots__ = ()
44
45  NAMES = (
46      'max_mem',
47      'encoding',
48      'posix_syntax',
49      'longest_match',
50      'log_errors',
51      'literal',
52      'never_nl',
53      'dot_nl',
54      'never_capture',
55      'case_sensitive',
56      'perl_classes',
57      'word_boundary',
58      'one_line',
59  )
60
61
62def compile(pattern, options=None):
63  if isinstance(pattern, _Regexp):
64    if options:
65      raise error('pattern is already compiled, so '
66                  'options may not be specified')
67    pattern = pattern._pattern
68  options = options or Options()
69  values = tuple(getattr(options, name) for name in Options.NAMES)
70  return _Regexp._make(pattern, values)
71
72
73def search(pattern, text, options=None):
74  return compile(pattern, options=options).search(text)
75
76
77def match(pattern, text, options=None):
78  return compile(pattern, options=options).match(text)
79
80
81def fullmatch(pattern, text, options=None):
82  return compile(pattern, options=options).fullmatch(text)
83
84
85def finditer(pattern, text, options=None):
86  return compile(pattern, options=options).finditer(text)
87
88
89def findall(pattern, text, options=None):
90  return compile(pattern, options=options).findall(text)
91
92
93def split(pattern, text, maxsplit=0, options=None):
94  return compile(pattern, options=options).split(text, maxsplit)
95
96
97def subn(pattern, repl, text, count=0, options=None):
98  return compile(pattern, options=options).subn(repl, text, count)
99
100
101def sub(pattern, repl, text, count=0, options=None):
102  return compile(pattern, options=options).sub(repl, text, count)
103
104
105def _encode(t):
106  return t.encode(encoding='utf-8')
107
108
109def _decode(b):
110  return b.decode(encoding='utf-8')
111
112
113def escape(pattern):
114  if isinstance(pattern, str):
115    encoded_pattern = _encode(pattern)
116    escaped = _re2.RE2.QuoteMeta(encoded_pattern)
117    decoded_escaped = _decode(escaped)
118    return decoded_escaped
119  else:
120    escaped = _re2.RE2.QuoteMeta(pattern)
121    return escaped
122
123
124def purge():
125  return _Regexp._make.cache_clear()
126
127
128_Anchor = _re2.RE2.Anchor
129_NULL_SPAN = (-1, -1)
130
131
132class _Regexp(object):
133
134  __slots__ = ('_pattern', '_regexp')
135
136  @classmethod
137  @functools.lru_cache(typed=True)
138  def _make(cls, pattern, values):
139    options = Options()
140    for name, value in zip(Options.NAMES, values):
141      setattr(options, name, value)
142    return cls(pattern, options)
143
144  def __init__(self, pattern, options):
145    self._pattern = pattern
146    if isinstance(self._pattern, str):
147      if options.encoding == Options.Encoding.LATIN1:
148        raise error('string type of pattern is str, but '
149                    'encoding specified in options is LATIN1')
150      encoded_pattern = _encode(self._pattern)
151      self._regexp = _re2.RE2(encoded_pattern, options)
152    else:
153      self._regexp = _re2.RE2(self._pattern, options)
154    if not self._regexp.ok():
155      raise error(self._regexp.error())
156
157  def __getstate__(self):
158    options = {name: getattr(self.options, name) for name in Options.NAMES}
159    return self._pattern, options
160
161  def __setstate__(self, state):
162    pattern, options = state
163    values = tuple(options[name] for name in Options.NAMES)
164    other = _Regexp._make(pattern, values)
165    self._pattern = other._pattern
166    self._regexp = other._regexp
167
168  def _match(self, anchor, text, pos=None, endpos=None):
169    pos = 0 if pos is None else max(0, min(pos, len(text)))
170    endpos = len(text) if endpos is None else max(0, min(endpos, len(text)))
171    if pos > endpos:
172      return
173    if isinstance(text, str):
174      encoded_text = _encode(text)
175      encoded_pos = _re2.CharLenToBytes(encoded_text, 0, pos)
176      if endpos == len(text):
177        # This is the common case.
178        encoded_endpos = len(encoded_text)
179      else:
180        encoded_endpos = encoded_pos + _re2.CharLenToBytes(
181            encoded_text, encoded_pos, endpos - pos)
182      decoded_offsets = {0: 0}
183      last_offset = 0
184      while True:
185        spans = self._regexp.Match(anchor, encoded_text, encoded_pos,
186                                   encoded_endpos)
187        if spans[0] == _NULL_SPAN:
188          break
189
190        # This algorithm is linear in the length of encoded_text. Specifically,
191        # no matter how many groups there are for a given regular expression or
192        # how many iterations through the loop there are for a given generator,
193        # this algorithm uses a single, straightforward pass over encoded_text.
194        offsets = sorted(set(itertools.chain(*spans)))
195        if offsets[0] == -1:
196          offsets = offsets[1:]
197        # Discard the rest of the items because they are useless now - and we
198        # could accumulate one item per str offset in the pathological case!
199        decoded_offsets = {last_offset: decoded_offsets[last_offset]}
200        for offset in offsets:
201          decoded_offsets[offset] = (
202              decoded_offsets[last_offset] +
203              _re2.BytesToCharLen(encoded_text, last_offset, offset))
204          last_offset = offset
205
206        def decode(span):
207          if span == _NULL_SPAN:
208            return span
209          return decoded_offsets[span[0]], decoded_offsets[span[1]]
210
211        decoded_spans = [decode(span) for span in spans]
212        yield _Match(self, text, pos, endpos, decoded_spans)
213        if encoded_pos == encoded_endpos:
214          break
215        elif encoded_pos == spans[0][1]:
216          # We matched the empty string at encoded_pos and would be stuck, so
217          # in order to make forward progress, increment the str offset.
218          encoded_pos += _re2.CharLenToBytes(encoded_text, encoded_pos, 1)
219        else:
220          encoded_pos = spans[0][1]
221    else:
222      while True:
223        spans = self._regexp.Match(anchor, text, pos, endpos)
224        if spans[0] == _NULL_SPAN:
225          break
226        yield _Match(self, text, pos, endpos, spans)
227        if pos == endpos:
228          break
229        elif pos == spans[0][1]:
230          # We matched the empty string at pos and would be stuck, so in order
231          # to make forward progress, increment the bytes offset.
232          pos += 1
233        else:
234          pos = spans[0][1]
235
236  def search(self, text, pos=None, endpos=None):
237    return next(self._match(_Anchor.UNANCHORED, text, pos, endpos), None)
238
239  def match(self, text, pos=None, endpos=None):
240    return next(self._match(_Anchor.ANCHOR_START, text, pos, endpos), None)
241
242  def fullmatch(self, text, pos=None, endpos=None):
243    return next(self._match(_Anchor.ANCHOR_BOTH, text, pos, endpos), None)
244
245  def finditer(self, text, pos=None, endpos=None):
246    return self._match(_Anchor.UNANCHORED, text, pos, endpos)
247
248  def findall(self, text, pos=None, endpos=None):
249    empty = type(text)()
250    items = []
251    for match in self.finditer(text, pos, endpos):
252      if not self.groups:
253        item = match.group()
254      elif self.groups == 1:
255        item = match.groups(default=empty)[0]
256      else:
257        item = match.groups(default=empty)
258      items.append(item)
259    return items
260
261  def _split(self, cb, text, maxsplit=0):
262    if maxsplit < 0:
263      return [text], 0
264    elif maxsplit > 0:
265      matchiter = itertools.islice(self.finditer(text), maxsplit)
266    else:
267      matchiter = self.finditer(text)
268    pieces = []
269    end = 0
270    numsplit = 0
271    for match in matchiter:
272      pieces.append(text[end:match.start()])
273      pieces.extend(cb(match))
274      end = match.end()
275      numsplit += 1
276    pieces.append(text[end:])
277    return pieces, numsplit
278
279  def split(self, text, maxsplit=0):
280    cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
281    pieces, _ = self._split(cb, text, maxsplit)
282    return pieces
283
284  def subn(self, repl, text, count=0):
285    cb = lambda match: [repl(match) if callable(repl) else match.expand(repl)]
286    empty = type(text)()
287    pieces, numsplit = self._split(cb, text, count)
288    joined_pieces = empty.join(pieces)
289    return joined_pieces, numsplit
290
291  def sub(self, repl, text, count=0):
292    joined_pieces, _ = self.subn(repl, text, count)
293    return joined_pieces
294
295  @property
296  def pattern(self):
297    return self._pattern
298
299  @property
300  def options(self):
301    return self._regexp.options()
302
303  @property
304  def groups(self):
305    return self._regexp.NumberOfCapturingGroups()
306
307  @property
308  def groupindex(self):
309    groups = self._regexp.NamedCapturingGroups()
310    if isinstance(self._pattern, str):
311      decoded_groups = [(_decode(group), index) for group, index in groups]
312      return dict(decoded_groups)
313    else:
314      return dict(groups)
315
316  @property
317  def programsize(self):
318    return self._regexp.ProgramSize()
319
320  @property
321  def reverseprogramsize(self):
322    return self._regexp.ReverseProgramSize()
323
324  @property
325  def programfanout(self):
326    return self._regexp.ProgramFanout()
327
328  @property
329  def reverseprogramfanout(self):
330    return self._regexp.ReverseProgramFanout()
331
332  def possiblematchrange(self, maxlen):
333    ok, min, max = self._regexp.PossibleMatchRange(maxlen)
334    if not ok:
335      raise error('failed to compute match range')
336    return min, max
337
338
339class _Match(object):
340
341  __slots__ = ('_regexp', '_text', '_pos', '_endpos', '_spans')
342
343  def __init__(self, regexp, text, pos, endpos, spans):
344    self._regexp = regexp
345    self._text = text
346    self._pos = pos
347    self._endpos = endpos
348    self._spans = spans
349
350  # Python prioritises three-digit octal numbers over group escapes.
351  # For example, \100 should not be handled the same way as \g<10>0.
352  _OCTAL_RE = compile('\\\\[0-7][0-7][0-7]')
353
354  # Python supports \1 through \99 (inclusive) and \g<...> syntax.
355  _GROUP_RE = compile('\\\\[1-9][0-9]?|\\\\g<\\w+>')
356
357  @classmethod
358  @functools.lru_cache(typed=True)
359  def _split(cls, template):
360    if isinstance(template, str):
361      backslash = '\\'
362    else:
363      backslash = b'\\'
364    empty = type(template)()
365    pieces = [empty]
366    index = template.find(backslash)
367    while index != -1:
368      piece, template = template[:index], template[index:]
369      pieces[-1] += piece
370      octal_match = cls._OCTAL_RE.match(template)
371      group_match = cls._GROUP_RE.match(template)
372      if (not octal_match) and group_match:
373        index = group_match.end()
374        piece, template = template[:index], template[index:]
375        pieces.extend((piece, empty))
376      else:
377        # 2 isn't enough for \o, \x, \N, \u and \U escapes, but none of those
378        # should contain backslashes, so break them here and then fix them at
379        # the beginning of the next loop iteration or right before returning.
380        index = 2
381        piece, template = template[:index], template[index:]
382        pieces[-1] += piece
383      index = template.find(backslash)
384    pieces[-1] += template
385    return pieces
386
387  def expand(self, template):
388    if isinstance(template, str):
389      unescape = codecs.unicode_escape_decode
390    else:
391      unescape = codecs.escape_decode
392    empty = type(template)()
393    # Make a copy so that we don't clobber the cached pieces!
394    pieces = list(self._split(template))
395    for index, piece in enumerate(pieces):
396      if not index % 2:
397        pieces[index], _ = unescape(piece)
398      else:
399        if len(piece) <= 3:  # \1 through \99 (inclusive)
400          group = int(piece[1:])
401        else:  # \g<...>
402          group = piece[3:-1]
403          try:
404            group = int(group)
405          except ValueError:
406            pass
407        pieces[index] = self.__getitem__(group) or empty
408    joined_pieces = empty.join(pieces)
409    return joined_pieces
410
411  def __getitem__(self, group):
412    if not isinstance(group, int):
413      try:
414        group = self._regexp.groupindex[group]
415      except KeyError:
416        raise IndexError('bad group name')
417    if not 0 <= group <= self._regexp.groups:
418      raise IndexError('bad group index')
419    span = self._spans[group]
420    if span == _NULL_SPAN:
421      return None
422    return self._text[span[0]:span[1]]
423
424  def group(self, *groups):
425    if not groups:
426      groups = (0,)
427    items = (self.__getitem__(group) for group in groups)
428    return next(items) if len(groups) == 1 else tuple(items)
429
430  def groups(self, default=None):
431    items = []
432    for group in range(1, self._regexp.groups + 1):
433      item = self.__getitem__(group)
434      items.append(default if item is None else item)
435    return tuple(items)
436
437  def groupdict(self, default=None):
438    items = []
439    for group, index in self._regexp.groupindex.items():
440      item = self.__getitem__(index)
441      items.append((group, default) if item is None else (group, item))
442    return dict(items)
443
444  def start(self, group=0):
445    if not 0 <= group <= self._regexp.groups:
446      raise IndexError('bad group index')
447    return self._spans[group][0]
448
449  def end(self, group=0):
450    if not 0 <= group <= self._regexp.groups:
451      raise IndexError('bad group index')
452    return self._spans[group][1]
453
454  def span(self, group=0):
455    if not 0 <= group <= self._regexp.groups:
456      raise IndexError('bad group index')
457    return self._spans[group]
458
459  @property
460  def re(self):
461    return self._regexp
462
463  @property
464  def string(self):
465    return self._text
466
467  @property
468  def pos(self):
469    return self._pos
470
471  @property
472  def endpos(self):
473    return self._endpos
474
475  @property
476  def lastindex(self):
477    max_end = -1
478    max_group = None
479    # We look for the rightmost right parenthesis by keeping the first group
480    # that ends at max_end because that is the leftmost/outermost group when
481    # there are nested groups!
482    for group in range(1, self._regexp.groups + 1):
483      end = self._spans[group][1]
484      if max_end < end:
485        max_end = end
486        max_group = group
487    return max_group
488
489  @property
490  def lastgroup(self):
491    max_group = self.lastindex
492    if not max_group:
493      return None
494    for group, index in self._regexp.groupindex.items():
495      if max_group == index:
496        return group
497    return None
498
499
500class Set(object):
501  """A Pythonic wrapper around RE2::Set."""
502
503  __slots__ = ('_set')
504
505  def __init__(self, anchor, options=None):
506    options = options or Options()
507    self._set = _re2.Set(anchor, options)
508
509  @classmethod
510  def SearchSet(cls, options=None):
511    return cls(_Anchor.UNANCHORED, options=options)
512
513  @classmethod
514  def MatchSet(cls, options=None):
515    return cls(_Anchor.ANCHOR_START, options=options)
516
517  @classmethod
518  def FullMatchSet(cls, options=None):
519    return cls(_Anchor.ANCHOR_BOTH, options=options)
520
521  def Add(self, pattern):
522    if isinstance(pattern, str):
523      encoded_pattern = _encode(pattern)
524      index = self._set.Add(encoded_pattern)
525    else:
526      index = self._set.Add(pattern)
527    if index == -1:
528      raise error('failed to add %r to Set' % pattern)
529    return index
530
531  def Compile(self):
532    if not self._set.Compile():
533      raise error('failed to compile Set')
534
535  def Match(self, text):
536    if isinstance(text, str):
537      encoded_text = _encode(text)
538      matches = self._set.Match(encoded_text)
539    else:
540      matches = self._set.Match(text)
541    return matches or None
542
543
544class Filter(object):
545  """A Pythonic wrapper around FilteredRE2."""
546
547  __slots__ = ('_filter', '_patterns')
548
549  def __init__(self):
550    self._filter = _re2.Filter()
551    self._patterns = []
552
553  def Add(self, pattern, options=None):
554    options = options or Options()
555    if isinstance(pattern, str):
556      encoded_pattern = _encode(pattern)
557      index = self._filter.Add(encoded_pattern, options)
558    else:
559      index = self._filter.Add(pattern, options)
560    if index == -1:
561      raise error('failed to add %r to Filter' % pattern)
562    self._patterns.append(pattern)
563    return index
564
565  def Compile(self):
566    if not self._filter.Compile():
567      raise error('failed to compile Filter')
568
569  def Match(self, text, potential=False):
570    if isinstance(text, str):
571      encoded_text = _encode(text)
572      matches = self._filter.Match(encoded_text, potential)
573    else:
574      matches = self._filter.Match(text, potential)
575    return matches or None
576
577  def re(self, index):
578    if not 0 <= index < len(self._patterns):
579      raise IndexError('bad index')
580    proxy = object.__new__(_Regexp)
581    proxy._pattern = self._patterns[index]
582    proxy._regexp = self._filter.GetRE2(index)
583    return proxy
584