Skip to content
Snippets Groups Projects
db.go 6.54 KiB
Newer Older
  • Learn to ignore specific revisions
  • ale's avatar
    ale committed
    package main
    
    import (
    	"database/sql"
    	"fmt"
    	"strings"
    	"time"
    
    	"git.autistici.org/ai3/go-common/sqlutil"
    )
    
    type location struct {
    	IP      string `json:"ip"`
    	City    string `json:"city,omitempty"`
    	Country string `json:"country,omitempty"`
    }
    
    type session struct {
    	ID                string
    	BootTime          int64
    	User              string
    	LocalUser         string
    	SSHKeyFingerprint string
    	StartedAt         time.Time
    	EndedAt           time.Time
    	EndedReboot       bool
    	Location          location
    }
    
    // A session is active if it lacks an EndedAt timestamp.
    func (s *session) isActive(bTime int64) bool {
    	return s.EndedAt.IsZero() && s.BootTime == bTime
    }
    
    const (
    	entryTypeOpen = iota
    	entryTypeClose
    )
    
    func (s *session) logEntry(entryType int) *logEntry {
    	l := &logEntry{
    		ID:                s.ID,
    		BootTime:          s.BootTime,
    		User:              s.User,
    		LocalUser:         s.LocalUser,
    		SSHKeyFingerprint: s.SSHKeyFingerprint,
    		Location:          s.Location,
    	}
    	switch entryType {
    	case entryTypeOpen:
    		l.Type = "ssh_open_session"
    	case entryTypeClose:
    		l.Type = "ssh_close_session"
    		l.Duration = int(s.EndedAt.Sub(s.StartedAt).Seconds())
    	}
    	return l
    }
    
    const timeFmt = "Mon Jan _2 15:04"
    
    func (s *session) formatString(bTime int64) string {
    	locStr := s.Location.IP
    	var locAdd []string
    	if s.Location.City != "" {
    		locAdd = append(locAdd, s.Location.City)
    	}
    	if s.Location.Country != "" {
    		locAdd = append(locAdd, s.Location.Country)
    	}
    	if len(locAdd) > 0 {
    		locStr += fmt.Sprintf(" (%s)", strings.Join(locAdd, ", "))
    	}
    
    	if s.isActive(bTime) {
    		return fmt.Sprintf(
    			"%-15s %-15s %-30s %-30s %-16s   still logged in",
    			s.User,
    			s.LocalUser,
    			s.SSHKeyFingerprint,
    			locStr,
    			s.StartedAt.Format(timeFmt),
    		)
    	}
    
    	var durationStr string
    	if s.BootTime != bTime {
    		durationStr = "reboot"
    	} else {
    		// Round duration to nearest second.
    		duration := s.EndedAt.Sub(s.StartedAt)
    		duration = time.Duration(duration.Seconds()) * time.Second
    		durationStr = fmt.Sprintf("%-16s (%s)", s.EndedAt.Format(timeFmt), duration)
    	}
    
    	return fmt.Sprintf(
    		"%-15s %-15s %-30s %-30s %-16s - %s",
    		s.User,
    		s.LocalUser,
    		s.SSHKeyFingerprint,
    		locStr,
    		s.StartedAt.Format(timeFmt),
    		durationStr,
    	)
    }
    
    type logEntry struct {
    	Type              string   `json:"log_type"`
    	ID                string   `json:"session_id"`
    	BootTime          int64    `json:"btime"`
    	User              string   `json:"user"`
    	LocalUser         string   `json:"local_user"`
    	SSHKeyFingerprint string   `json:"ssh_key_fp"`
    	Location          location `json:"location"`
    	Duration          int      `json:"duration,omitempty"`
    }
    
    // The table does not have a primary key (except for the implicit
    // rowid), because we can't guarantee the uniqueness of session_id
    // except for knowing that there can't be two "open" sessions with the
    // same ID at the same time.
    var migrations = []func(*sql.Tx) error{
    	sqlutil.Statement(`
    CREATE TABLE sessions (
        session_id SMALLTEXT NOT NULL,
        btime INTEGER NOT NULL,
        user SMALLTEXT NOT NULL,
        local_user SMALLTEXT NOT NULL,
        ssh_key_fp TEXT NOT NULL,
        started_at DATETIME NOT NULL,
        ended_at DATETIME,
        ended_reboot BOOLEAN,
        location_ip SMALLTEXT,
        location_country SMALLTEXT,
        location_city SMALLTEXT
    );
    `, `
    CREATE INDEX idx_sessions_id_btime ON sessions(session_id, btime)
    `, `
    CREATE INDEX idx_sessions_user ON sessions(user)
    `),
    }
    
    func openDB(path string) (*sql.DB, error) {
    	return sqlutil.OpenDB(
    		path, sqlutil.WithMigrations(migrations))
    }
    
    func createSession(tx *sql.Tx, s *session) error {
    	_, err := tx.Exec(`
    		INSERT INTO sessions
    		  (session_id, btime, user, local_user, ssh_key_fp, started_at, location_ip, location_country, location_city)
    		VALUES
    		  (?, ?, ?, ?, ?, ?, ?, ?, ?)
    `,
    		s.ID, s.BootTime, s.User, s.LocalUser, s.SSHKeyFingerprint, s.StartedAt, s.Location.IP, s.Location.Country, s.Location.City,
    	)
    	return err
    }
    
    func closeSession(tx *sql.Tx, s *session, isReboot bool) error {
    	s.EndedAt = time.Now()
    	s.EndedReboot = isReboot
    
    	_, err := tx.Exec(`
    		UPDATE sessions SET
    		  ended_at=?, ended_reboot=?
    		WHERE
    		  session_id = ? AND btime = ? AND ended_at IS NULL
    `,
    		s.EndedAt, s.EndedReboot, s.ID, s.BootTime,
    	)
    	return err
    }
    
    // Matches both sql.Row and sql.Rows.
    type scanSource interface {
    	Scan(...any) error
    }
    
    // Field list must match the src.Scan() call in scanSessionRow.
    var sessionFields = []string{
    	"session_id", "btime", "user", "local_user", "ssh_key_fp", "started_at", "ended_at", "ended_reboot", "location_ip", "location_country", "location_city",
    }
    
    func scanSessionQuery(where string) string {
    	return fmt.Sprintf("SELECT %s FROM sessions %s", strings.Join(sessionFields, ", "), where)
    }
    
    func scanSessionRow(src scanSource) (*session, error) {
    	var (
    		s                          session
    		endedAt                    sql.NullTime
    		endedReboot                sql.NullBool
    		locCity, locCountry, locIP sql.NullString
    	)
    
    	if err := src.Scan(
    		&s.ID, &s.BootTime, &s.User, &s.LocalUser, &s.SSHKeyFingerprint, &s.StartedAt,
    		&endedAt, &endedReboot, &locIP, &locCountry, &locCity); err != nil {
    		return nil, err
    	}
    
    	if endedAt.Valid {
    		s.EndedAt = endedAt.Time
    	}
    	if endedReboot.Valid {
    		s.EndedReboot = endedReboot.Bool
    	}
    	if locIP.Valid {
    		s.Location.IP = locIP.String
    	}
    	if locCity.Valid {
    		s.Location.City = locCity.String
    	}
    	if locCountry.Valid {
    		s.Location.Country = locCountry.String
    	}
    
    	return &s, nil
    }
    
    func scanSessionsRows(rows *sql.Rows) ([]*session, error) {
    	var out []*session
    	for rows.Next() {
    		s, err := scanSessionRow(rows)
    		if err != nil {
    			return nil, err
    		}
    		out = append(out, s)
    	}
    
    	return out, rows.Err()
    }
    
    func getOpenSession(tx *sql.Tx, sessionID string, bTime int64) (*session, error) {
    	return scanSessionRow(tx.QueryRow(scanSessionQuery(`
    		WHERE session_id = ? AND btime = ? AND ended_at IS NULL
    `),
    		sessionID, bTime))
    }
    
    func getLatestSessions(tx *sql.Tx, n int) ([]*session, error) {
    	rows, err := tx.Query(scanSessionQuery(`
    		ORDER BY started_at DESC LIMIT ?
    `),
    		n,
    	)
    	if err != nil {
    		return nil, err
    	}
    	defer rows.Close()
    
    	return scanSessionsRows(rows)
    }
    
    func getUserSessions(tx *sql.Tx, user string) ([]*session, error) {
    	rows, err := tx.Query(scanSessionQuery(`
    		WHERE user = ?
    		ORDER BY started_at DESC
    `),
    		user,
    	)
    	if err != nil {
    		return nil, err
    	}
    	defer rows.Close()
    
    	return scanSessionsRows(rows)
    }
    
    func getPendingSessions(tx *sql.Tx, bTime int64) ([]*session, error) {
    	rows, err := tx.Query(scanSessionQuery(`
    		WHERE ended_at IS NULL AND btime != ?
    `),
    		bTime,
    	)
    	if err != nil {
    		return nil, err
    	}
    	defer rows.Close()
    
    	return scanSessionsRows(rows)
    }