maddy

Fork https://github.com/foxcpp/maddy

git clone git://git.lin.moe/go/maddy.git

 1package proxy_protocol
 2
 3import (
 4	"crypto/tls"
 5	"net"
 6	"strings"
 7
 8	"github.com/c0va23/go-proxyprotocol"
 9	"github.com/foxcpp/maddy/framework/config"
10	tls2 "github.com/foxcpp/maddy/framework/config/tls"
11	"github.com/foxcpp/maddy/framework/log"
12)
13
14type ProxyProtocol struct {
15	trust     []net.IPNet
16	tlsConfig *tls.Config
17}
18
19func ProxyProtocolDirective(_ *config.Map, node config.Node) (interface{}, error) {
20	p := ProxyProtocol{}
21
22	childM := config.NewMap(nil, node)
23	var trustList []string
24
25	childM.StringList("trust", false, false, nil, &trustList)
26	childM.Custom("tls", true, false, nil, tls2.TLSDirective, &p.tlsConfig)
27
28	if _, err := childM.Process(); err != nil {
29		return nil, err
30	}
31
32	if len(node.Args) > 0 {
33		if trustList == nil {
34			trustList = make([]string, 0)
35		}
36		trustList = append(trustList, node.Args...)
37	}
38
39	for _, trust := range trustList {
40		if !strings.Contains(trust, "/") {
41			trust += "/32"
42		}
43		_, ipNet, err := net.ParseCIDR(trust)
44		if err != nil {
45			return nil, err
46		}
47		p.trust = append(p.trust, *ipNet)
48	}
49
50	return &p, nil
51}
52
53func NewListener(inner net.Listener, p *ProxyProtocol, logger log.Logger) net.Listener {
54	var listener net.Listener
55
56	sourceChecker := func(upstream net.Addr) (bool, error) {
57		if tcpAddr, ok := upstream.(*net.TCPAddr); ok {
58			if len(p.trust) == 0 {
59				return true, nil
60			}
61			for _, trusted := range p.trust {
62				if trusted.Contains(tcpAddr.IP) {
63					return true, nil
64				}
65			}
66		} else if _, ok := upstream.(*net.UnixAddr); ok {
67			// UNIX local socket connection, always trusted
68			return true, nil
69		}
70
71		logger.Printf("proxy_protocol: connection from untrusted source %s", upstream)
72		return false, nil
73	}
74
75	listener = proxyprotocol.NewDefaultListener(inner).
76		WithLogger(proxyprotocol.LoggerFunc(func(format string, v ...interface{}) {
77			logger.Debugf("proxy_protocol: "+format, v...)
78		})).
79		WithSourceChecker(sourceChecker)
80
81	if p.tlsConfig != nil {
82		listener = tls.NewListener(listener, p.tlsConfig)
83	}
84
85	return listener
86}