diff --git a/go.mod b/go.mod index 4d73b1cfc7a2581c90c67f2b59747e4321ba0c91..6c3c91afc4bde522c8d0c17ef3765c9686a59772 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gofrs/flock v0.8.0 // indirect github.com/google/go-cmp v0.5.8 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 + github.com/mattn/go-sqlite3 v1.14.14 github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 github.com/prometheus/client_golang v1.12.2 github.com/russross/blackfriday/v2 v2.1.0 diff --git a/go.sum b/go.sum index 2c44a2fafcb09d765e137f71a3ac95292549e775..2cc4a8895f29f46e55bf0b02a2ac81325a573bc4 100644 --- a/go.sum +++ b/go.sum @@ -552,7 +552,10 @@ github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxm github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-shellwords v1.0.10/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-zglob v0.0.1/go.mod h1:9fxibJccNxU2cnpIKLRRFA7zX7qhkJIQWBb449FYHOo= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= diff --git a/sqlutil/db.go b/sqlutil/db.go new file mode 100644 index 0000000000000000000000000000000000000000..531683440c453fa43dca7c0c129f2173f34addbc --- /dev/null +++ b/sqlutil/db.go @@ -0,0 +1,155 @@ +package sqlutil + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +// DebugMigrations can be set to true to dump statements to stderr. +var DebugMigrations bool + +const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF" + +type sqlOptions struct { + migrations []func(*sql.Tx) error + sqlopts string +} + +type Option func(*sqlOptions) + +func WithMigrations(migrations []func(*sql.Tx) error) Option { + return func(opts *sqlOptions) { + opts.migrations = migrations + } +} + +func WithSqliteOptions(sqlopts string) Option { + return func(opts *sqlOptions) { + opts.sqlopts = sqlopts + } +} + +// OpenDB opens a SQLite database and runs the database migrations. +func OpenDB(dburi string, options ...Option) (*sql.DB, error) { + var opts sqlOptions + opts.sqlopts = defaultOptions + for _, o := range options { + o(&opts) + } + + // Add sqlite3-specific parameters if none are already + // specified in the connection URI. + if !strings.Contains(dburi, "?") { + dburi += opts.sqlopts + } + + db, err := sql.Open("sqlite3", dburi) + if err != nil { + return nil, err + } + + // Limit the pool to a single connection. + // https://github.com/mattn/go-sqlite3/issues/209 + db.SetMaxOpenConns(1) + + if err = migrate(db, opts.migrations); err != nil { + db.Close() // nolint + return nil, err + } + + return db, nil +} + +// Fetch legacy (golang-migrate/migrate/v4) schema version. +func getLegacyMigrationVersion(tx *sql.Tx) (int, error) { + var version int + if err := tx.QueryRow(`SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1`).Scan(&version); err != nil { + return 0, err + } + return version, nil +} + +func migrate(db *sql.DB, migrations []func(*sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("DB migration begin transaction: %w", err) + } + defer tx.Rollback() // nolint: errcheck + + var idx int + if err = tx.QueryRow("PRAGMA user_version").Scan(&idx); err != nil { + return fmt.Errorf("getting latest applied migration: %w", err) + } + if idx == 0 { + if legacyIdx, err := getLegacyMigrationVersion(tx); err == nil { + idx = legacyIdx + } + } + + if idx == len(migrations) { + // Already fully migrated, nothing needed. + return nil + } else if idx > len(migrations) { + return fmt.Errorf("database is at version %d, which is more recent than this binary understands", idx) + } + + for i, f := range migrations[idx:] { + if err := f(tx); err != nil { + return fmt.Errorf("migration to version %d failed: %w", i+1, err) + } + } + + if n := len(migrations); n > 0 { + // For some reason, ? substitution doesn't work in PRAGMA + // statements, sqlite reports a parse error. + if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", n)); err != nil { + return fmt.Errorf("recording new DB version: %w", err) + } + log.Printf("db migration: upgraded schema version %d -> %d", idx, n) + } + + return tx.Commit() +} + +// Statement for migrations, executes one or more SQL statements. +func Statement(idl ...string) func(*sql.Tx) error { + return func(tx *sql.Tx) error { + for _, stmt := range idl { + if DebugMigrations { + log.Printf("db migration: executing: %s", stmt) + } + if _, err := tx.Exec(stmt); err != nil { + return err + } + } + return nil + } +} + +func WithTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + if err := f(tx); err != nil { + tx.Rollback() // nolint + return err + } + + return tx.Commit() +} + +func WithReadonlyTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return err + } + defer tx.Rollback() // nolint + return f(tx) +} diff --git a/sqlutil/db_test.go b/sqlutil/db_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e6da51c8e485570aa61c488b7f51ca266bc0dd60 --- /dev/null +++ b/sqlutil/db_test.go @@ -0,0 +1,181 @@ +package sqlutil + +import ( + "context" + "database/sql" + "io/ioutil" + "os" + "testing" +) + +func init() { + DebugMigrations = true +} + +func TestOpenDB(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + db, err := OpenDB(dir + "/test.db") + if err != nil { + t.Fatal(err) + } + db.Close() +} + +func getTestValue(db *sql.DB, id int) (out string, err error) { + err = WithReadonlyTx(context.Background(), db, func(tx *sql.Tx) error { + return tx.QueryRow("SELECT value FROM test WHERE id=?", id).Scan(&out) + }) + return +} + +func checkTestValue(t *testing.T, db *sql.DB) { + value, err := getTestValue(db, 1) + if err != nil { + t.Fatal(err) + } + if value != "test" { + t.Fatalf("got bad value '%s', expected 'test'", value) + } +} + +func TestOpenDB_Migrations_MultipleStatements(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ + Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"), + Statement("CREATE INDEX idx_test_value ON test(value)"), + Statement("INSERT INTO test (id, value) VALUES (1, 'test')"), + })) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + checkTestValue(t, db) +} + +func TestOpenDB_Migrations_SingleStatement(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ + Statement( + "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)", + "CREATE INDEX idx_test_value ON test(value)", + "INSERT INTO test (id, value) VALUES (1, 'test')", + ), + })) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + checkTestValue(t, db) +} + +func TestOpenDB_Migrations_Versions(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + migrations := []func(*sql.Tx) error{ + Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"), + Statement("CREATE INDEX idx_test_value ON test(value)"), + } + + db, err := OpenDB(dir+"/test.db", WithMigrations(migrations)) + if err != nil { + t.Fatal("first open: ", err) + } + db.Close() + + migrations = append(migrations, Statement("INSERT INTO test (id, value) VALUES (1, 'test')")) + db, err = OpenDB(dir+"/test.db", WithMigrations(migrations)) + if err != nil { + t.Fatal("second open: ", err) + } + defer db.Close() + + checkTestValue(t, db) +} + +func TestOpenDB_Write(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ + Statement( + "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)", + "CREATE INDEX idx_test_value ON test(value)", + ), + })) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = WithTx(context.Background(), db, func(tx *sql.Tx) error { + _, err := tx.Exec("INSERT INTO test (id, value) VALUES (?, ?)", 1, "test") + return err + }) + if err != nil { + t.Fatalf("INSERT error: %v", err) + } + + checkTestValue(t, db) +} + +func TestOpenDB_Migrations_Legacy(t *testing.T) { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + db, err := sql.Open("sqlite3", dir+"/test.db") + if err != nil { + t.Fatal(err) + } + for _, stmt := range []string{ + "CREATE TABLE schema_migrations (version uint64,dirty bool)", + "INSERT INTO schema_migrations (version, dirty) VALUES (2, 0)", + "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)", + "CREATE INDEX idx_test_value ON test(value)", + } { + if _, err := db.Exec(stmt); err != nil { + t.Fatalf("statement '%s': %v", stmt, err) + } + } + db.Close() + + migrations := []func(*sql.Tx) error{ + Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"), + Statement("CREATE INDEX idx_test_value ON test(value)"), + Statement("INSERT INTO test (id, value) VALUES (1, 'test')"), + } + + db, err = OpenDB(dir+"/test.db", WithMigrations(migrations)) + if err != nil { + t.Fatal("first open: ", err) + } + defer db.Close() + + checkTestValue(t, db) +} diff --git a/sqlutil/query_builder.go b/sqlutil/query_builder.go new file mode 100644 index 0000000000000000000000000000000000000000..2d1e30c030ff8a989c88ac9a4e0f63047d09b22c --- /dev/null +++ b/sqlutil/query_builder.go @@ -0,0 +1,47 @@ +package sqlutil + +import ( + "database/sql" + "strings" +) + +// QueryBuilder is a very simple programmatic query builder, to +// simplify the operation of adding WHERE and ORDER BY clauses +// programatically. +type QueryBuilder struct { + base string + tail string + where []string + args []interface{} +} + +// NewQuery returns a query builder starting with the given base query. +func NewQuery(s string) *QueryBuilder { + return &QueryBuilder{base: s} +} + +// OrderBy adds an ORDER BY clause. +func (q *QueryBuilder) OrderBy(s string) *QueryBuilder { + q.tail += " ORDER BY " + q.tail += s + return q +} + +// Where adds a WHERE clause with associated argument(s). +func (q *QueryBuilder) Where(clause string, args ...interface{}) *QueryBuilder { + q.where = append(q.where, clause) + q.args = append(q.args, args...) + return q +} + +// Query executes the resulting query in the given transaction. +func (q *QueryBuilder) Query(tx *sql.Tx) (*sql.Rows, error) { + s := q.base + if len(q.where) > 0 { + s += " WHERE " + s += strings.Join(q.where, " AND ") + } + s += q.tail + + return tx.Query(s, q.args...) +}