Skip to content
Snippets Groups Projects
db_test.go 2.47 KiB
Newer Older
  • Learn to ignore specific revisions
  • ale's avatar
    ale committed
    package main
    
    import (
    	"database/sql"
    	"errors"
    	"fmt"
    	"os"
    	"strings"
    	"testing"
    	"time"
    )
    
    func withTestDB(t *testing.T, f func()) {
    	dir, err := os.MkdirTemp("", "")
    	if err != nil {
    		t.Fatal(err)
    	}
    	defer os.RemoveAll(dir)
    
    	*dbPath = dir + "/test.db"
    
    	f()
    }
    
    var testSession = &session{
    	ID:                "id1",
    	User:              "user",
    	LocalUser:         "root",
    	SSHKeyFingerprint: "fp",
    	StartedAt:         time.Now(),
    }
    
    func TestSession_Format(t *testing.T) {
    	s := testSession.formatString(0)
    	if !strings.Contains(s, "still logged in") {
    		t.Fatalf("bad string for active session: %s", s)
    	}
    }
    
    func TestDB_Session(t *testing.T) {
    	withTestDB(t, func() {
    		err := withDB(func(tx *sql.Tx) error {
    			return createSession(tx, testSession)
    		})
    		if err != nil {
    			t.Fatal(err)
    		}
    
    		// Create a new session.
    		err = withDB(func(tx *sql.Tx) error {
    			s, err := getOpenSession(tx, "id1", 0)
    			if err != nil {
    				return err
    			}
    
    			if s.ID != testSession.ID ||
    				s.User != testSession.User ||
    				s.LocalUser != testSession.LocalUser ||
    				s.SSHKeyFingerprint != testSession.SSHKeyFingerprint {
    				return fmt.Errorf("bad getOpenSession result: %+v", s)
    			}
    
    			if !s.isActive(0) {
    				return errors.New("getOpenSession result is not active")
    			}
    
    			return closeSession(tx, s, false)
    		})
    		if err != nil {
    			t.Fatal(err)
    		}
    
    		// Close the active session.
    		err = withDB(func(tx *sql.Tx) error {
    			s, err := getOpenSession(tx, "id1", 0)
    			if err != sql.ErrNoRows {
    				return fmt.Errorf("unexpected getOpenSession error: %w", err)
    			}
    			if s != nil {
    				return errors.New("getOpenSession result is not nil")
    			}
    			return nil
    		})
    		if err != nil {
    			t.Fatal(err)
    		}
    
    		// Check getLatestSessions / getUserSessions.
    		for _, td := range []struct {
    			name string
    			f    func(*sql.Tx) ([]*session, error)
    		}{
    			{
    				"getLatestSessions", func(tx *sql.Tx) ([]*session, error) { return getLatestSessions(tx, 10) },
    			},
    			{
    				"getUserSessions", func(tx *sql.Tx) ([]*session, error) { return getUserSessions(tx, "user") },
    			},
    		} {
    			if err := withDB(func(tx *sql.Tx) error {
    				sessions, err := td.f(tx)
    				if err != nil {
    					return err
    				}
    				if len(sessions) != 1 {
    					return fmt.Errorf("unexpected number of sessions: %d", len(sessions))
    				}
    				if sessions[0].ID != testSession.ID {
    					return fmt.Errorf("getLatestSessions returned unexpected result: %+v", sessions[0])
    				}
    				return nil
    			}); err != nil {
    				t.Fatalf("%s: %v", td.name, err)
    			}
    		}
    	})
    }