diff --git a/sqlutil/db.go b/sqlutil/db.go index b6bc7dbe042804cddcd07a43f57fe4c01ecc703d..531683440c453fa43dca7c0c129f2173f34addbc 100644 --- a/sqlutil/db.go +++ b/sqlutil/db.go @@ -15,11 +15,37 @@ var DebugMigrations bool const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF" +type sqlOptions struct { + migrations []func(*sql.Tx) error + sqlopts string +} + +type Option func(*sqlOptions) + +func WithMigrations(migrations []func(*sql.Tx) error) Option { + return func(opts *sqlOptions) { + opts.migrations = migrations + } +} + +func WithSqliteOptions(sqlopts string) Option { + return func(opts *sqlOptions) { + opts.sqlopts = sqlopts + } +} + // OpenDB opens a SQLite database and runs the database migrations. -func OpenDB(dburi string, migrations []func(*sql.Tx) error) (*sql.DB, error) { - // Add some sqlite3-specific parameters if none are specified. +func OpenDB(dburi string, options ...Option) (*sql.DB, error) { + var opts sqlOptions + opts.sqlopts = defaultOptions + for _, o := range options { + o(&opts) + } + + // Add sqlite3-specific parameters if none are already + // specified in the connection URI. if !strings.Contains(dburi, "?") { - dburi += defaultOptions + dburi += opts.sqlopts } db, err := sql.Open("sqlite3", dburi) @@ -31,7 +57,7 @@ func OpenDB(dburi string, migrations []func(*sql.Tx) error) (*sql.DB, error) { // https://github.com/mattn/go-sqlite3/issues/209 db.SetMaxOpenConns(1) - if err = migrate(db, migrations); err != nil { + if err = migrate(db, opts.migrations); err != nil { db.Close() // nolint return nil, err } diff --git a/sqlutil/db_test.go b/sqlutil/db_test.go index cc52635147fbf5c5a7a8a5701106672ce1fe8d4e..a91b7e2c66be6259070469890272ab512b888850 100644 --- a/sqlutil/db_test.go +++ b/sqlutil/db_test.go @@ -19,7 +19,7 @@ func TestOpenDB(t *testing.T) { } defer os.RemoveAll(dir) - db, err := OpenDB(dir+"/test.db", nil) + db, err := OpenDB(dir + "/test.db") if err != nil { t.Fatal(err) } @@ -50,11 +50,11 @@ func TestOpenDB_Migrations_MultipleStatements(t *testing.T) { } defer os.RemoveAll(dir) - db, err := OpenDB(dir+"/test.db", []func(*sql.Tx) error{ + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ Statement("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)"), Statement("CREATE INDEX idx_test_value ON test(value)"), Statement("INSERT INTO test (id, value) VALUES (1, 'test')"), - }) + })) if err != nil { t.Fatal(err) } @@ -70,13 +70,13 @@ func TestOpenDB_Migrations_SingleStatement(t *testing.T) { } defer os.RemoveAll(dir) - db, err := OpenDB(dir+"/test.db", []func(*sql.Tx) error{ + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ Statement( "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)", "CREATE INDEX idx_test_value ON test(value)", "INSERT INTO test (id, value) VALUES (1, 'test')", ), - }) + })) if err != nil { t.Fatal(err) } @@ -97,14 +97,14 @@ func TestOpenDB_Migrations_Versions(t *testing.T) { Statement("CREATE INDEX idx_test_value ON test(value)"), } - db, err := OpenDB(dir+"/test.db", migrations) + db, err := OpenDB(dir+"/test.db", WithMigrations(migrations)) if err != nil { t.Fatal("first open: ", err) } db.Close() migrations = append(migrations, Statement("INSERT INTO test (id, value) VALUES (1, 'test')")) - db, err = OpenDB(dir+"/test.db", migrations) + db, err = OpenDB(dir+"/test.db", WithMigrations(migrations)) if err != nil { t.Fatal("second open: ", err) } @@ -120,12 +120,12 @@ func TestOpenDB_Write(t *testing.T) { } defer os.RemoveAll(dir) - db, err := OpenDB(dir+"/test.db", []func(*sql.Tx) error{ + db, err := OpenDB(dir+"/test.db", WithMigrations([]func(*sql.Tx) error{ Statement( "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, value TEXT)", "CREATE INDEX idx_test_value ON test(value)", ), - }) + })) if err != nil { t.Fatal(err) }