package main import ( "database/sql" "errors" "fmt" "os" "strings" "testing" "time" ) func withTestDB(t *testing.T, f func()) { dir, err := os.MkdirTemp("", "") if err != nil { t.Fatal(err) } defer os.RemoveAll(dir) *dbPath = dir + "/test.db" f() } var testSession = &session{ ID: "id1", User: "user", LocalUser: "root", SSHKeyFingerprint: "fp", StartedAt: time.Now(), } func TestSession_Format(t *testing.T) { s := testSession.formatString(0) if !strings.Contains(s, "still logged in") { t.Fatalf("bad string for active session: %s", s) } } func TestDB_Session(t *testing.T) { withTestDB(t, func() { err := withDB(func(tx *sql.Tx) error { return createSession(tx, testSession) }) if err != nil { t.Fatal(err) } // Create a new session. err = withDB(func(tx *sql.Tx) error { s, err := getOpenSession(tx, "id1", 0) if err != nil { return err } if s.ID != testSession.ID || s.User != testSession.User || s.LocalUser != testSession.LocalUser || s.SSHKeyFingerprint != testSession.SSHKeyFingerprint { return fmt.Errorf("bad getOpenSession result: %+v", s) } if !s.isActive(0) { return errors.New("getOpenSession result is not active") } return closeSession(tx, s, false) }) if err != nil { t.Fatal(err) } // Close the active session. err = withDB(func(tx *sql.Tx) error { s, err := getOpenSession(tx, "id1", 0) if !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("unexpected getOpenSession error: %w", err) } if s != nil { return errors.New("getOpenSession result is not nil") } return nil }) if err != nil { t.Fatal(err) } // Check getLatestSessions / getUserSessions. for _, td := range []struct { name string f func(*sql.Tx) ([]*session, error) }{ { "getLatestSessions", func(tx *sql.Tx) ([]*session, error) { return getLatestSessions(tx, 10) }, }, { "getUserSessions", func(tx *sql.Tx) ([]*session, error) { return getUserSessions(tx, "user") }, }, } { if err := withDB(func(tx *sql.Tx) error { sessions, err := td.f(tx) if err != nil { return err } if len(sessions) != 1 { return fmt.Errorf("unexpected number of sessions: %d", len(sessions)) } if sessions[0].ID != testSession.ID { return fmt.Errorf("getLatestSessions returned unexpected result: %+v", sessions[0]) } return nil }); err != nil { t.Fatalf("%s: %v", td.name, err) } } }) }