1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http
6
7import "math"
8
9// A routingIndex optimizes conflict detection by indexing patterns.
10//
11// The basic idea is to rule out patterns that cannot conflict with a given
12// pattern because they have a different literal in a corresponding segment.
13// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
14type routingIndex struct {
15	// map from a particular segment position and value to all registered patterns
16	// with that value in that position.
17	// For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
18	// but not "/a", "b/a", "/a/c" or "/a/{x}".
19	segments map[routingIndexKey][]*pattern
20	// All patterns that end in a multi wildcard (including trailing slash).
21	// We do not try to be clever about indexing multi patterns, because there
22	// are unlikely to be many of them.
23	multis []*pattern
24}
25
26type routingIndexKey struct {
27	pos int    // 0-based segment position
28	s   string // literal, or empty for wildcard
29}
30
31func (idx *routingIndex) addPattern(pat *pattern) {
32	if pat.lastSegment().multi {
33		idx.multis = append(idx.multis, pat)
34	} else {
35		if idx.segments == nil {
36			idx.segments = map[routingIndexKey][]*pattern{}
37		}
38		for pos, seg := range pat.segments {
39			key := routingIndexKey{pos: pos, s: ""}
40			if !seg.wild {
41				key.s = seg.s
42			}
43			idx.segments[key] = append(idx.segments[key], pat)
44		}
45	}
46}
47
48// possiblyConflictingPatterns calls f on all patterns that might conflict with
49// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
50// with that error.
51//
52// To be correct, possiblyConflictingPatterns must include all patterns that
53// might conflict. But it may also include patterns that cannot conflict.
54// For instance, an implementation that returns all registered patterns is correct.
55// We use this fact throughout, simplifying the implementation by returning more
56// patterns that we might need to.
57func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
58	// Terminology:
59	//   dollar pattern: one ending in "{$}"
60	//   multi pattern: one ending in a trailing slash or "{x...}" wildcard
61	//   ordinary pattern: neither of the above
62
63	// apply f to all the pats, stopping on error.
64	apply := func(pats []*pattern) error {
65		if err != nil {
66			return err
67		}
68		for _, p := range pats {
69			err = f(p)
70			if err != nil {
71				return err
72			}
73		}
74		return nil
75	}
76
77	// Our simple indexing scheme doesn't try to prune multi patterns; assume
78	// any of them can match the argument.
79	if err := apply(idx.multis); err != nil {
80		return err
81	}
82	if pat.lastSegment().s == "/" {
83		// All paths that a dollar pattern matches end in a slash; no paths that
84		// an ordinary pattern matches do. So only other dollar or multi
85		// patterns can conflict with a dollar pattern. Furthermore, conflicting
86		// dollar patterns must have the {$} in the same position.
87		return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
88	}
89	// For ordinary and multi patterns, the only conflicts can be with a multi,
90	// or a pattern that has the same literal or a wildcard at some literal
91	// position.
92	// We could intersect all the possible matches at each position, but we
93	// do something simpler: we find the position with the fewest patterns.
94	var lmin, wmin []*pattern
95	min := math.MaxInt
96	hasLit := false
97	for i, seg := range pat.segments {
98		if seg.multi {
99			break
100		}
101		if !seg.wild {
102			hasLit = true
103			lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
104			wpats := idx.segments[routingIndexKey{s: "", pos: i}]
105			if sum := len(lpats) + len(wpats); sum < min {
106				lmin = lpats
107				wmin = wpats
108				min = sum
109			}
110		}
111	}
112	if hasLit {
113		apply(lmin)
114		apply(wmin)
115		return err
116	}
117
118	// This pattern is all wildcards.
119	// Check it against everything.
120	for _, pats := range idx.segments {
121		apply(pats)
122	}
123	return err
124}
125