Skip to content
Snippets Groups Projects
Commit 4f7bac42 authored by ale's avatar ale
Browse files

Merge branch 'sql' into 'master'

Add sqlutil package

See merge request !66
parents c7650f35 2eff7a10
No related branches found
No related tags found
1 merge request!66Add sqlutil package
......@@ -16,6 +16,7 @@ require (
github.com/gofrs/flock v0.8.0 // indirect
github.com/google/go-cmp v0.5.8
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
github.com/mattn/go-sqlite3 v1.14.14
github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75
github.com/prometheus/client_golang v1.12.2
github.com/russross/blackfriday/v2 v2.1.0
......
package sqlutil
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
_ "github.com/mattn/go-sqlite3"
)
// DebugMigrations can be set to true to dump statements to stderr.
var DebugMigrations bool
const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF"
type sqlOptions struct {
migrations []func(*sql.Tx) error
sqlopts string
}
type Option func(*sqlOptions)
func WithMigrations(migrations []func(*sql.Tx) error) Option {
return func(opts *sqlOptions) {
opts.migrations = migrations
}
}
func WithSqliteOptions(sqlopts string) Option {
return func(opts *sqlOptions) {
opts.sqlopts = sqlopts
}
}
// OpenDB opens a SQLite database and runs the database migrations.
func OpenDB(dburi string, options ...Option) (*sql.DB, error) {
var opts sqlOptions
opts.sqlopts = defaultOptions
for _, o := range options {
o(&opts)
}
// Add sqlite3-specific parameters if none are already
// specified in the connection URI.
if !strings.Contains(dburi, "?") {
dburi += opts.sqlopts
}
db, err := sql.Open("sqlite3", dburi)
if err != nil {
return nil, err
}
// Limit the pool to a single connection.
// https://github.com/mattn/go-sqlite3/issues/209
db.SetMaxOpenConns(1)
if err = migrate(db, opts.migrations); err != nil {
db.Close() // nolint
return nil, err
}
return db, nil
}
// Fetch legacy (golang-migrate/migrate/v4) schema version.
func getLegacyMigrationVersion(tx *sql.Tx) (int, error) {
var version int
if err := tx.QueryRow(`SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1`).Scan(&version); err != nil {
return 0, err
}
return version, nil
}
func migrate(db *sql.DB, migrations []func(*sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("DB migration begin transaction: %w", err)
}
defer tx.Rollback() // nolint: errcheck
var idx int
if err = tx.QueryRow("PRAGMA user_version").Scan(&idx); err != nil {
return fmt.Errorf("getting latest applied migration: %w", err)
}
if idx == 0 {
if legacyIdx, err := getLegacyMigrationVersion(tx); err == nil {
idx = legacyIdx
}
}
if idx == len(migrations) {
// Already fully migrated, nothing needed.
return nil
} else if idx > len(migrations) {
return fmt.Errorf("database is at version %d, which is more recent than this binary understands", idx)
}
for i, f := range migrations[idx:] {
if err := f(tx); err != nil {
return fmt.Errorf("migration to version %d failed: %w", i+1, err)
}
}
if n := len(migrations); n > 0 {
// For some reason, ? substitution doesn't work in PRAGMA
// statements, sqlite reports a parse error.
if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", n)); err != nil {
return fmt.Errorf("recording new DB version: %w", err)
}
log.Printf("db migration: upgraded schema version %d -> %d", idx, n)
}
return tx.Commit()
}
// Statement for migrations, executes one or more SQL statements.
func Statement(idl ...string) func(*sql.Tx) error {
return func(tx *sql.Tx) error {
for _, stmt := range idl {
if DebugMigrations {
log.Printf("db migration: executing: %s", stmt)
}
if _, err := tx.Exec(stmt); err != nil {
return err
}
}
return nil
}
}
func WithTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
if err := f(tx); err != nil {
tx.Rollback() // nolint
return err
}
return tx.Commit()
}
func WithReadonlyTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return err
}
defer tx.Rollback() // nolint
return f(tx)
}
package sqlutil
import (
"context"
"database/sql"
"io/ioutil"
"os"
"testing"
)
func init() {
DebugMigrations = true
}
func TestOpenDB(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := OpenDB(dir + "/test.db")
if err != nil {
t.Fatal(err)
}
db.Close()
}
func getTestValue(db *sql.DB, id int) (out string, err error) {
err = WithReadonlyTx(context.Background(), db, func(tx *sql.Tx) error {
return tx.QueryRow("SELECT value FROM test WHERE id=?", id).Scan(&out)
})
return
}
func checkTestValue(t *testing.T, db *sql.DB) {
value, err := getTestValue(db, 1)
if err != nil {
t.Fatal(err)
}
if value != "test" {
t.Fatalf("got bad value '%s', expected 'test'", value)
}
}
func TestOpenDB_Migrations_MultipleStatements(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{
Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"),
Statement("CREATE INDEX idx_test_value ON test(value)"),
Statement("INSERT INTO test (id, value) VALUES (1, 'test')"),
}))
if err != nil {
t.Fatal(err)
}
defer db.Close()
checkTestValue(t, db)
}
func TestOpenDB_Migrations_SingleStatement(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{
Statement(
"CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)",
"CREATE INDEX idx_test_value ON test(value)",
"INSERT INTO test (id, value) VALUES (1, 'test')",
),
}))
if err != nil {
t.Fatal(err)
}
defer db.Close()
checkTestValue(t, db)
}
func TestOpenDB_Migrations_Versions(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
migrations := []func(*sql.Tx) error{
Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"),
Statement("CREATE INDEX idx_test_value ON test(value)"),
}
db, err := OpenDB(dir+"/test.db", WithMigrations(migrations))
if err != nil {
t.Fatal("first open: ", err)
}
db.Close()
migrations = append(migrations, Statement("INSERT INTO test (id, value) VALUES (1, 'test')"))
db, err = OpenDB(dir+"/test.db", WithMigrations(migrations))
if err != nil {
t.Fatal("second open: ", err)
}
defer db.Close()
checkTestValue(t, db)
}
func TestOpenDB_Write(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{
Statement(
"CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)",
"CREATE INDEX idx_test_value ON test(value)",
),
}))
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = WithTx(context.Background(), db, func(tx *sql.Tx) error {
_, err := tx.Exec("INSERT INTO test (id, value) VALUES (?, ?)", 1, "test")
return err
})
if err != nil {
t.Fatalf("INSERT error: %v", err)
}
checkTestValue(t, db)
}
func TestOpenDB_Migrations_Legacy(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
db, err := sql.Open("sqlite3", dir+"/test.db")
if err != nil {
t.Fatal(err)
}
for _, stmt := range []string{
"CREATE TABLE schema_migrations (version uint64,dirty bool)",
"INSERT INTO schema_migrations (version, dirty) VALUES (2, 0)",
"CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)",
"CREATE INDEX idx_test_value ON test(value)",
} {
if _, err := db.Exec(stmt); err != nil {
t.Fatalf("statement '%s': %v", stmt, err)
}
}
db.Close()
migrations := []func(*sql.Tx) error{
Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"),
Statement("CREATE INDEX idx_test_value ON test(value)"),
Statement("INSERT INTO test (id, value) VALUES (1, 'test')"),
}
db, err = OpenDB(dir+"/test.db", WithMigrations(migrations))
if err != nil {
t.Fatal("first open: ", err)
}
defer db.Close()
checkTestValue(t, db)
}
package sqlutil
import (
"database/sql"
"strings"
)
// QueryBuilder is a very simple programmatic query builder, to
// simplify the operation of adding WHERE and ORDER BY clauses
// programatically.
type QueryBuilder struct {
base string
tail string
where []string
args []interface{}
}
// NewQuery returns a query builder starting with the given base query.
func NewQuery(s string) *QueryBuilder {
return &QueryBuilder{base: s}
}
// OrderBy adds an ORDER BY clause.
func (q *QueryBuilder) OrderBy(s string) *QueryBuilder {
q.tail += " ORDER BY "
q.tail += s
return q
}
// Where adds a WHERE clause with associated argument(s).
func (q *QueryBuilder) Where(clause string, args ...interface{}) *QueryBuilder {
q.where = append(q.where, clause)
q.args = append(q.args, args...)
return q
}
// Query executes the resulting query in the given transaction.
func (q *QueryBuilder) Query(tx *sql.Tx) (*sql.Rows, error) {
s := q.base
if len(q.where) > 0 {
s += " WHERE "
s += strings.Join(q.where, " AND ")
}
s += q.tail
return tx.Query(s, q.args...)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment