Skip to content
Snippets Groups Projects
Commit 9cc39d18 authored by ale's avatar ale
Browse files

Do not unnecessarily prepare SQL statements

parent 87570256
No related branches found
No related tags found
No related merge requests found
......@@ -8,26 +8,8 @@ import (
ippb "git.autistici.org/ai3/tools/iprep/proto"
)
var statements = map[string]string{
"insert_event": `
INSERT INTO events (ip, event_type, count, timestamp) VALUES (?, ?, ?, ?)
`,
"scan_by_ip": `
SELECT
event_type, SUM(count) AS sum
FROM
events
WHERE
ip = ?
AND timestamp > ?
GROUP BY
event_type
`,
}
type DB struct {
db *sql.DB
stmts StatementMap
}
func Open(path string) (*DB, error) {
......@@ -40,15 +22,8 @@ func Open(path string) (*DB, error) {
return nil, err
}
stmts, err := NewStatementMap(db, statements)
if err != nil {
db.Close()
return nil, err
}
return &DB{
db: db,
stmts: stmts,
}, nil
}
......@@ -61,7 +36,14 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error {
if err != nil {
return err
}
stmt := db.stmts.Get(tx, "insert_event")
stmt, err := tx.Prepare(
"INSERT INTO events (ip, event_type, count, timestamp) VALUES (?, ?, ?, ?)")
if err != nil {
tx.Rollback()
return err
}
now := time.Now()
for _, bt := range aggr.ByType {
......@@ -90,7 +72,11 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) {
}
defer tx.Rollback() // nolint
rows, err := db.stmts.Get(tx, "scan_by_ip").Query(ip, startTime)
rows, err := tx.Query(
`SELECT event_type, SUM(count) AS sum FROM events
WHERE ip = ? AND timestamp > ?
GROUP BY event_type`,
ip, startTime)
if err != nil {
return nil, err
}
......
......@@ -3,7 +3,6 @@ package sqlite
import (
"database/sql"
"fmt"
"log"
migrate "github.com/golang-migrate/migrate/v4"
......@@ -71,33 +70,3 @@ func runDatabaseMigrations(db *sql.DB) error {
}
return nil
}
// A StatementMap holds named compiled statements.
type StatementMap map[string]*sql.Stmt
// NewStatementMap compiles the given named statements and returns a
// new StatementMap.
func NewStatementMap(db *sql.DB, statements map[string]string) (StatementMap, error) {
stmts := make(map[string]*sql.Stmt)
for name, qstr := range statements {
stmt, err := db.Prepare(qstr)
if err != nil {
return nil, fmt.Errorf("error compiling statement '%s': %v", name, err)
}
stmts[name] = stmt
}
return StatementMap(stmts), nil
}
// Close all resources associated with the StatementMap.
func (m StatementMap) Close() {
for _, s := range m {
s.Close() // nolint
}
}
// Get a named compiled statement (or nil if not found), and associate
// it with the given transaction.
func (m StatementMap) Get(tx *sql.Tx, name string) *sql.Stmt {
return tx.Stmt(m[name])
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment