diff --git a/db/sqlite/driver.go b/db/sqlite/driver.go index 2b00d00c70d3a94ea2f5e5f931d3386537a0c334..39b006f2a7fadec42fda3cde13b655fcbac1ec9c 100644 --- a/db/sqlite/driver.go +++ b/db/sqlite/driver.go @@ -8,26 +8,8 @@ import ( ippb "git.autistici.org/ai3/tools/iprep/proto" ) -var statements = map[string]string{ - "insert_event": ` -INSERT INTO events (ip, event_type, count, timestamp) VALUES (?, ?, ?, ?) -`, - "scan_by_ip": ` -SELECT - event_type, SUM(count) AS sum -FROM - events -WHERE - ip = ? - AND timestamp > ? -GROUP BY - event_type -`, -} - type DB struct { - db *sql.DB - stmts StatementMap + db *sql.DB } func Open(path string) (*DB, error) { @@ -40,15 +22,8 @@ func Open(path string) (*DB, error) { return nil, err } - stmts, err := NewStatementMap(db, statements) - if err != nil { - db.Close() - return nil, err - } - return &DB{ - db: db, - stmts: stmts, + db: db, }, nil } @@ -61,7 +36,14 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error { if err != nil { return err } - stmt := db.stmts.Get(tx, "insert_event") + + stmt, err := tx.Prepare( + "INSERT INTO events (ip, event_type, count, timestamp) VALUES (?, ?, ?, ?)") + if err != nil { + tx.Rollback() + return err + } + now := time.Now() for _, bt := range aggr.ByType { @@ -90,7 +72,11 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) { } defer tx.Rollback() // nolint - rows, err := db.stmts.Get(tx, "scan_by_ip").Query(ip, startTime) + rows, err := tx.Query( + `SELECT event_type, SUM(count) AS sum FROM events + WHERE ip = ? AND timestamp > ? + GROUP BY event_type`, + ip, startTime) if err != nil { return nil, err } diff --git a/db/sqlite/sql.go b/db/sqlite/sql.go index fcbb7acd4c08f992628525eeacb536435cc22319..9110d61719628962e0a2855e69cfc5a16231f5e2 100644 --- a/db/sqlite/sql.go +++ b/db/sqlite/sql.go @@ -3,7 +3,6 @@ package sqlite import ( "database/sql" - "fmt" "log" migrate "github.com/golang-migrate/migrate/v4" @@ -71,33 +70,3 @@ func runDatabaseMigrations(db *sql.DB) error { } return nil } - -// A StatementMap holds named compiled statements. -type StatementMap map[string]*sql.Stmt - -// NewStatementMap compiles the given named statements and returns a -// new StatementMap. -func NewStatementMap(db *sql.DB, statements map[string]string) (StatementMap, error) { - stmts := make(map[string]*sql.Stmt) - for name, qstr := range statements { - stmt, err := db.Prepare(qstr) - if err != nil { - return nil, fmt.Errorf("error compiling statement '%s': %v", name, err) - } - stmts[name] = stmt - } - return StatementMap(stmts), nil -} - -// Close all resources associated with the StatementMap. -func (m StatementMap) Close() { - for _, s := range m { - s.Close() // nolint - } -} - -// Get a named compiled statement (or nil if not found), and associate -// it with the given transaction. -func (m StatementMap) Get(tx *sql.Tx, name string) *sql.Stmt { - return tx.Stmt(m[name]) -}