mlisting

Mailing list service

git clone git://git.lin.moe/go/mlisting.git

 1package sqlite
 2
 3import (
 4	"context"
 5	"database/sql"
 6	"embed"
 7	"fmt"
 8	"io"
 9)
10
11var (
12	//go:embed *.sql
13	SQLFiles   embed.FS
14	Migrations []func(context.Context, *sql.Tx) error
15)
16
17func init() {
18	Migrations = []func(context.Context, *sql.Tx) error{
19		fromFile("000_init.sql"),
20		fromFile("001_request_table.sql"),
21		migrateAddMessageInReplyTo,
22	}
23}
24
25func fromFile(fname string) func(context.Context, *sql.Tx) error {
26	sql_file, err := SQLFiles.Open(fname)
27	if err != nil {
28		panic(err)
29	}
30	defer sql_file.Close()
31
32	raw_sql, err := io.ReadAll(sql_file)
33	if err != nil {
34		panic(err)
35	}
36	_sql := string(raw_sql)
37
38	return func(ctx context.Context, tx *sql.Tx) error {
39		_, err := tx.ExecContext(ctx, _sql)
40		return err
41	}
42}
43
44func migrate(ctx context.Context, db *sql.DB) (err error) {
45	var (
46		version, maxVersion int64
47	)
48
49	db.QueryRowContext(ctx, "pragma user_version").Scan(&version)
50	maxVersion = int64(len(Migrations))
51
52	for v := version; v < maxVersion; v++ {
53		tx, err := db.BeginTx(ctx, nil)
54		if err != nil {
55			return err
56		}
57
58		if err = Migrations[v](ctx, tx); err != nil {
59			tx.Rollback()
60			return err
61		}
62		if _, err = tx.ExecContext(ctx, fmt.Sprintf("pragma user_version = %d", v+1)); err != nil {
63			tx.Rollback()
64			return err
65		}
66		if err = tx.Commit(); err != nil {
67			return err
68		}
69	}
70
71	return nil
72}