package storage import ( "database/sql" "embed" "fmt" "sort" "strconv" "strings" ) //go:embed migrations/*.sql var migrationFS embed.FS // migration represents a single database migration. type migration struct { Version int Name string SQL string } // Migrate applies any pending migrations to the database. func Migrate(db *sql.DB) error { // Ensure the schema_version table exists. if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER NOT NULL)`); err != nil { return fmt.Errorf("creating schema_version table: %w", err) } current, err := currentVersion(db) if err != nil { return fmt.Errorf("reading schema version: %w", err) } migrations, err := loadMigrations() if err != nil { return fmt.Errorf("loading migrations: %w", err) } for _, m := range migrations { if m.Version <= current { continue } tx, err := db.Begin() if err != nil { return fmt.Errorf("begin migration %d: %w", m.Version, err) } if _, err := tx.Exec(m.SQL); err != nil { tx.Rollback() return fmt.Errorf("applying migration %d (%s): %w", m.Version, m.Name, err) } if current == 0 { if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.Version); err != nil { tx.Rollback() return fmt.Errorf("inserting schema version %d: %w", m.Version, err) } } else { if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.Version); err != nil { tx.Rollback() return fmt.Errorf("updating schema version to %d: %w", m.Version, err) } } current = m.Version if err := tx.Commit(); err != nil { return fmt.Errorf("commit migration %d: %w", m.Version, err) } } return nil } func currentVersion(db *sql.DB) (int, error) { var version int err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version) if err == sql.ErrNoRows { return 0, nil } return version, err } func loadMigrations() ([]migration, error) { entries, err := migrationFS.ReadDir("migrations") if err != nil { return nil, fmt.Errorf("reading migrations dir: %w", err) } var migrations []migration for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } // Parse version from filename: NNN_description.sql parts := strings.SplitN(entry.Name(), "_", 2) if len(parts) < 2 { return nil, fmt.Errorf("invalid migration filename: %s", entry.Name()) } version, err := strconv.Atoi(parts[0]) if err != nil { return nil, fmt.Errorf("parsing version from %s: %w", entry.Name(), err) } data, err := migrationFS.ReadFile("migrations/" + entry.Name()) if err != nil { return nil, fmt.Errorf("reading migration %s: %w", entry.Name(), err) } migrations = append(migrations, migration{ Version: version, Name: entry.Name(), SQL: string(data), }) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil }