main.go 3.99 KiB
package main
import (
"context"
"database/sql"
"encoding/json"
"errors"
"flag"
"fmt"
"log"
"log/syslog"
"os"
"time"
"git.autistici.org/ai3/go-common/sqlutil"
"golang.org/x/crypto/ssh"
)
var (
bootFlag = flag.Bool("boot", false, "cleanup wtmp entries (run once at boot)")
dbPath = flag.String("db", "/var/lib/ssh-key-wtmp/wtmp.db", "`path` to the sqlite wtmp database")
)
var slog *log.Logger = log.Default()
func handleOpen(tx *sql.Tx, bTime int64, sid string, sshKey ssh.PublicKey, user, localUser string) error {
s := &session{
ID: sid,
BootTime: bTime,
User: user,
LocalUser: localUser,
SSHKeyFingerprint: ssh.FingerprintSHA256(sshKey),
StartedAt: time.Now(),
}
if err := augmentLocation(kLocationConfig, envGetRemoteIP(), &s.Location); err != nil {
slog.Printf("location error: %v", err)
}
emit(s.logEntry(entryTypeOpen))
if err := createSession(tx, s); err != nil {
return err
}
return nil
}
func handleClose(tx *sql.Tx, bTime int64, sid string) error {
s, err := getOpenSession(tx, sid, bTime)
if err != nil {
return fmt.Errorf("could not retrieve existing open session %s", sid)
}
if err := closeSession(tx, s, false); err != nil {
return fmt.Errorf("could not save session: %w", err)
}
emit(s.logEntry(entryTypeClose))
return nil
}
func withDB(f func(tx *sql.Tx) error) error {
db, err := openDB(*dbPath)
if err != nil {
return err
}
defer db.Close()
return sqlutil.WithTx(context.Background(), db, f)
}
func cleanupAfterReboot(bTime int64) error {
return withDB(func(tx *sql.Tx) error {
pending, err := getPendingSessions(tx, bTime)
if err != nil {
return err
}
for _, s := range pending {
if err := closeSession(tx, s, true); err != nil {
return err
}
}
return nil
})
}
func run(bTime int64) error {
sid, err := envGetSessionID()
if err != nil {
return err
}
switch getenv("PAM_TYPE") {
case "open_session":
sshKey, err := envGetSSHAuthInfo()
if errors.Is(err, errNonPublicKeyLogin) {
// Do nothing.
return nil
} else if err != nil {
return err
}
localUser, err := envGetUser()
if err != nil {
return err
}
keyPaths, err := userAuthorizedKeysFile(localUser)
if err != nil {
return err
}
_, user, err := findKeyInAuthorizedKeys(keyPaths, sshKey)
if err != nil {
return err
}
return withDB(func(tx *sql.Tx) error {
return handleOpen(tx, bTime, sid, sshKey, user, localUser)
})
case "close_session":
return withDB(func(tx *sql.Tx) error {
return handleClose(tx, bTime, sid)
})
case "":
return errors.New("PAM_TYPE unset")
default:
return nil
}
}
func query(arg string) ([]*session, error) {
var sessions []*session
err := withDB(func(tx *sql.Tx) (err error) {
if arg == "" {
sessions, err = getLatestSessions(tx, 100)
} else {
sessions, err = getUserSessions(tx, arg)
}
return
})
return sessions, err
}
func printSessions(sessions []*session, bTime int64) {
for _, s := range sessions {
fmt.Println(s.formatString(bTime))
}
}
func emit(obj any) {
enc, err := json.Marshal(obj)
if err != nil {
slog.Printf("error serializing JSON object: %v", err)
return
}
slog.Printf("@cee:%s", enc)
}
func main() {
slog.SetFlags(0)
flag.Parse()
bTime, err := getBootTime()
if err != nil {
slog.Fatal(err)
}
if *bootFlag {
if err := cleanupAfterReboot(bTime); err != nil {
slog.Fatal(err)
}
return
}
// Autodetect mode of operation: pam-exec hook or cli tool.
// If PAM_TYPE is defined, we're in the pam-exec hook.
if s := os.Getenv("PAM_TYPE"); s != "" {
if l, err := syslog.NewLogger(syslog.LOG_AUTH|syslog.LOG_INFO, 0); err == nil {
slog = l
}
err := run(bTime)
// Lack of a SSH_CONNECTION env var means we're being
// invoked for a non-ssh service, ignore it.
if err != nil && !errors.Is(err, errNoSSHConnection) {
slog.Fatal(err)
}
os.Exit(0)
}
// Cli tool (query).
sessions, err := query(flag.Arg(0))
if err != nil {
slog.Fatal(err)
}
printSessions(sessions, bTime)
}