1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Test that Resolver.Dial can be a func returning an in-memory net.Conn
6// speaking DNS.
7
8package net
9
10import (
11	"bytes"
12	"context"
13	"errors"
14	"fmt"
15	"reflect"
16	"slices"
17	"testing"
18	"time"
19
20	"golang.org/x/net/dns/dnsmessage"
21)
22
23func TestResolverDialFunc(t *testing.T) {
24	r := &Resolver{
25		PreferGo: true,
26		Dial: newResolverDialFunc(&resolverDialHandler{
27			StartDial: func(network, address string) error {
28				t.Logf("StartDial(%q, %q) ...", network, address)
29				return nil
30			},
31			Question: func(h dnsmessage.Header, q dnsmessage.Question) {
32				t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
33					q.Name.String(), q.Type, q.Class)
34			},
35			// TODO: add test without HandleA* hooks specified at all, that Go
36			// doesn't issue retries; map to something terminal.
37			HandleA: func(w AWriter, name string) error {
38				w.AddIP([4]byte{1, 2, 3, 4})
39				w.AddIP([4]byte{5, 6, 7, 8})
40				return nil
41			},
42			HandleAAAA: func(w AAAAWriter, name string) error {
43				w.AddIP([16]byte{1: 1, 15: 15})
44				w.AddIP([16]byte{2: 2, 14: 14})
45				return nil
46			},
47			HandleSRV: func(w SRVWriter, name string) error {
48				w.AddSRV(1, 2, 80, "foo.bar.")
49				w.AddSRV(2, 3, 81, "bar.baz.")
50				return nil
51			},
52		}),
53	}
54	ctx := context.Background()
55	const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
56
57	t.Run("LookupIP", func(t *testing.T) {
58		ips, err := r.LookupIP(ctx, "ip", fakeDomain)
59		if err != nil {
60			t.Fatal(err)
61		}
62		if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
63			t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
64		}
65	})
66
67	t.Run("LookupSRV", func(t *testing.T) {
68		_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
69		if err != nil {
70			t.Fatal(err)
71		}
72		want := []*SRV{
73			{
74				Target:   "foo.bar.",
75				Port:     80,
76				Priority: 1,
77				Weight:   2,
78			},
79			{
80				Target:   "bar.baz.",
81				Port:     81,
82				Priority: 2,
83				Weight:   3,
84			},
85		}
86		if !reflect.DeepEqual(got, want) {
87			t.Errorf("wrong result. got:")
88			for _, r := range got {
89				t.Logf("  - %+v", r)
90			}
91		}
92	})
93}
94
95func sortedIPStrings(ips []IP) []string {
96	ret := make([]string, len(ips))
97	for i, ip := range ips {
98		ret[i] = ip.String()
99	}
100	slices.Sort(ret)
101	return ret
102}
103
104func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
105	return func(ctx context.Context, network, address string) (Conn, error) {
106		a := &resolverFuncConn{
107			h:       h,
108			network: network,
109			address: address,
110			ttl:     10, // 10 second default if unset
111		}
112		if h.StartDial != nil {
113			if err := h.StartDial(network, address); err != nil {
114				return nil, err
115			}
116		}
117		return a, nil
118	}
119}
120
121type resolverDialHandler struct {
122	// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
123	// Any error returned aborts the dial and is returned unwrapped.
124	StartDial func(network, address string) error
125
126	Question func(dnsmessage.Header, dnsmessage.Question)
127
128	// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
129	// A nil error means success.
130	HandleA    func(w AWriter, name string) error
131	HandleAAAA func(w AAAAWriter, name string) error
132	HandleSRV  func(w SRVWriter, name string) error
133}
134
135type ResponseWriter struct{ a *resolverFuncConn }
136
137func (w ResponseWriter) header() dnsmessage.ResourceHeader {
138	q := w.a.q
139	return dnsmessage.ResourceHeader{
140		Name:  q.Name,
141		Type:  q.Type,
142		Class: q.Class,
143		TTL:   w.a.ttl,
144	}
145}
146
147// SetTTL sets the TTL for subsequent written resources.
148// Once a resource has been written, SetTTL calls are no-ops.
149// That is, it can only be called at most once, before anything
150// else is written.
151func (w ResponseWriter) SetTTL(seconds uint32) {
152	// ... intention is last one wins and mutates all previously
153	// written records too, but that's a little annoying.
154	// But it's also annoying if the requirement is it needs to be set
155	// last.
156	// And it's also annoying if it's possible for users to set
157	// different TTLs per Answer.
158	if w.a.wrote {
159		return
160	}
161	w.a.ttl = seconds
162
163}
164
165type AWriter struct{ ResponseWriter }
166
167func (w AWriter) AddIP(v4 [4]byte) {
168	w.a.wrote = true
169	err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
170	if err != nil {
171		panic(err)
172	}
173}
174
175type AAAAWriter struct{ ResponseWriter }
176
177func (w AAAAWriter) AddIP(v6 [16]byte) {
178	w.a.wrote = true
179	err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
180	if err != nil {
181		panic(err)
182	}
183}
184
185type SRVWriter struct{ ResponseWriter }
186
187// AddSRV adds a SRV record. The target name must end in a period and
188// be 63 bytes or fewer.
189func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
190	targetName, err := dnsmessage.NewName(target)
191	if err != nil {
192		return err
193	}
194	w.a.wrote = true
195	err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
196		Priority: priority,
197		Weight:   weight,
198		Port:     port,
199		Target:   targetName,
200	})
201	if err != nil {
202		panic(err) // internal fault, not user
203	}
204	return nil
205}
206
207var (
208	ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
209	ErrRefused  = errors.New("refused")             // maps to RCode5, REFUSED
210)
211
212type resolverFuncConn struct {
213	h       *resolverDialHandler
214	network string
215	address string
216	builder *dnsmessage.Builder
217	q       dnsmessage.Question
218	ttl     uint32
219	wrote   bool
220
221	rbuf bytes.Buffer
222}
223
224func (*resolverFuncConn) Close() error                       { return nil }
225func (*resolverFuncConn) LocalAddr() Addr                    { return someaddr{} }
226func (*resolverFuncConn) RemoteAddr() Addr                   { return someaddr{} }
227func (*resolverFuncConn) SetDeadline(t time.Time) error      { return nil }
228func (*resolverFuncConn) SetReadDeadline(t time.Time) error  { return nil }
229func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
230
231func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
232	return a.rbuf.Read(p)
233}
234
235func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
236	if len(packet) < 2 {
237		return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
238	}
239	reqLen := int(packet[0])<<8 | int(packet[1])
240	req := packet[2:]
241	if len(req) != reqLen {
242		return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
243	}
244
245	var parser dnsmessage.Parser
246	h, err := parser.Start(req)
247	if err != nil {
248		// TODO: hook
249		return 0, err
250	}
251	q, err := parser.Question()
252	hadQ := (err == nil)
253	if err == nil && a.h.Question != nil {
254		a.h.Question(h, q)
255	}
256	if err != nil && err != dnsmessage.ErrSectionDone {
257		return 0, err
258	}
259
260	resh := h
261	resh.Response = true
262	resh.Authoritative = true
263	if hadQ {
264		resh.RCode = dnsmessage.RCodeSuccess
265	} else {
266		resh.RCode = dnsmessage.RCodeNotImplemented
267	}
268	a.rbuf.Grow(514)
269	a.rbuf.WriteByte('X') // reserved header for beu16 length
270	a.rbuf.WriteByte('Y') // reserved header for beu16 length
271	builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
272	a.builder = &builder
273	if hadQ {
274		a.q = q
275		a.builder.StartQuestions()
276		err := a.builder.Question(q)
277		if err != nil {
278			return 0, fmt.Errorf("Question: %w", err)
279		}
280		a.builder.StartAnswers()
281		switch q.Type {
282		case dnsmessage.TypeA:
283			if a.h.HandleA != nil {
284				resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
285			}
286		case dnsmessage.TypeAAAA:
287			if a.h.HandleAAAA != nil {
288				resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
289			}
290		case dnsmessage.TypeSRV:
291			if a.h.HandleSRV != nil {
292				resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
293			}
294		}
295	}
296	tcpRes, err := builder.Finish()
297	if err != nil {
298		return 0, fmt.Errorf("Finish: %w", err)
299	}
300
301	n = len(tcpRes) - 2
302	tcpRes[0] = byte(n >> 8)
303	tcpRes[1] = byte(n)
304	a.rbuf.Write(tcpRes[2:])
305
306	return len(packet), nil
307}
308
309type someaddr struct{}
310
311func (someaddr) Network() string { return "unused" }
312func (someaddr) String() string  { return "unused-someaddr" }
313
314func mapRCode(err error) dnsmessage.RCode {
315	switch err {
316	case nil:
317		return dnsmessage.RCodeSuccess
318	case ErrNotExist:
319		return dnsmessage.RCodeNameError
320	case ErrRefused:
321		return dnsmessage.RCodeRefused
322	default:
323		return dnsmessage.RCodeServerFailure
324	}
325}
326