1package migrate23import (4 "context"5 "database/sql"6 "errors"7 "fmt"89 "github.com/charmbracelet/log/v2"10 "github.com/charmbracelet/soft-serve/pkg/db"11)1213// MigrateFunc is a function that executes a migration.14type MigrateFunc func(ctx context.Context, tx *db.Tx) error // nolint:revive1516// Migration is a struct that contains the name of the migration and the17// function to execute it.18type Migration struct {19 Version int6420 Name string21 Migrate MigrateFunc22 Rollback MigrateFunc23}2425// Migrations is a database model to store migrations.26type Migrations struct {27 ID int64 `db:"id"`28 Name string `db:"name"`29 Version int64 `db:"version"`30}3132func (Migrations) schema(driverName string) string {33 switch driverName {34 case "sqlite3", "sqlite":35 return `CREATE TABLE IF NOT EXISTS migrations (36 id INTEGER PRIMARY KEY AUTOINCREMENT,37 name TEXT NOT NULL,38 version INTEGER NOT NULL UNIQUE39 );40 `41 case "postgres":42 return `CREATE TABLE IF NOT EXISTS migrations (43 id SERIAL PRIMARY KEY,44 name TEXT NOT NULL,45 version INTEGER NOT NULL UNIQUE46 );47 `48 case "mysql":49 return `CREATE TABLE IF NOT EXISTS migrations (50 id INT NOT NULL AUTO_INCREMENT,51 name TEXT NOT NULL,52 version INT NOT NULL,53 UNIQUE (version),54 PRIMARY KEY (id)55 );56 `57 default:58 panic("unknown driver")59 }60}6162// Migrate runs the migrations.63func Migrate(ctx context.Context, dbx *db.DB) error {64 logger := log.FromContext(ctx).WithPrefix("migrate")65 return dbx.TransactionContext(ctx, func(tx *db.Tx) error {66 if !hasTable(tx, "migrations") {67 if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {68 return err69 }70 }7172 var migrs Migrations73 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {74 if !errors.Is(err, sql.ErrNoRows) {75 return err76 }77 }7879 for _, m := range migrations {80 if m.Version <= migrs.Version {81 continue82 }8384 logger.Infof("running migration %d. %s", m.Version, m.Name)85 if err := m.Migrate(ctx, tx); err != nil {86 return err87 }8889 if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {90 return err91 }92 }9394 return nil95 })96}9798// Rollback rolls back a migration.99func Rollback(ctx context.Context, dbx *db.DB) error {100 logger := log.FromContext(ctx).WithPrefix("migrate")101 return dbx.TransactionContext(ctx, func(tx *db.Tx) error {102 var migrs Migrations103 if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {104 if !errors.Is(err, sql.ErrNoRows) {105 return fmt.Errorf("there are no migrations to rollback: %w", err)106 }107 }108109 if migrs.Version == 0 || len(migrations) < int(migrs.Version) {110 return fmt.Errorf("there are no migrations to rollback")111 }112113 m := migrations[migrs.Version-1]114 logger.Infof("rolling back migration %d. %s", m.Version, m.Name)115 if err := m.Rollback(ctx, tx); err != nil {116 return err117 }118119 if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {120 return err121 }122123 return nil124 })125}126127func hasTable(tx *db.Tx, tableName string) bool {128 var query string129 switch tx.DriverName() {130 case "sqlite3", "sqlite":131 query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"132 case "postgres":133 fallthrough134 case "mysql":135 query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"136 }137138 query = tx.Rebind(query)139 var name string140 err := tx.Get(&name, query, tableName)141 return err == nil142}