1package dns23import (4 "context"5 "fmt"6 "net"7 "reflect"8 "strconv"9 "testing"10 "time"1112 "github.com/foxcpp/maddy/framework/log"13 "github.com/miekg/dns"14)1516type TestSrvAction int1718const (19 TestSrvTimeout TestSrvAction = iota20 TestSrvServfail21 TestSrvNoAddr22 TestSrvOk23)2425func (a TestSrvAction) String() string {26 switch a {27 case TestSrvTimeout:28 return "SrvTimeout"29 case TestSrvServfail:30 return "SrvServfail"31 case TestSrvNoAddr:32 return "SrvNoAddr"33 case TestSrvOk:34 return "SrvOk"35 default:36 panic("wtf action")37 }38}3940type IPAddrTestServer struct {41 udpServ dns.Server42 aAction TestSrvAction43 aAD bool44 aaaaAction TestSrvAction45 aaaaAD bool46}4748func (s *IPAddrTestServer) Run() {49 pconn, err := net.ListenPacket("udp4", "127.0.0.1:0")50 if err != nil {51 panic(err)52 }53 s.udpServ.PacketConn = pconn54 s.udpServ.Handler = s55 go s.udpServ.ActivateAndServe() //nolint:errcheck56}5758func (s *IPAddrTestServer) Close() {59 s.udpServ.PacketConn.Close()60}6162func (s *IPAddrTestServer) Addr() *net.UDPAddr {63 return s.udpServ.PacketConn.LocalAddr().(*net.UDPAddr)64}6566func (s *IPAddrTestServer) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {67 q := m.Question[0]6869 var (70 act TestSrvAction71 ad bool72 )73 switch q.Qtype {74 case dns.TypeA:75 act = s.aAction76 ad = s.aAD77 case dns.TypeAAAA:78 act = s.aaaaAction79 ad = s.aaaaAD80 default:81 panic("wtf qtype")82 }8384 reply := new(dns.Msg)85 reply.SetReply(m)86 reply.RecursionAvailable = true87 reply.AuthenticatedData = ad8889 switch act {90 case TestSrvTimeout:91 return // no nobody heard from him since...92 case TestSrvServfail:93 reply.Rcode = dns.RcodeServerFailure94 case TestSrvNoAddr:95 case TestSrvOk:96 switch q.Qtype {97 case dns.TypeA:98 reply.Answer = append(reply.Answer, &dns.A{99 Hdr: dns.RR_Header{100 Name: q.Name,101 Rrtype: dns.TypeA,102 Class: dns.ClassINET,103 Ttl: 9999,104 },105 A: net.ParseIP("127.0.0.1"),106 })107 case dns.TypeAAAA:108 reply.Answer = append(reply.Answer, &dns.AAAA{109 Hdr: dns.RR_Header{110 Name: q.Name,111 Rrtype: dns.TypeAAAA,112 Class: dns.ClassINET,113 Ttl: 9999,114 },115 AAAA: net.ParseIP("::1"),116 })117 }118 }119120 if err := w.WriteMsg(reply); err != nil {121 panic(err)122 }123}124125func TestExtResolver_AuthLookupIPAddr(t *testing.T) {126 // AuthLookupIPAddr has a rather convoluted logic for combined A/AAAA127 // lookups that return the best-effort result and also has some nuanced in128 // AD flag handling for use in DANE algorithms.129130 // Silence log messages about disregarded I/O errors.131 log.DefaultLogger.Out = nil132133 test := func(aAct, aaaaAct TestSrvAction, aAD, aaaaAD, ad bool, addrs []net.IP, err bool) {134 t.Helper()135 t.Run(fmt.Sprintln(aAct, aaaaAct, aAD, aaaaAD), func(t *testing.T) {136 t.Helper()137138 s := IPAddrTestServer{}139 s.aAction = aAct140 s.aaaaAction = aaaaAct141 s.aAD = aAD142 s.aaaaAD = aaaaAD143 s.Run()144 defer s.Close()145 res := ExtResolver{146 cl: new(dns.Client),147 Cfg: &dns.ClientConfig{148 Servers: []string{"127.0.0.1"},149 Port: strconv.Itoa(s.Addr().Port),150 Timeout: 1,151 },152 }153 res.cl.Dialer = &net.Dialer{154 Timeout: 500 * time.Millisecond,155 }156157 ctx, cancel := context.WithCancel(context.Background())158 defer cancel()159160 actualAd, actualAddrs, actualErr := res.AuthLookupIPAddr(ctx, "maddy.test")161 if (actualErr != nil) != err {162 t.Fatal("actualErr:", actualErr, "expectedErr:", err)163 }164 if actualAd != ad {165 t.Error("actualAd:", actualAd, "expectedAd:", ad)166 }167 ipAddrs := make([]net.IPAddr, 0, len(addrs))168 if len(addrs) == 0 {169 ipAddrs = nil // lookup returns nil addrs for error cases170 }171 for _, a := range addrs {172 ipAddrs = append(ipAddrs, net.IPAddr{IP: a, Zone: ""})173 }174 if !reflect.DeepEqual(actualAddrs, ipAddrs) {175 t.Logf("actualAddrs: %#+v", actualAddrs)176 t.Logf("addrs: %#+v", ipAddrs)177 t.Fail()178 }179 })180 }181182 test(TestSrvOk, TestSrvOk, true, true, true, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)183 test(TestSrvOk, TestSrvOk, true, false, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)184 test(TestSrvOk, TestSrvOk, false, true, false, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)185 test(TestSrvOk, TestSrvOk, false, false, false, []net.IP{net.ParseIP("::1"), net.ParseIP("127.0.0.1").To4()}, false)186 test(TestSrvOk, TestSrvTimeout, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)187 test(TestSrvOk, TestSrvServfail, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)188 test(TestSrvOk, TestSrvNoAddr, true, true, true, []net.IP{net.ParseIP("127.0.0.1").To4()}, false)189 test(TestSrvNoAddr, TestSrvOk, true, true, true, []net.IP{net.ParseIP("::1")}, false)190 test(TestSrvServfail, TestSrvServfail, true, true, false, nil, true)191192 // actualAd is false, we don't want to risk reporting positive AD result if193 // something is wrong with IPv4 lookup.194 test(TestSrvTimeout, TestSrvOk, true, true, false, []net.IP{net.ParseIP("::1")}, false)195 test(TestSrvServfail, TestSrvOk, true, true, false, []net.IP{net.ParseIP("::1")}, false)196}