Skip to content
Snippets Groups Projects
env_test.go 2.06 KiB
Newer Older
ale's avatar
ale committed
package main

import (
	"errors"
	"os"
	"testing"
)

func withTestEnv(env map[string]string, f func()) {
	getenv = func(key string) string {
		return env[key]
	}
	defer func() {
		getenv = os.Getenv
	}()

	f()
}

func TestEnv_GetSSHAuthInfo(t *testing.T) {
	withTestEnv(map[string]string{
		"SSH_AUTH_INFO_0": "publickey ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKHGU3+iLYiepGfenTlieDiv4hY4r0PlZa+mnb7NoKkc",
	}, func() {
		_, err := envGetSSHAuthInfo()
		if err != nil {
			t.Fatalf("envGetSSHAuthInfo: %v", err)
		}
	})
}

func TestEnv_GetSSHAuthInfo_NonPublicKey(t *testing.T) {
	withTestEnv(map[string]string{
		"SSH_AUTH_INFO_0": "password abcdef",
	}, func() {
		_, err := envGetSSHAuthInfo()
		if !errors.Is(err, errNonPublicKeyLogin) {
			t.Fatalf("envGetSSHAuthInfo: unexpected error %v", err)
		}
	})
}

func TestEnv_GetUser(t *testing.T) {
	withTestEnv(map[string]string{
		"PAM_USER": "user1",
	}, func() {
		u, err := envGetUser()
		if err != nil || u != "user1" {
			t.Fatalf("envGetUser: u=%v err=%v", u, err)
		}
	})
}

func TestEnv_GetSessionID(t *testing.T) {
	for _, td := range []struct {
		env map[string]string
		id  string
		err error
	}{
		{
			env: map[string]string{
				"SSH_CONNECTION": "127.0.0.1 10023",
				"XDG_SESSION_ID": "abcdef",
			},
			id: "7b272ca2f26f5e5f132bd84e94a774b2",
		},
		{
			env: map[string]string{
				"SSH_CONNECTION": "127.0.0.1 10023",
			},
			id: "ed83f217c70180f5ae9886a7cf2518d4",
		},
		{
			env: map[string]string{
				"XDG_SESSION_ID": "abcdef",
			},
			err: errNoSSHConnection,
		},
	} {
		withTestEnv(td.env, func() {
			id, err := envGetSessionID()
			if !errors.Is(err, td.err) {
				t.Errorf("%v - unexpected error %v", td.env, err)
			}
			if id != td.id {
				t.Errorf("%v - unexpected id %s", td.env, id)
			}
		})
	}
}

func TestEnv_GetRemoteIP(t *testing.T) {
	for _, env := range []map[string]string{
		{
			"SSH_CONNECTION": "127.0.0.1 10023",
		},
		{
			"PAM_RHOST": "127.0.0.1",
		},
	} {
		withTestEnv(env, func() {
			ip := envGetRemoteIP()
			if ip != "127.0.0.1" {
				t.Fatalf("unexpected ip for %v - %s", env, ip)
			}
		})
	}
}