1/*2Maddy Mail Server - Composable all-in-one email server.3Copyright © 2019-2020 Max Mazurov <fox.cpp@disroot.org>, Maddy Mail Server contributors45This program is free software: you can redistribute it and/or modify6it under the terms of the GNU General Public License as published by7the Free Software Foundation, either version 3 of the License, or8(at your option) any later version.910This program is distributed in the hope that it will be useful,11but WITHOUT ANY WARRANTY; without even the implied warranty of12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the13GNU General Public License for more details.1415You should have received a copy of the GNU General Public License16along with this program. If not, see <https://www.gnu.org/licenses/>.17*/1819package table2021import (22 "context"23 "database/sql"24 "fmt"25 "strings"2627 "github.com/foxcpp/maddy/framework/config"28 "github.com/foxcpp/maddy/framework/module"29 _ "github.com/lib/pq"30)3132type SQL struct {33 modName string34 instName string3536 namedArgs bool3738 db *sql.DB39 lookup *sql.Stmt40 add *sql.Stmt41 list *sql.Stmt42 set *sql.Stmt43 del *sql.Stmt44}4546func NewSQL(modName, instName string, _, _ []string) (module.Module, error) {47 return &SQL{48 modName: modName,49 instName: instName,50 }, nil51}5253func (s *SQL) Name() string {54 return s.modName55}5657func (s *SQL) InstanceName() string {58 return s.instName59}6061func (s *SQL) Init(cfg *config.Map) error {62 var (63 driver string64 initQueries []string65 dsnParts []string66 lookupQuery string6768 addQuery string69 listQuery string70 removeQuery string71 setQuery string72 )73 cfg.StringList("init", false, false, nil, &initQueries)74 cfg.String("driver", false, true, "", &driver)75 cfg.StringList("dsn", false, true, nil, &dsnParts)76 cfg.Bool("named_args", false, false, &s.namedArgs)7778 cfg.String("lookup", false, true, "", &lookupQuery)7980 cfg.String("add", false, false, "", &addQuery)81 cfg.String("list", false, false, "", &listQuery)82 cfg.String("del", false, false, "", &removeQuery)83 cfg.String("set", false, false, "", &setQuery)84 if _, err := cfg.Process(); err != nil {85 return err86 }8788 if driver == "postgres" && s.namedArgs {89 return config.NodeErr(cfg.Block, "PostgreSQL driver does not support named_args")90 }9192 db, err := sql.Open(driver, strings.Join(dsnParts, " "))93 if err != nil {94 return config.NodeErr(cfg.Block, "failed to open db: %v", err)95 }96 s.db = db9798 for _, init := range initQueries {99 if _, err := db.Exec(init); err != nil {100 return config.NodeErr(cfg.Block, "init query failed: %v", err)101 }102 }103104 s.lookup, err = db.Prepare(lookupQuery)105 if err != nil {106 return config.NodeErr(cfg.Block, "failed to prepare lookup query: %v", err)107 }108 if addQuery != "" {109 s.add, err = db.Prepare(addQuery)110 if err != nil {111 return config.NodeErr(cfg.Block, "failed to prepare add query: %v", err)112 }113 }114 if listQuery != "" {115 s.list, err = db.Prepare(listQuery)116 if err != nil {117 return config.NodeErr(cfg.Block, "failed to prepare list query: %v", err)118 }119 }120 if setQuery != "" {121 s.set, err = db.Prepare(setQuery)122 if err != nil {123 return config.NodeErr(cfg.Block, "failed to prepare set query: %v", err)124 }125 }126 if removeQuery != "" {127 s.del, err = db.Prepare(removeQuery)128 if err != nil {129 return config.NodeErr(cfg.Block, "failed to prepare del query: %v", err)130 }131 }132133 return nil134}135136func (s *SQL) Close() error {137 s.lookup.Close()138 return s.db.Close()139}140141func (s *SQL) Lookup(ctx context.Context, val string) (string, bool, error) {142 var (143 repl string144 row *sql.Row145 )146 if s.namedArgs {147 row = s.lookup.QueryRowContext(ctx, sql.Named("key", val))148 } else {149 row = s.lookup.QueryRowContext(ctx, val)150 }151 if err := row.Scan(&repl); err != nil {152 if err == sql.ErrNoRows {153 return "", false, nil154 }155 return "", false, fmt.Errorf("%s: lookup %s: %w", s.modName, val, err)156 }157 return repl, true, nil158}159160func (s *SQL) LookupMulti(ctx context.Context, val string) ([]string, error) {161 var (162 repl []string163 rows *sql.Rows164 err error165 )166 if s.namedArgs {167 rows, err = s.lookup.QueryContext(ctx, sql.Named("key", val))168 } else {169 rows, err = s.lookup.QueryContext(ctx, val)170 }171 if err != nil {172 return nil, fmt.Errorf("%s; lookup %s: %w", s.modName, val, err)173 }174 for rows.Next() {175 var res string176 if err := rows.Scan(&res); err != nil {177 return nil, fmt.Errorf("%s; lookup %s: %w", s.modName, val, err)178 }179 repl = append(repl, res)180 }181 if err := rows.Err(); err != nil {182 return nil, fmt.Errorf("%s; lookup %s: %w", s.modName, val, err)183 }184 return repl, nil185}186187func (s *SQL) Keys() ([]string, error) {188 if s.list == nil {189 return nil, fmt.Errorf("%s: table is not mutable (no 'list' query)", s.modName)190 }191192 rows, err := s.list.Query()193 if err != nil {194 return nil, fmt.Errorf("%s: list: %w", s.modName, err)195 }196 defer rows.Close()197 var list []string198 for rows.Next() {199 var key string200 if err := rows.Scan(&key); err != nil {201 return nil, fmt.Errorf("%s: list: %w", s.modName, err)202 }203 list = append(list, key)204 }205 return list, nil206}207208func (s *SQL) RemoveKey(k string) error {209 if s.del == nil {210 return fmt.Errorf("%s: table is not mutable (no 'del' query)", s.modName)211 }212213 var err error214 if s.namedArgs {215 _, err = s.del.Exec(sql.Named("key", k))216 } else {217 _, err = s.del.Exec(k)218 }219 if err != nil {220 return fmt.Errorf("%s: del %s: %w", s.modName, k, err)221 }222 return nil223}224225func (s *SQL) SetKey(k, v string) error {226 if s.set == nil {227 return fmt.Errorf("%s: table is not mutable (no 'set' query)", s.modName)228 }229 if s.add == nil {230 return fmt.Errorf("%s: table is not mutable (no 'add' query)", s.modName)231 }232233 var args []interface{}234 if s.namedArgs {235 args = []interface{}{sql.Named("key", k), sql.Named("value", v)}236 } else {237 args = []interface{}{k, v}238 }239240 if _, err := s.add.Exec(args...); err != nil {241 if _, err := s.set.Exec(args...); err != nil {242 return fmt.Errorf("%s: add %s: %w", s.modName, k, err)243 }244 return nil245 }246 return nil247}248249func init() {250 module.Register("table.sql_query", NewSQL)251}