maddy

Fork https://github.com/foxcpp/maddy

git clone git://git.lin.moe/go/maddy.git

  1package dns
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"reflect"
  8	"strconv"
  9	"testing"
 10	"time"
 11
 12	"github.com/foxcpp/maddy/framework/log"
 13	"github.com/miekg/dns"
 14)
 15
 16type TestSrvAction int
 17
 18const (
 19	TestSrvTimeout TestSrvAction = iota
 20	TestSrvServfail
 21	TestSrvNoAddr
 22	TestSrvOk
 23)
 24
 25func (a TestSrvAction) String() string {
 26	switch a {
 27	case TestSrvTimeout:
 28		return "SrvTimeout"
 29	case TestSrvServfail:
 30		return "SrvServfail"
 31	case TestSrvNoAddr:
 32		return "SrvNoAddr"
 33	case TestSrvOk:
 34		return "SrvOk"
 35	default:
 36		panic("wtf action")
 37	}
 38}
 39
 40type IPAddrTestServer struct {
 41	udpServ    dns.Server
 42	aAction    TestSrvAction
 43	aAD        bool
 44	aaaaAction TestSrvAction
 45	aaaaAD     bool
 46}
 47
 48func (s *IPAddrTestServer) Run() {
 49	pconn, err := net.ListenPacket("udp4", "127.0.0.1:0")
 50	if err != nil {
 51		panic(err)
 52	}
 53	s.udpServ.PacketConn = pconn
 54	s.udpServ.Handler = s
 55	go s.udpServ.ActivateAndServe() //nolint:errcheck
 56}
 57
 58func (s *IPAddrTestServer) Close() {
 59	s.udpServ.PacketConn.Close()
 60}
 61
 62func (s *IPAddrTestServer) Addr() *net.UDPAddr {
 63	return s.udpServ.PacketConn.LocalAddr().(*net.UDPAddr)
 64}
 65
 66func (s *IPAddrTestServer) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
 67	q := m.Question[0]
 68
 69	var (
 70		act TestSrvAction
 71		ad  bool
 72	)
 73	switch q.Qtype {
 74	case dns.TypeA:
 75		act = s.aAction
 76		ad = s.aAD
 77	case dns.TypeAAAA:
 78		act = s.aaaaAction
 79		ad = s.aaaaAD
 80	default:
 81		panic("wtf qtype")
 82	}
 83
 84	reply := new(dns.Msg)
 85	reply.SetReply(m)
 86	reply.RecursionAvailable = true
 87	reply.AuthenticatedData = ad
 88
 89	switch act {
 90	case TestSrvTimeout:
 91		return // no nobody heard from him since...
 92	case TestSrvServfail:
 93		reply.Rcode = dns.RcodeServerFailure
 94	case TestSrvNoAddr:
 95	case TestSrvOk:
 96		switch q.Qtype {
 97		case dns.TypeA:
 98			reply.Answer = append(reply.Answer, &dns.A{
 99				Hdr: dns.RR_Header{
100					Name:   q.Name,
101					Rrtype: dns.TypeA,
102					Class:  dns.ClassINET,
103					Ttl:    9999,
104				},
105				A: net.ParseIP("127.0.0.1"),
106			})
107		case dns.TypeAAAA:
108			reply.Answer = append(reply.Answer, &dns.AAAA{
109				Hdr: dns.RR_Header{
110					Name:   q.Name,
111					Rrtype: dns.TypeAAAA,
112					Class:  dns.ClassINET,
113					Ttl:    9999,
114				},
115				AAAA: net.ParseIP("::1"),
116			})
117		}
118	}
119
120	if err := w.WriteMsg(reply); err != nil {
121		panic(err)
122	}
123}
124
125func TestExtResolver_AuthLookupIPAddr(t *testing.T) {
126	// AuthLookupIPAddr has a rather convoluted logic for combined A/AAAA
127	// lookups that return the best-effort result and also has some nuanced in
128	// AD flag handling for use in DANE algorithms.
129
130	// Silence log messages about disregarded I/O errors.
131	log.DefaultLogger.Out = nil
132
133	test := func(aAct, aaaaAct TestSrvAction, aAD, aaaaAD, ad bool, addrs []net.IP, err bool) {
134		t.Helper()
135		t.Run(fmt.Sprintln(aAct, aaaaAct, aAD, aaaaAD), func(t *testing.T) {
136			t.Helper()
137
138			s := IPAddrTestServer{}
139			s.aAction = aAct
140			s.aaaaAction = aaaaAct
141			s.aAD = aAD
142			s.aaaaAD = aaaaAD
143			s.Run()
144			defer s.Close()
145			res := ExtResolver{
146				cl: new(dns.Client),
147				Cfg: &dns.ClientConfig{
148					Servers: []string{"127.0.0.1"},
149					Port:    strconv.Itoa(s.Addr().Port),
150					Timeout: 1,
151				},
152			}
153			res.cl.Dialer = &net.Dialer{
154				Timeout: 500 * time.Millisecond,
155			}
156
157			ctx, cancel := context.WithCancel(context.Background())
158			defer cancel()
159
160			actualAd, actualAddrs, actualErr := res.AuthLookupIPAddr(ctx, "maddy.test")
161			if (actualErr != nil) != err {
162				t.Fatal("actualErr:", actualErr, "expectedErr:", err)
163			}
164			if actualAd != ad {
165				t.Error("actualAd:", actualAd, "expectedAd:", ad)
166			}
167			ipAddrs := make([]net.IPAddr, 0, len(addrs))
168			if len(addrs) == 0 {
169				ipAddrs = nil // lookup returns nil addrs for error cases
170			}
171			for _, a := range addrs {
172				ipAddrs = append(ipAddrs, net.IPAddr{IP: a, Zone: ""})
173			}
174			if !reflect.DeepEqual(actualAddrs, ipAddrs) {
175				t.Logf("actualAddrs: %#+v", actualAddrs)
176				t.Logf("addrs: %#+v", ipAddrs)
177				t.Fail()
178			}
179		})
180	}
181
182	test(TestSrvOk, TestSrvOk, true, true, true, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)
183	test(TestSrvOk, TestSrvOk, true, false, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)
184	test(TestSrvOk, TestSrvOk, false, true, false, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)
185	test(TestSrvOk, TestSrvOk, false, false, false, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)
186	test(TestSrvOk, TestSrvTimeout, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)
187	test(TestSrvOk, TestSrvServfail, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)
188	test(TestSrvOk, TestSrvNoAddr, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)
189	test(TestSrvNoAddr, TestSrvOk, true, true, true, []net.IP{net.ParseIP("::1")}, false)
190	test(TestSrvServfail, TestSrvServfail, true, true, false, nil, true)
191
192	// actualAd is false, we don't want to risk reporting positive AD result if
193	// something is wrong with IPv4 lookup.
194	test(TestSrvTimeout, TestSrvOk, true, true, false, []net.IP{net.ParseIP("::1")}, false)
195	test(TestSrvServfail, TestSrvOk, true, true, false, []net.IP{net.ParseIP("::1")}, false)
196}