1// Copyright 2022 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Test that Resolver.Dial can be a func returning an in-memory net.Conn 6// speaking DNS. 7 8package net 9 10import ( 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "reflect" 16 "slices" 17 "testing" 18 "time" 19 20 "golang.org/x/net/dns/dnsmessage" 21) 22 23func TestResolverDialFunc(t *testing.T) { 24 r := &Resolver{ 25 PreferGo: true, 26 Dial: newResolverDialFunc(&resolverDialHandler{ 27 StartDial: func(network, address string) error { 28 t.Logf("StartDial(%q, %q) ...", network, address) 29 return nil 30 }, 31 Question: func(h dnsmessage.Header, q dnsmessage.Question) { 32 t.Logf("Header: %+v for %q (type=%v, class=%v)", h, 33 q.Name.String(), q.Type, q.Class) 34 }, 35 // TODO: add test without HandleA* hooks specified at all, that Go 36 // doesn't issue retries; map to something terminal. 37 HandleA: func(w AWriter, name string) error { 38 w.AddIP([4]byte{1, 2, 3, 4}) 39 w.AddIP([4]byte{5, 6, 7, 8}) 40 return nil 41 }, 42 HandleAAAA: func(w AAAAWriter, name string) error { 43 w.AddIP([16]byte{1: 1, 15: 15}) 44 w.AddIP([16]byte{2: 2, 14: 14}) 45 return nil 46 }, 47 HandleSRV: func(w SRVWriter, name string) error { 48 w.AddSRV(1, 2, 80, "foo.bar.") 49 w.AddSRV(2, 3, 81, "bar.baz.") 50 return nil 51 }, 52 }), 53 } 54 ctx := context.Background() 55 const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." 56 57 t.Run("LookupIP", func(t *testing.T) { 58 ips, err := r.LookupIP(ctx, "ip", fakeDomain) 59 if err != nil { 60 t.Fatal(err) 61 } 62 if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) { 63 t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want) 64 } 65 }) 66 67 t.Run("LookupSRV", func(t *testing.T) { 68 _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain) 69 if err != nil { 70 t.Fatal(err) 71 } 72 want := []*SRV{ 73 { 74 Target: "foo.bar.", 75 Port: 80, 76 Priority: 1, 77 Weight: 2, 78 }, 79 { 80 Target: "bar.baz.", 81 Port: 81, 82 Priority: 2, 83 Weight: 3, 84 }, 85 } 86 if !reflect.DeepEqual(got, want) { 87 t.Errorf("wrong result. got:") 88 for _, r := range got { 89 t.Logf(" - %+v", r) 90 } 91 } 92 }) 93} 94 95func sortedIPStrings(ips []IP) []string { 96 ret := make([]string, len(ips)) 97 for i, ip := range ips { 98 ret[i] = ip.String() 99 } 100 slices.Sort(ret) 101 return ret 102} 103 104func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { 105 return func(ctx context.Context, network, address string) (Conn, error) { 106 a := &resolverFuncConn{ 107 h: h, 108 network: network, 109 address: address, 110 ttl: 10, // 10 second default if unset 111 } 112 if h.StartDial != nil { 113 if err := h.StartDial(network, address); err != nil { 114 return nil, err 115 } 116 } 117 return a, nil 118 } 119} 120 121type resolverDialHandler struct { 122 // StartDial, if non-nil, is called when Go first calls Resolver.Dial. 123 // Any error returned aborts the dial and is returned unwrapped. 124 StartDial func(network, address string) error 125 126 Question func(dnsmessage.Header, dnsmessage.Question) 127 128 // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). 129 // A nil error means success. 130 HandleA func(w AWriter, name string) error 131 HandleAAAA func(w AAAAWriter, name string) error 132 HandleSRV func(w SRVWriter, name string) error 133} 134 135type ResponseWriter struct{ a *resolverFuncConn } 136 137func (w ResponseWriter) header() dnsmessage.ResourceHeader { 138 q := w.a.q 139 return dnsmessage.ResourceHeader{ 140 Name: q.Name, 141 Type: q.Type, 142 Class: q.Class, 143 TTL: w.a.ttl, 144 } 145} 146 147// SetTTL sets the TTL for subsequent written resources. 148// Once a resource has been written, SetTTL calls are no-ops. 149// That is, it can only be called at most once, before anything 150// else is written. 151func (w ResponseWriter) SetTTL(seconds uint32) { 152 // ... intention is last one wins and mutates all previously 153 // written records too, but that's a little annoying. 154 // But it's also annoying if the requirement is it needs to be set 155 // last. 156 // And it's also annoying if it's possible for users to set 157 // different TTLs per Answer. 158 if w.a.wrote { 159 return 160 } 161 w.a.ttl = seconds 162 163} 164 165type AWriter struct{ ResponseWriter } 166 167func (w AWriter) AddIP(v4 [4]byte) { 168 w.a.wrote = true 169 err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) 170 if err != nil { 171 panic(err) 172 } 173} 174 175type AAAAWriter struct{ ResponseWriter } 176 177func (w AAAAWriter) AddIP(v6 [16]byte) { 178 w.a.wrote = true 179 err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) 180 if err != nil { 181 panic(err) 182 } 183} 184 185type SRVWriter struct{ ResponseWriter } 186 187// AddSRV adds a SRV record. The target name must end in a period and 188// be 63 bytes or fewer. 189func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { 190 targetName, err := dnsmessage.NewName(target) 191 if err != nil { 192 return err 193 } 194 w.a.wrote = true 195 err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ 196 Priority: priority, 197 Weight: weight, 198 Port: port, 199 Target: targetName, 200 }) 201 if err != nil { 202 panic(err) // internal fault, not user 203 } 204 return nil 205} 206 207var ( 208 ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN 209 ErrRefused = errors.New("refused") // maps to RCode5, REFUSED 210) 211 212type resolverFuncConn struct { 213 h *resolverDialHandler 214 network string 215 address string 216 builder *dnsmessage.Builder 217 q dnsmessage.Question 218 ttl uint32 219 wrote bool 220 221 rbuf bytes.Buffer 222} 223 224func (*resolverFuncConn) Close() error { return nil } 225func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } 226func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } 227func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } 228func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } 229func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } 230 231func (a *resolverFuncConn) Read(p []byte) (n int, err error) { 232 return a.rbuf.Read(p) 233} 234 235func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { 236 if len(packet) < 2 { 237 return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) 238 } 239 reqLen := int(packet[0])<<8 | int(packet[1]) 240 req := packet[2:] 241 if len(req) != reqLen { 242 return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) 243 } 244 245 var parser dnsmessage.Parser 246 h, err := parser.Start(req) 247 if err != nil { 248 // TODO: hook 249 return 0, err 250 } 251 q, err := parser.Question() 252 hadQ := (err == nil) 253 if err == nil && a.h.Question != nil { 254 a.h.Question(h, q) 255 } 256 if err != nil && err != dnsmessage.ErrSectionDone { 257 return 0, err 258 } 259 260 resh := h 261 resh.Response = true 262 resh.Authoritative = true 263 if hadQ { 264 resh.RCode = dnsmessage.RCodeSuccess 265 } else { 266 resh.RCode = dnsmessage.RCodeNotImplemented 267 } 268 a.rbuf.Grow(514) 269 a.rbuf.WriteByte('X') // reserved header for beu16 length 270 a.rbuf.WriteByte('Y') // reserved header for beu16 length 271 builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) 272 a.builder = &builder 273 if hadQ { 274 a.q = q 275 a.builder.StartQuestions() 276 err := a.builder.Question(q) 277 if err != nil { 278 return 0, fmt.Errorf("Question: %w", err) 279 } 280 a.builder.StartAnswers() 281 switch q.Type { 282 case dnsmessage.TypeA: 283 if a.h.HandleA != nil { 284 resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) 285 } 286 case dnsmessage.TypeAAAA: 287 if a.h.HandleAAAA != nil { 288 resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) 289 } 290 case dnsmessage.TypeSRV: 291 if a.h.HandleSRV != nil { 292 resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) 293 } 294 } 295 } 296 tcpRes, err := builder.Finish() 297 if err != nil { 298 return 0, fmt.Errorf("Finish: %w", err) 299 } 300 301 n = len(tcpRes) - 2 302 tcpRes[0] = byte(n >> 8) 303 tcpRes[1] = byte(n) 304 a.rbuf.Write(tcpRes[2:]) 305 306 return len(packet), nil 307} 308 309type someaddr struct{} 310 311func (someaddr) Network() string { return "unused" } 312func (someaddr) String() string { return "unused-someaddr" } 313 314func mapRCode(err error) dnsmessage.RCode { 315 switch err { 316 case nil: 317 return dnsmessage.RCodeSuccess 318 case ErrNotExist: 319 return dnsmessage.RCodeNameError 320 case ErrRefused: 321 return dnsmessage.RCodeRefused 322 default: 323 return dnsmessage.RCodeServerFailure 324 } 325} 326