maddy

Fork https://github.com/foxcpp/maddy

git clone git://git.lin.moe/go/maddy.git

  1/*
  2Maddy Mail Server - Composable all-in-one email server.
  3Copyright © 2019-2020 Max Mazurov <fox.cpp@disroot.org>, Maddy Mail Server contributors
  4
  5This program is free software: you can redistribute it and/or modify
  6it under the terms of the GNU General Public License as published by
  7the Free Software Foundation, either version 3 of the License, or
  8(at your option) any later version.
  9
 10This program is distributed in the hope that it will be useful,
 11but WITHOUT ANY WARRANTY; without even the implied warranty of
 12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13GNU General Public License for more details.
 14
 15You should have received a copy of the GNU General Public License
 16along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17*/
 18
 19package tls
 20
 21import (
 22	"crypto/tls"
 23	"errors"
 24	"fmt"
 25	"path/filepath"
 26	"sync"
 27	"time"
 28
 29	"github.com/foxcpp/maddy/framework/config"
 30	"github.com/foxcpp/maddy/framework/hooks"
 31	"github.com/foxcpp/maddy/framework/log"
 32	"github.com/foxcpp/maddy/framework/module"
 33)
 34
 35type FileLoader struct {
 36	instName   string
 37	inlineArgs []string
 38	certPaths  []string
 39	keyPaths   []string
 40	log        log.Logger
 41
 42	certs     []tls.Certificate
 43	certsLock sync.RWMutex
 44
 45	reloadTick *time.Ticker
 46	stopTick   chan struct{}
 47}
 48
 49func NewFileLoader(_, instName string, _, inlineArgs []string) (module.Module, error) {
 50	return &FileLoader{
 51		instName:   instName,
 52		inlineArgs: inlineArgs,
 53		log:        log.Logger{Name: "tls.loader.file", Debug: log.DefaultLogger.Debug},
 54		stopTick:   make(chan struct{}),
 55	}, nil
 56}
 57
 58func (f *FileLoader) Init(cfg *config.Map) error {
 59	cfg.StringList("certs", false, false, nil, &f.certPaths)
 60	cfg.StringList("keys", false, false, nil, &f.keyPaths)
 61	if _, err := cfg.Process(); err != nil {
 62		return err
 63	}
 64
 65	if len(f.certPaths) != len(f.keyPaths) {
 66		return errors.New("tls.loader.file: mismatch in certs and keys count")
 67	}
 68
 69	if len(f.inlineArgs)%2 != 0 {
 70		return errors.New("tls.loader.file: odd amount of arguments")
 71	}
 72	for i := 0; i < len(f.inlineArgs); i += 2 {
 73		f.certPaths = append(f.certPaths, f.inlineArgs[i])
 74		f.keyPaths = append(f.keyPaths, f.inlineArgs[i+1])
 75	}
 76
 77	for _, certPath := range f.certPaths {
 78		if !filepath.IsAbs(certPath) {
 79			return fmt.Errorf("tls.loader.file: only absolute paths allowed in certificate paths: sorry :(")
 80		}
 81	}
 82
 83	if err := f.loadCerts(); err != nil {
 84		return err
 85	}
 86
 87	hooks.AddHook(hooks.EventReload, func() {
 88		f.log.Println("reloading certificates")
 89		if err := f.loadCerts(); err != nil {
 90			f.log.Error("reload failed", err)
 91		}
 92	})
 93
 94	f.reloadTick = time.NewTicker(time.Minute)
 95	go f.reloadTicker()
 96	return nil
 97}
 98
 99func (f *FileLoader) Close() error {
100	f.reloadTick.Stop()
101	f.stopTick <- struct{}{}
102	return nil
103}
104
105func (f *FileLoader) Name() string {
106	return "tls.loader.file"
107}
108
109func (f *FileLoader) InstanceName() string {
110	return f.instName
111}
112
113func (f *FileLoader) reloadTicker() {
114	for {
115		select {
116		case <-f.reloadTick.C:
117			f.log.Debugln("reloading certs")
118			if err := f.loadCerts(); err != nil {
119				f.log.Error("reload failed", err)
120			}
121		case <-f.stopTick:
122			return
123		}
124	}
125}
126
127func (f *FileLoader) loadCerts() error {
128	if len(f.certPaths) != len(f.keyPaths) {
129		return errors.New("mismatch in certs and keys count")
130	}
131
132	if len(f.certPaths) == 0 {
133		return errors.New("tls.loader.file: at least one certificate required")
134	}
135
136	certs := make([]tls.Certificate, 0, len(f.certPaths))
137
138	for i := range f.certPaths {
139		certPath := f.certPaths[i]
140		keyPath := f.keyPaths[i]
141
142		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
143		if err != nil {
144			return fmt.Errorf("failed to load %s and %s: %v", certPath, keyPath, err)
145		}
146		certs = append(certs, cert)
147	}
148
149	f.certsLock.Lock()
150	defer f.certsLock.Unlock()
151	f.certs = certs
152
153	return nil
154}
155
156func (f *FileLoader) ConfigureTLS(c *tls.Config) error {
157	// Loader function replaces only the whole slice.
158	f.certsLock.RLock()
159	defer f.certsLock.RUnlock()
160
161	c.Certificates = f.certs
162	return nil
163}
164
165func init() {
166	var _ module.TLSLoader = &FileLoader{}
167	module.Register("tls.loader.file", NewFileLoader)
168}