package main import ( "context" "database/sql" "encoding/json" "errors" "flag" "fmt" "log" "log/syslog" "os" "time" "git.autistici.org/ai3/go-common/sqlutil" "golang.org/x/crypto/ssh" ) var ( bootFlag = flag.Bool("boot", false, "cleanup wtmp entries (run once at boot)") dbPath = flag.String("db", "/var/lib/ssh-wtmp/wtmp.db", "`path` to the sqlite wtmp database") ) var slog *log.Logger = log.Default() func handleOpen(tx *sql.Tx, bTime int64, sid string, sshKey ssh.PublicKey, user, localUser string) error { s := &session{ ID: sid, BootTime: bTime, User: user, LocalUser: localUser, SSHKeyFingerprint: ssh.FingerprintSHA256(sshKey), StartedAt: time.Now(), } s.Location.IP = envGetRemoteIP() if err := augmentLocation(&s.Location); err != nil { slog.Printf("location error: %v", err) } emit(s.logEntry(entryTypeOpen)) if err := createSession(tx, s); err != nil { return err } return nil } func handleClose(tx *sql.Tx, bTime int64, sid string) error { s, err := getOpenSession(tx, sid, bTime) if err != nil { return fmt.Errorf("could not retrieve existing open session %s", sid) } if err := closeSession(tx, s, false); err != nil { return fmt.Errorf("could not save session: %w", err) } emit(s.logEntry(entryTypeClose)) return nil } func withDB(f func(tx *sql.Tx) error) error { db, err := openDB(*dbPath) if err != nil { return err } defer db.Close() return sqlutil.WithTx(context.Background(), db, f) } func cleanupAfterReboot(bTime int64) error { return withDB(func(tx *sql.Tx) error { pending, err := getPendingSessions(tx, bTime) if err != nil { return err } for _, s := range pending { if err := closeSession(tx, s, true); err != nil { return err } } return nil }) } func run(bTime int64) error { sid, err := envGetSessionID() if err != nil { return err } switch getenv("PAM_TYPE") { case "open_session": sshKey, err := envGetSSHAuthInfo() if errors.Is(err, errNonPublicKeyLogin) { // Do nothing. return nil } else if err != nil { return err } localUser, err := envGetUser() if err != nil { return err } keyPaths, err := userAuthorizedKeysFile(localUser) if err != nil { return err } _, user, err := findKeyInAuthorizedKeys(keyPaths, sshKey) if err != nil { return err } return withDB(func(tx *sql.Tx) error { return handleOpen(tx, bTime, sid, sshKey, user, localUser) }) case "close_session": return withDB(func(tx *sql.Tx) error { return handleClose(tx, bTime, sid) }) case "": return errors.New("PAM_TYPE unset") default: return nil } } func query(arg string) ([]*session, error) { var sessions []*session err := withDB(func(tx *sql.Tx) (err error) { if arg == "" { sessions, err = getLatestSessions(tx, 100) } else { sessions, err = getUserSessions(tx, arg) } return }) return sessions, err } func printSessions(sessions []*session, bTime int64) { for _, s := range sessions { fmt.Println(s.formatString(bTime)) } } func emit(obj any) { enc, err := json.Marshal(obj) if err != nil { slog.Printf("error serializing JSON object: %v", err) return } slog.Printf("@cee:%s", enc) } func main() { slog.SetFlags(0) flag.Parse() bTime, err := getBootTime() if err != nil { slog.Fatal(err) } if *bootFlag { if err := cleanupAfterReboot(bTime); err != nil { slog.Fatal(err) } return } // Autodetect mode of operation: pam-exec hook or cli tool. // If PAM_TYPE is defined, we're in the pam-exec hook. if s := os.Getenv("PAM_TYPE"); s != "" { if l, err := syslog.NewLogger(syslog.LOG_AUTH|syslog.LOG_INFO, 0); err == nil { slog = l } err := run(bTime) if err != nil { slog.Fatal(err) } os.Exit(0) } // Cli tool (query). sessions, err := query(flag.Arg(0)) if err != nil { slog.Fatal(err) } printSessions(sessions, bTime) }