Commit 54575419 authored by ale's avatar ale
Browse files

Use SQLite in a simpler and less error-prone way

The previous practice of keeping around prepared statements is
somewhat harmful, as they end up bound to the connection. Drop the
statementMap concept, and just use tx.Query() directly instead.

Reduce issues with concurrent write access to SQLite by running
db.SetMaxOpenConns(1).
parent 9468865e
Pipeline #16598 passed with stages
in 2 minutes and 3 seconds
......@@ -12,31 +12,22 @@ var analysisStatements = map[string]string{
}
type analysisService struct {
db *sql.DB
stmts statementMap
db *sql.DB
}
func newAnalysisService(db *sql.DB) (*analysisService, error) {
stmts, err := newStatementMap(db, analysisStatements)
if err != nil {
return nil, err
}
return &analysisService{
db: db,
stmts: stmts,
db: db,
}, nil
}
func (d *analysisService) Close() {
d.stmts.Close()
}
func (d *analysisService) CheckDevice(ctx context.Context, username string, deviceInfo *usermetadb.DeviceInfo) (bool, error) {
var seen bool
err := withReadonlyTX(d.db, func(tx *sql.Tx) error {
stmt := d.stmts.get(tx, "check_device_info")
defer stmt.Close()
err := stmt.QueryRow(username, deviceInfo.ID).Scan(&seen)
err := tx.QueryRow(analysisStatements["check_device_info"], username, deviceInfo.ID).Scan(&seen)
if err != nil && err != sql.ErrNoRows {
return err
}
......
......@@ -36,24 +36,16 @@ var lastloginDBStatements = map[string]string{
}
type lastloginDB struct {
db *sql.DB
stmts statementMap
db *sql.DB
}
func newLastloginDB(db *sql.DB) (*lastloginDB, error) {
stmts, err := newStatementMap(db, lastloginDBStatements)
if err != nil {
return nil, err
}
return &lastloginDB{
db: db,
stmts: stmts,
db: db,
}, nil
}
func (l *lastloginDB) Close() {
l.stmts.Close()
}
func (l *lastloginDB) AddLastLogin(ctx context.Context, entry *usermetadb.LastLoginEntry) error {
......@@ -65,20 +57,15 @@ func (l *lastloginDB) AddLastLogin(ctx context.Context, entry *usermetadb.LastLo
return err
}
stmt := "insert_or_replace_last_login"
args := []interface{}{
entry.Username,
entry.Service,
roundTimestamp(entry.Timestamp),
}
return retryBusy(ctx, func() error {
return withTX(l.db, func(tx *sql.Tx) error {
stmt := l.stmts.get(tx, stmt)
defer stmt.Close()
_, err := stmt.Exec(args...)
return err
})
return withTX(l.db, func(tx *sql.Tx) error {
_, err := tx.Exec(lastloginDBStatements["insert_or_replace_last_login"], args...)
return err
})
}
......@@ -100,9 +87,7 @@ func (l *lastloginDB) GetLastLogin(ctx context.Context, username string, service
var entries []*usermetadb.LastLoginEntry
err := withReadonlyTX(l.db, func(tx *sql.Tx) error {
stmt := l.stmts.get(tx, stmt)
defer stmt.Close()
rows, err := stmt.Query(args...)
rows, err := tx.Query(lastloginDBStatements[stmt], args...)
if err != nil {
return err
}
......
......@@ -27,6 +27,32 @@ func generateLastLoginEntries() []*usermetadb.LastLoginEntry {
return entries
}
func BenchmarkLastLogin(b *testing.B) {
defer os.Remove("bench.db")
db, err := openDB("bench.db")
if err != nil {
b.Fatal(err)
}
defer db.Close()
ll, err := newLastloginDB(db)
if err != nil {
b.Fatal(err)
}
defer ll.Close()
entries := generateLastLoginEntries()
b.Run("AddLastLogin", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for i := 0; pb.Next(); i++ {
// nolint: errcheck
ll.AddLastLogin(context.Background(), entries[i%len(entries)])
}
})
})
}
func TestLastlogin_LoginAdded(t *testing.T) {
defer os.Remove("test.db")
db, err := openDB("test.db")
......
......@@ -84,29 +84,24 @@ func (s *usernameSet) foreach(f func(string)) {
// Leisurely consolidate logs for users, but only when necessary, and
// proceeding at a slow pace. We wake up once per minute, and clean up
// the log history for users that have logged in since.
func compactLogs(db *sql.DB, stmts statementMap, username string, count int) error {
func compactLogs(db *sql.DB, username string, count int) error {
tx, err := db.Begin()
if err != nil {
return err
}
stmt := stmts.get(tx, "consolidate_userlog")
defer stmt.Close()
if _, err := stmt.Exec(username, count); err != nil {
if _, err := tx.Exec(userlogDBStatements["consolidate_userlog"], username, count); err != nil {
return err
}
return tx.Commit()
}
func runPruneStmt(db *sql.DB, stmts statementMap, stmtName string, cutoff time.Time) error {
func runPruneStmt(db *sql.DB, stmtName string, cutoff time.Time) error {
// Run each batch deletion in its own transaction, just to be nice.
tx, err := db.Begin()
if err != nil {
return err
}
stmt := stmts.get(tx, stmtName)
defer stmt.Close()
_, err = stmt.Exec(cutoff)
stmt.Close()
_, err = tx.Exec(userlogDBStatements[stmtName], cutoff)
if err != nil {
tx.Rollback() // nolint
return err
......@@ -115,12 +110,12 @@ func runPruneStmt(db *sql.DB, stmts statementMap, stmtName string, cutoff time.T
}
// Remove old entries (both logs and devices) from the database.
func pruneLogs(db *sql.DB, stmts statementMap, pruneCutoffDays int) (rerr error) {
func pruneLogs(db *sql.DB, pruneCutoffDays int) (rerr error) {
cutoff := time.Now().AddDate(0, 0, -pruneCutoffDays)
// Always run both statements, even if the first returns an error.
for _, stmtName := range []string{"prune_userlog", "prune_device_info"} {
if err := runPruneStmt(db, stmts, stmtName, cutoff); err != nil {
if err := runPruneStmt(db, stmtName, cutoff); err != nil {
rerr = err
}
}
......
......@@ -20,18 +20,13 @@ func TestPruneLogs(t *testing.T) {
}
defer db.Close()
stmts, err := newStatementMap(db, userlogDBStatements)
if err != nil {
t.Fatal(err)
}
bulkLoadTestLogs(t, db)
n := countLogs(db)
if n == 0 {
t.Fatal("no logs loaded?")
}
if err := pruneLogs(db, stmts, 365); err != nil {
if err := pruneLogs(db, 365); err != nil {
t.Fatal("pruneLogs():", err)
}
n2 := countLogs(db)
......@@ -54,11 +49,6 @@ func TestCompactLogs(t *testing.T) {
}
defer db.Close()
stmts, err := newStatementMap(db, userlogDBStatements)
if err != nil {
t.Fatal(err)
}
bulkLoadTestLogs(t, db)
user1 := randomUsernames[0]
user2 := randomUsernames[1]
......@@ -71,7 +61,7 @@ func TestCompactLogs(t *testing.T) {
// Find a reasonable limit.
limit := n1 / 2
if err := compactLogs(db, stmts, user1, limit); err != nil {
if err := compactLogs(db, user1, limit); err != nil {
t.Fatal("compactLogs():", err)
}
......
......@@ -44,11 +44,15 @@ func BenchmarkServer_AddLog(b *testing.B) {
c, _ := client.New(&clientutil.BackendConfig{URL: httpSrv.URL})
entries := generateTestLogs(b.N, generateAllRandomDevices())
for _, e := range entries {
if err := c.AddLog(context.Background(), "", e); err != nil {
b.Fatalf("AddLog(%+v): %v", e, err)
b.Run("AddLog", func(b *testing.B) {
for i := 0; i < b.N; i++ {
e := entries[i%len(entries)]
if err := c.AddLog(context.Background(), "", e); err != nil {
b.Fatalf("AddLog(%+v): %v", e, err)
}
}
}
})
}
func TestServer_AddLastLogin(t *testing.T) {
......
package server
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"strings"
"time"
migrate "github.com/golang-migrate/migrate/v4"
msqlite3 "github.com/golang-migrate/migrate/v4/database/sqlite3"
bindata "github.com/golang-migrate/migrate/v4/source/go_bindata"
sqlite3 "github.com/mattn/go-sqlite3"
"git.autistici.org/id/usermetadb/migrations"
)
......@@ -23,13 +17,12 @@ const dbDriver = "sqlite3"
func openDB(dburi string) (*sql.DB, error) {
if !strings.Contains(dburi, "?") {
// Tune the SQLite engine:
// - activate the shared cache
// - activate the WAL journal
// - increase busy_timeout
// - enable auto-vacuuming since we delete lots of data
// - disable syncing, to reduce disk write load at the expense
// of durability.
dburi += "?cache=shared&_busy_timeout=10000&_journal=WAL&_mutex=full&_sync=OFF&_auto_vacuum=incremental"
dburi += "?cache=shared&_journal=WAL&_sync=OFF&_auto_vacuum=incremental"
}
db, err := sql.Open(dbDriver, dburi)
......@@ -42,12 +35,15 @@ func openDB(dburi string) (*sql.DB, error) {
return nil, err
}
// Running database migrations closes the database. Re-open it.
db, err = sql.Open(dbDriver, dburi)
if err != nil {
return nil, err
}
// Limit the pool to a single connection.
// https://github.com/mattn/go-sqlite3/issues/209
db.SetMaxOpenConns(1)
return db, nil
}
......@@ -91,30 +87,6 @@ func runDatabaseMigrations(db *sql.DB) error {
return nil
}
type statementMap map[string]*sql.Stmt
func newStatementMap(db *sql.DB, statements map[string]string) (statementMap, error) {
stmts := make(map[string]*sql.Stmt)
for name, qstr := range statements {
stmt, err := db.Prepare(qstr)
if err != nil {
return nil, fmt.Errorf("error compiling statement '%s': %v", name, err)
}
stmts[name] = stmt
}
return statementMap(stmts), nil
}
func (m statementMap) Close() {
for _, s := range m {
s.Close() // nolint
}
}
func (m statementMap) get(tx *sql.Tx, name string) *sql.Stmt {
return tx.Stmt(m[name])
}
func withTX(db *sql.DB, f func(*sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
......@@ -135,31 +107,3 @@ func withReadonlyTX(db *sql.DB, f func(*sql.Tx) error) error {
defer tx.Rollback() // nolint
return f(tx)
}
func isBusy(err error) bool {
switch e := err.(type) {
case sqlite3.Error:
return e.Code == 5
default:
return false
}
}
var defaultQueryTimeout = 3 * time.Second
func retryBusy(ctx context.Context, f func() error) error {
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultQueryTimeout)
}
for time.Now().Before(deadline) {
if err := f(); !isBusy(err) {
return err
}
// Random sleep, max 1ms.
time.Sleep(time.Duration(rand.Float64()) * time.Millisecond)
}
return errors.New("query timed out waiting to lock the database")
}
......@@ -116,20 +116,13 @@ var userlogDBStatements = map[string]string{
type userlogDB struct {
db *sql.DB
stmts statementMap
cronjobs []*cronJob
pendingCompaction *usernameSet
}
func newUserlogDB(db *sql.DB, compactionKeepCount, maxAgeDays int) (*userlogDB, error) {
stmts, err := newStatementMap(db, userlogDBStatements)
if err != nil {
return nil, err
}
udb := &userlogDB{
db: db,
stmts: stmts,
db: db,
}
// Selectively add cronjobs if the related parameters are non-zero.
......@@ -138,7 +131,7 @@ func newUserlogDB(db *sql.DB, compactionKeepCount, maxAgeDays int) (*userlogDB,
udb.pendingCompaction = pc
udb.cronjobs = append(udb.cronjobs, newCron(func() {
pc.foreach(func(username string) {
if err := compactLogs(db, stmts, username, compactionKeepCount); err != nil {
if err := compactLogs(db, username, compactionKeepCount); err != nil {
log.Printf("error cleaning up logs for user %s: %v", username, err)
}
})
......@@ -146,7 +139,7 @@ func newUserlogDB(db *sql.DB, compactionKeepCount, maxAgeDays int) (*userlogDB,
}
if maxAgeDays > 0 {
udb.cronjobs = append(udb.cronjobs, newCron(func() {
if err := pruneLogs(db, stmts, maxAgeDays); err != nil {
if err := pruneLogs(db, maxAgeDays); err != nil {
log.Printf("error in log pruning: %v", err)
}
}, pruneInterval))
......@@ -159,7 +152,6 @@ func (u *userlogDB) Close() {
for _, j := range u.cronjobs {
j.Stop()
}
u.stmts.Close()
}
// Update or create an entry in the 'devices' table.
......@@ -167,13 +159,10 @@ func (u *userlogDB) updateDeviceInfo(tx *sql.Tx, username string, deviceInfo *us
now := roundTimestamp(time.Now().UTC())
var seen bool
stmt := u.stmts.get(tx, "check_device_info")
defer stmt.Close()
err := stmt.QueryRow(username, deviceInfo.ID).Scan(&seen)
err := tx.QueryRow(userlogDBStatements["check_device_info"], username, deviceInfo.ID).Scan(&seen)
if err == sql.ErrNoRows {
insStmt := u.stmts.get(tx, "insert_device_info")
defer insStmt.Close()
_, err = insStmt.Exec(
_, err = tx.Exec(
userlogDBStatements["insert_device_info"],
username,
deviceInfo.ID,
deviceInfo.RemoteZone,
......@@ -191,9 +180,8 @@ func (u *userlogDB) updateDeviceInfo(tx *sql.Tx, username string, deviceInfo *us
// legitimately happen that the other parameters
// change too (like when restoring a backup to another
// platform/os), so we update them too.
insStmt := u.stmts.get(tx, "update_device_info")
defer insStmt.Close()
_, err = insStmt.Exec(
_, err = tx.Exec(
userlogDBStatements["update_device_info"],
now,
deviceInfo.RemoteZone,
deviceInfo.UserAgent,
......@@ -236,18 +224,14 @@ func (u *userlogDB) AddLog(ctx context.Context, entry *usermetadb.LogEntry) erro
args = append(args, entry.DeviceInfo.Mobile)
}
if err := retryBusy(ctx, func() error {
return withTX(u.db, func(tx *sql.Tx) error {
if entry.DeviceInfo != nil {
if err := u.updateDeviceInfo(tx, entry.Username, entry.DeviceInfo); err != nil {
return err
}
if err := withTX(u.db, func(tx *sql.Tx) error {
if entry.DeviceInfo != nil {
if err := u.updateDeviceInfo(tx, entry.Username, entry.DeviceInfo); err != nil {
return err
}
stmt := u.stmts.get(tx, stmtName)
defer stmt.Close()
_, err := stmt.Exec(args...)
return err
})
}
_, err := tx.Exec(userlogDBStatements[stmtName], args...)
return err
}); err != nil {
return err
}
......@@ -270,9 +254,7 @@ func (u *userlogDB) GetUserLogs(ctx context.Context, username string, maxDays, l
var out []*usermetadb.LogEntry
err := withReadonlyTX(u.db, func(tx *sql.Tx) error {
stmt := u.stmts.get(tx, "get_user_logs")
defer stmt.Close()
rows, err := stmt.Query(username, cutoff, limit)
rows, err := tx.Query(userlogDBStatements["get_user_logs"], username, cutoff, limit)
if err != nil {
return err
}
......@@ -310,9 +292,7 @@ func (u *userlogDB) GetUserLogs(ctx context.Context, username string, maxDays, l
func (u *userlogDB) GetUserDevices(ctx context.Context, username string) ([]*usermetadb.MetaDeviceInfo, error) {
var out []*usermetadb.MetaDeviceInfo
err := withReadonlyTX(u.db, func(tx *sql.Tx) error {
stmt := u.stmts.get(tx, "devices_with_counts")
defer stmt.Close()
rows, err := stmt.Query(username)
rows, err := tx.Query(userlogDBStatements["devices_with_counts"], username)
if err != nil {
return err
}
......
......@@ -84,19 +84,20 @@ func generateTestLogs(n int, userDevices map[string][]*usermetadb.DeviceInfo) []
}
func bulkLoadTestLogs(t testing.TB, db *sql.DB) *usermetadb.LogEntry {
tx, _ := db.Begin()
tx, err := db.Begin()
if err != nil {
t.Fatalf("tx.Begin: %v", err)
}
stmt, err := db.Prepare(userlogDBStatements["insert_userlog_with_device_info"])
insertLogStmt, err := tx.Prepare(userlogDBStatements["insert_userlog_with_device_info"])
if err != nil {
t.Fatalf("insert_userlog: %v", err)
}
insertLogStmt := tx.Stmt(stmt)
defer insertLogStmt.Close()
stmt, err = db.Prepare(userlogDBStatements["insert_device_info"])
insertDeviceInfoStmt, err := tx.Prepare(userlogDBStatements["insert_device_info"])
if err != nil {
t.Fatalf("insert_device_info: %v", err)
}
insertDeviceInfoStmt := tx.Stmt(stmt)
defer insertDeviceInfoStmt.Close()
userDevices := generateAllRandomDevices()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment