mlisting

Mailing list service

git clone git://git.lin.moe/go/mlisting.git

 1package sqlite
 2
 3import (
 4	"context"
 5	"database/sql"
 6	"embed"
 7	"fmt"
 8	"io"
 9)
10
11var (
12	//go:embed *.sql
13	SQLFiles   embed.FS
14	Migrations []func(context.Context, *sql.Tx) error
15)
16
17func init() {
18	Migrations = []func(context.Context, *sql.Tx) error{
19		fromFile("000_init.sql"),
20		fromFile("001_request_table.sql"),
21	}
22}
23
24func fromFile(fname string) func(context.Context, *sql.Tx) error {
25	sql_file, err := SQLFiles.Open(fname)
26	if err != nil {
27		panic(err)
28	}
29	defer sql_file.Close()
30
31	raw_sql, err := io.ReadAll(sql_file)
32	if err != nil {
33		panic(err)
34	}
35	_sql := string(raw_sql)
36
37	return func(ctx context.Context, tx *sql.Tx) error {
38		_, err := tx.ExecContext(ctx, _sql)
39		return err
40	}
41}
42
43func migrate(ctx context.Context, db *sql.DB) (err error) {
44	var (
45		version, maxVersion int64
46	)
47
48	db.QueryRowContext(ctx, "pragma user_version").Scan(&version)
49	maxVersion = int64(len(Migrations))
50
51	for v := version; v < maxVersion; v++ {
52		tx, err := db.BeginTx(ctx, nil)
53		if err != nil {
54			return err
55		}
56
57		if err = Migrations[v](ctx, tx); err != nil {
58			tx.Rollback()
59			return err
60		}
61		if _, err = tx.ExecContext(ctx, fmt.Sprintf("pragma user_version = %d", v+1)); err != nil {
62			tx.Rollback()
63			return err
64		}
65		if err = tx.Commit(); err != nil {
66			return err
67		}
68	}
69
70	return nil
71}