1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http
6
7import (
8	"io"
9	"net/http/httptrace"
10	"net/http/internal/ascii"
11	"net/textproto"
12	"slices"
13	"strings"
14	"sync"
15	"time"
16
17	"golang.org/x/net/http/httpguts"
18)
19
20// A Header represents the key-value pairs in an HTTP header.
21//
22// The keys should be in canonical form, as returned by
23// [CanonicalHeaderKey].
24type Header map[string][]string
25
26// Add adds the key, value pair to the header.
27// It appends to any existing values associated with key.
28// The key is case insensitive; it is canonicalized by
29// [CanonicalHeaderKey].
30func (h Header) Add(key, value string) {
31	textproto.MIMEHeader(h).Add(key, value)
32}
33
34// Set sets the header entries associated with key to the
35// single element value. It replaces any existing values
36// associated with key. The key is case insensitive; it is
37// canonicalized by [textproto.CanonicalMIMEHeaderKey].
38// To use non-canonical keys, assign to the map directly.
39func (h Header) Set(key, value string) {
40	textproto.MIMEHeader(h).Set(key, value)
41}
42
43// Get gets the first value associated with the given key. If
44// there are no values associated with the key, Get returns "".
45// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is
46// used to canonicalize the provided key. Get assumes that all
47// keys are stored in canonical form. To use non-canonical keys,
48// access the map directly.
49func (h Header) Get(key string) string {
50	return textproto.MIMEHeader(h).Get(key)
51}
52
53// Values returns all values associated with the given key.
54// It is case insensitive; [textproto.CanonicalMIMEHeaderKey] is
55// used to canonicalize the provided key. To use non-canonical
56// keys, access the map directly.
57// The returned slice is not a copy.
58func (h Header) Values(key string) []string {
59	return textproto.MIMEHeader(h).Values(key)
60}
61
62// get is like Get, but key must already be in CanonicalHeaderKey form.
63func (h Header) get(key string) string {
64	if v := h[key]; len(v) > 0 {
65		return v[0]
66	}
67	return ""
68}
69
70// has reports whether h has the provided key defined, even if it's
71// set to 0-length slice.
72func (h Header) has(key string) bool {
73	_, ok := h[key]
74	return ok
75}
76
77// Del deletes the values associated with key.
78// The key is case insensitive; it is canonicalized by
79// [CanonicalHeaderKey].
80func (h Header) Del(key string) {
81	textproto.MIMEHeader(h).Del(key)
82}
83
84// Write writes a header in wire format.
85func (h Header) Write(w io.Writer) error {
86	return h.write(w, nil)
87}
88
89func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
90	return h.writeSubset(w, nil, trace)
91}
92
93// Clone returns a copy of h or nil if h is nil.
94func (h Header) Clone() Header {
95	if h == nil {
96		return nil
97	}
98
99	// Find total number of values.
100	nv := 0
101	for _, vv := range h {
102		nv += len(vv)
103	}
104	sv := make([]string, nv) // shared backing array for headers' values
105	h2 := make(Header, len(h))
106	for k, vv := range h {
107		if vv == nil {
108			// Preserve nil values. ReverseProxy distinguishes
109			// between nil and zero-length header values.
110			h2[k] = nil
111			continue
112		}
113		n := copy(sv, vv)
114		h2[k] = sv[:n:n]
115		sv = sv[n:]
116	}
117	return h2
118}
119
120var timeFormats = []string{
121	TimeFormat,
122	time.RFC850,
123	time.ANSIC,
124}
125
126// ParseTime parses a time header (such as the Date: header),
127// trying each of the three formats allowed by HTTP/1.1:
128// [TimeFormat], [time.RFC850], and [time.ANSIC].
129func ParseTime(text string) (t time.Time, err error) {
130	for _, layout := range timeFormats {
131		t, err = time.Parse(layout, text)
132		if err == nil {
133			return
134		}
135	}
136	return
137}
138
139var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
140
141// stringWriter implements WriteString on a Writer.
142type stringWriter struct {
143	w io.Writer
144}
145
146func (w stringWriter) WriteString(s string) (n int, err error) {
147	return w.w.Write([]byte(s))
148}
149
150type keyValues struct {
151	key    string
152	values []string
153}
154
155// headerSorter contains a slice of keyValues sorted by keyValues.key.
156type headerSorter struct {
157	kvs []keyValues
158}
159
160var headerSorterPool = sync.Pool{
161	New: func() any { return new(headerSorter) },
162}
163
164// sortedKeyValues returns h's keys sorted in the returned kvs
165// slice. The headerSorter used to sort is also returned, for possible
166// return to headerSorterCache.
167func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
168	hs = headerSorterPool.Get().(*headerSorter)
169	if cap(hs.kvs) < len(h) {
170		hs.kvs = make([]keyValues, 0, len(h))
171	}
172	kvs = hs.kvs[:0]
173	for k, vv := range h {
174		if !exclude[k] {
175			kvs = append(kvs, keyValues{k, vv})
176		}
177	}
178	hs.kvs = kvs
179	slices.SortFunc(hs.kvs, func(a, b keyValues) int { return strings.Compare(a.key, b.key) })
180	return kvs, hs
181}
182
183// WriteSubset writes a header in wire format.
184// If exclude is not nil, keys where exclude[key] == true are not written.
185// Keys are not canonicalized before checking the exclude map.
186func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
187	return h.writeSubset(w, exclude, nil)
188}
189
190func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
191	ws, ok := w.(io.StringWriter)
192	if !ok {
193		ws = stringWriter{w}
194	}
195	kvs, sorter := h.sortedKeyValues(exclude)
196	var formattedVals []string
197	for _, kv := range kvs {
198		if !httpguts.ValidHeaderFieldName(kv.key) {
199			// This could be an error. In the common case of
200			// writing response headers, however, we have no good
201			// way to provide the error back to the server
202			// handler, so just drop invalid headers instead.
203			continue
204		}
205		for _, v := range kv.values {
206			v = headerNewlineToSpace.Replace(v)
207			v = textproto.TrimString(v)
208			for _, s := range []string{kv.key, ": ", v, "\r\n"} {
209				if _, err := ws.WriteString(s); err != nil {
210					headerSorterPool.Put(sorter)
211					return err
212				}
213			}
214			if trace != nil && trace.WroteHeaderField != nil {
215				formattedVals = append(formattedVals, v)
216			}
217		}
218		if trace != nil && trace.WroteHeaderField != nil {
219			trace.WroteHeaderField(kv.key, formattedVals)
220			formattedVals = nil
221		}
222	}
223	headerSorterPool.Put(sorter)
224	return nil
225}
226
227// CanonicalHeaderKey returns the canonical format of the
228// header key s. The canonicalization converts the first
229// letter and any letter following a hyphen to upper case;
230// the rest are converted to lowercase. For example, the
231// canonical key for "accept-encoding" is "Accept-Encoding".
232// If s contains a space or invalid header field bytes, it is
233// returned without modifications.
234func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
235
236// hasToken reports whether token appears with v, ASCII
237// case-insensitive, with space or comma boundaries.
238// token must be all lowercase.
239// v may contain mixed cased.
240func hasToken(v, token string) bool {
241	if len(token) > len(v) || token == "" {
242		return false
243	}
244	if v == token {
245		return true
246	}
247	for sp := 0; sp <= len(v)-len(token); sp++ {
248		// Check that first character is good.
249		// The token is ASCII, so checking only a single byte
250		// is sufficient. We skip this potential starting
251		// position if both the first byte and its potential
252		// ASCII uppercase equivalent (b|0x20) don't match.
253		// False positives ('^' => '~') are caught by EqualFold.
254		if b := v[sp]; b != token[0] && b|0x20 != token[0] {
255			continue
256		}
257		// Check that start pos is on a valid token boundary.
258		if sp > 0 && !isTokenBoundary(v[sp-1]) {
259			continue
260		}
261		// Check that end pos is on a valid token boundary.
262		if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
263			continue
264		}
265		if ascii.EqualFold(v[sp:sp+len(token)], token) {
266			return true
267		}
268	}
269	return false
270}
271
272func isTokenBoundary(b byte) bool {
273	return b == ' ' || b == ',' || b == '\t'
274}
275