Skip to content
Snippets Groups Projects

Refactor the login handler

Merged ale requested to merge better-login into master
1 file
+ 23
9
Compare changes
  • Side-by-side
  • Inline
  • 6387bf4c
    Block default favicon requests · 6387bf4c
    ale authored
    If we don't, they will trigger the login handler and invalidate the
    current session (if any), which prevents the user from being able to
    log in.
+ 183
315
package server
//go:generate python sri.py templates/*.html
//go:generate go run scripts/sri.go --package server --output sri_map.go static
//go:generate go-bindata --nocompress --pkg server static/... templates/...
import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/rs/cors"
"git.autistici.org/id/auth"
authclient "git.autistici.org/id/auth/client"
ksclient "git.autistici.org/id/keystore/client"
"git.autistici.org/id/go-sso/httputil"
"git.autistici.org/id/go-sso/server/device"
"git.autistici.org/id/go-sso/server/httputil"
"git.autistici.org/id/go-sso/server/login"
)
const authSessionKey = "_auth"
type authSession struct {
*httputil.ExpiringSession
// User name and other information (like group membership).
Username string
UserInfo *auth.UserInfo
// Services the user has logged in to from this session.
Services []string
}
// AddService adds a service to the current session (if it's not
// already there).
func (s *authSession) AddService(service string) {
for _, svc := range s.Services {
if svc == service {
return
}
}
s.Services = append(s.Services, service)
}
// By default, make users log in again after (almost) one day.
var defaultAuthSessionLifetime = 20 * time.Hour
func newAuthSession(ttl time.Duration, username string, userinfo *auth.UserInfo) *authSession {
return &authSession{
ExpiringSession: httputil.NewExpiringSession(ttl),
Username: username,
UserInfo: userinfo,
}
}
// A relatively strict CSP.
const contentSecurityPolicy = "default-src 'none'; img-src 'self' data:; script-src 'self'; style-src 'self'; connect-src 'self';"
func init() {
gob.Register(&authSession{})
}
// Slightly looser CSP for the logout page: it needs to load remote
// images.
const logoutContentSecurityPolicy = "default-src 'none'; img-src *; script-src 'self'; style-src 'self'; connect-src *;"
// Returns the URL of the login handler on the target service.
func serviceLoginCallback(service, destination, token string) string {
@@ -88,105 +51,126 @@ func serviceLogoutCallback(service string) string {
// Server for the SSO protocol. Provides the HTTP interface to a
// LoginService.
type Server struct {
authSessionStore sessions.Store
authSessionLifetime time.Duration
loginHandler *loginHandler
authSessionLifetime int
loginService *LoginService
keystore ksclient.Client
keystoreGroups []string
csrfSecret []byte
renderer *renderer
renderer *httputil.Renderer
urlPrefix string
homepageRedirectURL string
allowedOrigins []string
// User-configurable static data that we serve from memory.
siteLogo *staticContent
siteFavicon *staticContent
}
func sl2bl(sl []string) [][]byte {
var out [][]byte
for _, s := range sl {
out = append(out, []byte(s))
}
return out
handler http.Handler
}
// New returns a new Server.
func New(loginService *LoginService, authClient authclient.Client, config *Config) (*Server, error) {
urlPrefix := strings.TrimRight(config.URLPrefix, "/")
sessionSecrets := sl2bl(config.SessionSecrets)
store := sessions.NewCookieStore(sessionSecrets...)
store.Options = &sessions.Options{
HttpOnly: true,
Secure: true,
MaxAge: 0,
Path: urlPrefix + "/",
}
renderer := newRenderer(config)
s := &Server{
authSessionLifetime: defaultAuthSessionLifetime,
authSessionStore: store,
renderer := httputil.NewRenderer(
parseEmbeddedTemplates(),
map[string]interface{}{
"URLPrefix": urlPrefix,
"AccountRecoveryURL": config.AccountRecoveryURL,
"SiteName": config.SiteName,
"SiteLogo": config.SiteLogo,
"SiteFavicon": config.SiteFavicon,
},
)
h := &Server{
loginService: loginService,
urlPrefix: urlPrefix,
homepageRedirectURL: config.HomepageRedirectURL,
allowedOrigins: config.AllowedCORSOrigins,
authSessionLifetime: config.AuthSessionLifetimeSeconds,
renderer: renderer,
}
if config.CSRFSecret != "" {
s.csrfSecret = []byte(config.CSRFSecret)
}
if config.AuthSessionLifetimeSeconds > 0 {
s.authSessionLifetime = time.Duration(config.AuthSessionLifetimeSeconds) * time.Second
}
if config.SiteLogo != "" {
siteLogo, err := loadStaticContent(config.SiteLogo)
if config.KeyStore != nil {
ks, err := ksclient.New(config.KeyStore)
if err != nil {
return nil, err
}
s.siteLogo = siteLogo
log.Printf("keystore client enabled")
h.keystore = ks
h.keystoreGroups = config.KeyStoreEnableGroups
}
if config.SiteFavicon != "" {
siteFavicon, err := loadStaticContent(config.SiteFavicon)
devMgr, err := device.New(config.DeviceManager, urlPrefix)
if err != nil {
return nil, err
}
s.siteFavicon = siteFavicon
}
if config.KeyStore != nil {
ks, err := ksclient.New(config.KeyStore)
// The root HTTP handler. If a URL prefix is set, we can't
// just add a StripPrefix in front of everything, as the
// handlers need access to the actual full request URL, so we
// just inject the prefix everywhere.
root := http.NewServeMux()
// If we have customized content, serve it from well-known URLs.
if config.SiteLogo != "" {
siteLogo, err := httputil.LoadStaticContent(config.SiteLogo)
if err != nil {
return nil, err
}
log.Printf("keystore client enabled")
s.keystore = ks
s.keystoreGroups = config.KeyStoreEnableGroups
root.Handle(h.urlFor("/img/site_logo"), siteLogo)
}
devMgr, err := device.New(config.DeviceManager)
if config.SiteFavicon != "" {
siteFavicon, err := httputil.LoadStaticContent(config.SiteFavicon)
if err != nil {
return nil, err
}
s.loginHandler = newLoginHandler(s.loginCallback, devMgr, authClient,
config.AuthService, config.U2FAppID, config.URLPrefix,
renderer, sessionSecrets...)
return s, nil
root.Handle(h.urlFor("/favicon.ico"), siteFavicon)
} else if urlPrefix == "" {
// Block default favicon requests (created by error pages, or
// if we don't set a custom favicon) *before* the login
// handler runs, or it will invalidate the session!
root.HandleFunc(h.urlFor("/favicon.ico"), func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
})
}
func inAnyGroups(groups, ref []string) bool {
for _, rr := range ref {
for _, gg := range groups {
if gg == rr {
return true
}
}
// Serve static content to anyone.
staticPath := h.urlFor("/static/")
root.Handle(staticPath, http.StripPrefix(staticPath, http.FileServer(&assetfs.AssetFS{
Asset: Asset,
AssetDir: AssetDir,
AssetInfo: AssetInfo,
Prefix: "static",
})))
// Add the /exchange endpoint (which does not use the normal
// HTTP-based login workflow).
root.HandleFunc(h.urlFor("/exchange"), h.handleExchange)
// Build the main IDP application router, wrap it with a login
// handler, optional CSRF protection, custom HTTP headers,
// etc.
mainh := http.NewServeMux()
mainh.HandleFunc("/logout", h.handleLogout)
mainh.HandleFunc("/", h.handleGrantTicket)
loginh := login.New(mainh, devMgr, authClient,
config.AuthService, config.U2FAppID, urlPrefix,
config.HomepageRedirectURL, renderer, h.loginCallback,
sl2bl(config.SessionSecrets),
time.Duration(config.AuthSessionLifetimeSeconds)*time.Second)
apph := httputil.WithDynamicHeaders(loginh, contentSecurityPolicy)
if config.CSRFSecret != "" {
apph = csrf.Protect([]byte(config.CSRFSecret))(apph)
}
return false
// Add CORS headers on the main IDP endpoints.
corsp := cors.New(cors.Options{
AllowedOrigins: config.AllowedCORSOrigins,
AllowedHeaders: []string{"*"},
AllowCredentials: true,
MaxAge: 86400,
})
apph = corsp.Handler(apph)
root.Handle(h.urlFor("/"), apph)
h.handler = root
return h, nil
}
// We unlock the keystore if the following conditions are met:
@@ -206,17 +190,17 @@ func (h *Server) maybeUnlockKeystore(ctx context.Context, username, password str
}
shard = userinfo.Shard
}
return true, h.keystore.Open(ctx, shard, username, password, int(h.authSessionLifetime.Seconds()))
return true, h.keystore.Open(ctx, shard, username, password, h.authSessionLifetime)
}
func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, username, password string, userinfo *auth.UserInfo) error {
// Open the keystore for this user with the password used to
// authenticate. Set the TTL to the duration of the
// authenticated session.
decrypted, err := h.maybeUnlockKeystore(req.Context(), username, password, userinfo)
// Callback called by the login handler whenever a user successfully
// logs in. We use it to unlock the keystore with the user's password.
func (h *Server) loginCallback(ctx context.Context, username, password string, userinfo *auth.UserInfo) error {
// Open the keystore for this user, with the same password
// used to authenticate.
decrypted, err := h.maybeUnlockKeystore(ctx, username, password, userinfo)
if err != nil {
log.Printf("failed to unlock keystore for user %s: %v", username, err)
return err
return fmt.Errorf("failed to unlock keystore for user %s: %v", username, err)
}
var kmsg string
@@ -224,43 +208,20 @@ func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, usernam
kmsg = " (key unlocked)"
}
log.Printf("successful login for user %s%s", username, kmsg)
// Create cookie-based session for the authenticated user.
session := newAuthSession(h.authSessionLifetime, username, userinfo)
httpSession, _ := h.authSessionStore.Get(req, authSessionKey) // nolint
httpSession.Values["data"] = session
return httpSession.Save(req, w)
}
func (h *Server) redirectToLogin(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, h.loginHandler.makeLoginURL(req), http.StatusFound)
}
func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSession), authFail func(http.ResponseWriter, *http.Request)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
httpSession, err := h.authSessionStore.Get(req, authSessionKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, ok := httpSession.Values["data"].(*authSession)
if ok && session.Valid() {
f(w, req, session)
return
}
httpSession.Options.MaxAge = -1
delete(httpSession.Values, "data")
if err := httpSession.Save(req, w); err != nil {
log.Printf("error saving session: %v", err)
}
authFail(w, req)
})
return nil
}
// Token signing handler. Authorizes an authenticated user to a service by
// signing a token with the user's identity. The client is redirected back to
// the original service, with the signed token.
func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, session *authSession) {
func (h *Server) handleGrantTicket(w http.ResponseWriter, req *http.Request) {
// We need this check here because this handler is usually
// mounted at the application root.
if req.URL.Path != h.urlFor("/") {
http.NotFound(w, req)
return
}
// Extract the authorization request parameters from the HTTP
// request query args.
//
@@ -268,7 +229,13 @@ func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, sessio
// it is a POST request redirected from a 307, so we do not
// call req.FormValue() but look directly into request.URL
// instead.
username := session.Username
auth, ok := login.GetAuth(req.Context())
if !ok {
http.Error(w, "No valid session", http.StatusBadRequest)
return
}
username := auth.Username
service := req.URL.Query().Get("s")
destination := req.URL.Query().Get("d")
nonce := req.URL.Query().Get("n")
@@ -292,8 +259,8 @@ func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, sessio
var groups []string
if groupsStr != "" {
reqGroups := strings.Split(groupsStr, ",")
if len(reqGroups) > 0 && session.UserInfo != nil {
groups = intersectGroups(reqGroups, session.UserInfo.Groups)
if len(reqGroups) > 0 && auth.UserInfo != nil {
groups = intersectGroups(reqGroups, auth.UserInfo.Groups)
// We only make this check here as a convenience to
// the user (we may be able to show a nicer UI): the
// actual group ACL must be applied on the destination
@@ -314,28 +281,32 @@ func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, sessio
return
}
session.AddService(service)
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving session: %v", err)
}
// Record the service in the session.
auth.AddService(service)
// Redirect to service callback.
callbackURL := serviceLoginCallback(service, destination, token)
http.Redirect(w, req, callbackURL, http.StatusFound)
}
func (h *Server) alreadyLoggedOut(w http.ResponseWriter, req *http.Request) {
http.Error(w, "You do not seem to be logged in", http.StatusBadRequest)
}
type logoutServiceInfo struct {
URL string `json:"url"`
Name string `json:"name"`
}
func (h *Server) handleLogout(w http.ResponseWriter, req *http.Request, session *authSession) {
// Logout handler. We generate a page that triggers child logout
// requests to all the services the user is logged in to.
func (h *Server) handleLogout(w http.ResponseWriter, req *http.Request) {
auth, ok := login.GetAuth(req.Context())
if !ok {
http.Error(w, "No valid session", http.StatusBadRequest)
return
}
//
var svcs []logoutServiceInfo
for _, svc := range session.Services {
for _, svc := range auth.Services {
svcs = append(svcs, logoutServiceInfo{
Name: svc,
URL: serviceLogoutCallback(svc),
@@ -349,32 +320,20 @@ func (h *Server) handleLogout(w http.ResponseWriter, req *http.Request, session
"IncludeLogoutScripts": true,
}
// Clear the local session. Ignore errors.
httpSession, _ := h.authSessionStore.Get(req, authSessionKey) // nolint
delete(httpSession.Values, "data")
httpSession.Options.MaxAge = -1
httpSession.Save(req, w) // nolint
// Close the keystore.
if h.keystore != nil {
var shard string
if session.UserInfo != nil {
shard = session.UserInfo.Shard
if auth.UserInfo != nil {
shard = auth.UserInfo.Shard
}
if err := h.keystore.Close(req.Context(), shard, session.Username); err != nil {
log.Printf("failed to wipe keystore for user %s: %v", session.Username, err)
if err := h.keystore.Close(req.Context(), shard, auth.Username); err != nil {
// This is not a fatal error.
log.Printf("warning: failed to wipe keystore for user %s: %v", auth.Username, err)
}
}
w.Header().Set("Content-Security-Policy", logoutContentSecurityPolicy)
body, err := h.renderer.Render(req, "logout.html", data)
if err != nil {
log.Printf("template error in logout(): %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(body) // nolint
h.renderer.Render(w, req, "logout.html", data)
}
func (h *Server) handleExchange(w http.ResponseWriter, req *http.Request) {
@@ -406,97 +365,14 @@ func (h *Server) urlFor(path string) string {
// Handler returns the http.Handler for the SSO server application.
func (h *Server) Handler() http.Handler {
// The root HTTP handler. This must be a gorilla/mux.Router since
// sessions depend on it.
//
// If a URL prefix is set, we can't just add a StripPrefix in
// front of everything, as the handlers need access to the
// actual full request URL, so we just inject the prefix
// everywhere.
root := mux.NewRouter()
// If we have customized content, serve it from well-known URLs.
if h.siteLogo != nil {
root.Handle(h.urlFor("/img/site_logo"), h.siteLogo)
}
if h.siteFavicon != nil {
root.Handle(h.urlFor("/favicon.ico"), h.siteFavicon)
}
// Serve static content to anyone.
staticPath := h.urlFor("/static/")
root.PathPrefix(staticPath).Handler(http.StripPrefix(staticPath, http.FileServer(&assetfs.AssetFS{
Asset: Asset,
AssetDir: AssetDir,
AssetInfo: AssetInfo,
Prefix: "static",
})))
// Build the main IDP application router, with optional CSRF
// protection.
m := http.NewServeMux()
m.Handle(h.urlFor("/login"), h.loginHandler)
m.Handle(h.urlFor("/logout"), h.withAuth(h.handleLogout, h.alreadyLoggedOut))
idph := http.Handler(m)
if h.csrfSecret != nil {
idph = csrf.Protect(h.csrfSecret)(idph)
}
// Add CORS headers on the main SSO API endpoint.
c := cors.New(cors.Options{
AllowedOrigins: h.allowedOrigins,
AllowedHeaders: []string{"*"},
AllowCredentials: true,
MaxAge: 86400,
})
// Add the SSO provider endpoints (root path and /exchange),
// which do not need CSRF. We use a HandlerFunc to bypass the
// '/' dispatch semantics of the standard http.ServeMux.
ssoh := c.Handler(h.withAuth(h.handleHomepage, h.redirectToLogin))
userh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == h.urlFor("/"):
ssoh.ServeHTTP(w, r)
case r.URL.Path == h.urlFor("/exchange"):
h.handleExchange(w, r)
default:
idph.ServeHTTP(w, r)
}
})
// User-facing routes require cache-busting and CSP headers.
root.PathPrefix(h.urlFor("/")).Handler(withDynamicHeaders(c.Handler(userh)))
return root
}
// A relatively strict CSP.
const contentSecurityPolicy = "default-src 'none'; img-src 'self' data:; script-src 'self'; style-src 'self'; connect-src 'self';"
// Slightly looser CSP for the logout page: it needs to load remote
// images.
const logoutContentSecurityPolicy = "default-src 'none'; img-src *; script-src 'self'; style-src 'self'; connect-src *;"
func withDynamicHeaders(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Expires", "-1")
w.Header().Set("X-Frame-Options", "NONE")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("X-Content-Type-Options", "nosniff")
if w.Header().Get("Content-Security-Policy") == "" {
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
}
h.ServeHTTP(w, r)
})
return h.handler
}
// Parse the templates that are embedded with the binary (in bindata.go).
func parseEmbeddedTemplates() *template.Template {
root := template.New("").Funcs(template.FuncMap{
"json": toJSON,
"SRI": sriIntegrity,
})
files, err := AssetDir("templates")
if err != nil {
@@ -514,64 +390,56 @@ func parseEmbeddedTemplates() *template.Template {
return root
}
type renderer struct {
tpl *template.Template
urlPrefix string
siteName string
siteLogo string
siteFavicon string
accountRecoveryURL string
func sl2bl(sl []string) [][]byte {
var out [][]byte
for _, s := range sl {
out = append(out, []byte(s))
}
return out
}
func newRenderer(config *Config) *renderer {
return &renderer{
tpl: parseEmbeddedTemplates(),
urlPrefix: strings.TrimRight(config.URLPrefix, "/"),
accountRecoveryURL: config.AccountRecoveryURL,
siteName: config.SiteName,
siteLogo: config.SiteLogo,
siteFavicon: config.SiteFavicon,
// Returns true if the intersection of the sets isn't empty (in O(N^2)
// time).
func inAnyGroups(groups, ref []string) bool {
for _, rr := range ref {
for _, gg := range groups {
if gg == rr {
return true
}
}
func (r *renderer) Render(req *http.Request, templateName string, data map[string]interface{}) ([]byte, error) {
data["CSRFField"] = csrf.TemplateField(req)
data["URLPrefix"] = r.urlPrefix
data["AccountRecoveryURL"] = r.accountRecoveryURL
data["SiteName"] = r.siteName
data["SiteLogo"] = r.siteLogo
data["SiteFavicon"] = r.siteFavicon
var buf bytes.Buffer
if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
return nil, err
}
return buf.Bytes(), nil
return false
}
type staticContent struct {
modtime time.Time
name string
data []byte
// Returns the intersection of two string lists (in O(N^2) time).
func intersectGroups(a, b []string) []string {
var out []string
for _, aa := range a {
for _, bb := range b {
if aa == bb {
out = append(out, aa)
break
}
}
func loadStaticContent(path string) (*staticContent, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
data, err := ioutil.ReadFile(path) // #nosec
return out
}
// Template helper function that encodes its input as JSON.
func toJSON(obj interface{}) string {
data, err := json.Marshal(obj)
if err != nil {
return nil, err
return ""
}
return &staticContent{
name: path,
modtime: stat.ModTime(),
data: data,
}, nil
return string(data)
}
func (c *staticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
// Return an integrity= attribute for the given URI (which should be
// supplied without an eventual prefix).
func sriIntegrity(uri string) template.HTML {
sri, ok := sriMap[uri]
if !ok {
return template.HTML("")
}
return template.HTML(fmt.Sprintf(" integrity=\"%s\"", sri))
}
Loading