xref: /aosp_15_r20/external/cronet/third_party/re2/src/re2/prog.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2007 The RE2 Authors.  All Rights Reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 
5 // Compiled regular expression representation.
6 // Tested by compile_test.cc
7 
8 #include "re2/prog.h"
9 
10 #if defined(__AVX2__)
11 #include <immintrin.h>
12 #ifdef _MSC_VER
13 #include <intrin.h>
14 #endif
15 #endif
16 #include <stdint.h>
17 #include <string.h>
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/base/macros.h"
23 #include "absl/strings/str_format.h"
24 #include "util/logging.h"
25 #include "re2/bitmap256.h"
26 
27 namespace re2 {
28 
29 // Constructors per Inst opcode
30 
InitAlt(uint32_t out,uint32_t out1)31 void Prog::Inst::InitAlt(uint32_t out, uint32_t out1) {
32   DCHECK_EQ(out_opcode_, 0);
33   set_out_opcode(out, kInstAlt);
34   out1_ = out1;
35 }
36 
InitByteRange(int lo,int hi,int foldcase,uint32_t out)37 void Prog::Inst::InitByteRange(int lo, int hi, int foldcase, uint32_t out) {
38   DCHECK_EQ(out_opcode_, 0);
39   set_out_opcode(out, kInstByteRange);
40   lo_ = lo & 0xFF;
41   hi_ = hi & 0xFF;
42   hint_foldcase_ = foldcase&1;
43 }
44 
InitCapture(int cap,uint32_t out)45 void Prog::Inst::InitCapture(int cap, uint32_t out) {
46   DCHECK_EQ(out_opcode_, 0);
47   set_out_opcode(out, kInstCapture);
48   cap_ = cap;
49 }
50 
InitEmptyWidth(EmptyOp empty,uint32_t out)51 void Prog::Inst::InitEmptyWidth(EmptyOp empty, uint32_t out) {
52   DCHECK_EQ(out_opcode_, 0);
53   set_out_opcode(out, kInstEmptyWidth);
54   empty_ = empty;
55 }
56 
InitMatch(int32_t id)57 void Prog::Inst::InitMatch(int32_t id) {
58   DCHECK_EQ(out_opcode_, 0);
59   set_opcode(kInstMatch);
60   match_id_ = id;
61 }
62 
InitNop(uint32_t out)63 void Prog::Inst::InitNop(uint32_t out) {
64   DCHECK_EQ(out_opcode_, 0);
65   set_opcode(kInstNop);
66 }
67 
InitFail()68 void Prog::Inst::InitFail() {
69   DCHECK_EQ(out_opcode_, 0);
70   set_opcode(kInstFail);
71 }
72 
Dump()73 std::string Prog::Inst::Dump() {
74   switch (opcode()) {
75     default:
76       return absl::StrFormat("opcode %d", static_cast<int>(opcode()));
77 
78     case kInstAlt:
79       return absl::StrFormat("alt -> %d | %d", out(), out1_);
80 
81     case kInstAltMatch:
82       return absl::StrFormat("altmatch -> %d | %d", out(), out1_);
83 
84     case kInstByteRange:
85       return absl::StrFormat("byte%s [%02x-%02x] %d -> %d",
86                              foldcase() ? "/i" : "",
87                              lo_, hi_, hint(), out());
88 
89     case kInstCapture:
90       return absl::StrFormat("capture %d -> %d", cap_, out());
91 
92     case kInstEmptyWidth:
93       return absl::StrFormat("emptywidth %#x -> %d",
94                              static_cast<int>(empty_), out());
95 
96     case kInstMatch:
97       return absl::StrFormat("match! %d", match_id());
98 
99     case kInstNop:
100       return absl::StrFormat("nop -> %d", out());
101 
102     case kInstFail:
103       return absl::StrFormat("fail");
104   }
105 }
106 
Prog()107 Prog::Prog()
108   : anchor_start_(false),
109     anchor_end_(false),
110     reversed_(false),
111     did_flatten_(false),
112     did_onepass_(false),
113     start_(0),
114     start_unanchored_(0),
115     size_(0),
116     bytemap_range_(0),
117     prefix_foldcase_(false),
118     prefix_size_(0),
119     list_count_(0),
120     bit_state_text_max_size_(0),
121     dfa_mem_(0),
122     dfa_first_(NULL),
123     dfa_longest_(NULL) {
124 }
125 
~Prog()126 Prog::~Prog() {
127   DeleteDFA(dfa_longest_);
128   DeleteDFA(dfa_first_);
129   if (prefix_foldcase_)
130     delete[] prefix_dfa_;
131 }
132 
133 typedef SparseSet Workq;
134 
AddToQueue(Workq * q,int id)135 static inline void AddToQueue(Workq* q, int id) {
136   if (id != 0)
137     q->insert(id);
138 }
139 
ProgToString(Prog * prog,Workq * q)140 static std::string ProgToString(Prog* prog, Workq* q) {
141   std::string s;
142   for (Workq::iterator i = q->begin(); i != q->end(); ++i) {
143     int id = *i;
144     Prog::Inst* ip = prog->inst(id);
145     s += absl::StrFormat("%d. %s\n", id, ip->Dump());
146     AddToQueue(q, ip->out());
147     if (ip->opcode() == kInstAlt || ip->opcode() == kInstAltMatch)
148       AddToQueue(q, ip->out1());
149   }
150   return s;
151 }
152 
FlattenedProgToString(Prog * prog,int start)153 static std::string FlattenedProgToString(Prog* prog, int start) {
154   std::string s;
155   for (int id = start; id < prog->size(); id++) {
156     Prog::Inst* ip = prog->inst(id);
157     if (ip->last())
158       s += absl::StrFormat("%d. %s\n", id, ip->Dump());
159     else
160       s += absl::StrFormat("%d+ %s\n", id, ip->Dump());
161   }
162   return s;
163 }
164 
Dump()165 std::string Prog::Dump() {
166   if (did_flatten_)
167     return FlattenedProgToString(this, start_);
168 
169   Workq q(size_);
170   AddToQueue(&q, start_);
171   return ProgToString(this, &q);
172 }
173 
DumpUnanchored()174 std::string Prog::DumpUnanchored() {
175   if (did_flatten_)
176     return FlattenedProgToString(this, start_unanchored_);
177 
178   Workq q(size_);
179   AddToQueue(&q, start_unanchored_);
180   return ProgToString(this, &q);
181 }
182 
DumpByteMap()183 std::string Prog::DumpByteMap() {
184   std::string map;
185   for (int c = 0; c < 256; c++) {
186     int b = bytemap_[c];
187     int lo = c;
188     while (c < 256-1 && bytemap_[c+1] == b)
189       c++;
190     int hi = c;
191     map += absl::StrFormat("[%02x-%02x] -> %d\n", lo, hi, b);
192   }
193   return map;
194 }
195 
196 // Is ip a guaranteed match at end of text, perhaps after some capturing?
IsMatch(Prog * prog,Prog::Inst * ip)197 static bool IsMatch(Prog* prog, Prog::Inst* ip) {
198   for (;;) {
199     switch (ip->opcode()) {
200       default:
201         LOG(DFATAL) << "Unexpected opcode in IsMatch: " << ip->opcode();
202         return false;
203 
204       case kInstAlt:
205       case kInstAltMatch:
206       case kInstByteRange:
207       case kInstFail:
208       case kInstEmptyWidth:
209         return false;
210 
211       case kInstCapture:
212       case kInstNop:
213         ip = prog->inst(ip->out());
214         break;
215 
216       case kInstMatch:
217         return true;
218     }
219   }
220 }
221 
222 // Peep-hole optimizer.
Optimize()223 void Prog::Optimize() {
224   Workq q(size_);
225 
226   // Eliminate nops.  Most are taken out during compilation
227   // but a few are hard to avoid.
228   q.clear();
229   AddToQueue(&q, start_);
230   for (Workq::iterator i = q.begin(); i != q.end(); ++i) {
231     int id = *i;
232 
233     Inst* ip = inst(id);
234     int j = ip->out();
235     Inst* jp;
236     while (j != 0 && (jp=inst(j))->opcode() == kInstNop) {
237       j = jp->out();
238     }
239     ip->set_out(j);
240     AddToQueue(&q, ip->out());
241 
242     if (ip->opcode() == kInstAlt) {
243       j = ip->out1();
244       while (j != 0 && (jp=inst(j))->opcode() == kInstNop) {
245         j = jp->out();
246       }
247       ip->out1_ = j;
248       AddToQueue(&q, ip->out1());
249     }
250   }
251 
252   // Insert kInstAltMatch instructions
253   // Look for
254   //   ip: Alt -> j | k
255   //	  j: ByteRange [00-FF] -> ip
256   //    k: Match
257   // or the reverse (the above is the greedy one).
258   // Rewrite Alt to AltMatch.
259   q.clear();
260   AddToQueue(&q, start_);
261   for (Workq::iterator i = q.begin(); i != q.end(); ++i) {
262     int id = *i;
263     Inst* ip = inst(id);
264     AddToQueue(&q, ip->out());
265     if (ip->opcode() == kInstAlt)
266       AddToQueue(&q, ip->out1());
267 
268     if (ip->opcode() == kInstAlt) {
269       Inst* j = inst(ip->out());
270       Inst* k = inst(ip->out1());
271       if (j->opcode() == kInstByteRange && j->out() == id &&
272           j->lo() == 0x00 && j->hi() == 0xFF &&
273           IsMatch(this, k)) {
274         ip->set_opcode(kInstAltMatch);
275         continue;
276       }
277       if (IsMatch(this, j) &&
278           k->opcode() == kInstByteRange && k->out() == id &&
279           k->lo() == 0x00 && k->hi() == 0xFF) {
280         ip->set_opcode(kInstAltMatch);
281       }
282     }
283   }
284 }
285 
EmptyFlags(absl::string_view text,const char * p)286 uint32_t Prog::EmptyFlags(absl::string_view text, const char* p) {
287   int flags = 0;
288 
289   // ^ and \A
290   if (p == text.data())
291     flags |= kEmptyBeginText | kEmptyBeginLine;
292   else if (p[-1] == '\n')
293     flags |= kEmptyBeginLine;
294 
295   // $ and \z
296   if (p == text.data() + text.size())
297     flags |= kEmptyEndText | kEmptyEndLine;
298   else if (p < text.data() + text.size() && p[0] == '\n')
299     flags |= kEmptyEndLine;
300 
301   // \b and \B
302   if (p == text.data() && p == text.data() + text.size()) {
303     // no word boundary here
304   } else if (p == text.data()) {
305     if (IsWordChar(p[0]))
306       flags |= kEmptyWordBoundary;
307   } else if (p == text.data() + text.size()) {
308     if (IsWordChar(p[-1]))
309       flags |= kEmptyWordBoundary;
310   } else {
311     if (IsWordChar(p[-1]) != IsWordChar(p[0]))
312       flags |= kEmptyWordBoundary;
313   }
314   if (!(flags & kEmptyWordBoundary))
315     flags |= kEmptyNonWordBoundary;
316 
317   return flags;
318 }
319 
320 // ByteMapBuilder implements a coloring algorithm.
321 //
322 // The first phase is a series of "mark and merge" batches: we mark one or more
323 // [lo-hi] ranges, then merge them into our internal state. Batching is not for
324 // performance; rather, it means that the ranges are treated indistinguishably.
325 //
326 // Internally, the ranges are represented using a bitmap that stores the splits
327 // and a vector that stores the colors; both of them are indexed by the ranges'
328 // last bytes. Thus, in order to merge a [lo-hi] range, we split at lo-1 and at
329 // hi (if not already split), then recolor each range in between. The color map
330 // (i.e. from the old color to the new color) is maintained for the lifetime of
331 // the batch and so underpins this somewhat obscure approach to set operations.
332 //
333 // The second phase builds the bytemap from our internal state: we recolor each
334 // range, then store the new color (which is now the byte class) in each of the
335 // corresponding array elements. Finally, we output the number of byte classes.
336 class ByteMapBuilder {
337  public:
ByteMapBuilder()338   ByteMapBuilder() {
339     // Initial state: the [0-255] range has color 256.
340     // This will avoid problems during the second phase,
341     // in which we assign byte classes numbered from 0.
342     splits_.Set(255);
343     colors_[255] = 256;
344     nextcolor_ = 257;
345   }
346 
347   void Mark(int lo, int hi);
348   void Merge();
349   void Build(uint8_t* bytemap, int* bytemap_range);
350 
351  private:
352   int Recolor(int oldcolor);
353 
354   Bitmap256 splits_;
355   int colors_[256];
356   int nextcolor_;
357   std::vector<std::pair<int, int>> colormap_;
358   std::vector<std::pair<int, int>> ranges_;
359 
360   ByteMapBuilder(const ByteMapBuilder&) = delete;
361   ByteMapBuilder& operator=(const ByteMapBuilder&) = delete;
362 };
363 
Mark(int lo,int hi)364 void ByteMapBuilder::Mark(int lo, int hi) {
365   DCHECK_GE(lo, 0);
366   DCHECK_GE(hi, 0);
367   DCHECK_LE(lo, 255);
368   DCHECK_LE(hi, 255);
369   DCHECK_LE(lo, hi);
370 
371   // Ignore any [0-255] ranges. They cause us to recolor every range, which
372   // has no effect on the eventual result and is therefore a waste of time.
373   if (lo == 0 && hi == 255)
374     return;
375 
376   ranges_.emplace_back(lo, hi);
377 }
378 
Merge()379 void ByteMapBuilder::Merge() {
380   for (std::vector<std::pair<int, int>>::const_iterator it = ranges_.begin();
381        it != ranges_.end();
382        ++it) {
383     int lo = it->first-1;
384     int hi = it->second;
385 
386     if (0 <= lo && !splits_.Test(lo)) {
387       splits_.Set(lo);
388       int next = splits_.FindNextSetBit(lo+1);
389       colors_[lo] = colors_[next];
390     }
391     if (!splits_.Test(hi)) {
392       splits_.Set(hi);
393       int next = splits_.FindNextSetBit(hi+1);
394       colors_[hi] = colors_[next];
395     }
396 
397     int c = lo+1;
398     while (c < 256) {
399       int next = splits_.FindNextSetBit(c);
400       colors_[next] = Recolor(colors_[next]);
401       if (next == hi)
402         break;
403       c = next+1;
404     }
405   }
406   colormap_.clear();
407   ranges_.clear();
408 }
409 
Build(uint8_t * bytemap,int * bytemap_range)410 void ByteMapBuilder::Build(uint8_t* bytemap, int* bytemap_range) {
411   // Assign byte classes numbered from 0.
412   nextcolor_ = 0;
413 
414   int c = 0;
415   while (c < 256) {
416     int next = splits_.FindNextSetBit(c);
417     uint8_t b = static_cast<uint8_t>(Recolor(colors_[next]));
418     while (c <= next) {
419       bytemap[c] = b;
420       c++;
421     }
422   }
423 
424   *bytemap_range = nextcolor_;
425 }
426 
Recolor(int oldcolor)427 int ByteMapBuilder::Recolor(int oldcolor) {
428   // Yes, this is a linear search. There can be at most 256
429   // colors and there will typically be far fewer than that.
430   // Also, we need to consider keys *and* values in order to
431   // avoid recoloring a given range more than once per batch.
432   std::vector<std::pair<int, int>>::const_iterator it =
433       std::find_if(colormap_.begin(), colormap_.end(),
434                    [=](const std::pair<int, int>& kv) -> bool {
435                      return kv.first == oldcolor || kv.second == oldcolor;
436                    });
437   if (it != colormap_.end())
438     return it->second;
439   int newcolor = nextcolor_;
440   nextcolor_++;
441   colormap_.emplace_back(oldcolor, newcolor);
442   return newcolor;
443 }
444 
ComputeByteMap()445 void Prog::ComputeByteMap() {
446   // Fill in bytemap with byte classes for the program.
447   // Ranges of bytes that are treated indistinguishably
448   // will be mapped to a single byte class.
449   ByteMapBuilder builder;
450 
451   // Don't repeat the work for ^ and $.
452   bool marked_line_boundaries = false;
453   // Don't repeat the work for \b and \B.
454   bool marked_word_boundaries = false;
455 
456   for (int id = 0; id < size(); id++) {
457     Inst* ip = inst(id);
458     if (ip->opcode() == kInstByteRange) {
459       int lo = ip->lo();
460       int hi = ip->hi();
461       builder.Mark(lo, hi);
462       if (ip->foldcase() && lo <= 'z' && hi >= 'a') {
463         int foldlo = lo;
464         int foldhi = hi;
465         if (foldlo < 'a')
466           foldlo = 'a';
467         if (foldhi > 'z')
468           foldhi = 'z';
469         if (foldlo <= foldhi) {
470           foldlo += 'A' - 'a';
471           foldhi += 'A' - 'a';
472           builder.Mark(foldlo, foldhi);
473         }
474       }
475       // If this Inst is not the last Inst in its list AND the next Inst is
476       // also a ByteRange AND the Insts have the same out, defer the merge.
477       if (!ip->last() &&
478           inst(id+1)->opcode() == kInstByteRange &&
479           ip->out() == inst(id+1)->out())
480         continue;
481       builder.Merge();
482     } else if (ip->opcode() == kInstEmptyWidth) {
483       if (ip->empty() & (kEmptyBeginLine|kEmptyEndLine) &&
484           !marked_line_boundaries) {
485         builder.Mark('\n', '\n');
486         builder.Merge();
487         marked_line_boundaries = true;
488       }
489       if (ip->empty() & (kEmptyWordBoundary|kEmptyNonWordBoundary) &&
490           !marked_word_boundaries) {
491         // We require two batches here: the first for ranges that are word
492         // characters, the second for ranges that are not word characters.
493         for (bool isword : {true, false}) {
494           int j;
495           for (int i = 0; i < 256; i = j) {
496             for (j = i + 1; j < 256 &&
497                             Prog::IsWordChar(static_cast<uint8_t>(i)) ==
498                                 Prog::IsWordChar(static_cast<uint8_t>(j));
499                  j++)
500               ;
501             if (Prog::IsWordChar(static_cast<uint8_t>(i)) == isword)
502               builder.Mark(i, j - 1);
503           }
504           builder.Merge();
505         }
506         marked_word_boundaries = true;
507       }
508     }
509   }
510 
511   builder.Build(bytemap_, &bytemap_range_);
512 
513   if ((0)) {  // For debugging, use trivial bytemap.
514     LOG(ERROR) << "Using trivial bytemap.";
515     for (int i = 0; i < 256; i++)
516       bytemap_[i] = static_cast<uint8_t>(i);
517     bytemap_range_ = 256;
518   }
519 }
520 
521 // Prog::Flatten() implements a graph rewriting algorithm.
522 //
523 // The overall process is similar to epsilon removal, but retains some epsilon
524 // transitions: those from Capture and EmptyWidth instructions; and those from
525 // nullable subexpressions. (The latter avoids quadratic blowup in transitions
526 // in the worst case.) It might be best thought of as Alt instruction elision.
527 //
528 // In conceptual terms, it divides the Prog into "trees" of instructions, then
529 // traverses the "trees" in order to produce "lists" of instructions. A "tree"
530 // is one or more instructions that grow from one "root" instruction to one or
531 // more "leaf" instructions; if a "tree" has exactly one instruction, then the
532 // "root" is also the "leaf". In most cases, a "root" is the successor of some
533 // "leaf" (i.e. the "leaf" instruction's out() returns the "root" instruction)
534 // and is considered a "successor root". A "leaf" can be a ByteRange, Capture,
535 // EmptyWidth or Match instruction. However, this is insufficient for handling
536 // nested nullable subexpressions correctly, so in some cases, a "root" is the
537 // dominator of the instructions reachable from some "successor root" (i.e. it
538 // has an unreachable predecessor) and is considered a "dominator root". Since
539 // only Alt instructions can be "dominator roots" (other instructions would be
540 // "leaves"), only Alt instructions are required to be marked as predecessors.
541 //
542 // Dividing the Prog into "trees" comprises two passes: marking the "successor
543 // roots" and the predecessors; and marking the "dominator roots". Sorting the
544 // "successor roots" by their bytecode offsets enables iteration in order from
545 // greatest to least during the second pass; by working backwards in this case
546 // and flooding the graph no further than "leaves" and already marked "roots",
547 // it becomes possible to mark "dominator roots" without doing excessive work.
548 //
549 // Traversing the "trees" is just iterating over the "roots" in order of their
550 // marking and flooding the graph no further than "leaves" and "roots". When a
551 // "leaf" is reached, the instruction is copied with its successor remapped to
552 // its "root" number. When a "root" is reached, a Nop instruction is generated
553 // with its successor remapped similarly. As each "list" is produced, its last
554 // instruction is marked as such. After all of the "lists" have been produced,
555 // a pass over their instructions remaps their successors to bytecode offsets.
Flatten()556 void Prog::Flatten() {
557   if (did_flatten_)
558     return;
559   did_flatten_ = true;
560 
561   // Scratch structures. It's important that these are reused by functions
562   // that we call in loops because they would thrash the heap otherwise.
563   SparseSet reachable(size());
564   std::vector<int> stk;
565   stk.reserve(size());
566 
567   // First pass: Marks "successor roots" and predecessors.
568   // Builds the mapping from inst-ids to root-ids.
569   SparseArray<int> rootmap(size());
570   SparseArray<int> predmap(size());
571   std::vector<std::vector<int>> predvec;
572   MarkSuccessors(&rootmap, &predmap, &predvec, &reachable, &stk);
573 
574   // Second pass: Marks "dominator roots".
575   SparseArray<int> sorted(rootmap);
576   std::sort(sorted.begin(), sorted.end(), sorted.less);
577   for (SparseArray<int>::const_iterator i = sorted.end() - 1;
578        i != sorted.begin();
579        --i) {
580     if (i->index() != start_unanchored() && i->index() != start())
581       MarkDominator(i->index(), &rootmap, &predmap, &predvec, &reachable, &stk);
582   }
583 
584   // Third pass: Emits "lists". Remaps outs to root-ids.
585   // Builds the mapping from root-ids to flat-ids.
586   std::vector<int> flatmap(rootmap.size());
587   std::vector<Inst> flat;
588   flat.reserve(size());
589   for (SparseArray<int>::const_iterator i = rootmap.begin();
590        i != rootmap.end();
591        ++i) {
592     flatmap[i->value()] = static_cast<int>(flat.size());
593     EmitList(i->index(), &rootmap, &flat, &reachable, &stk);
594     flat.back().set_last();
595     // We have the bounds of the "list", so this is the
596     // most convenient point at which to compute hints.
597     ComputeHints(&flat, flatmap[i->value()], static_cast<int>(flat.size()));
598   }
599 
600   list_count_ = static_cast<int>(flatmap.size());
601   for (int i = 0; i < kNumInst; i++)
602     inst_count_[i] = 0;
603 
604   // Fourth pass: Remaps outs to flat-ids.
605   // Counts instructions by opcode.
606   for (int id = 0; id < static_cast<int>(flat.size()); id++) {
607     Inst* ip = &flat[id];
608     if (ip->opcode() != kInstAltMatch)  // handled in EmitList()
609       ip->set_out(flatmap[ip->out()]);
610     inst_count_[ip->opcode()]++;
611   }
612 
613 #if !defined(NDEBUG)
614   // Address a `-Wunused-but-set-variable' warning from Clang 13.x.
615   size_t total = 0;
616   for (int i = 0; i < kNumInst; i++)
617     total += inst_count_[i];
618   CHECK_EQ(total, flat.size());
619 #endif
620 
621   // Remap start_unanchored and start.
622   if (start_unanchored() == 0) {
623     DCHECK_EQ(start(), 0);
624   } else if (start_unanchored() == start()) {
625     set_start_unanchored(flatmap[1]);
626     set_start(flatmap[1]);
627   } else {
628     set_start_unanchored(flatmap[1]);
629     set_start(flatmap[2]);
630   }
631 
632   // Finally, replace the old instructions with the new instructions.
633   size_ = static_cast<int>(flat.size());
634   inst_ = PODArray<Inst>(size_);
635   memmove(inst_.data(), flat.data(), size_*sizeof inst_[0]);
636 
637   // Populate the list heads for BitState.
638   // 512 instructions limits the memory footprint to 1KiB.
639   if (size_ <= 512) {
640     list_heads_ = PODArray<uint16_t>(size_);
641     // 0xFF makes it more obvious if we try to look up a non-head.
642     memset(list_heads_.data(), 0xFF, size_*sizeof list_heads_[0]);
643     for (int i = 0; i < list_count_; ++i)
644       list_heads_[flatmap[i]] = i;
645   }
646 
647   // BitState allocates a bitmap of size list_count_ * (text.size()+1)
648   // for tracking pairs of possibilities that it has already explored.
649   const size_t kBitStateBitmapMaxSize = 256*1024;  // max size in bits
650   bit_state_text_max_size_ = kBitStateBitmapMaxSize / list_count_ - 1;
651 }
652 
MarkSuccessors(SparseArray<int> * rootmap,SparseArray<int> * predmap,std::vector<std::vector<int>> * predvec,SparseSet * reachable,std::vector<int> * stk)653 void Prog::MarkSuccessors(SparseArray<int>* rootmap,
654                           SparseArray<int>* predmap,
655                           std::vector<std::vector<int>>* predvec,
656                           SparseSet* reachable, std::vector<int>* stk) {
657   // Mark the kInstFail instruction.
658   rootmap->set_new(0, rootmap->size());
659 
660   // Mark the start_unanchored and start instructions.
661   if (!rootmap->has_index(start_unanchored()))
662     rootmap->set_new(start_unanchored(), rootmap->size());
663   if (!rootmap->has_index(start()))
664     rootmap->set_new(start(), rootmap->size());
665 
666   reachable->clear();
667   stk->clear();
668   stk->push_back(start_unanchored());
669   while (!stk->empty()) {
670     int id = stk->back();
671     stk->pop_back();
672   Loop:
673     if (reachable->contains(id))
674       continue;
675     reachable->insert_new(id);
676 
677     Inst* ip = inst(id);
678     switch (ip->opcode()) {
679       default:
680         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
681         break;
682 
683       case kInstAltMatch:
684       case kInstAlt:
685         // Mark this instruction as a predecessor of each out.
686         for (int out : {ip->out(), ip->out1()}) {
687           if (!predmap->has_index(out)) {
688             predmap->set_new(out, static_cast<int>(predvec->size()));
689             predvec->emplace_back();
690           }
691           (*predvec)[predmap->get_existing(out)].emplace_back(id);
692         }
693         stk->push_back(ip->out1());
694         id = ip->out();
695         goto Loop;
696 
697       case kInstByteRange:
698       case kInstCapture:
699       case kInstEmptyWidth:
700         // Mark the out of this instruction as a "root".
701         if (!rootmap->has_index(ip->out()))
702           rootmap->set_new(ip->out(), rootmap->size());
703         id = ip->out();
704         goto Loop;
705 
706       case kInstNop:
707         id = ip->out();
708         goto Loop;
709 
710       case kInstMatch:
711       case kInstFail:
712         break;
713     }
714   }
715 }
716 
MarkDominator(int root,SparseArray<int> * rootmap,SparseArray<int> * predmap,std::vector<std::vector<int>> * predvec,SparseSet * reachable,std::vector<int> * stk)717 void Prog::MarkDominator(int root, SparseArray<int>* rootmap,
718                          SparseArray<int>* predmap,
719                          std::vector<std::vector<int>>* predvec,
720                          SparseSet* reachable, std::vector<int>* stk) {
721   reachable->clear();
722   stk->clear();
723   stk->push_back(root);
724   while (!stk->empty()) {
725     int id = stk->back();
726     stk->pop_back();
727   Loop:
728     if (reachable->contains(id))
729       continue;
730     reachable->insert_new(id);
731 
732     if (id != root && rootmap->has_index(id)) {
733       // We reached another "tree" via epsilon transition.
734       continue;
735     }
736 
737     Inst* ip = inst(id);
738     switch (ip->opcode()) {
739       default:
740         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
741         break;
742 
743       case kInstAltMatch:
744       case kInstAlt:
745         stk->push_back(ip->out1());
746         id = ip->out();
747         goto Loop;
748 
749       case kInstByteRange:
750       case kInstCapture:
751       case kInstEmptyWidth:
752         break;
753 
754       case kInstNop:
755         id = ip->out();
756         goto Loop;
757 
758       case kInstMatch:
759       case kInstFail:
760         break;
761     }
762   }
763 
764   for (SparseSet::const_iterator i = reachable->begin();
765        i != reachable->end();
766        ++i) {
767     int id = *i;
768     if (predmap->has_index(id)) {
769       for (int pred : (*predvec)[predmap->get_existing(id)]) {
770         if (!reachable->contains(pred)) {
771           // id has a predecessor that cannot be reached from root!
772           // Therefore, id must be a "root" too - mark it as such.
773           if (!rootmap->has_index(id))
774             rootmap->set_new(id, rootmap->size());
775         }
776       }
777     }
778   }
779 }
780 
EmitList(int root,SparseArray<int> * rootmap,std::vector<Inst> * flat,SparseSet * reachable,std::vector<int> * stk)781 void Prog::EmitList(int root, SparseArray<int>* rootmap,
782                     std::vector<Inst>* flat,
783                     SparseSet* reachable, std::vector<int>* stk) {
784   reachable->clear();
785   stk->clear();
786   stk->push_back(root);
787   while (!stk->empty()) {
788     int id = stk->back();
789     stk->pop_back();
790   Loop:
791     if (reachable->contains(id))
792       continue;
793     reachable->insert_new(id);
794 
795     if (id != root && rootmap->has_index(id)) {
796       // We reached another "tree" via epsilon transition. Emit a kInstNop
797       // instruction so that the Prog does not become quadratically larger.
798       flat->emplace_back();
799       flat->back().set_opcode(kInstNop);
800       flat->back().set_out(rootmap->get_existing(id));
801       continue;
802     }
803 
804     Inst* ip = inst(id);
805     switch (ip->opcode()) {
806       default:
807         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
808         break;
809 
810       case kInstAltMatch:
811         flat->emplace_back();
812         flat->back().set_opcode(kInstAltMatch);
813         flat->back().set_out(static_cast<int>(flat->size()));
814         flat->back().out1_ = static_cast<uint32_t>(flat->size())+1;
815         ABSL_FALLTHROUGH_INTENDED;
816 
817       case kInstAlt:
818         stk->push_back(ip->out1());
819         id = ip->out();
820         goto Loop;
821 
822       case kInstByteRange:
823       case kInstCapture:
824       case kInstEmptyWidth:
825         flat->emplace_back();
826         memmove(&flat->back(), ip, sizeof *ip);
827         flat->back().set_out(rootmap->get_existing(ip->out()));
828         break;
829 
830       case kInstNop:
831         id = ip->out();
832         goto Loop;
833 
834       case kInstMatch:
835       case kInstFail:
836         flat->emplace_back();
837         memmove(&flat->back(), ip, sizeof *ip);
838         break;
839     }
840   }
841 }
842 
843 // For each ByteRange instruction in [begin, end), computes a hint to execution
844 // engines: the delta to the next instruction (in flat) worth exploring iff the
845 // current instruction matched.
846 //
847 // Implements a coloring algorithm related to ByteMapBuilder, but in this case,
848 // colors are instructions and recoloring ranges precisely identifies conflicts
849 // between instructions. Iterating backwards over [begin, end) is guaranteed to
850 // identify the nearest conflict (if any) with only linear complexity.
ComputeHints(std::vector<Inst> * flat,int begin,int end)851 void Prog::ComputeHints(std::vector<Inst>* flat, int begin, int end) {
852   Bitmap256 splits;
853   int colors[256];
854 
855   bool dirty = false;
856   for (int id = end; id >= begin; --id) {
857     if (id == end ||
858         (*flat)[id].opcode() != kInstByteRange) {
859       if (dirty) {
860         dirty = false;
861         splits.Clear();
862       }
863       splits.Set(255);
864       colors[255] = id;
865       // At this point, the [0-255] range is colored with id.
866       // Thus, hints cannot point beyond id; and if id == end,
867       // hints that would have pointed to id will be 0 instead.
868       continue;
869     }
870     dirty = true;
871 
872     // We recolor the [lo-hi] range with id. Note that first ratchets backwards
873     // from end to the nearest conflict (if any) during recoloring.
874     int first = end;
875     auto Recolor = [&](int lo, int hi) {
876       // Like ByteMapBuilder, we split at lo-1 and at hi.
877       --lo;
878 
879       if (0 <= lo && !splits.Test(lo)) {
880         splits.Set(lo);
881         int next = splits.FindNextSetBit(lo+1);
882         colors[lo] = colors[next];
883       }
884       if (!splits.Test(hi)) {
885         splits.Set(hi);
886         int next = splits.FindNextSetBit(hi+1);
887         colors[hi] = colors[next];
888       }
889 
890       int c = lo+1;
891       while (c < 256) {
892         int next = splits.FindNextSetBit(c);
893         // Ratchet backwards...
894         first = std::min(first, colors[next]);
895         // Recolor with id - because it's the new nearest conflict!
896         colors[next] = id;
897         if (next == hi)
898           break;
899         c = next+1;
900       }
901     };
902 
903     Inst* ip = &(*flat)[id];
904     int lo = ip->lo();
905     int hi = ip->hi();
906     Recolor(lo, hi);
907     if (ip->foldcase() && lo <= 'z' && hi >= 'a') {
908       int foldlo = lo;
909       int foldhi = hi;
910       if (foldlo < 'a')
911         foldlo = 'a';
912       if (foldhi > 'z')
913         foldhi = 'z';
914       if (foldlo <= foldhi) {
915         foldlo += 'A' - 'a';
916         foldhi += 'A' - 'a';
917         Recolor(foldlo, foldhi);
918       }
919     }
920 
921     if (first != end) {
922       uint16_t hint = static_cast<uint16_t>(std::min(first - id, 32767));
923       ip->hint_foldcase_ |= hint<<1;
924     }
925   }
926 }
927 
928 // The final state will always be this, which frees up a register for the hot
929 // loop and thus avoids the spilling that can occur when building with Clang.
930 static const size_t kShiftDFAFinal = 9;
931 
932 // This function takes the prefix as std::string (i.e. not const std::string&
933 // as normal) because it's going to clobber it, so a temporary is convenient.
BuildShiftDFA(std::string prefix)934 static uint64_t* BuildShiftDFA(std::string prefix) {
935   // This constant is for convenience now and also for correctness later when
936   // we clobber the prefix, but still need to know how long it was initially.
937   const size_t size = prefix.size();
938 
939   // Construct the NFA.
940   // The table is indexed by input byte; each element is a bitfield of states
941   // reachable by the input byte. Given a bitfield of the current states, the
942   // bitfield of states reachable from those is - for this specific purpose -
943   // always ((ncurr << 1) | 1). Intersecting the reachability bitfields gives
944   // the bitfield of the next states reached by stepping over the input byte.
945   // Credits for this technique: the Hyperscan paper by Geoff Langdale et al.
946   uint16_t nfa[256]{};
947   for (size_t i = 0; i < size; ++i) {
948     uint8_t b = prefix[i];
949     nfa[b] |= 1 << (i+1);
950   }
951   // This is the `\C*?` for unanchored search.
952   for (int b = 0; b < 256; ++b)
953     nfa[b] |= 1;
954 
955   // This maps from DFA state to NFA states; the reverse mapping is used when
956   // recording transitions and gets implemented with plain old linear search.
957   // The "Shift DFA" technique limits this to ten states when using uint64_t;
958   // to allow for the initial state, we use at most nine bytes of the prefix.
959   // That same limit is also why uint16_t is sufficient for the NFA bitfield.
960   uint16_t states[kShiftDFAFinal+1]{};
961   states[0] = 1;
962   for (size_t dcurr = 0; dcurr < size; ++dcurr) {
963     uint8_t b = prefix[dcurr];
964     uint16_t ncurr = states[dcurr];
965     uint16_t nnext = nfa[b] & ((ncurr << 1) | 1);
966     size_t dnext = dcurr+1;
967     if (dnext == size)
968       dnext = kShiftDFAFinal;
969     states[dnext] = nnext;
970   }
971 
972   // Sort and unique the bytes of the prefix to avoid repeating work while we
973   // record transitions. This clobbers the prefix, but it's no longer needed.
974   std::sort(prefix.begin(), prefix.end());
975   prefix.erase(std::unique(prefix.begin(), prefix.end()), prefix.end());
976 
977   // Construct the DFA.
978   // The table is indexed by input byte; each element is effectively a packed
979   // array of uint6_t; each array value will be multiplied by six in order to
980   // avoid having to do so later in the hot loop as well as masking/shifting.
981   // Credits for this technique: "Shift-based DFAs" on GitHub by Per Vognsen.
982   uint64_t* dfa = new uint64_t[256]{};
983   // Record a transition from each state for each of the bytes of the prefix.
984   // Note that all other input bytes go back to the initial state by default.
985   for (size_t dcurr = 0; dcurr < size; ++dcurr) {
986     for (uint8_t b : prefix) {
987       uint16_t ncurr = states[dcurr];
988       uint16_t nnext = nfa[b] & ((ncurr << 1) | 1);
989       size_t dnext = 0;
990       while (states[dnext] != nnext)
991         ++dnext;
992       dfa[b] |= static_cast<uint64_t>(dnext * 6) << (dcurr * 6);
993       // Convert ASCII letters to uppercase and record the extra transitions.
994       // Note that ASCII letters are guaranteed to be lowercase at this point
995       // because that's how the parser normalises them. #FunFact: 'k' and 's'
996       // match U+212A and U+017F, respectively, so they won't occur here when
997       // using UTF-8 encoding because the parser will emit character classes.
998       if ('a' <= b && b <= 'z') {
999         b -= 'a' - 'A';
1000         dfa[b] |= static_cast<uint64_t>(dnext * 6) << (dcurr * 6);
1001       }
1002     }
1003   }
1004   // This lets the final state "saturate", which will matter for performance:
1005   // in the hot loop, we check for a match only at the end of each iteration,
1006   // so we must keep signalling the match until we get around to checking it.
1007   for (int b = 0; b < 256; ++b)
1008     dfa[b] |= static_cast<uint64_t>(kShiftDFAFinal * 6) << (kShiftDFAFinal * 6);
1009 
1010   return dfa;
1011 }
1012 
ConfigurePrefixAccel(const std::string & prefix,bool prefix_foldcase)1013 void Prog::ConfigurePrefixAccel(const std::string& prefix,
1014                                 bool prefix_foldcase) {
1015   prefix_foldcase_ = prefix_foldcase;
1016   prefix_size_ = prefix.size();
1017   if (prefix_foldcase_) {
1018     // Use PrefixAccel_ShiftDFA().
1019     // ... and no more than nine bytes of the prefix. (See above for details.)
1020     prefix_size_ = std::min(prefix_size_, kShiftDFAFinal);
1021     prefix_dfa_ = BuildShiftDFA(prefix.substr(0, prefix_size_));
1022   } else if (prefix_size_ != 1) {
1023     // Use PrefixAccel_FrontAndBack().
1024     prefix_front_ = prefix.front();
1025     prefix_back_ = prefix.back();
1026   } else {
1027     // Use memchr(3).
1028     prefix_front_ = prefix.front();
1029   }
1030 }
1031 
PrefixAccel_ShiftDFA(const void * data,size_t size)1032 const void* Prog::PrefixAccel_ShiftDFA(const void* data, size_t size) {
1033   if (size < prefix_size_)
1034     return NULL;
1035 
1036   uint64_t curr = 0;
1037 
1038   // At the time of writing, rough benchmarks on a Broadwell machine showed
1039   // that this unroll factor (i.e. eight) achieves a speedup factor of two.
1040   if (size >= 8) {
1041     const uint8_t* p = reinterpret_cast<const uint8_t*>(data);
1042     const uint8_t* endp = p + (size&~7);
1043     do {
1044       uint8_t b0 = p[0];
1045       uint8_t b1 = p[1];
1046       uint8_t b2 = p[2];
1047       uint8_t b3 = p[3];
1048       uint8_t b4 = p[4];
1049       uint8_t b5 = p[5];
1050       uint8_t b6 = p[6];
1051       uint8_t b7 = p[7];
1052 
1053       uint64_t next0 = prefix_dfa_[b0];
1054       uint64_t next1 = prefix_dfa_[b1];
1055       uint64_t next2 = prefix_dfa_[b2];
1056       uint64_t next3 = prefix_dfa_[b3];
1057       uint64_t next4 = prefix_dfa_[b4];
1058       uint64_t next5 = prefix_dfa_[b5];
1059       uint64_t next6 = prefix_dfa_[b6];
1060       uint64_t next7 = prefix_dfa_[b7];
1061 
1062       uint64_t curr0 = next0 >> (curr  & 63);
1063       uint64_t curr1 = next1 >> (curr0 & 63);
1064       uint64_t curr2 = next2 >> (curr1 & 63);
1065       uint64_t curr3 = next3 >> (curr2 & 63);
1066       uint64_t curr4 = next4 >> (curr3 & 63);
1067       uint64_t curr5 = next5 >> (curr4 & 63);
1068       uint64_t curr6 = next6 >> (curr5 & 63);
1069       uint64_t curr7 = next7 >> (curr6 & 63);
1070 
1071       if ((curr7 & 63) == kShiftDFAFinal * 6) {
1072         // At the time of writing, using the same masking subexpressions from
1073         // the preceding lines caused Clang to clutter the hot loop computing
1074         // them - even though they aren't actually needed for shifting! Hence
1075         // these rewritten conditions, which achieve a speedup factor of two.
1076         if (((curr7-curr0) & 63) == 0) return p+1-prefix_size_;
1077         if (((curr7-curr1) & 63) == 0) return p+2-prefix_size_;
1078         if (((curr7-curr2) & 63) == 0) return p+3-prefix_size_;
1079         if (((curr7-curr3) & 63) == 0) return p+4-prefix_size_;
1080         if (((curr7-curr4) & 63) == 0) return p+5-prefix_size_;
1081         if (((curr7-curr5) & 63) == 0) return p+6-prefix_size_;
1082         if (((curr7-curr6) & 63) == 0) return p+7-prefix_size_;
1083         if (((curr7-curr7) & 63) == 0) return p+8-prefix_size_;
1084       }
1085 
1086       curr = curr7;
1087       p += 8;
1088     } while (p != endp);
1089     data = p;
1090     size = size&7;
1091   }
1092 
1093   const uint8_t* p = reinterpret_cast<const uint8_t*>(data);
1094   const uint8_t* endp = p + size;
1095   while (p != endp) {
1096     uint8_t b = *p++;
1097     uint64_t next = prefix_dfa_[b];
1098     curr = next >> (curr & 63);
1099     if ((curr & 63) == kShiftDFAFinal * 6)
1100       return p-prefix_size_;
1101   }
1102   return NULL;
1103 }
1104 
1105 #if defined(__AVX2__)
1106 // Finds the least significant non-zero bit in n.
FindLSBSet(uint32_t n)1107 static int FindLSBSet(uint32_t n) {
1108   DCHECK_NE(n, 0);
1109 #if defined(__GNUC__)
1110   return __builtin_ctz(n);
1111 #elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
1112   unsigned long c;
1113   _BitScanForward(&c, n);
1114   return static_cast<int>(c);
1115 #else
1116   int c = 31;
1117   for (int shift = 1 << 4; shift != 0; shift >>= 1) {
1118     uint32_t word = n << shift;
1119     if (word != 0) {
1120       n = word;
1121       c -= shift;
1122     }
1123   }
1124   return c;
1125 #endif
1126 }
1127 #endif
1128 
PrefixAccel_FrontAndBack(const void * data,size_t size)1129 const void* Prog::PrefixAccel_FrontAndBack(const void* data, size_t size) {
1130   DCHECK_GE(prefix_size_, 2);
1131   if (size < prefix_size_)
1132     return NULL;
1133   // Don't bother searching the last prefix_size_-1 bytes for prefix_front_.
1134   // This also means that probing for prefix_back_ doesn't go out of bounds.
1135   size -= prefix_size_-1;
1136 
1137 #if defined(__AVX2__)
1138   // Use AVX2 to look for prefix_front_ and prefix_back_ 32 bytes at a time.
1139   if (size >= sizeof(__m256i)) {
1140     const __m256i* fp = reinterpret_cast<const __m256i*>(
1141         reinterpret_cast<const char*>(data));
1142     const __m256i* bp = reinterpret_cast<const __m256i*>(
1143         reinterpret_cast<const char*>(data) + prefix_size_-1);
1144     const __m256i* endfp = fp + size/sizeof(__m256i);
1145     const __m256i f_set1 = _mm256_set1_epi8(prefix_front_);
1146     const __m256i b_set1 = _mm256_set1_epi8(prefix_back_);
1147     do {
1148       const __m256i f_loadu = _mm256_loadu_si256(fp++);
1149       const __m256i b_loadu = _mm256_loadu_si256(bp++);
1150       const __m256i f_cmpeq = _mm256_cmpeq_epi8(f_set1, f_loadu);
1151       const __m256i b_cmpeq = _mm256_cmpeq_epi8(b_set1, b_loadu);
1152       const int fb_testz = _mm256_testz_si256(f_cmpeq, b_cmpeq);
1153       if (fb_testz == 0) {  // ZF: 1 means zero, 0 means non-zero.
1154         const __m256i fb_and = _mm256_and_si256(f_cmpeq, b_cmpeq);
1155         const int fb_movemask = _mm256_movemask_epi8(fb_and);
1156         const int fb_ctz = FindLSBSet(fb_movemask);
1157         return reinterpret_cast<const char*>(fp-1) + fb_ctz;
1158       }
1159     } while (fp != endfp);
1160     data = fp;
1161     size = size%sizeof(__m256i);
1162   }
1163 #endif
1164 
1165   const char* p0 = reinterpret_cast<const char*>(data);
1166   for (const char* p = p0;; p++) {
1167     DCHECK_GE(size, static_cast<size_t>(p-p0));
1168     p = reinterpret_cast<const char*>(memchr(p, prefix_front_, size - (p-p0)));
1169     if (p == NULL || p[prefix_size_-1] == prefix_back_)
1170       return p;
1171   }
1172 }
1173 
1174 }  // namespace re2
1175