package infra import ( "database/sql" "fmt" "io/fs" "path" "sort" "strconv" "strings" "time" ) const createSchemaMigrationsTable = ` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, name TEXT NOT NULL, applied_at TEXT NOT NULL );` // ApplyVersionedMigrations applies pending SQLite migrations from fsys, tracking // applied versions in a schema_migrations table. Each migration runs in its // own transaction; on error the tx is rolled back and the function returns. // // Migration filenames must be NNN_name.sql (e.g. 001_init.sql, // 002_add_users.sql). The numeric prefix is the version. Files without a // numeric prefix or with non-.sql extensions are skipped. // // dir is the directory inside fsys containing the migrations (e.g. // "migrations"). Idempotent: migrations whose version <= current are skipped. func ApplyVersionedMigrations(conn *sql.DB, fsys fs.FS, dir string) error { if _, err := conn.Exec(createSchemaMigrationsTable); err != nil { return fmt.Errorf("apply_versioned_migrations: create schema_migrations: %w", err) } current, err := versionedMigrationsCurrentVersion(conn) if err != nil { return err } files, err := versionedMigrationsList(fsys, dir) if err != nil { return err } for _, mf := range files { if mf.version <= current { continue } content, err := fs.ReadFile(fsys, path.Join(dir, mf.filename)) if err != nil { return fmt.Errorf("apply_versioned_migrations: read %s: %w", mf.filename, err) } tx, err := conn.Begin() if err != nil { return fmt.Errorf("apply_versioned_migrations: begin tx for %s: %w", mf.filename, err) } if _, err := tx.Exec(string(content)); err != nil { tx.Rollback() //nolint:errcheck return fmt.Errorf("apply_versioned_migrations: exec %s: %w", mf.filename, err) } if _, err := tx.Exec( "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, ?)", mf.version, mf.filename, time.Now().UTC().Format(time.RFC3339), ); err != nil { tx.Rollback() //nolint:errcheck return fmt.Errorf("apply_versioned_migrations: record %s: %w", mf.filename, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("apply_versioned_migrations: commit %s: %w", mf.filename, err) } } return nil } // versionedMigrationsCurrentVersion returns MAX(version) from schema_migrations, // or 0 if the table is empty. func versionedMigrationsCurrentVersion(conn *sql.DB) (int, error) { var v int err := conn.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&v) if err != nil { return 0, fmt.Errorf("apply_versioned_migrations: read current version: %w", err) } return v, nil } type versionedMigrationFile struct { version int filename string } // versionedMigrationsList reads dir from fsys and returns .sql files with a // numeric NNN_ prefix, sorted by version ascending. func versionedMigrationsList(fsys fs.FS, dir string) ([]versionedMigrationFile, error) { entries, err := fs.ReadDir(fsys, dir) if err != nil { return nil, fmt.Errorf("apply_versioned_migrations: read dir %q: %w", dir, err) } var files []versionedMigrationFile for _, e := range entries { if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") { continue } parts := strings.SplitN(e.Name(), "_", 2) if len(parts) < 2 { continue } v, err := strconv.Atoi(parts[0]) if err != nil { continue } files = append(files, versionedMigrationFile{version: v, filename: e.Name()}) } sort.Slice(files, func(i, j int) bool { return files[i].version < files[j].version }) return files, nil }