package fn_operations import ( "database/sql" "embed" "fmt" "path" "sort" "strconv" "strings" "time" ) //go:embed migrations/*.sql var migrationsFS embed.FS const migrationTableSQL = ` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, name TEXT NOT NULL, applied_at TEXT NOT NULL );` // migrate applies pending migrations to the database. func migrate(conn *sql.DB) error { if _, err := conn.Exec(migrationTableSQL); err != nil { return fmt.Errorf("creating schema_migrations table: %w", err) } current, err := currentVersion(conn) if err != nil { return err } files, err := listMigrations() if err != nil { return err } for _, mf := range files { if mf.version <= current { continue } content, err := migrationsFS.ReadFile(path.Join("migrations", mf.filename)) if err != nil { return fmt.Errorf("reading migration %s: %w", mf.filename, err) } tx, err := conn.Begin() if err != nil { return fmt.Errorf("beginning transaction for migration %d: %w", mf.version, err) } if _, err := tx.Exec(string(content)); err != nil { tx.Rollback() return fmt.Errorf("applying migration %s: %w", mf.filename, err) } if _, err := tx.Exec( "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, ?)", mf.version, mf.filename, time.Now().UTC().Format(time.RFC3339), ); err != nil { tx.Rollback() return fmt.Errorf("recording migration %s: %w", mf.filename, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("committing migration %s: %w", mf.filename, err) } } return nil } func currentVersion(conn *sql.DB) (int, error) { var v int err := conn.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&v) if err != nil { return 0, fmt.Errorf("reading current schema version: %w", err) } return v, nil } type migrationFile struct { version int filename string } func listMigrations() ([]migrationFile, error) { entries, err := migrationsFS.ReadDir("migrations") if err != nil { return nil, fmt.Errorf("reading migrations directory: %w", err) } var files []migrationFile for _, e := range entries { if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") { continue } parts := strings.SplitN(e.Name(), "_", 2) if len(parts) < 2 { continue } v, err := strconv.Atoi(parts[0]) if err != nil { continue } files = append(files, migrationFile{version: v, filename: e.Name()}) } sort.Slice(files, func(i, j int) bool { return files[i].version < files[j].version }) return files, nil }