diff --git a/httpsso/handler.go b/httpsso/handler.go index 93f37302ad847177c7a0df12a6f112d3347825fd..0bdde27c976f9e16d35f888f476ff88a0aa0e3ec 100644 --- a/httpsso/handler.go +++ b/httpsso/handler.go @@ -1,6 +1,7 @@ package httpsso import ( + "context" "crypto/rand" "encoding/gob" "encoding/hex" @@ -23,6 +24,44 @@ type authSession struct { Auth bool Username string + Groups []string +} + +type authSessionKeyType int + +var authSessionKey authSessionKeyType = 42 + +func getCurrentAuthSession(req *http.Request) *authSession { + s, ok := req.Context().Value(authSessionKey).(*authSession) + if !ok { + return nil + } + return s +} + +// Authenticated returns true if the user is successfully +// authenticated, in the call trace following SSOWrapper.Wrap. +func Authenticated(req *http.Request) bool { + if s := getCurrentAuthSession(req); s != nil { + return s.Auth + } + return false +} + +// Username of the currently authenticated user. +func Username(req *http.Request) string { + if s := getCurrentAuthSession(req); s != nil && s.Auth { + return s.Username + } + return "" +} + +// Groups returns the group list for the currently authenticated user. +func Groups(req *http.Request) []string { + if s := getCurrentAuthSession(req); s != nil && s.Auth { + return s.Groups + } + return nil } var authSessionLifetime = 1 * time.Hour @@ -83,7 +122,8 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http. if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth { req.Header.Set("X-Authenticated-User", auth.Username) - h.ServeHTTP(w, req) + ctx := context.WithValue(req.Context(), authSessionKey, auth) + h.ServeHTTP(w, req.WithContext(ctx)) return } @@ -117,6 +157,7 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi ExpiringSession: httputil.NewExpiringSession(authSessionLifetime), Auth: true, Username: tkt.User, + Groups: tkt.Groups, } if err := sessions.Save(req, w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -137,7 +178,7 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess w.Header().Set("Access-Control-Allow-Origin", s.serverOrigin) w.Header().Set("Access-Control-Allow-Credentials", "true") } - io.WriteString(w, "OK") + io.WriteString(w, "OK") // nolint } // Redirect to the SSO server. diff --git a/httpsso/handler_test.go b/httpsso/handler_test.go index cccec5f508e45ab8cfb0ebddd3509dbbc3309c11..7cb1f65332bc155c0751322d55db56e5a47a0653 100644 --- a/httpsso/handler_test.go +++ b/httpsso/handler_test.go @@ -8,6 +8,7 @@ import ( "net/http/cookiejar" "net/http/httptest" "net/url" + "strings" "testing" "time" @@ -41,6 +42,8 @@ const ( testLoginServer = "login.example.com" ) +var testGroups = []string{"group1", "group2"} + func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte { resp, err := c.Get(base + path) if err != nil { @@ -65,13 +68,13 @@ func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain s t.Fatal(err) } nonce := loc.Query().Get("n") - tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second) + tkt := sso.NewTicket("user", service, domain, nonce, testGroups, 300*time.Second) signed, err := signer.Sign(tkt) if err != nil { t.Fatal("Sign():", err) } u := make(url.Values) - destURL := "https://" + testHost + "/test" + destURL := "https://" + testHost + path u.Set("d", destURL) u.Set("t", signed) resp, err = c.Get(base + "/sso_login?" + u.Encode()) @@ -108,8 +111,11 @@ func TestSSOWrapper(t *testing.T) { // Build a test app - note that we want to use a gorilla Mux // here, otherwise cookie-based sessions won't work. m := mux.NewRouter() + m.HandleFunc("/test/groups", func(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, strings.Join(Groups(req), ",")) // nolint + }) m.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { - io.WriteString(w, "OK") + io.WriteString(w, "OK") // nolint }) w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32)) @@ -118,13 +124,20 @@ func TestSSOWrapper(t *testing.T) { } // Start a local test https server. - srv := httptest.NewTLSServer(w.Wrap(m, testService, nil)) + srv := httptest.NewTLSServer(w.Wrap(m, testService, testGroups)) defer srv.Close() // Request a sample URL. c := newTestHTTPClient() data := string(makeAuthRequest(t, c, srv.URL, "/test", testService, testDomain, priv)) if data != "OK" { - t.Fatalf("Get() returned bad data: %s", data) + t.Fatalf("Get(/test) returned bad data: %s", data) + } + + // Another URL, clean client, verify context values. + c = newTestHTTPClient() + data = string(makeAuthRequest(t, c, srv.URL, "/test/groups", testService, testDomain, priv)) + if data != "group1,group2" { + t.Fatalf("Get(/test/groups) returned bad data: %s", data) } }