Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
  • redirect-307
2 results

http.go

Blame
  • Forked from id / go-sso
    146 commits behind the upstream repository.
    http.go 16.17 KiB
    package server
    
    //go:generate python sri.py templates/*.html
    //go:generate go-bindata --nocompress --pkg server static/... templates/...
    
    import (
    	"bytes"
    	"context"
    	"encoding/gob"
    	"encoding/json"
    	"fmt"
    	"html/template"
    	"io"
    	"io/ioutil"
    	"log"
    	"net/http"
    	"net/url"
    	"os"
    	"strings"
    	"time"
    
    	assetfs "github.com/elazarl/go-bindata-assetfs"
    	"github.com/gorilla/csrf"
    	"github.com/gorilla/mux"
    	"github.com/gorilla/sessions"
    
    	"git.autistici.org/id/auth"
    	authclient "git.autistici.org/id/auth/client"
    	ksclient "git.autistici.org/id/keystore/client"
    
    	"git.autistici.org/id/go-sso/httputil"
    	"git.autistici.org/id/go-sso/server/device"
    )
    
    const authSessionKey = "_auth"
    
    type authSession struct {
    	*httputil.ExpiringSession
    
    	// User name and other information (like group membership).
    	Username string
    	UserInfo *auth.UserInfo
    
    	// Services the user has logged in to from this session.
    	Services []string
    }
    
    // AddService adds a service to the current session (if it's not
    // already there).
    func (s *authSession) AddService(service string) {
    	for _, svc := range s.Services {
    		if svc == service {
    			return
    		}
    	}
    	s.Services = append(s.Services, service)
    }
    
    // By default, make users log in again after (almost) one day.
    var defaultAuthSessionLifetime = 20 * time.Hour
    
    func newAuthSession(ttl time.Duration, username string, userinfo *auth.UserInfo) *authSession {
    	return &authSession{
    		ExpiringSession: httputil.NewExpiringSession(ttl),
    		Username:        username,
    		UserInfo:        userinfo,
    	}
    }
    
    func init() {
    	gob.Register(&authSession{})
    }
    
    // Returns the URL of the login handler on the target service.
    func serviceLoginCallback(service, destination, token string) string {
    	v := make(url.Values)
    	v.Set("t", token)
    	v.Set("d", destination)
    	return fmt.Sprintf("https://%ssso_login?%s", service, v.Encode())
    }
    
    // Returns the URL of the logout handler on the target service.
    func serviceLogoutCallback(service string) string {
    	return fmt.Sprintf("https://%ssso_logout", service)
    }
    
    // Server for the SSO protocol. Provides the HTTP interface to a
    // LoginService.
    type Server struct {
    	authSessionStore    sessions.Store
    	authSessionLifetime time.Duration
    	loginHandler        *loginHandler
    	loginService        *LoginService
    	keystore            ksclient.Client
    	keystoreGroups      []string
    	csrfSecret          []byte
    	renderer            *renderer
    	urlPrefix           string
    	homepageRedirectURL string
    
    	// User-configurable static data that we serve from memory.
    	siteLogo    *staticContent
    	siteFavicon *staticContent
    }
    
    func sl2bl(sl []string) [][]byte {
    	var out [][]byte
    	for _, s := range sl {
    		out = append(out, []byte(s))
    	}
    	return out
    }
    
    // New returns a new Server.
    func New(loginService *LoginService, authClient authclient.Client, config *Config) (*Server, error) {
    	urlPrefix := strings.TrimRight(config.URLPrefix, "/")
    	sessionSecrets := sl2bl(config.SessionSecrets)
    	store := sessions.NewCookieStore(sessionSecrets...)
    	store.Options = &sessions.Options{
    		HttpOnly: true,
    		Secure:   true,
    		MaxAge:   0,
    		Path:     urlPrefix + "/",
    	}
    
    	renderer := newRenderer(config)
    	s := &Server{
    		authSessionLifetime: defaultAuthSessionLifetime,
    		authSessionStore:    store,
    		loginService:        loginService,
    		urlPrefix:           urlPrefix,
    		homepageRedirectURL: config.HomepageRedirectURL,
    		renderer:            renderer,
    	}
    	if config.CSRFSecret != "" {
    		s.csrfSecret = []byte(config.CSRFSecret)
    	}
    	if config.AuthSessionLifetimeSeconds > 0 {
    		s.authSessionLifetime = time.Duration(config.AuthSessionLifetimeSeconds) * time.Second
    	}
    
    	if config.SiteLogo != "" {
    		siteLogo, err := loadStaticContent(config.SiteLogo)
    		if err != nil {
    			return nil, err
    		}
    		s.siteLogo = siteLogo
    	}
    	if config.SiteFavicon != "" {
    		siteFavicon, err := loadStaticContent(config.SiteFavicon)
    		if err != nil {
    			return nil, err
    		}
    		s.siteFavicon = siteFavicon
    	}
    
    	if config.KeyStore != nil {
    		ks, err := ksclient.New(config.KeyStore)
    		if err != nil {
    			return nil, err
    		}
    		log.Printf("keystore client enabled")
    		s.keystore = ks
    		s.keystoreGroups = config.KeyStoreEnableGroups
    	}
    
    	devMgr, err := device.New(config.DeviceManager)
    	if err != nil {
    		return nil, err
    	}
    	s.loginHandler = newLoginHandler(s.loginCallback, devMgr, authClient,
    		config.AuthService, config.U2FAppID, config.URLPrefix,
    		renderer, sessionSecrets...)
    
    	return s, nil
    }
    
    func inAnyGroups(groups, ref []string) bool {
    	for _, rr := range ref {
    		for _, gg := range groups {
    			if gg == rr {
    				return true
    			}
    		}
    	}
    	return false
    }
    
    // We unlock the keystore if the following conditions are met:
    // keystore_enable_groups is set, userinfo is not nil, and the groups match.
    func (h *Server) maybeUnlockKeystore(ctx context.Context, username, password string, userinfo *auth.UserInfo) (bool, error) {
    	if h.keystore == nil {
    		return false, nil
    	}
    
    	var shard string
    	if len(h.keystoreGroups) > 0 {
    		if userinfo == nil {
    			return false, nil
    		}
    		if !inAnyGroups(userinfo.Groups, h.keystoreGroups) {
    			return false, nil
    		}
    		shard = userinfo.Shard
    	}
    	return true, h.keystore.Open(ctx, shard, username, password, int(h.authSessionLifetime.Seconds()))
    }
    
    func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, username, password string, userinfo *auth.UserInfo) error {
    	// Open the keystore for this user with the password used to
    	// authenticate. Set the TTL to the duration of the
    	// authenticated session.
    	decrypted, err := h.maybeUnlockKeystore(req.Context(), username, password, userinfo)
    	if err != nil {
    		log.Printf("failed to unlock keystore for user %s: %v", username, err)
    		return err
    	}
    
    	var kmsg string
    	if decrypted {
    		kmsg = " (key unlocked)"
    	}
    	log.Printf("successful login for user %s%s", username, kmsg)
    
    	// Create cookie-based session for the authenticated user.
    	session := newAuthSession(h.authSessionLifetime, username, userinfo)
    	httpSession, _ := h.authSessionStore.Get(req, authSessionKey) // nolint
    	httpSession.Values["data"] = session
    	return httpSession.Save(req, w)
    }
    
    func (h *Server) redirectToLogin(w http.ResponseWriter, req *http.Request) {
    	http.Redirect(w, req, h.loginHandler.makeLoginURL(req), http.StatusFound)
    }
    
    func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSession), authFail func(http.ResponseWriter, *http.Request)) http.Handler {
    	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    		httpSession, err := h.authSessionStore.Get(req, authSessionKey)
    		if err != nil {
    			http.Error(w, err.Error(), http.StatusInternalServerError)
    			return
    		}
    		session, ok := httpSession.Values["data"].(*authSession)
    		if ok && session.Valid() {
    			f(w, req, session)
    			return
    		}
    		httpSession.Options.MaxAge = -1
    		delete(httpSession.Values, "data")
    		if err := httpSession.Save(req, w); err != nil {
    			log.Printf("error saving session: %v", err)
    		}
    		authFail(w, req)
    	})
    }
    
    // Token signing handler. Authorizes an authenticated user to a service by
    // signing a token with the user's identity. The client is redirected back to
    // the original service, with the signed token.
    func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, session *authSession) {
    	// Extract the authorization request parameters from the HTTP
    	// request.
    	username := session.Username
    	service := req.FormValue("s")
    	destination := req.FormValue("d")
    	nonce := req.FormValue("n")
    	var groups, reqGroups []string
    	if gstr := req.FormValue("g"); gstr != "" {
    		reqGroups = strings.Split(gstr, ",")
    		if len(reqGroups) > 0 && session.UserInfo != nil {
    			groups = intersectGroups(reqGroups, session.UserInfo.Groups)
    			// We only make this check here as a convenience to
    			// the user (we may be able to show a nicer UI): the
    			// actual group ACL must be applied on the destination
    			// service, because the 'g' parameter is untrusted at
    			// this stage.
    			if len(groups) == 0 {
    				http.Error(w, "Forbidden", http.StatusForbidden)
    				return
    			}
    		}
    	}
    
    	// If the above parameters are unset, we're probably faced with a user
    	// that reached this URL by other means. Redirect them to the
    	// configured homepageRedirectURL, or at least return a slightly more
    	// user-friendly error.
    	if service == "" || destination == "" {
    		if h.homepageRedirectURL != "" {
    			http.Redirect(w, req, h.homepageRedirectURL, http.StatusFound)
    		} else {
    			http.Error(w, "You are not supposed to reach this page directly. Use the back button in your browser instead.", http.StatusBadRequest)
    		}
    		return
    	}
    
    	// Make the authorization request.
    	token, err := h.loginService.Authorize(username, service, destination, nonce, groups)
    	if err != nil {
    		log.Printf("auth error: %v: user=%s service=%s destination=%s nonce=%s groups=%s", err, username, service, destination, nonce, req.FormValue("g"))
    		http.Error(w, err.Error(), http.StatusBadRequest)
    		return
    	}
    
    	session.AddService(service)
    	if err := sessions.Save(req, w); err != nil {
    		log.Printf("error saving session: %v", err)
    	}
    
    	// Redirect to service callback.
    	callbackURL := serviceLoginCallback(service, destination, token)
    	http.Redirect(w, req, callbackURL, http.StatusFound)
    }
    
    func (h *Server) alreadyLoggedOut(w http.ResponseWriter, req *http.Request) {
    	http.Error(w, "You do not seem to be logged in", http.StatusBadRequest)
    }
    
    type logoutServiceInfo struct {
    	URL  string `json:"url"`
    	Name string `json:"name"`
    }
    
    func (h *Server) handleLogout(w http.ResponseWriter, req *http.Request, session *authSession) {
    	var svcs []logoutServiceInfo
    	for _, svc := range session.Services {
    		svcs = append(svcs, logoutServiceInfo{
    			Name: svc,
    			URL:  serviceLogoutCallback(svc),
    		})
    	}
    
    	svcJSON, _ := json.Marshal(svcs) // nolint
    	data := map[string]interface{}{
    		"Services":             svcs,
    		"ServicesJSON":         string(svcJSON),
    		"IncludeLogoutScripts": true,
    	}
    
    	// Clear the local session. Ignore errors.
    	httpSession, _ := h.authSessionStore.Get(req, authSessionKey) // nolint
    	delete(httpSession.Values, "data")
    	httpSession.Options.MaxAge = -1
    	httpSession.Save(req, w) // nolint
    
    	// Close the keystore.
    	if h.keystore != nil {
    		var shard string
    		if session.UserInfo != nil {
    			shard = session.UserInfo.Shard
    		}
    		if err := h.keystore.Close(req.Context(), shard, session.Username); err != nil {
    			log.Printf("failed to wipe keystore for user %s: %v", session.Username, err)
    		}
    	}
    
    	w.Header().Set("Content-Security-Policy", logoutContentSecurityPolicy)
    
    	body, err := h.renderer.Render(req, "logout.html", data)
    	if err != nil {
    		log.Printf("template error in logout(): %v", err)
    		http.Error(w, err.Error(), http.StatusInternalServerError)
    		return
    	}
    	w.Write(body) // nolint
    }
    
    func (h *Server) handleExchange(w http.ResponseWriter, req *http.Request) {
    	curToken := req.FormValue("cur_tkt")
    	curService := req.FormValue("cur_svc")
    	curNonce := req.FormValue("cur_nonce")
    	newService := req.FormValue("new_svc")
    	newNonce := req.FormValue("new_nonce")
    
    	token, err := h.loginService.Exchange(curToken, curService, curNonce, newService, newNonce)
    	switch {
    	case err == ErrUnauthorized:
    		log.Printf("unauthorized exchange request (%s -> %s)", curService, newService)
    		http.Error(w, "Forbidden", http.StatusForbidden)
    		return
    	case err != nil:
    		log.Printf("exchange error (%s -> %s): %v", curService, newService, err)
    		http.Error(w, err.Error(), http.StatusBadRequest)
    		return
    	}
    
    	w.Header().Set("Content-Type", "text/plain")
    	io.WriteString(w, token) // nolint
    }
    
    func (h *Server) urlFor(path string) string {
    	return h.urlPrefix + path
    }
    
    // Handler returns the http.Handler for the SSO server application.
    func (h *Server) Handler() http.Handler {
    	// The root HTTP handler. This must be a gorilla/mux.Router since
    	// sessions depend on it.
    	//
    	// If a URL prefix is set, we can't just add a StripPrefix in
    	// front of everything, as the handlers need access to the
    	// actual full request URL, so we just inject the prefix
    	// everywhere.
    	root := mux.NewRouter()
    
    	// If we have customized content, serve it from well-known URLs.
    	if h.siteLogo != nil {
    		root.Handle(h.urlFor("/img/site_logo"), h.siteLogo)
    	}
    	if h.siteFavicon != nil {
    		root.Handle(h.urlFor("/favicon.ico"), h.siteFavicon)
    	}
    
    	// Serve static content to anyone.
    	staticPath := h.urlFor("/static/")
    	root.PathPrefix(staticPath).Handler(http.StripPrefix(staticPath, http.FileServer(&assetfs.AssetFS{
    		Asset:     Asset,
    		AssetDir:  AssetDir,
    		AssetInfo: AssetInfo,
    		Prefix:    "static",
    	})))
    
    	// Build the main IDP application router, with optional CSRF
    	// protection.
    	m := http.NewServeMux()
    	m.Handle(h.urlFor("/login"), h.loginHandler)
    	m.Handle(h.urlFor("/logout"), h.withAuth(h.handleLogout, h.alreadyLoggedOut))
    	idph := http.Handler(m)
    	if h.csrfSecret != nil {
    		idph = csrf.Protect(h.csrfSecret)(idph)
    	}
    
    	// Add the SSO provider endpoints (root path and /exchange),
    	// which do not need CSRF. We use a HandlerFunc to bypass the
    	// '/' dispatch semantics of the standard http.ServeMux.
    	ssoh := h.withAuth(h.handleHomepage, h.redirectToLogin)
    	userh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		switch {
    		case r.Method == "GET" && r.URL.Path == h.urlFor("/"):
    			ssoh.ServeHTTP(w, r)
    		case r.URL.Path == h.urlFor("/exchange"):
    			h.handleExchange(w, r)
    		default:
    			idph.ServeHTTP(w, r)
    		}
    	})
    
    	// User-facing routes require cache-busting and CSP headers.
    	root.PathPrefix(h.urlFor("/")).Handler(withDynamicHeaders(userh))
    
    	return root
    }
    
    // A relatively strict CSP.
    const contentSecurityPolicy = "default-src 'none'; img-src 'self' data:; script-src 'self'; style-src 'self'; connect-src 'self';"
    
    // Slightly looser CSP for the logout page: it needs to load remote
    // images.
    const logoutContentSecurityPolicy = "default-src 'none'; img-src *; script-src 'self'; style-src 'self'; connect-src *;"
    
    func withDynamicHeaders(h http.Handler) http.Handler {
    	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		w.Header().Set("Pragma", "no-cache")
    		w.Header().Set("Cache-Control", "no-store")
    		w.Header().Set("Expires", "-1")
    		w.Header().Set("X-Frame-Options", "NONE")
    		w.Header().Set("X-XSS-Protection", "1; mode=block")
    		w.Header().Set("X-Content-Type-Options", "nosniff")
    		if w.Header().Get("Content-Security-Policy") == "" {
    			w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
    		}
    		h.ServeHTTP(w, r)
    	})
    }
    
    // Parse the templates that are embedded with the binary (in bindata.go).
    func parseEmbeddedTemplates() *template.Template {
    	root := template.New("").Funcs(template.FuncMap{
    		"json": toJSON,
    	})
    	files, err := AssetDir("templates")
    	if err != nil {
    		log.Fatalf("no asset dir for templates: %v", err)
    	}
    	for _, f := range files {
    		b, err := Asset("templates/" + f)
    		if err != nil {
    			log.Fatalf("could not read embedded template %s: %v", f, err)
    		}
    		if _, err := root.New(f).Parse(string(b)); err != nil {
    			log.Fatalf("error parsing template %s: %v", f, err)
    		}
    	}
    	return root
    }
    
    type renderer struct {
    	tpl *template.Template
    
    	urlPrefix          string
    	siteName           string
    	siteLogo           string
    	siteFavicon        string
    	accountRecoveryURL string
    }
    
    func newRenderer(config *Config) *renderer {
    	return &renderer{
    		tpl:                parseEmbeddedTemplates(),
    		urlPrefix:          strings.TrimRight(config.URLPrefix, "/"),
    		accountRecoveryURL: config.AccountRecoveryURL,
    		siteName:           config.SiteName,
    		siteLogo:           config.SiteLogo,
    		siteFavicon:        config.SiteFavicon,
    	}
    }
    
    func (r *renderer) Render(req *http.Request, templateName string, data map[string]interface{}) ([]byte, error) {
    	data["CSRFField"] = csrf.TemplateField(req)
    	data["URLPrefix"] = r.urlPrefix
    	data["AccountRecoveryURL"] = r.accountRecoveryURL
    	data["SiteName"] = r.siteName
    	data["SiteLogo"] = r.siteLogo
    	data["SiteFavicon"] = r.siteFavicon
    
    	var buf bytes.Buffer
    	if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
    		return nil, err
    	}
    	return buf.Bytes(), nil
    }
    
    type staticContent struct {
    	modtime time.Time
    	name    string
    	data    []byte
    }
    
    func loadStaticContent(path string) (*staticContent, error) {
    	stat, err := os.Stat(path)
    	if err != nil {
    		return nil, err
    	}
    	data, err := ioutil.ReadFile(path) // #nosec
    	if err != nil {
    		return nil, err
    	}
    	return &staticContent{
    		name:    path,
    		modtime: stat.ModTime(),
    		data:    data,
    	}, nil
    }
    
    func (c *staticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    	http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
    }