Skip to content
Snippets Groups Projects
Commit 671884f2 authored by ale's avatar ale
Browse files

Add a test for the Proxy

parent a9c941a3
No related branches found
No related tags found
No related merge requests found
......@@ -41,9 +41,8 @@ const (
testLoginServer = "login.example.com"
)
func makeAuthRequest(t testing.TB, baseUri, path, service, domain string, priv []byte) []byte {
c := newTestHTTPClient()
resp, err := c.Get(baseUri + path)
func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
resp, err := c.Get(base + path)
if err != nil {
t.Fatalf("Get(%s): %v", path, err)
}
......@@ -75,7 +74,7 @@ func makeAuthRequest(t testing.TB, baseUri, path, service, domain string, priv [
destURL := "https://" + testHost + "/test"
u.Set("d", destURL)
u.Set("t", signed)
resp, err = c.Get(baseUri + "/sso_login?" + u.Encode())
resp, err = c.Get(base + "/sso_login?" + u.Encode())
if err != nil {
t.Fatal("Get(/sso_login):", err)
}
......@@ -88,7 +87,7 @@ func makeAuthRequest(t testing.TB, baseUri, path, service, domain string, priv [
}
// Finally, requesting the original URL should work now.
resp, err = c.Get(baseUri + path)
resp, err = c.Get(base + path)
if err != nil {
t.Fatalf("Get(%s, post-auth): %v", path, err)
}
......@@ -123,7 +122,8 @@ func TestSSOWrapper(t *testing.T) {
defer srv.Close()
// Request a sample URL.
data := string(makeAuthRequest(t, srv.URL, "/test", testService, testDomain, priv))
c := newTestHTTPClient()
data := string(makeAuthRequest(t, c, srv.URL, "/test", testService, testDomain, priv))
if data != "OK" {
t.Fatalf("Get() returned bad data: %s", data)
}
......
package proxy
import (
"crypto/rand"
"crypto/tls"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
sso "git.autistici.org/id/go-sso"
"github.com/gorilla/securecookie"
"golang.org/x/crypto/ed25519"
)
func createTestServer(t testing.TB) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
io.WriteString(w, "OK")
}))
}
func TestProxy(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
ioutil.WriteFile(tmpdir+"/public.key", pub, 0644)
targetSrv := createTestServer(t)
targetURL, _ := url.Parse(targetSrv.URL)
defer targetSrv.Close()
config := &Config{
SessionAuthKey: string(securecookie.GenerateRandomKey(64)),
SessionEncKey: string(securecookie.GenerateRandomKey(32)),
SSOLoginServerURL: "https://login.example.com/",
SSOPublicKeyFile: tmpdir + "/public.key",
SSODomain: "example.com",
Backends: []*Backend{
&Backend{
Host: "test.example.com",
Upstream: []string{targetURL.Host},
},
},
}
proxy, err := NewProxy(config)
if err != nil {
t.Fatal(err)
}
proxySrv := httptest.NewTLSServer(proxy)
defer proxySrv.Close()
c := newTestHTTPClient(proxySrv.URL)
data := string(makeAuthRequest(t, c, "https://test.example.com", "/", "test.example.com/", "example.com", priv))
if data != "OK" {
t.Fatalf("bad response: %s", data)
}
}
// Create a http.Client locked to a specific address - no matter the
// URL, the underlying transport will make a connection to the server
// specified in uri.
func newTestHTTPClient(uri string) *http.Client {
u, _ := url.Parse(uri)
addr := u.Host
jar, _ := cookiejar.New(nil)
return &http.Client{
Jar: jar,
Transport: &http.Transport{
Dial: func(n, _ string) (net.Conn, error) {
return net.Dial(n, addr)
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
// This client will not follow redirects.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
resp, err := c.Get(base + path)
if err != nil {
t.Fatalf("Get(%s): %v", path, err)
}
if resp.StatusCode != http.StatusFound {
t.Fatalf("Get(%s) expected 302, got %d", path, resp.StatusCode)
}
loc, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatalf("Get(%s) redirects to unparsable URL %s: %v", path, resp.Header.Get("Location"), err)
}
// if loc.Host != testLoginServer {
// t.Fatalf("Get(%s) got bad redirect: %s", path, loc)
// }
resp.Body.Close()
// Sign a ticket, pretending we are the SSO server, then make
// a new request to the sso_login endpoint.
signer, err := sso.NewSigner(priv)
if err != nil {
t.Fatal(err)
}
nonce := loc.Query().Get("n")
tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
signed, err := signer.Sign(tkt)
if err != nil {
t.Fatal("Sign():", err)
}
u := make(url.Values)
destURL := base + path
u.Set("d", destURL)
u.Set("t", signed)
resp, err = c.Get(base + "/sso_login?" + u.Encode())
if err != nil {
t.Fatal("Get(/sso_login):", err)
}
if resp.StatusCode != http.StatusFound {
t.Fatalf("Get(/sso_login) expected 302, got %d", resp.StatusCode)
}
resp.Body.Close()
if s := resp.Header.Get("Location"); s != destURL {
t.Fatalf("Get(/sso_login) redirects to unexpected location %s", s)
}
// Finally, requesting the original URL should work now.
resp, err = c.Get(base + path)
if err != nil {
t.Fatalf("Get(%s, post-auth): %v", path, err)
}
if resp.StatusCode != 200 {
t.Fatalf("Get(%s, post-auth) expected 200, got %d", path, resp.StatusCode)
}
data, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
return data
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment