Skip to content
Snippets Groups Projects
db.go 7.33 KiB
package main

import (
	"database/sql"
	"fmt"
	"strings"
	"time"

	"git.autistici.org/ai3/go-common/sqlutil"
)

type location struct {
	IP              string `json:"ip,omitempty"`
	City            string `json:"city,omitempty"`
	Country         string `json:"country,omitempty"`
	ASN             string `json:"asn,omitempty"`
	ASNOrganization string `json:"asn_org,omitempty"`
}

type session struct {
	ID                string
	BootTime          int64
	User              string
	LocalUser         string
	SSHKeyFingerprint string
	StartedAt         time.Time
	EndedAt           time.Time
	EndedReboot       bool
	Location          location
}

// A session is active if it lacks an EndedAt timestamp.
func (s *session) isActive(bTime int64) bool {
	return s.EndedAt.IsZero() && s.BootTime == bTime
}

const (
	entryTypeOpen = iota
	entryTypeClose
)

func (s *session) logEntry(entryType int) *logEntry {
	l := &logEntry{
		ID:                s.ID,
		BootTime:          s.BootTime,
		User:              s.User,
		LocalUser:         s.LocalUser,
		SSHKeyFingerprint: s.SSHKeyFingerprint,
		Location:          s.Location,
	}
	switch entryType {
	case entryTypeOpen:
		l.Type = "ssh_open_session"
	case entryTypeClose:
		l.Type = "ssh_close_session"
		l.Duration = int(s.EndedAt.Sub(s.StartedAt).Seconds())
	}
	return l
}

const timeFmt = "Mon Jan _2 15:04"

func formatLocation(loc *location) string {
	var locAdd []string
	if loc.City != "" {
		locAdd = append(locAdd, loc.City)
	}
	if loc.Country != "" {
		locAdd = append(locAdd, loc.Country)
	}
	if loc.ASN != "" {
		locAdd = append(locAdd, fmt.Sprintf("%s - %s", loc.ASN, loc.ASNOrganization))
	}

	if len(locAdd) > 0 {
		if loc.IP != "" {
			return fmt.Sprintf("%s (%s)", loc.IP, strings.Join(locAdd, ", "))
		}
		return strings.Join(locAdd, ", ")
	}
	return ""
}

func (s *session) formatString(bTime int64) string {
	locStr := formatLocation(&s.Location)

	if s.isActive(bTime) {
		return fmt.Sprintf(
			"%-15s %-15s %-30s %-30s %-16s   still logged in",
			s.User,
			s.LocalUser,
			s.SSHKeyFingerprint,
			locStr,
			s.StartedAt.Format(timeFmt),
		)
	}

	var durationStr string
	if s.BootTime != bTime {
		durationStr = "reboot"
	} else {
		// Round duration to nearest second.
		duration := s.EndedAt.Sub(s.StartedAt)
		duration = time.Duration(duration.Seconds()) * time.Second
		durationStr = fmt.Sprintf("%-16s (%s)", s.EndedAt.Format(timeFmt), duration)
	}

	return fmt.Sprintf(
		"%-15s %-15s %-30s %-30s %-16s - %s",
		s.User,
		s.LocalUser,
		s.SSHKeyFingerprint,
		locStr,
		s.StartedAt.Format(timeFmt),
		durationStr,
	)
}

type logEntry struct {
	Type              string   `json:"log_type"`
	ID                string   `json:"session_id"`
	BootTime          int64    `json:"btime"`
	User              string   `json:"user"`
	LocalUser         string   `json:"local_user"`
	SSHKeyFingerprint string   `json:"ssh_key_fp"`
	Location          location `json:"location"`
	Duration          int      `json:"duration,omitempty"`
}

// The table does not have a primary key (except for the implicit
// rowid), because we can't guarantee the uniqueness of session_id
// except for knowing that there can't be two "open" sessions with the
// same ID at the same time.
var migrations = []func(*sql.Tx) error{
	sqlutil.Statement(`
CREATE TABLE sessions (
    session_id SMALLTEXT NOT NULL,
    btime INTEGER NOT NULL,
    user SMALLTEXT NOT NULL,
    local_user SMALLTEXT NOT NULL,
    ssh_key_fp TEXT NOT NULL,
    started_at DATETIME NOT NULL,
    ended_at DATETIME,
    ended_reboot BOOLEAN,
    location_ip SMALLTEXT,
    location_country SMALLTEXT,
    location_city SMALLTEXT
);
`, `
CREATE INDEX idx_sessions_id_btime ON sessions(session_id, btime)
`, `
CREATE INDEX idx_sessions_user ON sessions(user)
`),
	sqlutil.Statement(`
ALTER TABLE sessions ADD COLUMN location_asn SMALLTEXT
`, `
ALTER TABLE sessions ADD COLUMN location_asn_org SMALLTEXT
`),
}

