maddy

Fork https://github.com/foxcpp/maddy

git clone git://git.lin.moe/go/maddy.git

  1/*
  2Maddy Mail Server - Composable all-in-one email server.
  3Copyright © 2019-2020 Max Mazurov <fox.cpp@disroot.org>, Maddy Mail Server contributors
  4
  5This program is free software: you can redistribute it and/or modify
  6it under the terms of the GNU General Public License as published by
  7the Free Software Foundation, either version 3 of the License, or
  8(at your option) any later version.
  9
 10This program is distributed in the hope that it will be useful,
 11but WITHOUT ANY WARRANTY; without even the implied warranty of
 12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13GNU General Public License for more details.
 14
 15You should have received a copy of the GNU General Public License
 16along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17*/
 18
 19// Package limit provides a module object that can be used to restrict the
 20// concurrency and rate of the messages flow globally or on per-source,
 21// per-destination basis.
 22//
 23// Note, all domain inputs are interpreted with the assumption they are already
 24// normalized.
 25//
 26// Low-level components are available in the limiters/ subpackage.
 27package limits
 28
 29import (
 30	"context"
 31	"net"
 32	"strconv"
 33	"time"
 34
 35	"github.com/foxcpp/maddy/framework/config"
 36	"github.com/foxcpp/maddy/framework/module"
 37	"github.com/foxcpp/maddy/internal/limits/limiters"
 38)
 39
 40type Group struct {
 41	instName string
 42
 43	global limiters.MultiLimit
 44	ip     *limiters.BucketSet // BucketSet of MultiLimit
 45	source *limiters.BucketSet // BucketSet of MultiLimit
 46	dest   *limiters.BucketSet // BucketSet of MultiLimit
 47}
 48
 49func New(_, instName string, _, _ []string) (module.Module, error) {
 50	return &Group{
 51		instName: instName,
 52	}, nil
 53}
 54
 55func (g *Group) Init(cfg *config.Map) error {
 56	var (
 57		globalL []limiters.L
 58		ipL     []func() limiters.L
 59		sourceL []func() limiters.L
 60		destL   []func() limiters.L
 61	)
 62
 63	for _, child := range cfg.Block.Children {
 64		if len(child.Args) < 1 {
 65			return config.NodeErr(child, "at least two arguments are required")
 66		}
 67
 68		var (
 69			ctor func() limiters.L
 70			err  error
 71		)
 72		switch kind := child.Args[0]; kind {
 73		case "rate":
 74			ctor, err = rateCtor(child, child.Args[1:])
 75		case "concurrency":
 76			ctor, err = concurrencyCtor(child, child.Args[1:])
 77		default:
 78			return config.NodeErr(child, "unknown limit kind: %v", kind)
 79		}
 80		if err != nil {
 81			return err
 82		}
 83
 84		switch scope := child.Name; scope {
 85		case "all":
 86			globalL = append(globalL, ctor())
 87		case "ip":
 88			ipL = append(ipL, ctor)
 89		case "source":
 90			sourceL = append(sourceL, ctor)
 91		case "destination":
 92			destL = append(destL, ctor)
 93		default:
 94			return config.NodeErr(child, "unknown limit scope: %v", scope)
 95		}
 96	}
 97
 98	// 20010 is slightly higher than the default max. recipients count in
 99	// endpoint/smtp.
100	g.global = limiters.MultiLimit{Wrapped: globalL}
101	if len(ipL) != 0 {
102		g.ip = limiters.NewBucketSet(func() limiters.L {
103			l := make([]limiters.L, 0, len(ipL))
104			for _, ctor := range ipL {
105				l = append(l, ctor())
106			}
107			return &limiters.MultiLimit{Wrapped: l}
108		}, 1*time.Minute, 20010)
109	}
110	if len(sourceL) != 0 {
111		g.source = limiters.NewBucketSet(func() limiters.L {
112			l := make([]limiters.L, 0, len(sourceL))
113			for _, ctor := range sourceL {
114				l = append(l, ctor())
115			}
116			return &limiters.MultiLimit{Wrapped: l}
117		}, 1*time.Minute, 20010)
118	}
119	if len(destL) != 0 {
120		g.dest = limiters.NewBucketSet(func() limiters.L {
121			l := make([]limiters.L, 0, len(sourceL))
122			for _, ctor := range sourceL {
123				l = append(l, ctor())
124			}
125			return &limiters.MultiLimit{Wrapped: l}
126		}, 1*time.Minute, 20010)
127	}
128
129	return nil
130}
131
132func rateCtor(node config.Node, args []string) (func() limiters.L, error) {
133	period := 1 * time.Second
134	burst := 0
135
136	switch len(args) {
137	case 2:
138		var err error
139		period, err = time.ParseDuration(args[1])
140		if err != nil {
141			return nil, config.NodeErr(node, "%v", err)
142		}
143		fallthrough
144	case 1:
145		var err error
146		burst, err = strconv.Atoi(args[0])
147		if err != nil {
148			return nil, config.NodeErr(node, "%v", err)
149		}
150	case 0:
151		return nil, config.NodeErr(node, "at least burst size is needed")
152	default:
153		return nil, config.NodeErr(node, "too many arguments")
154	}
155
156	return func() limiters.L {
157		return limiters.NewRate(burst, period)
158	}, nil
159}
160
161func concurrencyCtor(node config.Node, args []string) (func() limiters.L, error) {
162	if len(args) != 1 {
163		return nil, config.NodeErr(node, "max concurrency value is needed")
164	}
165	max, err := strconv.Atoi(args[0])
166	if err != nil {
167		return nil, config.NodeErr(node, "%v", err)
168	}
169	return func() limiters.L {
170		return limiters.NewSemaphore(max)
171	}, nil
172}
173
174func (g *Group) TakeMsg(ctx context.Context, addr net.IP, sourceDomain string) error {
175	ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
176	defer cancel()
177
178	if err := g.global.TakeContext(ctx); err != nil {
179		return err
180	}
181
182	if g.ip != nil {
183		if err := g.ip.TakeContext(ctx, addr.String()); err != nil {
184			g.global.Release()
185			return err
186		}
187	}
188	if g.source != nil {
189		if err := g.source.TakeContext(ctx, sourceDomain); err != nil {
190			g.global.Release()
191			g.ip.Release(addr.String())
192			return err
193		}
194	}
195	return nil
196}
197
198func (g *Group) TakeDest(ctx context.Context, domain string) error {
199	if g.dest == nil {
200		return nil
201	}
202	ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
203	defer cancel()
204	return g.dest.TakeContext(ctx, domain)
205}
206
207func (g *Group) ReleaseMsg(addr net.IP, sourceDomain string) {
208	g.global.Release()
209	if g.ip != nil {
210		g.ip.Release(addr.String())
211	}
212	if g.source != nil {
213		g.source.Release(sourceDomain)
214	}
215}
216
217func (g *Group) ReleaseDest(domain string) {
218	if g.dest == nil {
219		return
220	}
221	g.dest.Release(domain)
222}
223
224func (g *Group) Name() string {
225	return "limits"
226}
227
228func (g *Group) InstanceName() string {
229	return g.instName
230}
231
232func init() {
233	module.Register("limits", New)
234}