xref: /aosp_15_r20/external/zstd/contrib/match_finders/zstd_edist.c (revision 01826a4963a0d8a59bc3812d29bdf0fb76416722)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 /*-*************************************
12 *  Dependencies
13 ***************************************/
14 
15 /* Currently relies on qsort when combining contiguous matches. This can probably
16  * be avoided but would require changes to the algorithm. The qsort is far from
17  * the bottleneck in this algorithm even for medium sized files so it's probably
18  * not worth trying to address */
19 #include <stdlib.h>
20 #include <assert.h>
21 
22 #include "zstd_edist.h"
23 #include "mem.h"
24 
25 /*-*************************************
26 *  Constants
27 ***************************************/
28 
29 /* Just a sential for the entries of the diagonal matrix */
30 #define ZSTD_EDIST_DIAG_MAX (S32)(1 << 30)
31 
32 /* How large should a snake be to be considered a 'big' snake.
33  * For an explanation of what a 'snake' is with respect to the
34  * edit distance matrix, see the linked paper in zstd_edist.h */
35 #define ZSTD_EDIST_SNAKE_THRESH 20
36 
37 /* After how many iterations should we start to use the heuristic
38  * based on 'big' snakes */
39 #define ZSTD_EDIST_SNAKE_ITER_THRESH 200
40 
41 /* After how many iterations should be just give up and take
42  * the best available edit script for this round */
43 #define ZSTD_EDIST_EXPENSIVE_THRESH 1024
44 
45 /*-*************************************
46 *  Structures
47 ***************************************/
48 
49 typedef struct {
50     U32 dictIdx;
51     U32 srcIdx;
52     U32 matchLength;
53 } ZSTD_eDist_match;
54 
55 typedef struct {
56     const BYTE* dict;
57     const BYTE* src;
58     size_t dictSize;
59     size_t srcSize;
60     S32* forwardDiag;            /* Entries of the forward diagonal stored here */
61     S32* backwardDiag;           /* Entries of the backward diagonal stored here.
62                                   *   Note: this buffer and the 'forwardDiag' buffer
63                                   *   are contiguous. See the ZSTD_eDist_genSequences */
64     ZSTD_eDist_match* matches;   /* Accumulate matches of length 1 in this buffer.
65                                   *   In a subsequence post-processing step, we combine
66                                   *   contiguous matches. */
67     U32 nbMatches;
68 } ZSTD_eDist_state;
69 
70 typedef struct {
71     S32 dictMid;           /* The mid diagonal for the dictionary */
72     S32 srcMid;            /* The mid diagonal for the source */
73     int lowUseHeuristics;  /* Should we use heuristics for the low part */
74     int highUseHeuristics; /* Should we use heuristics for the high part */
75 } ZSTD_eDist_partition;
76 
77 /*-*************************************
78 *  Internal
79 ***************************************/
80 
ZSTD_eDist_diag(ZSTD_eDist_state * state,ZSTD_eDist_partition * partition,S32 dictLow,S32 dictHigh,S32 srcLow,S32 srcHigh,int useHeuristics)81 static void ZSTD_eDist_diag(ZSTD_eDist_state* state,
82                     ZSTD_eDist_partition* partition,
83                     S32 dictLow, S32 dictHigh, S32 srcLow,
84                     S32 srcHigh, int useHeuristics)
85 {
86     S32* const forwardDiag = state->forwardDiag;
87     S32* const backwardDiag = state->backwardDiag;
88     const BYTE* const dict = state->dict;
89     const BYTE* const src = state->src;
90 
91     S32 const diagMin = dictLow - srcHigh;
92     S32 const diagMax = dictHigh - srcLow;
93     S32 const forwardMid = dictLow - srcLow;
94     S32 const backwardMid = dictHigh - srcHigh;
95 
96     S32 forwardMin = forwardMid;
97     S32 forwardMax = forwardMid;
98     S32 backwardMin = backwardMid;
99     S32 backwardMax = backwardMid;
100     int odd = (forwardMid - backwardMid) & 1;
101     U32 iterations;
102 
103     forwardDiag[forwardMid] = dictLow;
104     backwardDiag[backwardMid] = dictHigh;
105 
106     /* Main loop for updating diag entries. Unless useHeuristics is
107      * set to false, this loop will run until it finds the minimal
108      * edit script */
109     for (iterations = 1;;iterations++) {
110         S32 diag;
111         int bigSnake = 0;
112 
113         if (forwardMin > diagMin) {
114             forwardMin--;
115             forwardDiag[forwardMin - 1] = -1;
116         } else {
117             forwardMin++;
118         }
119 
120         if (forwardMax < diagMax) {
121             forwardMax++;
122             forwardDiag[forwardMax + 1] = -1;
123         } else {
124             forwardMax--;
125         }
126 
127         for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
128             S32 dictIdx;
129             S32 srcIdx;
130             S32 low = forwardDiag[diag - 1];
131             S32 high = forwardDiag[diag + 1];
132             S32 dictIdx0 = low < high ? high : low + 1;
133 
134             for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
135                 dictIdx < dictHigh && srcIdx < srcHigh && dict[dictIdx] == src[srcIdx];
136                 dictIdx++, srcIdx++) continue;
137 
138             if (dictIdx - dictIdx0 > ZSTD_EDIST_SNAKE_THRESH)
139                 bigSnake = 1;
140 
141             forwardDiag[diag] = dictIdx;
142 
143             if (odd && backwardMin <= diag && diag <= backwardMax && backwardDiag[diag] <= dictIdx) {
144                 partition->dictMid = dictIdx;
145                 partition->srcMid = srcIdx;
146                 partition->lowUseHeuristics = 0;
147                 partition->highUseHeuristics = 0;
148                 return;
149             }
150         }
151 
152         if (backwardMin > diagMin) {
153             backwardMin--;
154             backwardDiag[backwardMin - 1] = ZSTD_EDIST_DIAG_MAX;
155         } else {
156             backwardMin++;
157         }
158 
159         if (backwardMax < diagMax) {
160             backwardMax++;
161             backwardDiag[backwardMax + 1] = ZSTD_EDIST_DIAG_MAX;
162         } else {
163             backwardMax--;
164         }
165 
166 
167         for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
168             S32 dictIdx;
169             S32 srcIdx;
170             S32 low = backwardDiag[diag - 1];
171             S32 high = backwardDiag[diag + 1];
172             S32 dictIdx0 = low < high ? low : high - 1;
173 
174             for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
175                 dictLow < dictIdx && srcLow < srcIdx && dict[dictIdx - 1] == src[srcIdx - 1];
176                 dictIdx--, srcIdx--) continue;
177 
178             if (dictIdx0 - dictIdx > ZSTD_EDIST_SNAKE_THRESH)
179                 bigSnake = 1;
180 
181             backwardDiag[diag] = dictIdx;
182 
183             if (!odd && forwardMin <= diag && diag <= forwardMax && dictIdx <= forwardDiag[diag]) {
184                 partition->dictMid = dictIdx;
185                 partition->srcMid = srcIdx;
186                 partition->lowUseHeuristics = 0;
187                 partition->highUseHeuristics = 0;
188                 return;
189             }
190         }
191 
192         if (!useHeuristics)
193             continue;
194 
195         /* Everything under this point is a heuristic. Using these will
196          * substantially speed up the match finding. In some cases, taking
197          * the total match finding time from several minutes to seconds.
198          * Of course, the caveat is that the edit script found may no longer
199          * be optimal */
200 
201         /* Big snake heuristic */
202         if (iterations > ZSTD_EDIST_SNAKE_ITER_THRESH && bigSnake) {
203             {
204                 S32 best = 0;
205 
206                 for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
207                     S32 diagDiag = diag - forwardMid;
208                     S32 dictIdx = forwardDiag[diag];
209                     S32 srcIdx = dictIdx - diag;
210                     S32 v = (dictIdx - dictLow) * 2 - diagDiag;
211 
212                     if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
213                         if (v > best
214                           && dictLow + ZSTD_EDIST_SNAKE_THRESH <= dictIdx && dictIdx <= dictHigh
215                           && srcLow + ZSTD_EDIST_SNAKE_THRESH <= srcIdx && srcIdx <= srcHigh) {
216                             S32 k;
217                             for (k = 1; dict[dictIdx - k] == src[srcIdx - k]; k++) {
218                                 if (k == ZSTD_EDIST_SNAKE_THRESH) {
219                                     best = v;
220                                     partition->dictMid = dictIdx;
221                                     partition->srcMid = srcIdx;
222                                     break;
223                                 }
224                             }
225                         }
226                     }
227                 }
228 
229                 if (best > 0) {
230                     partition->lowUseHeuristics = 0;
231                     partition->highUseHeuristics = 1;
232                     return;
233                 }
234             }
235 
236             {
237                 S32 best = 0;
238 
239                 for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
240                     S32 diagDiag = diag - backwardMid;
241                     S32 dictIdx = backwardDiag[diag];
242                     S32 srcIdx = dictIdx - diag;
243                     S32 v = (dictHigh - dictIdx) * 2 + diagDiag;
244 
245                     if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
246                         if (v > best
247                           && dictLow < dictIdx && dictIdx <= dictHigh - ZSTD_EDIST_SNAKE_THRESH
248                           && srcLow < srcIdx && srcIdx <= srcHigh - ZSTD_EDIST_SNAKE_THRESH) {
249                             int k;
250                             for (k = 0; dict[dictIdx + k] == src[srcIdx + k]; k++) {
251                                 if (k == ZSTD_EDIST_SNAKE_THRESH - 1) {
252                                     best = v;
253                                     partition->dictMid = dictIdx;
254                                     partition->srcMid = srcIdx;
255                                     break;
256                                 }
257                             }
258                         }
259                     }
260                 }
261 
262                 if (best > 0) {
263                     partition->lowUseHeuristics = 1;
264                     partition->highUseHeuristics = 0;
265                     return;
266                 }
267             }
268         }
269 
270         /* More general 'too expensive' heuristic */
271         if (iterations >= ZSTD_EDIST_EXPENSIVE_THRESH) {
272             S32 forwardDictSrcBest;
273             S32 forwardDictBest = 0;
274             S32 backwardDictSrcBest;
275             S32 backwardDictBest = 0;
276 
277             forwardDictSrcBest = -1;
278             for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
279                 S32 dictIdx = MIN(forwardDiag[diag], dictHigh);
280                 S32 srcIdx = dictIdx - diag;
281 
282                 if (srcHigh < srcIdx) {
283                     dictIdx = srcHigh + diag;
284                     srcIdx = srcHigh;
285                 }
286 
287                 if (forwardDictSrcBest < dictIdx + srcIdx) {
288                     forwardDictSrcBest = dictIdx + srcIdx;
289                     forwardDictBest = dictIdx;
290                 }
291             }
292 
293             backwardDictSrcBest = ZSTD_EDIST_DIAG_MAX;
294             for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
295                 S32 dictIdx = MAX(dictLow, backwardDiag[diag]);
296                 S32 srcIdx = dictIdx - diag;
297 
298                 if (srcIdx < srcLow) {
299                     dictIdx = srcLow + diag;
300                     srcIdx = srcLow;
301                 }
302 
303                 if (dictIdx + srcIdx < backwardDictSrcBest) {
304                     backwardDictSrcBest = dictIdx + srcIdx;
305                     backwardDictBest = dictIdx;
306                 }
307             }
308 
309             if ((dictHigh + srcHigh) - backwardDictSrcBest < forwardDictSrcBest - (dictLow + srcLow)) {
310                 partition->dictMid = forwardDictBest;
311                 partition->srcMid = forwardDictSrcBest - forwardDictBest;
312                 partition->lowUseHeuristics = 0;
313                 partition->highUseHeuristics = 1;
314             } else {
315                 partition->dictMid = backwardDictBest;
316                 partition->srcMid = backwardDictSrcBest - backwardDictBest;
317                 partition->lowUseHeuristics = 1;
318                 partition->highUseHeuristics = 0;
319             }
320             return;
321         }
322     }
323 }
324 
ZSTD_eDist_insertMatch(ZSTD_eDist_state * state,S32 const dictIdx,S32 const srcIdx)325 static void ZSTD_eDist_insertMatch(ZSTD_eDist_state* state,
326                     S32 const dictIdx, S32 const srcIdx)
327 {
328     state->matches[state->nbMatches].dictIdx = dictIdx;
329     state->matches[state->nbMatches].srcIdx = srcIdx;
330     state->matches[state->nbMatches].matchLength = 1;
331     state->nbMatches++;
332 }
333 
ZSTD_eDist_compare(ZSTD_eDist_state * state,S32 dictLow,S32 dictHigh,S32 srcLow,S32 srcHigh,int useHeuristics)334 static int ZSTD_eDist_compare(ZSTD_eDist_state* state,
335                     S32 dictLow, S32 dictHigh, S32 srcLow,
336                     S32 srcHigh, int useHeuristics)
337 {
338     const BYTE* const dict = state->dict;
339     const BYTE* const src = state->src;
340 
341     /* Found matches while traversing from the low end */
342     while (dictLow < dictHigh && srcLow < srcHigh && dict[dictLow] == src[srcLow]) {
343         ZSTD_eDist_insertMatch(state, dictLow, srcLow);
344         dictLow++;
345         srcLow++;
346     }
347 
348     /* Found matches while traversing from the high end */
349     while (dictLow < dictHigh && srcLow < srcHigh && dict[dictHigh - 1] == src[srcHigh - 1]) {
350         ZSTD_eDist_insertMatch(state, dictHigh - 1, srcHigh - 1);
351         dictHigh--;
352         srcHigh--;
353     }
354 
355     /* If the low and high end end up touching. If we wanted to make
356      * note of the differences like most diffing algorithms do, we would
357      * do so here. In our case, we're only concerned with matches
358      * Note: if you wanted to find the edit distance of the algorithm,
359      *   you could just accumulate the cost for an insertion/deletion
360      *   below. */
361     if (dictLow == dictHigh) {
362         while (srcLow < srcHigh) {
363             /* Reaching this point means inserting src[srcLow] into
364              * the current position of dict */
365             srcLow++;
366         }
367     } else if (srcLow == srcHigh) {
368         while (dictLow < dictHigh) {
369             /* Reaching this point means deleting dict[dictLow] from
370              * the current position of dict */
371             dictLow++;
372         }
373     } else {
374         ZSTD_eDist_partition partition;
375         partition.dictMid = 0;
376         partition.srcMid = 0;
377         ZSTD_eDist_diag(state, &partition, dictLow, dictHigh,
378             srcLow, srcHigh, useHeuristics);
379         if (ZSTD_eDist_compare(state, dictLow, partition.dictMid,
380           srcLow, partition.srcMid, partition.lowUseHeuristics))
381             return 1;
382         if (ZSTD_eDist_compare(state, partition.dictMid, dictHigh,
383           partition.srcMid, srcHigh, partition.highUseHeuristics))
384             return 1;
385     }
386 
387     return 0;
388 }
389 
ZSTD_eDist_matchComp(const void * p,const void * q)390 static int ZSTD_eDist_matchComp(const void* p, const void* q)
391 {
392     S32 const l = ((ZSTD_eDist_match*)p)->srcIdx;
393     S32 const r = ((ZSTD_eDist_match*)q)->srcIdx;
394     return (l - r);
395 }
396 
397 /* The matches from the approach above will all be of the form
398  * (dictIdx, srcIdx, 1). This method combines contiguous matches
399  * of length MINMATCH or greater. Matches less than MINMATCH
400  * are discarded */
ZSTD_eDist_combineMatches(ZSTD_eDist_state * state)401 static void ZSTD_eDist_combineMatches(ZSTD_eDist_state* state)
402 {
403     /* Create a new buffer to put the combined matches into
404      * and memcpy to state->matches after */
405     ZSTD_eDist_match* combinedMatches =
406         ZSTD_malloc(state->nbMatches * sizeof(ZSTD_eDist_match),
407         ZSTD_defaultCMem);
408 
409     U32 nbCombinedMatches = 1;
410     size_t i;
411 
412     /* Make sure that the srcIdx and dictIdx are in sorted order.
413      * The combination step won't work otherwise */
414     qsort(state->matches, state->nbMatches, sizeof(ZSTD_eDist_match), ZSTD_eDist_matchComp);
415 
416     memcpy(combinedMatches, state->matches, sizeof(ZSTD_eDist_match));
417     for (i = 1; i < state->nbMatches; i++) {
418         ZSTD_eDist_match const match = state->matches[i];
419         ZSTD_eDist_match const combinedMatch =
420             combinedMatches[nbCombinedMatches - 1];
421         if (combinedMatch.srcIdx + combinedMatch.matchLength == match.srcIdx &&
422           combinedMatch.dictIdx + combinedMatch.matchLength == match.dictIdx) {
423             combinedMatches[nbCombinedMatches - 1].matchLength++;
424         } else {
425             /* Discard matches that are less than MINMATCH */
426             if (combinedMatches[nbCombinedMatches - 1].matchLength < MINMATCH) {
427                 nbCombinedMatches--;
428             }
429 
430             memcpy(combinedMatches + nbCombinedMatches,
431                 state->matches + i, sizeof(ZSTD_eDist_match));
432             nbCombinedMatches++;
433         }
434     }
435     memcpy(state->matches, combinedMatches, nbCombinedMatches * sizeof(ZSTD_eDist_match));
436     state->nbMatches = nbCombinedMatches;
437     ZSTD_free(combinedMatches, ZSTD_defaultCMem);
438 }
439 
ZSTD_eDist_convertMatchesToSequences(ZSTD_Sequence * sequences,ZSTD_eDist_state * state)440 static size_t ZSTD_eDist_convertMatchesToSequences(ZSTD_Sequence* sequences,
441     ZSTD_eDist_state* state)
442 {
443     const ZSTD_eDist_match* matches = state->matches;
444     size_t const nbMatches = state->nbMatches;
445     size_t const dictSize = state->dictSize;
446     size_t nbSequences = 0;
447     size_t i;
448     for (i = 0; i < nbMatches; i++) {
449         ZSTD_eDist_match const match = matches[i];
450         U32 const litLength = !i ? match.srcIdx :
451             match.srcIdx - (matches[i - 1].srcIdx + matches[i - 1].matchLength);
452         U32 const offset = (match.srcIdx + dictSize) - match.dictIdx;
453         U32 const matchLength = match.matchLength;
454         sequences[nbSequences].offset = offset;
455         sequences[nbSequences].litLength = litLength;
456         sequences[nbSequences].matchLength = matchLength;
457         nbSequences++;
458     }
459     return nbSequences;
460 }
461 
462 /*-*************************************
463 *  Internal utils
464 ***************************************/
465 
ZSTD_eDist_hamingDist(const BYTE * const a,const BYTE * const b,size_t n)466 static size_t ZSTD_eDist_hamingDist(const BYTE* const a,
467                         const BYTE* const b, size_t n)
468 {
469     size_t i;
470     size_t dist = 0;
471     for (i = 0; i < n; i++)
472         dist += a[i] != b[i];
473     return dist;
474 }
475 
476 /* This is a pretty naive recursive implementation that should only
477  * be used for quick tests obviously. Don't try and run this on a
478  * GB file or something. There are faster implementations. Use those
479  * if you need to run it for large files. */
ZSTD_eDist_levenshteinDist(const BYTE * const s,size_t const sn,const BYTE * const t,size_t const tn)480 static size_t ZSTD_eDist_levenshteinDist(const BYTE* const s,
481                         size_t const sn, const BYTE* const t,
482                         size_t const tn)
483 {
484     size_t a, b, c;
485 
486     if (!sn)
487         return tn;
488     if (!tn)
489         return sn;
490 
491     if (s[sn - 1] == t[tn - 1])
492         return ZSTD_eDist_levenshteinDist(
493             s, sn - 1, t, tn - 1);
494 
495     a = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn - 1);
496     b = ZSTD_eDist_levenshteinDist(s, sn, t, tn - 1);
497     c = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn);
498 
499     if (a > b)
500         a = b;
501     if (a > c)
502         a = c;
503 
504     return a + 1;
505 }
506 
ZSTD_eDist_validateMatches(ZSTD_eDist_match * matches,size_t const nbMatches,const BYTE * const dict,size_t const dictSize,const BYTE * const src,size_t const srcSize)507 static void ZSTD_eDist_validateMatches(ZSTD_eDist_match* matches,
508                         size_t const nbMatches, const BYTE* const dict,
509                         size_t const dictSize, const BYTE* const src,
510                         size_t const srcSize)
511 {
512     size_t i;
513     for (i = 0; i < nbMatches; i++) {
514         ZSTD_eDist_match match = matches[i];
515         U32 const dictIdx = match.dictIdx;
516         U32 const srcIdx = match.srcIdx;
517         U32 const matchLength = match.matchLength;
518 
519         assert(dictIdx + matchLength < dictSize);
520         assert(srcIdx + matchLength < srcSize);
521         assert(!memcmp(dict + dictIdx, src + srcIdx, matchLength));
522     }
523 }
524 
525 /*-*************************************
526 *  API
527 ***************************************/
528 
ZSTD_eDist_genSequences(ZSTD_Sequence * sequences,const void * dict,size_t dictSize,const void * src,size_t srcSize,int useHeuristics)529 size_t ZSTD_eDist_genSequences(ZSTD_Sequence* sequences,
530                         const void* dict, size_t dictSize,
531                         const void* src, size_t srcSize,
532                         int useHeuristics)
533 {
534     size_t const nbDiags = dictSize + srcSize + 3;
535     S32* buffer = ZSTD_malloc(nbDiags * 2 * sizeof(S32), ZSTD_defaultCMem);
536     ZSTD_eDist_state state;
537     size_t nbSequences = 0;
538 
539     state.dict = (const BYTE*)dict;
540     state.src = (const BYTE*)src;
541     state.dictSize = dictSize;
542     state.srcSize = srcSize;
543     state.forwardDiag = buffer;
544     state.backwardDiag = buffer + nbDiags;
545     state.forwardDiag += srcSize + 1;
546     state.backwardDiag += srcSize + 1;
547     state.matches = ZSTD_malloc(srcSize * sizeof(ZSTD_eDist_match), ZSTD_defaultCMem);
548     state.nbMatches = 0;
549 
550     ZSTD_eDist_compare(&state, 0, dictSize, 0, srcSize, 1);
551     ZSTD_eDist_combineMatches(&state);
552     nbSequences = ZSTD_eDist_convertMatchesToSequences(sequences, &state);
553 
554     ZSTD_free(buffer, ZSTD_defaultCMem);
555     ZSTD_free(state.matches, ZSTD_defaultCMem);
556 
557     return nbSequences;
558 }
559