db.go 7.33 KiB
package main
import (
"database/sql"
"fmt"
"strings"
"time"
"git.autistici.org/ai3/go-common/sqlutil"
)
type location struct {
IP string `json:"ip,omitempty"`
City string `json:"city,omitempty"`
Country string `json:"country,omitempty"`
ASN string `json:"asn,omitempty"`
ASNOrganization string `json:"asn_org,omitempty"`
}
type session struct {
ID string
BootTime int64
User string
LocalUser string
SSHKeyFingerprint string
StartedAt time.Time
EndedAt time.Time
EndedReboot bool
Location location
}
// A session is active if it lacks an EndedAt timestamp.
func (s *session) isActive(bTime int64) bool {
return s.EndedAt.IsZero() && s.BootTime == bTime
}
const (
entryTypeOpen = iota
entryTypeClose
)
func (s *session) logEntry(entryType int) *logEntry {
l := &logEntry{
ID: s.ID,
BootTime: s.BootTime,
User: s.User,
LocalUser: s.LocalUser,
SSHKeyFingerprint: s.SSHKeyFingerprint,
Location: s.Location,
}
switch entryType {
case entryTypeOpen:
l.Type = "ssh_open_session"
case entryTypeClose:
l.Type = "ssh_close_session"
l.Duration = int(s.EndedAt.Sub(s.StartedAt).Seconds())
}
return l
}
const timeFmt = "Mon Jan _2 15:04"
func formatLocation(loc *location) string {
var locAdd []string
if loc.City != "" {
locAdd = append(locAdd, loc.City)
}
if loc.Country != "" {
locAdd = append(locAdd, loc.Country)
}
if loc.ASN != "" {
locAdd = append(locAdd, fmt.Sprintf("%s - %s", loc.ASN, loc.ASNOrganization))
}
if len(locAdd) > 0 {
if loc.IP != "" {
return fmt.Sprintf("%s (%s)", loc.IP, strings.Join(locAdd, ", "))
}
return strings.Join(locAdd, ", ")
}
return ""
}
func (s *session) formatString(bTime int64) string {
locStr := formatLocation(&s.Location)
if s.isActive(bTime) {
return fmt.Sprintf(
"%-15s %-15s %-30s %-30s %-16s still logged in",
s.User,
s.LocalUser,
s.SSHKeyFingerprint,
locStr,
s.StartedAt.Format(timeFmt),
)
}
var durationStr string
if s.BootTime != bTime {
durationStr = "reboot"
} else {
// Round duration to nearest second.
duration := s.EndedAt.Sub(s.StartedAt)
duration = time.Duration(duration.Seconds()) * time.Second
durationStr = fmt.Sprintf("%-16s (%s)", s.EndedAt.Format(timeFmt), duration)
}
return fmt.Sprintf(
"%-15s %-15s %-30s %-30s %-16s - %s",
s.User,
s.LocalUser,
s.SSHKeyFingerprint,
locStr,
s.StartedAt.Format(timeFmt),
durationStr,
)
}
type logEntry struct {
Type string `json:"log_type"`
ID string `json:"session_id"`
BootTime int64 `json:"btime"`
User string `json:"user"`
LocalUser string `json:"local_user"`
SSHKeyFingerprint string `json:"ssh_key_fp"`
Location location `json:"location"`
Duration int `json:"duration,omitempty"`
}
// The table does not have a primary key (except for the implicit
// rowid), because we can't guarantee the uniqueness of session_id
// except for knowing that there can't be two "open" sessions with the
// same ID at the same time.
var migrations = []func(*sql.Tx) error{
sqlutil.Statement(`
CREATE TABLE sessions (
session_id SMALLTEXT NOT NULL,
btime INTEGER NOT NULL,
user SMALLTEXT NOT NULL,
local_user SMALLTEXT NOT NULL,
ssh_key_fp TEXT NOT NULL,
started_at DATETIME NOT NULL,
ended_at DATETIME,
ended_reboot BOOLEAN,
location_ip SMALLTEXT,
location_country SMALLTEXT,
location_city SMALLTEXT
);
`, `
CREATE INDEX idx_sessions_id_btime ON sessions(session_id, btime)
`, `
CREATE INDEX idx_sessions_user ON sessions(user)
`),
sqlutil.Statement(`
ALTER TABLE sessions ADD COLUMN location_asn SMALLTEXT
`, `
ALTER TABLE sessions ADD COLUMN location_asn_org SMALLTEXT
`),
}
func openDB(path string) (*sql.DB, error) {
return sqlutil.OpenDB(
path, sqlutil.WithMigrations(migrations))
}
func createSession(tx *sql.Tx, s *session) error {
_, err := tx.Exec(`
INSERT INTO sessions
(session_id, btime, user, local_user, ssh_key_fp, started_at,
location_ip, location_country, location_city, location_asn, location_asn_org)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
s.ID, s.BootTime, s.User, s.LocalUser, s.SSHKeyFingerprint, s.StartedAt,
s.Location.IP, s.Location.Country, s.Location.City, s.Location.ASN, s.Location.ASNOrganization,
)
return err
}
func closeSession(tx *sql.Tx, s *session, isReboot bool) error {
s.EndedAt = time.Now()
s.EndedReboot = isReboot
_, err := tx.Exec(`
UPDATE sessions SET
ended_at=?, ended_reboot=?
WHERE
session_id = ? AND btime = ? AND ended_at IS NULL
`,
s.EndedAt, s.EndedReboot, s.ID, s.BootTime,
)
return err
}
// Matches both sql.Row and sql.Rows.
type scanSource interface {
Scan(...any) error
}
// Field list must match the src.Scan() call in scanSessionRow.
var sessionFields = []string{
"session_id", "btime", "user", "local_user", "ssh_key_fp", "started_at", "ended_at", "ended_reboot",
"location_ip", "location_country", "location_city", "location_asn", "location_asn_org",
}
func scanSessionQuery(where string) string {
return fmt.Sprintf("SELECT %s FROM sessions %s", strings.Join(sessionFields, ", "), where)
}
func scanSessionRow(src scanSource) (*session, error) {
var (
s session
endedAt sql.NullTime
endedReboot sql.NullBool
locCity, locCountry, locIP sql.NullString
locASN, locASNOrg sql.NullString
)
if err := src.Scan(
&s.ID, &s.BootTime, &s.User, &s.LocalUser, &s.SSHKeyFingerprint, &s.StartedAt,
&endedAt, &endedReboot, &locIP, &locCountry, &locCity, &locASN, &locASNOrg); err != nil {
return nil, err
}
if endedAt.Valid {
s.EndedAt = endedAt.Time
}
if endedReboot.Valid {
s.EndedReboot = endedReboot.Bool
}
if locIP.Valid {
s.Location.IP = locIP.String
}
if locCity.Valid {
s.Location.City = locCity.String
}
if locCountry.Valid {
s.Location.Country = locCountry.String
}
if locASN.Valid {
s.Location.ASN = locASN.String
}
if locASNOrg.Valid {
s.Location.ASNOrganization = locASNOrg.String
}
return &s, nil
}
func scanSessionsRows(rows *sql.Rows) ([]*session, error) {
var out []*session
for rows.Next() {
s, err := scanSessionRow(rows)
if err != nil {
return nil, err
}
out = append(out, s)
}
return out, rows.Err()
}
func getOpenSession(tx *sql.Tx, sessionID string, bTime int64) (*session, error) {
return scanSessionRow(tx.QueryRow(scanSessionQuery(`
WHERE session_id = ? AND btime = ? AND ended_at IS NULL
`),
sessionID, bTime))
}
func getLatestSessions(tx *sql.Tx, n int) ([]*session, error) {
rows, err := tx.Query(scanSessionQuery(`
ORDER BY started_at DESC LIMIT ?
`),
n,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessionsRows(rows)
}
func getUserSessions(tx *sql.Tx, user string) ([]*session, error) {
rows, err := tx.Query(scanSessionQuery(`
WHERE user = ?
ORDER BY started_at DESC
`),
user,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessionsRows(rows)
}
func getPendingSessions(tx *sql.Tx, bTime int64) ([]*session, error) {
rows, err := tx.Query(scanSessionQuery(`
WHERE ended_at IS NULL AND btime != ?
`),
bTime,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessionsRows(rows)
}