diff --git a/go.mod b/go.mod index 4d73b1cfc7a2581c90c67f2b59747e4321ba0c91..a03b4d52dc61cda17fd727b473286e62814b48b4 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gofrs/flock v0.8.0 // indirect github.com/google/go-cmp v0.5.8 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 + github.com/mattn/go-sqlite3 v1.14.7 github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 github.com/prometheus/client_golang v1.12.2 github.com/russross/blackfriday/v2 v2.1.0 diff --git a/go.sum b/go.sum index 2c44a2fafcb09d765e137f71a3ac95292549e775..14ab45000eed1178bfe6e052d9e9f0802aa2d0ef 100644 --- a/go.sum +++ b/go.sum @@ -552,6 +552,7 @@ github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxm github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-shellwords v1.0.10/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-zglob v0.0.1/go.mod h1:9fxibJccNxU2cnpIKLRRFA7zX7qhkJIQWBb449FYHOo= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= diff --git a/sqlutil/db.go b/sqlutil/db.go new file mode 100644 index 0000000000000000000000000000000000000000..e860d709f338aa89f8b9fb2c7b7ebbfc649e3250 --- /dev/null +++ b/sqlutil/db.go @@ -0,0 +1,125 @@ +package sqlutil + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +// DebugMigrations can be set to true to dump statements to stderr. +var DebugMigrations = true + +const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF" + +// OpenDB opens a SQLite database and runs the database migrations. +func OpenDB(dburi string, migrations []func(*sql.Tx) error) (*sql.DB, error) { + // Add some sqlite3-specific parameters if none are specified. + if !strings.Contains(dburi, "?") { + dburi += defaultOptions + } + + db, err := sql.Open("sqlite3", dburi) + if err != nil { + return nil, err + } + + // Limit the pool to a single connection. + // https://github.com/mattn/go-sqlite3/issues/209 + db.SetMaxOpenConns(1) + + if err = migrate(db, migrations); err != nil { + db.Close() // nolint + return nil, err + } + + return db, nil +} + +// Fetch legacy (golang-migrate/migrate/v4) schema version. +func getLegacyMigrationVersion(tx *sql.Tx) (int, error) { + var version int + if err := tx.QueryRow(`SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1`).Scan(&version); err != nil { + return 0, err + } + return version, nil +} + +func migrate(db *sql.DB, migrations []func(*sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("DB migration begin transaction: %w", err) + } + defer tx.Rollback() // nolint: errcheck + + var idx int + if err = tx.QueryRow("PRAGMA user_version").Scan(&idx); err != nil { + return fmt.Errorf("getting latest applied migration: %w", err) + } + if idx == 0 { + if legacyIdx, err := getLegacyMigrationVersion(tx); err == nil { + idx = legacyIdx + } + } + + if idx == len(migrations) { + // Already fully migrated, nothing needed. + } else if idx > len(migrations) { + return fmt.Errorf("database is at version %d, which is more recent than this binary understands", idx) + } + + for i, f := range migrations[idx:] { + if err := f(tx); err != nil { + return fmt.Errorf("migration to version %d failed: %w", i+1, err) + } + } + + // For some reason, ? substitution doesn't work in PRAGMA + // statements, sqlite reports a parse error. + if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", len(migrations))); err != nil { + return fmt.Errorf("recording new DB version: %w", err) + } + + return tx.Commit() +} + +// Statement for migrations, executes one or more SQL statements. +func Statement(idl ...string) func(*sql.Tx) error { + return func(tx *sql.Tx) error { + for _, stmt := range idl { + if DebugMigrations { + log.Printf("db migration: executing: %s", stmt) + } + if _, err := tx.Exec(stmt); err != nil { + return err + } + } + return nil + } +} + +func WithTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + if err := f(tx); err != nil { + tx.Rollback() // nolint + return err + } + + return tx.Commit() +} + +func WithReadonlyTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return err + } + defer tx.Rollback() // nolint + return f(tx) +} diff --git a/sqlutil/query_builder.go b/sqlutil/query_builder.go new file mode 100644 index 0000000000000000000000000000000000000000..2d1e30c030ff8a989c88ac9a4e0f63047d09b22c --- /dev/null +++ b/sqlutil/query_builder.go @@ -0,0 +1,47 @@ +package sqlutil + +import ( + "database/sql" + "strings" +) + +// QueryBuilder is a very simple programmatic query builder, to +// simplify the operation of adding WHERE and ORDER BY clauses +// programatically. +type QueryBuilder struct { + base string + tail string + where []string + args []interface{} +} + +// NewQuery returns a query builder starting with the given base query. +func NewQuery(s string) *QueryBuilder { + return &QueryBuilder{base: s} +} + +// OrderBy adds an ORDER BY clause. +func (q *QueryBuilder) OrderBy(s string) *QueryBuilder { + q.tail += " ORDER BY " + q.tail += s + return q +} + +// Where adds a WHERE clause with associated argument(s). +func (q *QueryBuilder) Where(clause string, args ...interface{}) *QueryBuilder { + q.where = append(q.where, clause) + q.args = append(q.args, args...) + return q +} + +// Query executes the resulting query in the given transaction. +func (q *QueryBuilder) Query(tx *sql.Tx) (*sql.Rows, error) { + s := q.base + if len(q.where) > 0 { + s += " WHERE " + s += strings.Join(q.where, " AND ") + } + s += q.tail + + return tx.Query(s, q.args...) +}