Skip to content
Snippets Groups Projects
main.go 3.99 KiB
package main

import (
	"context"
	"database/sql"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"log"
	"log/syslog"
	"os"
	"time"

	"git.autistici.org/ai3/go-common/sqlutil"
	"golang.org/x/crypto/ssh"
)

var (
	bootFlag = flag.Bool("boot", false, "cleanup wtmp entries (run once at boot)")
	dbPath   = flag.String("db", "/var/lib/ssh-key-wtmp/wtmp.db", "`path` to the sqlite wtmp database")
)

var slog *log.Logger = log.Default()

func handleOpen(tx *sql.Tx, bTime int64, sid string, sshKey ssh.PublicKey, user, localUser string) error {
	s := &session{
		ID:                sid,
		BootTime:          bTime,
		User:              user,
		LocalUser:         localUser,
		SSHKeyFingerprint: ssh.FingerprintSHA256(sshKey),
		StartedAt:         time.Now(),
	}

	if err := augmentLocation(kLocationConfig, envGetRemoteIP(), &s.Location); err != nil {
		slog.Printf("location error: %v", err)
	}

	emit(s.logEntry(entryTypeOpen))

	if err := createSession(tx, s); err != nil {
		return err
	}
	return nil
}

func handleClose(tx *sql.Tx, bTime int64, sid string) error {
	s, err := getOpenSession(tx, sid, bTime)
	if err != nil {
		return fmt.Errorf("could not retrieve existing open session %s", sid)
	}

	if err := closeSession(tx, s, false); err != nil {
		return fmt.Errorf("could not save session: %w", err)
	}

	emit(s.logEntry(entryTypeClose))

	return nil
}

func withDB(f func(tx *sql.Tx) error) error {
	db, err := openDB(*dbPath)
	if err != nil {
		return err
	}
	defer db.Close()
	return sqlutil.WithTx(context.Background(), db, f)
}

func cleanupAfterReboot(bTime int64) error {
	return withDB(func(tx *sql.Tx) error {
		pending, err := getPendingSessions(tx, bTime)
		if err != nil {
			return err
		}
		for _, s := range pending {
			if err := closeSession(tx, s, true); err != nil {
				return err
			}
		}
		return nil
	})
}

func run(bTime int64) error {
	sid, err := envGetSessionID()
	if err != nil {
		return err
	}

	switch getenv("PAM_TYPE") {
	case "open_session":
		sshKey, err := envGetSSHAuthInfo()
		if errors.Is(err, errNonPublicKeyLogin) {
			// Do nothing.
			return nil
		} else if err != nil {
			return err
		}

		localUser, err := envGetUser()
		if err != nil {
			return err
		}

		keyPaths, err := userAuthorizedKeysFile(localUser)
		if err != nil {
			return err
		}
		_, user, err := findKeyInAuthorizedKeys(keyPaths, sshKey)
		if err != nil {
			return err
		}

		return withDB(func(tx *sql.Tx) error {
			return handleOpen(tx, bTime, sid, sshKey, user, localUser)
		})

	case "close_session":
		return withDB(func(tx *sql.Tx) error {
			return handleClose(tx, bTime, sid)
		})

	case "":
		return errors.New("PAM_TYPE unset")

	default:
		return nil
	}
}

func query(arg string) ([]*session, error) {
	var sessions []*session
	err := withDB(func(tx *sql.Tx) (err error) {
		if arg == "" {
			sessions, err = getLatestSessions(tx, 100)
		} else {
			sessions, err = getUserSessions(tx, arg)
		}
		return
	})
	return sessions, err
}

func printSessions(sessions []*session, bTime int64) {
	for _, s := range sessions {
		fmt.Println(s.formatString(bTime))
	}
}

func emit(obj any) {
	enc, err := json.Marshal(obj)
	if err != nil {
		slog.Printf("error serializing JSON object: %v", err)
		return
	}

	slog.Printf("@cee:%s", enc)
}

func main() {
	slog.SetFlags(0)
	flag.Parse()

	bTime, err := getBootTime()
	if err != nil {
		slog.Fatal(err)
	}

	if *bootFlag {
		if err := cleanupAfterReboot(bTime); err != nil {
			slog.Fatal(err)
		}
		return
	}

	// Autodetect mode of operation: pam-exec hook or cli tool.
	// If PAM_TYPE is defined, we're in the pam-exec hook.
	if s := os.Getenv("PAM_TYPE"); s != "" {
		if l, err := syslog.NewLogger(syslog.LOG_AUTH|syslog.LOG_INFO, 0); err == nil {
			slog = l
		}
		err := run(bTime)
		// Lack of a SSH_CONNECTION env var means we're being
		// invoked for a non-ssh service, ignore it.
		if err != nil && !errors.Is(err, errNoSSHConnection) {
			slog.Fatal(err)
		}
		os.Exit(0)
	}

	// Cli tool (query).
	sessions, err := query(flag.Arg(0))
	if err != nil {
		slog.Fatal(err)
	}

	printSessions(sessions, bTime)
}