package infra import ( "database/sql" "fmt" "strings" ) // DBInsertBatch inserts multiple rows into a table using a prepared statement // inside a transaction. columns must match the order of values in each row. // Returns the total number of rows affected. // Column and table names are validated to contain only safe identifier chars. func DBInsertBatch(db *sql.DB, table string, columns []string, rows [][]any) (int64, error) { if !validIdentifier.MatchString(table) { return 0, fmt.Errorf("db_insert_batch: invalid table name %q", table) } if len(columns) == 0 { return 0, fmt.Errorf("db_insert_batch: columns must not be empty") } if len(rows) == 0 { return 0, nil } for _, col := range columns { if !validIdentifier.MatchString(col) { return 0, fmt.Errorf("db_insert_batch: invalid column name %q", col) } } placeholders := make([]string, len(columns)) for i := range columns { placeholders[i] = "?" } query := fmt.Sprintf( "INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "), ) tx, err := db.Begin() if err != nil { return 0, fmt.Errorf("db_insert_batch: begin tx: %w", err) } defer tx.Rollback() //nolint:errcheck stmt, err := tx.Prepare(query) if err != nil { return 0, fmt.Errorf("db_insert_batch: prepare: %w", err) } defer stmt.Close() var total int64 for i, row := range rows { if len(row) != len(columns) { return 0, fmt.Errorf("db_insert_batch: row %d has %d values, expected %d", i, len(row), len(columns)) } result, err := stmt.Exec(row...) if err != nil { return 0, fmt.Errorf("db_insert_batch: exec row %d: %w", i, err) } n, _ := result.RowsAffected() total += n } if err := tx.Commit(); err != nil { return 0, fmt.Errorf("db_insert_batch: commit: %w", err) } return total, nil }