Skip to content
Snippets Groups Projects
Commit bd34d614 authored by ale's avatar ale
Browse files

Add tests for SSOWrapper

parent 3640a003
No related branches found
No related tags found
No related merge requests found
package httpsso
import (
"crypto/tls"
"io"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"golang.org/x/crypto/ed25519"
sso "git.autistici.org/id/go-sso"
)
func newTestHTTPClient() *http.Client {
jar, _ := cookiejar.New(nil)
return &http.Client{
Jar: jar,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
// This client will not follow redirects.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
const (
testHost = "service.example.com"
testService = "service.example.com/"
testDomain = "example.com"
testLoginServer = "login.example.com"
)
func makeAuthRequest(t testing.TB, baseUri, path, service, domain string, priv []byte) []byte {
c := newTestHTTPClient()
resp, err := c.Get(baseUri + path)
if err != nil {
t.Fatalf("Get(%s): %v", path, err)
}
if resp.StatusCode != http.StatusFound {
t.Fatalf("Get(%s) expected 302, got %d", path, resp.StatusCode)
}
loc, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatalf("Get(%s) redirects to unparsable URL %s: %v", path, resp.Header.Get("Location"), err)
}
if loc.Host != testLoginServer {
t.Fatalf("Get(%s) got bad redirect: %s", path, loc)
}
resp.Body.Close()
// Sign a ticket, pretending we are the SSO server, then make
// a new request to the sso_login endpoint.
signer, err := sso.NewSigner(priv)
if err != nil {
t.Fatal(err)
}
nonce := loc.Query().Get("n")
tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
signed, err := signer.Sign(tkt)
if err != nil {
t.Fatal("Sign():", err)
}
u := make(url.Values)
destURL := "https://" + testHost + "/test"
u.Set("d", destURL)
u.Set("t", signed)
resp, err = c.Get(baseUri + "/sso_login?" + u.Encode())
if err != nil {
t.Fatal("Get(/sso_login):", err)
}
if resp.StatusCode != http.StatusFound {
t.Fatalf("Get(/sso_login) expected 302, got %d", resp.StatusCode)
}
resp.Body.Close()
if s := resp.Header.Get("Location"); s != destURL {
t.Fatalf("Get(/sso_login) redirects to unexpected location %s", s)
}
// Finally, requesting the original URL should work now.
resp, err = c.Get(baseUri + path)
if err != nil {
t.Fatalf("Get(%s, post-auth): %v", path, err)
}
if resp.StatusCode != 200 {
t.Fatalf("Get(%s, post-auth) expected 200, got %d", path, resp.StatusCode)
}
data, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
return data
}
func TestSSOWrapper(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatal(err)
}
// Build a test app - note that we want to use a gorilla Mux
// here, otherwise cookie-based sessions won't work.
m := mux.NewRouter()
m.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) {
io.WriteString(w, "OK")
})
w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32))
if err != nil {
t.Fatal("NewSSOWrapper():", err)
}
// Start a local test https server.
srv := httptest.NewTLSServer(w.Wrap(m, testService, nil))
defer srv.Close()
// Request a sample URL.
data := string(makeAuthRequest(t, srv.URL, "/test", testService, testDomain, priv))
if data != "OK" {
t.Fatalf("Get() returned bad data: %s", data)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment