handler.go 5.09 KiB
package httpsso
import (
"crypto/rand"
"encoding/gob"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/sessions"
"git.autistici.org/id/go-sso"
"git.autistici.org/id/go-sso/httputil"
)
type authSession struct {
*httputil.ExpiringSession
Auth bool
Username string
}
var authSessionLifetime = 1 * time.Hour
func init() {
gob.Register(&authSession{})
}
// SSOWrapper protects http handlers with single-sign-on authentication.
type SSOWrapper struct {
v sso.Validator
sessionAuthKey []byte
sessionEncKey []byte
serverURL string
serverOrigin string
}
// NewSSOWrapper returns a new SSOWrapper that will authenticate users
// on the specified login service.
func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, sessionEncKey []byte) (*SSOWrapper, error) {
v, err := sso.NewValidator(pkey, domain)
if err != nil {
return nil, err
}
return &SSOWrapper{
v: v,
serverURL: serverURL,
serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey,
}, nil
}
// Wrap a http.Handler with authentication and access control.
// Currently only a simple form of group-based ACLs is supported.
func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.Handler {
svcPath := pathFromService(service)
store := sessions.NewCookieStore(s.sessionAuthKey, s.sessionEncKey)
store.Options = &sessions.Options{
HttpOnly: true,
Secure: true,
MaxAge: 0,
Path: svcPath,
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
session, _ := store.Get(req, "sso")
switch strings.TrimPrefix(req.URL.Path, svcPath) {
case "sso_login":
s.handleLogin(w, req, session, service, groups)
case "sso_logout":
s.handleLogout(w, req, session)
default:
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
h.ServeHTTP(w, req)
return
}
s.redirectToLogin(w, req, session, service, groups)
}
})
}
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
t := req.FormValue("t")
d := req.FormValue("d")
// Pop the nonce from the session.
nonce, ok := session.Values["nonce"].(string)
if !ok || nonce == "" {
log.Printf("got login request without nonce")
http.Error(w, "Missing nonce", http.StatusBadRequest)
return
}
delete(session.Values, "nonce")
tkt, err := s.v.Validate(t, nonce, service, groups)
if err != nil {
log.Printf("validation error for token %s: %v", t, err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Authenticate the user.
session.Values["a"] = &authSession{
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
}
if err := sessions.Save(req, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, req, d, http.StatusFound)
}
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) {
session.Options.MaxAge = -1
if err := sessions.Save(req, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
if s.serverOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", s.serverOrigin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
io.WriteString(w, "OK")
}
// Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
// Generate a random nonce and store it in the local session.
nonce := makeUniqueNonce()
session.Values["nonce"] = nonce
if err := sessions.Save(req, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
v := make(url.Values)
v.Set("s", service)
v.Set("d", getFullURL(req, "https").String())
v.Set("n", nonce)
v.Set("g", strings.Join(groups, ","))
loginURL := s.serverURL + "?" + v.Encode()
http.Redirect(w, req, loginURL, http.StatusFound)
}
// Extract the URL path from the service specification. The result
// will have both a leading and a trailing slash.
func pathFromService(service string) string {
i := strings.IndexRune(service, '/')
if i < 0 {
return ""
}
return service[i:]
}
// Return a full URL from a HTTP request, assuming the given scheme
// (the URL field in net/http.Request normally only contains path and
// query args).
func getFullURL(req *http.Request, scheme string) *url.URL {
u := *req.URL
u.Scheme = scheme
u.Host = req.Host
return &u
}
func makeUniqueNonce() string {
var b [8]byte
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
panic(err)
}
return hex.EncodeToString(b[:])
}
// Return the origin from a URL (stripping path and other components).
func originFromURL(s string) string {
parsed, err := url.Parse(s)
if err != nil {
return ""
}
return fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
}