Commit 4d70b167 authored by ale's avatar ale

Refactor the login handler

The login handler is now a simpler, standalone http.Handler
wrapper. The separation between the SSO application and the login
handler is now fairly complete.

The login handler no longer forces the user to a specific workflow via
session cookies, but it works on a request-by-request basis instead,
which makes the "back" button works as expected (allowing the user to
bail out of a broken 2FA process, for example).

Session handling has been simplified as well: there is a single
session for authentication and login state, which should remove the
opportunity for session synchronization errors.
parent 6d3a620e
Pipeline #5345 passed with stages
in 3 minutes and 18 seconds
......@@ -20,8 +20,6 @@ import (
)
type authSession struct {
*httputil.ExpiringSession
Auth bool
Username string
Groups []string
......@@ -29,7 +27,7 @@ type authSession struct {
type authSessionKeyType int
var authSessionKey authSessionKeyType = 42
const authSessionKey authSessionKeyType = 0
func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession)
......@@ -64,7 +62,7 @@ func Groups(req *http.Request) []string {
return nil
}
var authSessionLifetime = 1 * time.Hour
var defaultAuthSessionTTL = 1 * time.Hour
func init() {
gob.Register(&authSession{})
......@@ -77,6 +75,8 @@ type SSOWrapper struct {
sessionEncKey []byte
serverURL string
serverOrigin string
TTL time.Duration
}
// NewSSOWrapper returns a new SSOWrapper that will authenticate users
......@@ -93,6 +93,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey,
TTL: defaultAuthSessionTTL,
}, nil
}
......@@ -109,7 +110,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
session, _ := store.Get(req, "sso")
session, _ := httputil.GetExpiringSession(req, store, "sso", s.TTL)
switch strings.TrimPrefix(req.URL.Path, svcPath) {
case "sso_login":
......@@ -119,11 +120,11 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
s.handleLogout(w, req, session)
default:
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
if auth, ok := session.Values["a"].(*authSession); ok && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
ctx := context.WithValue(req.Context(), authSessionKey, auth)
h.ServeHTTP(w, req.WithContext(ctx))
req = req.WithContext(context.WithValue(req.Context(), authSessionKey, auth))
h.ServeHTTP(w, req)
return
}
......@@ -132,7 +133,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
})
}
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
t := req.FormValue("t")
d := req.FormValue("d")
......@@ -154,21 +155,25 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
// Authenticate the user.
session.Values["a"] = &authSession{
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
}
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, req, d, http.StatusFound)
}
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) {
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession) {
// Delete the auth session.
session.Options.MaxAge = -1
delete(session.Values, "sso")
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......@@ -182,11 +187,13 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
}
// Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
// Generate a random nonce and store it in the local session.
nonce := makeUniqueNonce()
session.Values["nonce"] = nonce
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......
package httputil
import "net/http"
// WithDynamicHeaders wraps an http.Handler with cache-busting and
// security-related headers appropriate for a user-facing dynamic
// application. The 'csp' argument sets a default
// Content-Security-Policy.
func WithDynamicHeaders(h http.Handler, csp string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("Pragma", "no-cache")
hdr.Set("Cache-Control", "no-store")
hdr.Set("Expires", "-1")
hdr.Set("X-Frame-Options", "NONE")
hdr.Set("X-XSS-Protection", "1; mode=block")
hdr.Set("X-Content-Type-Options", "nosniff")
if csp != "" && hdr.Get("Content-Security-Policy") == "" {
hdr.Set("Content-Security-Policy", csp)
}
h.ServeHTTP(w, r)
})
}
package httputil
import (
"bytes"
"html/template"
"io"
"log"
"net/http"
"strconv"
"github.com/gorilla/csrf"
)
// A Renderer just renders HTML templates with some common context
// variables. Context is represented as a map[string]interface{}, to
// allow the merge operation.
type Renderer struct {
tpl *template.Template
vars map[string]interface{}
}
// NewRenderer creates a new Renderer with the provided templates and
// default variables.
func NewRenderer(tpl *template.Template, vars map[string]interface{}) *Renderer {
return &Renderer{
tpl: tpl,
vars: vars,
}
}
// Render the named HTML template to 'w'.
func (r *Renderer) Render(w http.ResponseWriter, req *http.Request, templateName string, data map[string]interface{}) {
// Merge default variables with the ones passed in 'data',
// without modifying either. Always populate the CRSFField
// variable with the current CSRF token.
vars := make(map[string]interface{})
vars["CSRFField"] = csrf.TemplateField(req)
for k, v := range r.vars {
vars[k] = v
}
for k, v := range data {
vars[k] = v
}
// Render the template into a buffer, to prevent returning
// half-rendered templates when there is an error.
var buf bytes.Buffer
if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
log.Printf("template rendering error for %s: %v", req.URL.String(), err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write our response to the client.
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
io.Copy(w, &buf) // nolint
}
......@@ -2,7 +2,11 @@ package httputil
import (
"encoding/gob"
"log"
"net/http"
"time"
"github.com/gorilla/sessions"
)
// ExpiringSession is a session with server-side expiration check.
......@@ -12,24 +16,85 @@ import (
// browser for the latter, but we enforce time-based expiration on the
// server.
type ExpiringSession struct {
Expiry time.Time
*sessions.Session
}
// NewExpiringSession returns a session that is valid for the given
// duration.
func NewExpiringSession(ttl time.Duration) *ExpiringSession {
return &ExpiringSession{
Expiry: time.Now().Add(ttl),
// GetExpiringSession wraps a Session (obtained from 'store') with
// an ExpiringSession. If it's invalid or expired, a new empty Session
// will be created with an expiration time set using 'ttl'.
func GetExpiringSession(req *http.Request, store sessions.Store, key string, ttl time.Duration) (*ExpiringSession, error) {
now := time.Now()
// An error here just means that we failed to decode the
// existing session for some reason. A new session will always
// be returned, so we just pass along the error to the caller
// (so it can be logged).
s, err := store.Get(req, key)
// See if we have a valid session first.
if !s.IsNew {
if exp, ok := s.Values["_exp"].(time.Time); ok && now.Before(exp) {
return &ExpiringSession{Session: s}, err
}
// We can't call sessions.NewSession() because that
// won't register the session with the Registry, so it
// won't be sent with the response. Wipe the data
// instead.
for k := range s.Values {
delete(s.Values, k)
}
}
// The session is either invalid or expired, create a new
// blank one containing no data.
expiry := now.Add(ttl)
s.Values["_exp"] = expiry
return &ExpiringSession{Session: s}, err
}
// Valid returns true if the session has not expired yet.
// It can be called with a nil receiver.
func (e *ExpiringSession) Valid() bool {
return e != nil && time.Now().Before(e.Expiry)
// Wrapper for an http.ResponseWriter that ensures all tracked
// sessions are saved before the request body is sent.
//
// We have to duplicate the logic to call WriteHeader on the first
// Write, otherwise the underlying ResponseWriter won't call our
// WriteHeader function but its own instead.
type sessionResponseWriter struct {
http.ResponseWriter
headerWritten bool
req *http.Request
}
func init() {
gob.Register(&ExpiringSession{})
func (w *sessionResponseWriter) WriteHeader(statusCode int) {
if statusCode >= 200 && statusCode < 400 {
if err := sessions.Save(w.req, w.ResponseWriter); err != nil {
log.Printf("error saving sessions: %v", err)
}
}
w.ResponseWriter.WriteHeader(statusCode)
w.headerWritten = true
}
func (w *sessionResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
// NewSessionResponseWriter returns a wrapped http.ResponseWriter that
// will always remember to save the Gorilla sessions before writing
// the response body.
func NewSessionResponseWriter(w http.ResponseWriter, req *http.Request) http.ResponseWriter {
return &sessionResponseWriter{
ResponseWriter: w,
req: req,
}
}
func init() {
// Register time.Time with encoding/gob, to ensure that the
// ExpiringSession timestamp can be serialized.
var t time.Time
gob.Register(t)
}
package httputil
import (
"bytes"
"encoding/gob"
"reflect"
"net/http"
"testing"
"time"
"github.com/gorilla/sessions"
)
type mySession struct {
Data string
}
func init() {
gob.Register(&mySession{})
}
func TestExpiringSession(t *testing.T) {
type mySession struct {
*ExpiringSession
Data string
}
s := &mySession{
ExpiringSession: NewExpiringSession(60 * time.Second),
Data: "data",
}
store := sessions.NewCookieStore()
req, _ := http.NewRequest("GET", "http://localhost/", nil)
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(s); err != nil {
t.Fatal("encode:", err)
httpsess, err := GetExpiringSession(req, store, "testkey", 60*time.Second)
if err != nil {
t.Errorf("store.Get error: %v", err)
}
var s2 mySession
if err := gob.NewDecoder(&buf).Decode(&s2); err != nil {
t.Fatal("decode:", err)
}
if !reflect.DeepEqual(s.Data, s2.Data) {
t.Fatalf("sessions differ: %+v vs %+v", s, &s2)
if _, ok := httpsess.Values["mykey"].(*mySession); ok {
t.Fatal("got a session without any data")
}
}
package httputil
import (
"bytes"
"io/ioutil"
"net/http"
"os"
"time"
)
// StaticContent is an http.Handler that serves in-memory data as if
// it were a static file.
type StaticContent struct {
modtime time.Time
name string
data []byte
}
// LoadStaticContent creates a StaticContent by loading data from a file.
func LoadStaticContent(path string) (*StaticContent, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return nil, err
}
return &StaticContent{
name: path,
modtime: stat.ModTime(),
data: data,
}, nil
}
func (c *StaticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
}
......@@ -10,6 +10,7 @@ import (
"strings"
"git.autistici.org/id/auth"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/mssola/user_agent"
)
......@@ -35,7 +36,7 @@ type Config struct {
}
// New returns a new Manager with the given configuration.
func New(config *Config) (*Manager, error) {
func New(config *Config, urlPrefix string) (*Manager, error) {
if config == nil {
config = &Config{}
}
......@@ -45,9 +46,15 @@ func New(config *Config) (*Manager, error) {
log.Printf("Warning: GeoIP disabled: %v", err)
}
// This should only happen in tests.
if config.AuthKey == "" {
log.Printf("Warning: device_manager.auth_key unset, generating temporary random secrets")
config.AuthKey = string(securecookie.GenerateRandomKey(64))
}
return &Manager{
geodb: geodb,
store: newStore([]byte(config.AuthKey)),
store: newStore([]byte(config.AuthKey), urlPrefix),
}, nil
}
......
......@@ -4,11 +4,11 @@ import "github.com/gorilla/sessions"
const aVeryLongTimeInSeconds = 10 * 365 * 86400
func newStore(authKey []byte) sessions.Store {
func newStore(authKey []byte, urlPrefix string) sessions.Store {
// No encryption, long-term lifetime cookie.
store := sessions.NewCookieStore(authKey, nil)
store.Options = &sessions.Options{
Path: "/",
Path: urlPrefix + "/",
HttpOnly: true,
Secure: true,
MaxAge: aVeryLongTimeInSeconds,
......
......@@ -4,25 +4,20 @@ package server
//go:generate go-bindata --nocompress --pkg server static/... templates/...
import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/rs/cors"
"git.autistici.org/id/auth"
......@@ -31,46 +26,15 @@ import (
"git.autistici.org/id/go-sso/httputil"
"git.autistici.org/id/go-sso/server/device"
"git.autistici.org/id/go-sso/server/login"
)
const authSessionKey = "_auth"
type authSession struct {
*httputil.ExpiringSession
// User name and other information (like group membership).
Username string
UserInfo *auth.UserInfo
// Services the user has logged in to from this session.
Services []string
}
// AddService adds a service to the current session (if it's not
// already there).
func (s *authSession) AddService(service string) {
for _, svc := range s.Services {
if svc == service {
return
}
}
s.Services = append(s.Services, service)
}
// By default, make users log in again after (almost) one day.
var defaultAuthSessionLifetime = 20 * time.Hour
func newAuthSession(ttl time.Duration, username string, userinfo *auth.UserInfo) *authSession {
return &authSession{
ExpiringSession: httputil.NewExpiringSession(ttl),
Username: username,
UserInfo: userinfo,
}
}
// A relatively strict CSP.
const contentSecurityPolicy = "default-src 'none'; img-src 'self' data:; script-src 'self'; style-src 'self'; connect-src 'self';"
func init() {
gob.Register(&authSession{})
}
// Slightly looser CSP for the logout page: it needs to load remote
// images.
const logoutContentSecurityPolicy = "default-src 'none'; img-src *; script-src 'self'; style-src 'self'; connect-src *;"
// Returns the URL of the login handler on the target service.
func serviceLoginCallback(service, destination, token string) string {
......@@ -88,105 +52,122 @@ func serviceLogoutCallback(service string) string {
// Server for the SSO protocol. Provides the HTTP interface to a
// LoginService.
type Server struct {
authSessionStore sessions.Store
authSessionLifetime time.Duration
loginHandler *loginHandler
authSessionLifetime int
loginService *LoginService
keystore ksclient.Client
keystoreGroups []string
csrfSecret []byte
renderer *renderer
renderer *httputil.Renderer
urlPrefix string
homepageRedirectURL string
allowedOrigins []string
// User-configurable static data that we serve from memory.
siteLogo *staticContent
siteFavicon *staticContent
}
func sl2bl(sl []string) [][]byte {
var out [][]byte
for _, s := range sl {
out = append(out, []byte(s))
}
return out
handler http.Handler
}
// New returns a new Server.
func New(loginService *LoginService, authClient authclient.Client, config *Config) (*Server, error) {
urlPrefix := strings.TrimRight(config.URLPrefix, "/")
sessionSecrets := sl2bl(config.SessionSecrets)
store := sessions.NewCookieStore(sessionSecrets...)
store.Options = &sessions.Options{
HttpOnly: true,
Secure: true,
MaxAge: 0,
Path: urlPrefix + "/",
}
renderer := newRenderer(config)
s := &Server{
authSessionLifetime: defaultAuthSessionLifetime,
authSessionStore: store,
renderer := httputil.NewRenderer(
parseEmbeddedTemplates(),
map[string]interface{}{
"URLPrefix": urlPrefix,
"AccountRecoveryURL": config.AccountRecoveryURL,
"SiteName": config.SiteName,
"SiteLogo": config.SiteLogo,
"SiteFavicon": config.SiteFavicon,
},
)
h := &Server{
loginService: loginService,
urlPrefix: urlPrefix,
homepageRedirectURL: config.HomepageRedirectURL,
allowedOrigins: config.AllowedCORSOrigins,
authSessionLifetime: config.AuthSessionLifetimeSeconds,
renderer: renderer,
}
if config.CSRFSecret != "" {
s.csrfSecret = []byte(config.CSRFSecret)
if config.KeyStore != nil {
ks, err := ksclient.New(config.KeyStore)
if err != nil {
return nil, err
}
log.Printf("keystore client enabled")
h.keystore = ks
h.keystoreGroups = config.KeyStoreEnableGroups
}
if config.AuthSessionLifetimeSeconds > 0 {