package proxy import ( "crypto/rand" "crypto/tls" "io" "io/ioutil" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "os" "testing" "time" sso "git.autistici.org/id/go-sso" "github.com/gorilla/securecookie" "golang.org/x/crypto/ed25519" ) func createTestServer(t testing.TB) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "OK") })) } func TestProxy(t *testing.T) { tmpdir, err := ioutil.TempDir("", "") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } ioutil.WriteFile(tmpdir+"/public.key", pub, 0644) targetSrv := createTestServer(t) targetURL, _ := url.Parse(targetSrv.URL) defer targetSrv.Close() config := &Config{ SessionAuthKey: string(securecookie.GenerateRandomKey(64)), SessionEncKey: string(securecookie.GenerateRandomKey(32)), SSOLoginServerURL: "https://login.example.com/", SSOPublicKeyFile: tmpdir + "/public.key", SSODomain: "example.com", Backends: []*Backend{ &Backend{ Host: "test.example.com", Upstream: []string{targetURL.Host}, }, }, } proxy, err := NewProxy(config) if err != nil { t.Fatal(err) } proxySrv := httptest.NewTLSServer(proxy) defer proxySrv.Close() c := newTestHTTPClient(proxySrv.URL) data := string(makeAuthRequest(t, c, "https://test.example.com", "/", "test.example.com/", "example.com", priv)) if data != "OK" { t.Fatalf("bad response: %s", data) } } // Create a http.Client locked to a specific address - no matter the // URL, the underlying transport will make a connection to the server // specified in uri. func newTestHTTPClient(uri string) *http.Client { u, _ := url.Parse(uri) addr := u.Host jar, _ := cookiejar.New(nil) return &http.Client{ Jar: jar, Transport: &http.Transport{ Dial: func(n, _ string) (net.Conn, error) { return net.Dial(n, addr) }, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, // This client will not follow redirects. CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } } func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte { resp, err := c.Get(base + path) if err != nil { t.Fatalf("Get(%s): %v", path, err) } if resp.StatusCode != http.StatusFound { t.Fatalf("Get(%s) expected 302, got %d", path, resp.StatusCode) } loc, err := url.Parse(resp.Header.Get("Location")) if err != nil { t.Fatalf("Get(%s) redirects to unparsable URL %s: %v", path, resp.Header.Get("Location"), err) } // if loc.Host != testLoginServer { // t.Fatalf("Get(%s) got bad redirect: %s", path, loc) // } resp.Body.Close() // Sign a ticket, pretending we are the SSO server, then make // a new request to the sso_login endpoint. signer, err := sso.NewSigner(priv) if err != nil { t.Fatal(err) } nonce := loc.Query().Get("n") tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second) signed, err := signer.Sign(tkt) if err != nil { t.Fatal("Sign():", err) } u := make(url.Values) destURL := base + path u.Set("d", destURL) u.Set("t", signed) resp, err = c.Get(base + "/sso_login?" + u.Encode()) if err != nil { t.Fatal("Get(/sso_login):", err) } if resp.StatusCode != http.StatusFound { t.Fatalf("Get(/sso_login) expected 302, got %d", resp.StatusCode) } resp.Body.Close() if s := resp.Header.Get("Location"); s != destURL { t.Fatalf("Get(/sso_login) redirects to unexpected location %s", s) } // Finally, requesting the original URL should work now. resp, err = c.Get(base + path) if err != nil { t.Fatalf("Get(%s, post-auth): %v", path, err) } if resp.StatusCode != 200 { t.Fatalf("Get(%s, post-auth) expected 200, got %d", path, resp.StatusCode) } data, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() return data }