func openDB(path string) (*sql.DB, error) {
	return sqlutil.OpenDB(
		path, sqlutil.WithMigrations(migrations))
}

func createSession(tx *sql.Tx, s *session) error {
	_, err := tx.Exec(`
		INSERT INTO sessions
		  (session_id, btime, user, local_user, ssh_key_fp, started_at,
		   location_ip, location_country, location_city, location_asn, location_asn_org)
		VALUES
		  (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
		s.ID, s.BootTime, s.User, s.LocalUser, s.SSHKeyFingerprint, s.StartedAt,
		s.Location.IP, s.Location.Country, s.Location.City, s.Location.ASN, s.Location.ASNOrganization,
	)
	return err
}

func closeSession(tx *sql.Tx, s *session, isReboot bool) error {
	s.EndedAt = time.Now()
	s.EndedReboot = isReboot

	_, err := tx.Exec(`
		UPDATE sessions SET
		  ended_at=?, ended_reboot=?
		WHERE
		  session_id = ? AND btime = ? AND ended_at IS NULL
`,
		s.EndedAt, s.EndedReboot, s.ID, s.BootTime,
	)
	return err
}

// Matches both sql.Row and sql.Rows.
type scanSource interface {
	Scan(...any) error
}

// Field list must match the src.Scan() call in scanSessionRow.
var sessionFields = []string{
	"session_id", "btime", "user", "local_user", "ssh_key_fp", "started_at", "ended_at", "ended_reboot",
	"location_ip", "location_country", "location_city", "location_asn", "location_asn_org",
}

func scanSessionQuery(where string) string {
	return fmt.Sprintf("SELECT %s FROM sessions %s", strings.Join(sessionFields, ", "), where)
}

func scanSessionRow(src scanSource) (*session, error) {
	var (
		s                          session
		endedAt                    sql.NullTime
		endedReboot                sql.NullBool
		locCity, locCountry, locIP sql.NullString
		locASN, locASNOrg          sql.NullString
	)

	if err := src.Scan(
		&s.ID, &s.BootTime, &s.User, &s.LocalUser, &s.SSHKeyFingerprint, &s.StartedAt,
		&endedAt, &endedReboot, &locIP, &locCountry, &locCity, &locASN, &locASNOrg); err != nil {
		return nil, err
	}

	if endedAt.Valid {
		s.EndedAt = endedAt.Time
	}
	if endedReboot.Valid {
		s.EndedReboot = endedReboot.Bool
	}
	if locIP.Valid {
		s.Location.IP = locIP.String
	}
	if locCity.Valid {
		s.Location.City = locCity.String
	}
	if locCountry.Valid {
		s.Location.Country = locCountry.String
	}
	if locASN.Valid {
		s.Location.ASN = locASN.String
	}
	if locASNOrg.Valid {
		s.Location.ASNOrganization = locASNOrg.String
	}

	return &s, nil
}

func scanSessionsRows(rows *sql.Rows) ([]*session, error) {
	var out []*session
	for rows.Next() {
		s, err := scanSessionRow(rows)
		if err != nil {
			return nil, err
		}
		out = append(out, s)
	}

	return out, rows.Err()
}

func getOpenSession(tx *sql.Tx, sessionID string, bTime int64) (*session, error) {
	return scanSessionRow(tx.QueryRow(scanSessionQuery(`
		WHERE session_id = ? AND btime = ? AND ended_at IS NULL
`),
		sessionID, bTime))
}

func getLatestSessions(tx *sql.Tx, n int) ([]*session, error) {
	rows, err := tx.Query(scanSessionQuery(`
		ORDER BY started_at DESC LIMIT ?
`),
		n,
	)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	return scanSessionsRows(rows)
}

func getUserSessions(tx *sql.Tx, user string) ([]*session, error) {
	rows, err := tx.Query(scanSessionQuery(`
		WHERE user = ?
		ORDER BY started_at DESC
`),
		user,
	)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	return scanSessionsRows(rows)
}

func getPendingSessions(tx *sql.Tx, bTime int64) ([]*session, error) {
	rows, err := tx.Query(scanSessionQuery(`
		WHERE ended_at IS NULL AND btime != ?
`),
		bTime,
	)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	return scanSessionsRows(rows)
}