package infra import ( "database/sql" "fmt" "os" "path/filepath" "sort" "strings" ) const createMigrationsTable = ` CREATE TABLE IF NOT EXISTS _migrations ( version INTEGER PRIMARY KEY, name TEXT NOT NULL, up_sql TEXT NOT NULL, down_sql TEXT NOT NULL, applied_at TEXT NOT NULL DEFAULT (datetime('now')) )` // MigrationUp reads all .sql migration files from dir, creates the _migrations // table if it does not exist, and applies any pending migrations in version order. // Each migration runs in its own transaction. Returns the list of applied migrations. // If a migration fails, execution stops and the error is returned along with any // migrations that were successfully applied before the failure. func MigrationUp(db *sql.DB, dir string) ([]Migration, error) { // Ensure _migrations table exists if _, err := db.Exec(createMigrationsTable); err != nil { return nil, fmt.Errorf("migration_up: cannot create _migrations table: %w", err) } // Load files from directory allMigrations, err := loadMigrationsFromDir(dir) if err != nil { return nil, fmt.Errorf("migration_up: %w", err) } // Fetch already-applied versions applied, err := appliedVersions(db) if err != nil { return nil, fmt.Errorf("migration_up: %w", err) } // Filter pending migrations var pending []Migration for _, m := range allMigrations { if !applied[m.Version] { pending = append(pending, m) } } // Apply each pending migration in its own transaction var result []Migration for _, m := range pending { if err := applyMigration(db, m); err != nil { return result, fmt.Errorf("migration_up: applying version %d (%s): %w", m.Version, m.Name, err) } result = append(result, m) } return result, nil } // loadMigrationsFromDir reads and parses all .sql migration files from dir, // returning them sorted by version ascending. func loadMigrationsFromDir(dir string) ([]Migration, error) { entries, err := os.ReadDir(dir) if err != nil { return nil, fmt.Errorf("cannot read migrations directory %q: %w", dir, err) } var migrations []Migration for _, e := range entries { if e.IsDir() { continue } name := e.Name() if !strings.HasSuffix(strings.ToLower(name), ".sql") { continue } path := filepath.Join(dir, name) content, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("cannot read %q: %w", path, err) } m, err := MigrationParse(name, string(content)) if err != nil { return nil, fmt.Errorf("parse error in %q: %w", name, err) } migrations = append(migrations, m) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) return migrations, nil } // appliedVersions returns a set of version numbers already recorded in _migrations. func appliedVersions(db *sql.DB) (map[int]bool, error) { rows, err := db.Query("SELECT version FROM _migrations") if err != nil { return nil, fmt.Errorf("cannot query _migrations: %w", err) } defer rows.Close() applied := make(map[int]bool) for rows.Next() { var v int if err := rows.Scan(&v); err != nil { return nil, fmt.Errorf("scan version: %w", err) } applied[v] = true } return applied, rows.Err() } // applyMigration executes a migration's UpSQL within a transaction and records it // in _migrations. If UpSQL contains multiple statements, they are executed sequentially // using db.Exec (SQLite supports multiple statements via the C driver). func applyMigration(db *sql.DB, m Migration) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("begin transaction: %w", err) } defer tx.Rollback() //nolint:errcheck // Execute the up SQL (may contain multiple statements) if _, err := tx.Exec(m.UpSQL); err != nil { return fmt.Errorf("exec up_sql: %w", err) } // Record the migration const insertSQL = `INSERT INTO _migrations (version, name, up_sql, down_sql) VALUES (?, ?, ?, ?)` if _, err := tx.Exec(insertSQL, m.Version, m.Name, m.UpSQL, m.DownSQL); err != nil { return fmt.Errorf("record migration: %w", err) } return tx.Commit() }