diff --git a/cmd/launcher/sqlite.go b/cmd/launcher/sqlite.go index 20be13d..941a512 100644 --- a/cmd/launcher/sqlite.go +++ b/cmd/launcher/sqlite.go @@ -1,12 +1,37 @@ package main import ( + "context" "database/sql" + "database/sql/driver" moderncsqlite "modernc.org/sqlite" ) func init() { - // mautrix dbutil opens sqlite as "sqlite3"; register the pure-Go driver under that name. - sql.Register("sqlite3", &moderncsqlite.Driver{}) + // mautrix dbutil opens sqlite as "sqlite3"; register the pure-Go driver + // under that name. We add a connection hook that sets WAL mode and a + // busy timeout on every connection to prevent SQLITE_BUSY crashes during + // concurrent writes (crypto store sync + memory store). + d := &moderncsqlite.Driver{} + d.RegisterConnectionHook(sqlitePragmaHook) + sql.Register("sqlite3", d) +} + +// sqlitePragmaHook sets WAL journal mode and a 5-second busy timeout on +// every new SQLite connection. This prevents SQLITE_BUSY errors when +// multiple goroutines write concurrently (e.g. mautrix crypto sync + +// memory/knowledge stores). +func sqlitePragmaHook(conn moderncsqlite.ExecQuerierContext, _ string) error { + ctx := context.Background() + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA busy_timeout=5000", + } + for _, p := range pragmas { + if _, err := conn.ExecContext(ctx, p, []driver.NamedValue{}); err != nil { + return err + } + } + return nil } diff --git a/cmd/launcher/sqlite_test.go b/cmd/launcher/sqlite_test.go new file mode 100644 index 0000000..5988cc7 --- /dev/null +++ b/cmd/launcher/sqlite_test.go @@ -0,0 +1,198 @@ +package main + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "sync" + "testing" +) + +// TestSQLitePragmaHook verifies that every connection opened via the registered +// "sqlite3" driver has WAL journal mode and a busy_timeout set. +func TestSQLitePragmaHook(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + + // Force a real connection to be created (Open is lazy). + if err := db.Ping(); err != nil { + t.Fatalf("ping: %v", err) + } + + var journalMode string + if err := db.QueryRow("PRAGMA journal_mode").Scan(&journalMode); err != nil { + t.Fatalf("query journal_mode: %v", err) + } + if journalMode != "wal" { + t.Errorf("journal_mode = %q, want %q", journalMode, "wal") + } + + var busyTimeout int + if err := db.QueryRow("PRAGMA busy_timeout").Scan(&busyTimeout); err != nil { + t.Fatalf("query busy_timeout: %v", err) + } + if busyTimeout != 5000 { + t.Errorf("busy_timeout = %d, want %d", busyTimeout, 5000) + } +} + +// TestSQLiteConcurrentWrites verifies that concurrent writers do not get +// SQLITE_BUSY errors thanks to WAL mode and busy_timeout. +func TestSQLiteConcurrentWrites(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "concurrent.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + + // Create a table to write to. + if _, err := db.Exec(`CREATE TABLE kv (k TEXT PRIMARY KEY, v TEXT)`); err != nil { + t.Fatalf("create table: %v", err) + } + + // Simulate the scenario: multiple goroutines writing concurrently, + // like mautrix crypto sync + memory store + knowledge store. + const writers = 5 + const writesPerWriter = 50 + ctx := context.Background() + + var wg sync.WaitGroup + errs := make(chan error, writers*writesPerWriter) + + for w := 0; w < writers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < writesPerWriter; i++ { + _, err := db.ExecContext(ctx, + `INSERT OR REPLACE INTO kv (k, v) VALUES (?, ?)`, + // Use writer+iteration as key so they conflict + "key", "value", + ) + if err != nil { + errs <- err + } + } + }() + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent write error: %v", err) + } +} + +// TestSQLiteConcurrentWritesSeparateConnections tests with separate sql.DB +// instances (like crypto.db being opened by both mautrix and our code). +func TestSQLiteConcurrentWritesSeparateConnections(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "shared.db") + + // Open two separate connections to the same file (simulates mautrix + + // our memory store sharing a DB, or separate processes). + db1, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open db1: %v", err) + } + defer db1.Close() + + db2, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open db2: %v", err) + } + defer db2.Close() + + // Create table via db1 + if _, err := db1.Exec(`CREATE TABLE t (id INTEGER PRIMARY KEY, data TEXT)`); err != nil { + t.Fatalf("create table: %v", err) + } + + ctx := context.Background() + const iterations = 100 + + var wg sync.WaitGroup + errs := make(chan error, iterations*2) + + // Writer 1 (simulates mautrix SaveNextBatch) + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _, err := db1.ExecContext(ctx, + `INSERT INTO t (data) VALUES (?)`, "from_crypto_sync", + ) + if err != nil { + errs <- err + } + } + }() + + // Writer 2 (simulates our memory store SaveMessage) + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _, err := db2.ExecContext(ctx, + `INSERT INTO t (data) VALUES (?)`, "from_memory_store", + ) + if err != nil { + errs <- err + } + } + }() + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent write error (separate conns): %v", err) + } + + // Verify all writes succeeded + var count int + if err := db1.QueryRow("SELECT COUNT(*) FROM t").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + expected := iterations * 2 + if count != expected { + t.Errorf("row count = %d, want %d", count, expected) + } +} + +// TestSQLiteWALFileCreated verifies that WAL mode actually creates the -wal file, +// confirming the pragma took effect at the filesystem level. +func TestSQLiteWALFileCreated(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "walcheck.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + + // Create a table and write data to trigger WAL file creation. + if _, err := db.Exec(`CREATE TABLE x (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatalf("create: %v", err) + } + if _, err := db.Exec(`INSERT INTO x (id) VALUES (1)`); err != nil { + t.Fatalf("insert: %v", err) + } + + walPath := dbPath + "-wal" + if _, err := os.Stat(walPath); os.IsNotExist(err) { + t.Errorf("WAL file not created at %s — PRAGMA journal_mode=WAL may not be taking effect", walPath) + } +}