maddy

Fork https://github.com/foxcpp/maddy

git clone git://git.lin.moe/go/maddy.git

  1package ldap
  2
  3import (
  4	"context"
  5	"crypto/tls"
  6	"fmt"
  7	"net"
  8	"net/url"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/foxcpp/maddy/framework/config"
 14	tls2 "github.com/foxcpp/maddy/framework/config/tls"
 15	"github.com/foxcpp/maddy/framework/log"
 16	"github.com/foxcpp/maddy/framework/module"
 17	"github.com/go-ldap/ldap/v3"
 18)
 19
 20const modName = "auth.ldap"
 21
 22type Auth struct {
 23	instName string
 24
 25	urls           []string
 26	readBind       func(*ldap.Conn) error
 27	startls        bool
 28	tlsCfg         tls.Config
 29	dialer         *net.Dialer
 30	requestTimeout time.Duration
 31
 32	dnTemplate string
 33	// or
 34	baseDN         string
 35	filterTemplate string
 36
 37	conn     *ldap.Conn
 38	connLock sync.Mutex
 39
 40	log log.Logger
 41}
 42
 43func New(modName, instName string, _, inlineArgs []string) (module.Module, error) {
 44	return &Auth{
 45		instName: instName,
 46		log:      log.Logger{Name: modName},
 47		urls:     inlineArgs,
 48	}, nil
 49}
 50
 51func (a *Auth) Init(cfg *config.Map) error {
 52	a.dialer = &net.Dialer{}
 53
 54	cfg.Bool("debug", true, false, &a.log.Debug)
 55	cfg.Custom("tls_client", true, false, func() (interface{}, error) {
 56		return tls.Config{}, nil
 57	}, tls2.TLSClientBlock, &a.tlsCfg)
 58	cfg.Callback("urls", func(m *config.Map, node config.Node) error {
 59		a.urls = append(a.urls, node.Args...)
 60		return nil
 61	})
 62	cfg.Custom("bind", false, false, func() (interface{}, error) {
 63		return func(*ldap.Conn) error {
 64			return nil
 65		}, nil
 66	}, readBindDirective, &a.readBind)
 67	cfg.Bool("starttls", false, false, &a.startls)
 68	cfg.Duration("connect_timeout", false, false, time.Minute, &a.dialer.Timeout)
 69	cfg.Duration("request_timeout", false, false, time.Minute, &a.requestTimeout)
 70	cfg.String("dn_template", false, false, "", &a.dnTemplate)
 71	cfg.String("base_dn", false, false, "", &a.baseDN)
 72	cfg.String("filter", false, false, "", &a.filterTemplate)
 73	if _, err := cfg.Process(); err != nil {
 74		return err
 75	}
 76
 77	if a.dnTemplate == "" {
 78		if a.baseDN == "" {
 79			return fmt.Errorf("auth.ldap: base_dn not set")
 80		}
 81		if a.filterTemplate == "" {
 82			return fmt.Errorf("auth.ldap: filter not set")
 83		}
 84	} else {
 85		if a.baseDN != "" || a.filterTemplate != "" {
 86			return fmt.Errorf("auth.ldap: search directives set when dn_template is used")
 87		}
 88	}
 89
 90	if module.NoRun {
 91		return nil
 92	}
 93
 94	var err error
 95	a.conn, err = a.newConn()
 96	if err != nil {
 97		return fmt.Errorf("auth.ldap: %w", err)
 98	}
 99	return nil
100}
101
102func readBindDirective(c *config.Map, n config.Node) (interface{}, error) {
103	if len(n.Args) == 0 {
104		return nil, fmt.Errorf("auth.ldap: auth expects at least one argument")
105	}
106	switch n.Args[0] {
107	case "off":
108		return func(*ldap.Conn) error { return nil }, nil
109	case "unauth":
110		if len(n.Args) == 2 {
111			return func(c *ldap.Conn) error {
112				return c.UnauthenticatedBind(n.Args[1])
113			}, nil
114		}
115		return func(c *ldap.Conn) error {
116			return c.UnauthenticatedBind("")
117		}, nil
118	case "plain":
119		if len(n.Args) != 3 {
120			return nil, fmt.Errorf("auth.ldap: username and password expected for plaintext bind")
121		}
122		return func(c *ldap.Conn) error {
123			return c.Bind(n.Args[1], n.Args[2])
124		}, nil
125	case "external":
126		return (*ldap.Conn).ExternalBind, nil
127	}
128	return nil, fmt.Errorf("auth.ldap: unknown bind authentication: %v", n.Args[0])
129}
130
131func (a *Auth) Name() string {
132	return modName
133}
134
135func (a *Auth) InstanceName() string {
136	return a.instName
137}
138
139func (a *Auth) newConn() (*ldap.Conn, error) {
140	var (
141		conn   *ldap.Conn
142		tlsCfg *tls.Config
143	)
144	for _, u := range a.urls {
145		parsedURL, err := url.Parse(u)
146		if err != nil {
147			return nil, fmt.Errorf("auth.ldap: invalid server URL: %w", err)
148		}
149		hostname := parsedURL.Host
150		a.tlsCfg.ServerName = strings.Split(hostname, ":")[0]
151		tlsCfg = a.tlsCfg.Clone()
152
153		conn, err = ldap.DialURL(u, ldap.DialWithDialer(a.dialer), ldap.DialWithTLSConfig(tlsCfg))
154		if err != nil {
155			a.log.Error("cannot contact directory server", err, "url", u)
156			continue
157		}
158		break
159	}
160	if conn == nil {
161		return nil, fmt.Errorf("auth.ldap: all directory servers are unreachable")
162	}
163
164	if a.requestTimeout != 0 {
165		conn.SetTimeout(a.requestTimeout)
166	}
167
168	if a.startls {
169		if err := conn.StartTLS(tlsCfg); err != nil {
170			return nil, fmt.Errorf("auth.ldap: %w", err)
171		}
172	}
173
174	if err := a.readBind(conn); err != nil {
175		return nil, fmt.Errorf("auth.ldap: %w", err)
176	}
177
178	return conn, nil
179}
180
181func (a *Auth) getConn() (*ldap.Conn, error) {
182	a.connLock.Lock()
183	if a.conn == nil {
184		conn, err := a.newConn()
185		if err != nil {
186			a.connLock.Unlock()
187			return nil, err
188		}
189		a.conn = conn
190	}
191	if a.conn.IsClosing() {
192		a.conn.Close()
193		conn, err := a.newConn()
194		if err != nil {
195			a.connLock.Unlock()
196			return nil, err
197		}
198		a.conn = conn
199	}
200	return a.conn, nil
201}
202
203func (a *Auth) returnConn(conn *ldap.Conn) {
204	defer a.connLock.Unlock()
205	if err := a.readBind(conn); err != nil {
206		a.log.Error("failed to rebind for reading", err)
207		conn.Close()
208		a.conn = nil
209	}
210	if a.conn != conn {
211		a.conn.Close()
212	}
213	a.conn = conn
214}
215
216func (a *Auth) Lookup(_ context.Context, username string) (string, bool, error) {
217	conn, err := a.getConn()
218	if err != nil {
219		return "", false, err
220	}
221	defer a.returnConn(conn)
222
223	var userDN string
224	if a.dnTemplate != "" {
225		return "", false, fmt.Errorf("auth.ldap: lookups require search config but dn_template is used")
226	} else {
227		req := ldap.NewSearchRequest(
228			a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
229			2, 0, false,
230			strings.ReplaceAll(a.filterTemplate, "{username}", username),
231			[]string{"dn"}, nil)
232		res, err := conn.Search(req)
233		if err != nil {
234			return "", false, fmt.Errorf("auth.ldap: search: %w", err)
235		}
236		if len(res.Entries) > 1 {
237			return "", false, fmt.Errorf("auth.ldap: too manu entries returned (%d)", len(res.Entries))
238		}
239		if len(res.Entries) == 0 {
240			return "", false, nil
241		}
242		userDN = res.Entries[0].DN
243	}
244
245	return userDN, true, nil
246}
247
248func (a *Auth) AuthPlain(username, password string) error {
249	conn, err := a.getConn()
250	if err != nil {
251		return err
252	}
253	defer a.returnConn(conn)
254
255	var userDN string
256	if a.dnTemplate != "" {
257		userDN = strings.ReplaceAll(a.dnTemplate, "{username}", username)
258	} else {
259		req := ldap.NewSearchRequest(
260			a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
261			2, 0, false,
262			strings.ReplaceAll(a.filterTemplate, "{username}", username),
263			[]string{"dn"}, nil)
264		res, err := conn.Search(req)
265		if err != nil {
266			return fmt.Errorf("auth.ldap: search: %w", err)
267		}
268		if len(res.Entries) > 1 {
269			return fmt.Errorf("auth.ldap: too manu entries returned (%d)", len(res.Entries))
270		}
271		if len(res.Entries) == 0 {
272			return module.ErrUnknownCredentials
273		}
274		userDN = res.Entries[0].DN
275	}
276
277	if err := conn.Bind(userDN, password); err != nil {
278		return module.ErrUnknownCredentials
279	}
280
281	return nil
282}
283
284func init() {
285	var _ module.PlainAuth = &Auth{}
286	var _ module.Table = &Auth{}
287	module.Register(modName, New)
288	module.Register("table.ldap", New)
289}