Skip to content
Snippets Groups Projects
db_test.go 2.48 KiB
Newer Older
ale's avatar
ale committed
package main

import (
	"database/sql"
	"errors"
	"fmt"
	"os"
	"strings"
	"testing"
	"time"
)

func withTestDB(t *testing.T, f func()) {
	dir, err := os.MkdirTemp("", "")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(dir)

	*dbPath = dir + "/test.db"

	f()
}

var testSession = &session{
	ID:                "id1",
	User:              "user",
	LocalUser:         "root",
	SSHKeyFingerprint: "fp",
	StartedAt:         time.Now(),
}

func TestSession_Format(t *testing.T) {
	s := testSession.formatString(0)
	if !strings.Contains(s, "still logged in") {
		t.Fatalf("bad string for active session: %s", s)
	}
}

func TestDB_Session(t *testing.T) {
	withTestDB(t, func() {
		err := withDB(func(tx *sql.Tx) error {
			return createSession(tx, testSession)
		})
		if err != nil {
			t.Fatal(err)
		}

		// Create a new session.
		err = withDB(func(tx *sql.Tx) error {
			s, err := getOpenSession(tx, "id1", 0)
			if err != nil {
				return err
			}

			if s.ID != testSession.ID ||
				s.User != testSession.User ||
				s.LocalUser != testSession.LocalUser ||
				s.SSHKeyFingerprint != testSession.SSHKeyFingerprint {
				return fmt.Errorf("bad getOpenSession result: %+v", s)
			}

			if !s.isActive(0) {
				return errors.New("getOpenSession result is not active")
			}

			return closeSession(tx, s, false)
		})
		if err != nil {
			t.Fatal(err)
		}

		// Close the active session.
		err = withDB(func(tx *sql.Tx) error {
			s, err := getOpenSession(tx, "id1", 0)
			if !errors.Is(err, sql.ErrNoRows) {
ale's avatar
ale committed
				return fmt.Errorf("unexpected getOpenSession error: %w", err)
			}
			if s != nil {
				return errors.New("getOpenSession result is not nil")
			}
			return nil
		})
		if err != nil {
			t.Fatal(err)
		}

		// Check getLatestSessions / getUserSessions.
		for _, td := range []struct {
			name string
			f    func(*sql.Tx) ([]*session, error)
		}{
			{
				"getLatestSessions", func(tx *sql.Tx) ([]*session, error) { return getLatestSessions(tx, 10) },
			},
			{
				"getUserSessions", func(tx *sql.Tx) ([]*session, error) { return getUserSessions(tx, "user") },
			},
		} {
			if err := withDB(func(tx *sql.Tx) error {
				sessions, err := td.f(tx)
				if err != nil {
					return err
				}
				if len(sessions) != 1 {
					return fmt.Errorf("unexpected number of sessions: %d", len(sessions))
				}
				if sessions[0].ID != testSession.ID {
					return fmt.Errorf("getLatestSessions returned unexpected result: %+v", sessions[0])
				}
				return nil
			}); err != nil {
				t.Fatalf("%s: %v", td.name, err)
			}
		}
	})
}