package sqlutil import ( "context" "database/sql" "fmt" "log" "strings" _ "github.com/mattn/go-sqlite3" ) // DebugMigrations can be set to true to dump statements to stderr. var DebugMigrations bool const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF" type sqlOptions struct { migrations []func(*sql.Tx) error sqlopts string } type Option func(*sqlOptions) func WithMigrations(migrations []func(*sql.Tx) error) Option { return func(opts *sqlOptions) { opts.migrations = migrations } } func WithSqliteOptions(sqlopts string) Option { return func(opts *sqlOptions) { opts.sqlopts = sqlopts } } // OpenDB opens a SQLite database and runs the database migrations. func OpenDB(dburi string, options ...Option) (*sql.DB, error) { var opts sqlOptions opts.sqlopts = defaultOptions for _, o := range options { o(&opts) } // Add sqlite3-specific parameters if none are already // specified in the connection URI. if !strings.Contains(dburi, "?") { dburi += opts.sqlopts } db, err := sql.Open("sqlite3", dburi) if err != nil { return nil, err } // Limit the pool to a single connection. // https://github.com/mattn/go-sqlite3/issues/209 db.SetMaxOpenConns(1) if err = migrate(db, opts.migrations); err != nil { db.Close() // nolint return nil, err } return db, nil } // Fetch legacy (golang-migrate/migrate/v4) schema version. func getLegacyMigrationVersion(tx *sql.Tx) (int, error) { var version int if err := tx.QueryRow(`SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1`).Scan(&version); err != nil { return 0, err } return version, nil } func migrate(db *sql.DB, migrations []func(*sql.Tx) error) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("DB migration begin transaction: %w", err) } defer tx.Rollback() // nolint: errcheck var idx int if err = tx.QueryRow("PRAGMA user_version").Scan(&idx); err != nil { return fmt.Errorf("getting latest applied migration: %w", err) } if idx == 0 { if legacyIdx, err := getLegacyMigrationVersion(tx); err == nil { idx = legacyIdx } } if idx == len(migrations) { // Already fully migrated, nothing needed. return nil } else if idx > len(migrations) { return fmt.Errorf("database is at version %d, which is more recent than this binary understands", idx) } for i, f := range migrations[idx:] { if err := f(tx); err != nil { return fmt.Errorf("migration to version %d failed: %w", i+1, err) } } if n := len(migrations); n > 0 { // For some reason, ? substitution doesn't work in PRAGMA // statements, sqlite reports a parse error. if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", n)); err != nil { return fmt.Errorf("recording new DB version: %w", err) } log.Printf("db migration: upgraded schema version %d -> %d", idx, n) } return tx.Commit() } // Statement for migrations, executes one or more SQL statements. func Statement(idl ...string) func(*sql.Tx) error { return func(tx *sql.Tx) error { for _, stmt := range idl { if DebugMigrations { log.Printf("db migration: executing: %s", stmt) } if _, err := tx.Exec(stmt); err != nil { return err } } return nil } } func WithTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return err } if err := f(tx); err != nil { tx.Rollback() // nolint return err } return tx.Commit() } func WithReadonlyTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err != nil { return err } defer tx.Rollback() // nolint return f(tx) }