xref: /aosp_15_r20/external/licenseclassifier/stringclassifier/classifier.go (revision 46c4c49da23cae783fa41bf46525a6505638499a)
1// Copyright 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package stringclassifier finds the nearest match between a string and a set of known values. It
16// uses the Levenshtein Distance (LD) algorithm to determine this. A match with a large LD is less
17// likely to be correct than one with a small LD. A confidence percentage is returned, which
18// indicates how confident the algorithm is that the match is correct. The higher the percentage,
19// the greater the confidence that the match is correct.
20//
21// Example Usage:
22//
23//	type Text struct {
24//	  Name string
25//	  Text string
26//	}
27//
28//	func NewClassifier(knownTexts []Text) (*stringclassifier.Classifier, error) {
29//	  sc := stringclassifier.New(stringclassifier.FlattenWhitespace)
30//	  for _, known := range knownTexts {
31//	    if err := sc.AddValue(known.Name, known.Text); err != nil {
32//	      return nil, err
33//	    }
34//	  }
35//	  return sc, nil
36//	}
37//
38//	func IdentifyTexts(sc *stringclassifier.Classifier, unknownTexts []*Text) {
39//	  for _, unknown := range unknownTexts {
40//	    m := sc.NearestMatch(unknown.Text)
41//	    log.Printf("The nearest match to %q is %q (confidence: %v)",
42//	      unknown.Name, m.Name, m.Confidence)
43//	  }
44//	}
45package stringclassifier
46
47import (
48	"fmt"
49	"log"
50	"math"
51	"regexp"
52	"sort"
53	"sync"
54
55	"github.com/google/licenseclassifier/stringclassifier/internal/pq"
56	"github.com/google/licenseclassifier/stringclassifier/searchset"
57	"github.com/sergi/go-diff/diffmatchpatch"
58)
59
60// The diff/match/patch algorithm.
61var dmp = diffmatchpatch.New()
62
63const (
64	// DefaultConfidenceThreshold is the minimum ratio threshold between
65	// the matching range and the full source range that we're willing to
66	// accept in order to say that the matching range will produce a
67	// sufficiently good edit distance. I.e., if the matching range is
68	// below this threshold we won't run the Levenshtein Distance algorithm
69	// on it.
70	DefaultConfidenceThreshold float64 = 0.80
71
72	defaultMinDiffRatio float64 = 0.75
73)
74
75// A Classifier matches a string to a set of known values.
76type Classifier struct {
77	muValues    sync.RWMutex
78	values      map[string]*knownValue
79	normalizers []NormalizeFunc
80	threshold   float64
81
82	// MinDiffRatio defines the minimum ratio of the length difference
83	// allowed to consider a known value a possible match. This is used as
84	// a performance optimization to eliminate values that are unlikely to
85	// be a match.
86	//
87	// For example, a value of 0.75 means that the shorter string must be
88	// at least 75% the length of the longer string to consider it a
89	// possible match.
90	//
91	// Setting this to 1.0 will require that strings are identical length.
92	// Setting this to 0 will consider all known values as possible
93	// matches.
94	MinDiffRatio float64
95}
96
97// NormalizeFunc is a function that is used to normalize a string prior to comparison.
98type NormalizeFunc func(string) string
99
100// New creates a new Classifier with the provided NormalizeFuncs. Each
101// NormalizeFunc is applied in order to a string before comparison.
102func New(threshold float64, funcs ...NormalizeFunc) *Classifier {
103	return &Classifier{
104		values:       make(map[string]*knownValue),
105		normalizers:  append([]NormalizeFunc(nil), funcs...),
106		threshold:    threshold,
107		MinDiffRatio: defaultMinDiffRatio,
108	}
109}
110
111// knownValue identifies a value in the corpus to match against.
112type knownValue struct {
113	key             string
114	normalizedValue string
115	reValue         *regexp.Regexp
116	set             *searchset.SearchSet
117}
118
119// AddValue adds a known value to be matched against. If a value already exists
120// for key, an error is returned.
121func (c *Classifier) AddValue(key, value string) error {
122	c.muValues.Lock()
123	defer c.muValues.Unlock()
124	if _, ok := c.values[key]; ok {
125		return fmt.Errorf("value already registered with key %q", key)
126	}
127	norm := c.normalize(value)
128	c.values[key] = &knownValue{
129		key:             key,
130		normalizedValue: norm,
131		reValue:         regexp.MustCompile(norm),
132	}
133	return nil
134}
135
136// AddPrecomputedValue adds a known value to be matched against. The value has
137// already been normalized and the SearchSet object deserialized, so no
138// processing is necessary.
139func (c *Classifier) AddPrecomputedValue(key, value string, set *searchset.SearchSet) error {
140	c.muValues.Lock()
141	defer c.muValues.Unlock()
142	if _, ok := c.values[key]; ok {
143		return fmt.Errorf("value already registered with key %q", key)
144	}
145	set.GenerateNodeList()
146	c.values[key] = &knownValue{
147		key:             key,
148		normalizedValue: value,
149		reValue:         regexp.MustCompile(value),
150		set:             set,
151	}
152	return nil
153}
154
155// normalize a string by applying each of the registered NormalizeFuncs.
156func (c *Classifier) normalize(s string) string {
157	for _, fn := range c.normalizers {
158		s = fn(s)
159	}
160	return s
161}
162
163// Match identifies the result of matching a string against a knownValue.
164type Match struct {
165	Name       string  // Name of knownValue that was matched
166	Confidence float64 // Confidence percentage
167	Offset     int     // The offset into the unknown string the match was made
168	Extent     int     // The length from the offset into the unknown string
169}
170
171// Matches is a list of Match-es. This is here mainly so that the list can be
172// sorted.
173type Matches []*Match
174
175func (m Matches) Len() int      { return len(m) }
176func (m Matches) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
177func (m Matches) Less(i, j int) bool {
178	if math.Abs(m[j].Confidence-m[i].Confidence) < math.SmallestNonzeroFloat64 {
179		if m[i].Name == m[j].Name {
180			if m[i].Offset > m[j].Offset {
181				return false
182			}
183			if m[i].Offset == m[j].Offset {
184				return m[i].Extent > m[j].Extent
185			}
186			return true
187		}
188		return m[i].Name < m[j].Name
189	}
190	return m[i].Confidence > m[j].Confidence
191}
192
193// Names returns an unsorted slice of the names of the matched licenses.
194func (m Matches) Names() []string {
195	var names []string
196	for _, n := range m {
197		names = append(names, n.Name)
198	}
199	return names
200}
201
202// uniquify goes through the matches and removes any that are contained within
203// one with a higher confidence. This assumes that Matches is sorted.
204func (m Matches) uniquify() Matches {
205	type matchedRange struct {
206		offset, extent int
207	}
208
209	var matched []matchedRange
210	var matches Matches
211OUTER:
212	for _, match := range m {
213		for _, mr := range matched {
214			if match.Offset >= mr.offset && match.Offset <= mr.offset+mr.extent {
215				continue OUTER
216			}
217		}
218		matched = append(matched, matchedRange{match.Offset, match.Extent})
219		matches = append(matches, match)
220	}
221
222	return matches
223}
224
225// NearestMatch returns the name of the known value that most closely matches
226// the unknown string and a confidence percentage is returned indicating how
227// confident the classifier is in the result. A percentage of "1.0" indicates
228// an exact match, while a percentage of "0.0" indicates a complete mismatch.
229//
230// If the string is equidistant from multiple known values, it is undefined
231// which will be returned.
232func (c *Classifier) NearestMatch(s string) *Match {
233	pq := c.nearestMatch(s)
234	if pq.Len() == 0 {
235		return &Match{}
236	}
237	return pq.Pop().(*Match)
238}
239
240// MultipleMatch tries to determine which known strings are found within an
241// unknown string. This differs from "NearestMatch" in that it looks only at
242// those areas within the unknown string that are likely to match. A list of
243// potential matches are returned. It's up to the caller to determine which
244// ones are acceptable.
245func (c *Classifier) MultipleMatch(s string) (matches Matches) {
246	pq := c.multipleMatch(s)
247	if pq == nil {
248		return matches
249	}
250
251	// A map to remove duplicate entries.
252	m := make(map[Match]bool)
253
254	for pq.Len() != 0 {
255		v := pq.Pop().(*Match)
256		if _, ok := m[*v]; !ok {
257			m[*v] = true
258			matches = append(matches, v)
259		}
260	}
261
262	sort.Sort(matches)
263	return matches.uniquify()
264}
265
266// possibleMatch identifies a known value and it's diffRatio to a given string.
267type possibleMatch struct {
268	value     *knownValue
269	diffRatio float64
270}
271
272// likelyMatches is a slice of possibleMatches that can be sorted by their
273// diffRatio to a given string, such that the most likely matches (based on
274// length) are at the beginning.
275type likelyMatches []possibleMatch
276
277func (m likelyMatches) Len() int           { return len(m) }
278func (m likelyMatches) Less(i, j int) bool { return m[i].diffRatio > m[j].diffRatio }
279func (m likelyMatches) Swap(i, j int)      { m[i], m[j] = m[j], m[i] }
280
281// nearestMatch returns a Queue of values that the unknown string may be. The
282// values are compared via their Levenshtein Distance and ranked with the
283// nearest match at the beginning.
284func (c *Classifier) nearestMatch(unknown string) *pq.Queue {
285	var mu sync.Mutex // Protect the priority queue.
286	pq := pq.NewQueue(func(x, y interface{}) bool {
287		return x.(*Match).Confidence > y.(*Match).Confidence
288	}, nil)
289
290	unknown = c.normalize(unknown)
291	if len(unknown) == 0 {
292		return pq
293	}
294
295	c.muValues.RLock()
296	var likely likelyMatches
297	for _, v := range c.values {
298		dr := diffRatio(unknown, v.normalizedValue)
299		if dr < c.MinDiffRatio {
300			continue
301		}
302		if unknown == v.normalizedValue {
303			// We found an exact match.
304			pq.Push(&Match{Name: v.key, Confidence: 1.0, Offset: 0, Extent: len(unknown)})
305			c.muValues.RUnlock()
306			return pq
307		}
308		likely = append(likely, possibleMatch{value: v, diffRatio: dr})
309	}
310	c.muValues.RUnlock()
311	sort.Sort(likely)
312
313	var wg sync.WaitGroup
314	classifyString := func(name, unknown, known string) {
315		defer wg.Done()
316
317		diffs := dmp.DiffMain(unknown, known, true)
318		distance := dmp.DiffLevenshtein(diffs)
319		confidence := confidencePercentage(len(unknown), len(known), distance)
320		if confidence > 0.0 {
321			mu.Lock()
322			pq.Push(&Match{Name: name, Confidence: confidence, Offset: 0, Extent: len(unknown)})
323			mu.Unlock()
324		}
325	}
326
327	wg.Add(len(likely))
328	for _, known := range likely {
329		go classifyString(known.value.key, unknown, known.value.normalizedValue)
330	}
331	wg.Wait()
332	return pq
333}
334
335// matcher finds all potential matches of "known" in "unknown". The results are
336// placed in "queue".
337type matcher struct {
338	unknown     *searchset.SearchSet
339	normUnknown string
340	threshold   float64
341
342	mu    sync.Mutex
343	queue *pq.Queue
344}
345
346// newMatcher creates a "matcher" object.
347func newMatcher(unknown string, threshold float64) *matcher {
348	return &matcher{
349		unknown:     searchset.New(unknown, searchset.DefaultGranularity),
350		normUnknown: unknown,
351		threshold:   threshold,
352		queue: pq.NewQueue(func(x, y interface{}) bool {
353			return x.(*Match).Confidence > y.(*Match).Confidence
354		}, nil),
355	}
356}
357
358// findMatches takes a known text and finds all potential instances of it in
359// the unknown text. The resulting matches can then filtered to determine which
360// are the best matches.
361func (m *matcher) findMatches(known *knownValue) {
362	var mrs []searchset.MatchRanges
363	if all := known.reValue.FindAllStringIndex(m.normUnknown, -1); all != nil {
364		// We found exact matches. Just use those!
365		for _, a := range all {
366			var start, end int
367			for i, tok := range m.unknown.Tokens {
368				if tok.Offset == a[0] {
369					start = i
370				} else if tok.Offset >= a[len(a)-1]-len(tok.Text) {
371					end = i
372					break
373				}
374			}
375
376			mrs = append(mrs, searchset.MatchRanges{{
377				SrcStart:    0,
378				SrcEnd:      len(known.set.Tokens),
379				TargetStart: start,
380				TargetEnd:   end + 1,
381			}})
382		}
383	} else {
384		// No exact match. Perform a more thorough match.
385		mrs = searchset.FindPotentialMatches(known.set, m.unknown)
386	}
387
388	var wg sync.WaitGroup
389	for _, mr := range mrs {
390		if !m.withinConfidenceThreshold(known.set, mr) {
391			continue
392		}
393
394		wg.Add(1)
395		go func(mr searchset.MatchRanges) {
396			start, end := mr.TargetRange(m.unknown)
397			conf := levDist(m.normUnknown[start:end], known.normalizedValue)
398			if conf > 0.0 {
399				m.mu.Lock()
400				m.queue.Push(&Match{Name: known.key, Confidence: conf, Offset: start, Extent: end - start})
401				m.mu.Unlock()
402			}
403			wg.Done()
404		}(mr)
405	}
406	wg.Wait()
407}
408
409// withinConfidenceThreshold returns the Confidence we have in the potential
410// match. It does this by calculating the ratio of what's matching to the
411// original known text.
412func (m *matcher) withinConfidenceThreshold(known *searchset.SearchSet, mr searchset.MatchRanges) bool {
413	return float64(mr.Size())/float64(len(known.Tokens)) >= m.threshold
414}
415
416// multipleMatch returns a Queue of values that might be within the unknown
417// string. The values are compared via their Levenshtein Distance and ranked
418// with the nearest match at the beginning.
419func (c *Classifier) multipleMatch(unknown string) *pq.Queue {
420	normUnknown := c.normalize(unknown)
421	if normUnknown == "" {
422		return nil
423	}
424
425	m := newMatcher(normUnknown, c.threshold)
426
427	c.muValues.RLock()
428	var kvals []*knownValue
429	for _, known := range c.values {
430		kvals = append(kvals, known)
431	}
432	c.muValues.RUnlock()
433
434	var wg sync.WaitGroup
435	wg.Add(len(kvals))
436	for _, known := range kvals {
437		go func(known *knownValue) {
438			if known.set == nil {
439				k := searchset.New(known.normalizedValue, searchset.DefaultGranularity)
440				c.muValues.Lock()
441				c.values[known.key].set = k
442				c.muValues.Unlock()
443			}
444			m.findMatches(known)
445			wg.Done()
446		}(known)
447	}
448	wg.Wait()
449	return m.queue
450}
451
452// levDist runs the Levenshtein Distance algorithm on the known and unknown
453// texts to measure how well they match.
454func levDist(unknown, known string) float64 {
455	if len(known) == 0 || len(unknown) == 0 {
456		log.Printf("Zero-sized texts in Levenshtein Distance algorithm: known==%d, unknown==%d", len(known), len(unknown))
457		return 0.0
458	}
459
460	// Calculate the differences between the potentially matching known
461	// text and the unknown text.
462	diffs := dmp.DiffMain(unknown, known, false)
463	end := diffRangeEnd(known, diffs)
464
465	// Now execute the Levenshtein Distance algorithm to see how much it
466	// does match.
467	distance := dmp.DiffLevenshtein(diffs[:end])
468	return confidencePercentage(unknownTextLength(unknown, diffs), len(known), distance)
469}
470
471// unknownTextLength returns the length of the unknown text based on the diff range.
472func unknownTextLength(unknown string, diffs []diffmatchpatch.Diff) int {
473	last := len(diffs) - 1
474	for ; last >= 0; last-- {
475		if diffs[last].Type == diffmatchpatch.DiffEqual {
476			break
477		}
478	}
479	ulen := 0
480	for i := 0; i < last+1; i++ {
481		switch diffs[i].Type {
482		case diffmatchpatch.DiffEqual, diffmatchpatch.DiffDelete:
483			ulen += len(diffs[i].Text)
484		}
485	}
486	return ulen
487}
488
489// diffRangeEnd returns the end index for the "Diff" objects that constructs
490// (or nearly constructs) the "known" value.
491func diffRangeEnd(known string, diffs []diffmatchpatch.Diff) (end int) {
492	var seen string
493	for end = 0; end < len(diffs); end++ {
494		if seen == known {
495			// Once we've constructed the "known" value, then we've
496			// reached the point in the diff list where more
497			// "Diff"s would just make the Levenshtein Distance
498			// less valid. There shouldn't be further "DiffEqual"
499			// nodes, because there's nothing further to match in
500			// the "known" text.
501			break
502		}
503		switch diffs[end].Type {
504		case diffmatchpatch.DiffEqual, diffmatchpatch.DiffInsert:
505			seen += diffs[end].Text
506		}
507	}
508	return end
509}
510
511// confidencePercentage calculates how confident we are in the result of the
512// match. A percentage of "1.0" means an identical match. A confidence of "0.0"
513// means a complete mismatch.
514func confidencePercentage(ulen, klen, distance int) float64 {
515	if ulen == 0 && klen == 0 {
516		return 1.0
517	}
518	if ulen == 0 || klen == 0 || (distance > ulen && distance > klen) {
519		return 0.0
520	}
521	return 1.0 - float64(distance)/float64(max(ulen, klen))
522}
523
524// diffRatio calculates the ratio of the length of s1 and s2, returned as a
525// percentage of the length of the longer string. E.g., diffLength("abcd", "e")
526// would return 0.25 because "e" is 25% of the size of "abcd". Comparing
527// strings of equal length will return 1.
528func diffRatio(s1, s2 string) float64 {
529	x, y := len(s1), len(s2)
530	if x == 0 && y == 0 {
531		// Both strings are zero length
532		return 1.0
533	}
534	if x < y {
535		return float64(x) / float64(y)
536	}
537	return float64(y) / float64(x)
538}
539
540func max(a, b int) int {
541	if a > b {
542		return a
543	}
544	return b
545}
546
547func min(a, b int) int {
548	if a < b {
549		return a
550	}
551	return b
552}
553
554// wsRegexp is a regexp used to identify blocks of whitespace.
555var wsRegexp = regexp.MustCompile(`\s+`)
556
557// FlattenWhitespace will flatten contiguous blocks of whitespace down to a single space.
558var FlattenWhitespace NormalizeFunc = func(s string) string {
559	return wsRegexp.ReplaceAllString(s, " ")
560}
561