diff --git a/cmd/sso-proxy/main.go b/cmd/sso-proxy/main.go new file mode 100644 index 0000000000000000000000000000000000000000..505a601cf245f63d67b29398c49ef13151bcbe43 --- /dev/null +++ b/cmd/sso-proxy/main.go @@ -0,0 +1,81 @@ +package main + +import ( + "context" + "flag" + "io/ioutil" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "gopkg.in/yaml.v2" + + "git.autistici.org/id/go-sso/proxy" +) + +var ( + addr = flag.String("addr", ":5003", "address to listen on") + configFile = flag.String("config", "/etc/sso/proxy.yml", "path of config file") +) + +func loadConfig() (*proxy.Configuration, error) { + // Read YAML config. + data, err := ioutil.ReadFile(*configFile) + if err != nil { + return nil, err + } + var config proxy.Configuration + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, err + } + return &config, nil +} + +// Set defaults for command-line flags using variables from the environment. +func setFlagDefaultsFromEnv() { + flag.VisitAll(func(f *flag.Flag) { + envVar := "SSOPROXY_" + strings.ToUpper(strings.Replace(f.Name, "-", "_", -1)) + if value := os.Getenv(envVar); value != "" { + f.DefValue = value + f.Value.Set(value) + } + }) +} + +func main() { + setFlagDefaultsFromEnv() + flag.Parse() + + config, err := loadConfig() + if err != nil { + log.Fatal(err) + } + + h, err := proxy.NewProxy(config) + if err != nil { + log.Fatal(err) + } + + srv := &http.Server{ + Addr: *addr, + Handler: h, + } + + sigCh := make(chan os.Signal, 1) + go func() { + <-sigCh + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(ctx) + _ = srv.Close() + }() + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} diff --git a/debian/control b/debian/control index 540ba3694b302937f3d0d850f2953d24c4fb4a0a..d011af6a3c72a2943b684b9c52dd9653dcf4ec36 100644 --- a/debian/control +++ b/debian/control @@ -10,3 +10,9 @@ Architecture: any Depends: ${shlibs:Depends}, ${misc:Depends}, auth-server Description: Single-Sign-On server. Single-Sign-On server, integrated with git.autistici.org/id/auth. + +Package: sso-proxy +Architecture: any +Depends: ${shlibs:Depends}, ${misc:Depends} +Description: Single-Sign-On HTTP proxy. + Single-Sign-On HTTP proxy. diff --git a/debian/rules b/debian/rules index f6dd9d5e741b1cbed836c1e76b1ccd3a40d7c8cf..e940d36ce92728bcec9f6237f1959b0919e4c898 100755 --- a/debian/rules +++ b/debian/rules @@ -8,7 +8,7 @@ export DH_GOLANG_EXCLUDES = vendor dh $@ --with systemd --with golang --buildsystem golang override_dh_install: - rm -fr $(CURDIR)/debian/sso-server/usr/share/gocode + rm -fr $(CURDIR)/debian/tmp/usr/share/gocode dh_install override_dh_systemd_enable: diff --git a/debian/sso-proxy.default b/debian/sso-proxy.default new file mode 100644 index 0000000000000000000000000000000000000000..3f67963124e7b2fd6c5c48389a26f23b23e21769 --- /dev/null +++ b/debian/sso-proxy.default @@ -0,0 +1 @@ +ADDR=:5003 diff --git a/debian/sso-proxy.install b/debian/sso-proxy.install new file mode 100644 index 0000000000000000000000000000000000000000..19078fd89eed587c3af21189aedc2ac2bfceee8b --- /dev/null +++ b/debian/sso-proxy.install @@ -0,0 +1 @@ +usr/bin/sso-proxy diff --git a/debian/sso-proxy.postinst b/debian/sso-proxy.postinst new file mode 100755 index 0000000000000000000000000000000000000000..73768438ac8522fd4dbf496e61d47f5f3dd66ae6 --- /dev/null +++ b/debian/sso-proxy.postinst @@ -0,0 +1,16 @@ +#!/bin/sh + +set -e + +case "$1" in +configure) + addgroup --system --quiet sso-proxy + adduser --system --no-create-home --home /run/sso-proxy \ + --disabled-password --disabled-login \ + --quiet --ingroup sso-proxy sso-proxy + ;; +esac + +#DEBHELPER# + +exit 0 diff --git a/debian/sso-proxy.service b/debian/sso-proxy.service new file mode 100644 index 0000000000000000000000000000000000000000..382c65ae8a32fab8abb8934c00bbcb2fb7836834 --- /dev/null +++ b/debian/sso-proxy.service @@ -0,0 +1,13 @@ +[Unit] +Description=SSO Proxy + +[Service] +User=sso-proxy +Group=sso-proxy +EnvironmentFile=-/etc/default/sso-proxy +ExecStart=/usr/bin/sso-proxy --addr $ADDR +Restart=always + +[Install] +WantedBy=multi-user.target + diff --git a/debian/sso-server.install b/debian/sso-server.install new file mode 100644 index 0000000000000000000000000000000000000000..91096c7b0ea2d10cf3ccaa67711120f2957141d0 --- /dev/null +++ b/debian/sso-server.install @@ -0,0 +1 @@ +usr/bin/sso-server diff --git a/debian/postinst b/debian/sso-server.postinst similarity index 100% rename from debian/postinst rename to debian/sso-server.postinst diff --git a/httpsso/handler.go b/httpsso/handler.go new file mode 100644 index 0000000000000000000000000000000000000000..75a61ebedd412c260f3817eceaf4c8b498cc46b7 --- /dev/null +++ b/httpsso/handler.go @@ -0,0 +1,168 @@ +package httpsso + +import ( + "encoding/gob" + "encoding/hex" + "io" + "math/rand" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gorilla/sessions" + + "git.autistici.org/id/go-sso" + "git.autistici.org/id/go-sso/httputil" +) + +type authSession struct { + *httputil.ExpiringSession + + Auth bool + Username string +} + +var authSessionLifetime = 1 * time.Hour + +func init() { + gob.Register(&authSession{}) +} + +// SSOWrapper protects http handlers with single-sign-on authentication. +type SSOWrapper struct { + v sso.Validator + sessionAuthKey []byte + sessionEncKey []byte + serverURL string +} + +// NewSSOWrapper returns a new SSOWrapper that will authenticate users +// on the specified login service. +func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, sessionEncKey []byte) (*SSOWrapper, error) { + v, err := sso.NewValidator(pkey, domain) + if err != nil { + return nil, err + } + + return &SSOWrapper{ + v: v, + serverURL: serverURL, + sessionAuthKey: sessionAuthKey, + sessionEncKey: sessionEncKey, + }, nil +} + +// Wrap a http.Handler with authentication and access control. +// Currently only a simple form of group-based ACLs is supported. +func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.Handler { + svcPath := pathFromService(service) + store := sessions.NewCookieStore(s.sessionAuthKey, s.sessionEncKey) + store.Options = &sessions.Options{ + HttpOnly: true, + Secure: true, + MaxAge: 0, + Path: svcPath, + } + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + session, _ := store.Get(req, "sso") + + switch strings.TrimPrefix(req.URL.Path, svcPath) { + case "sso_login": + s.handleLogin(w, req, session, service, groups) + + case "sso_logout": + s.handleLogout(w, req, session) + + default: + if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth { + req.Header.Set("X-Authenticated-User", auth.Username) + + h.ServeHTTP(w, req) + return + } + + s.redirectToLogin(w, req, session, service, groups) + } + }) +} + +func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) { + t := req.FormValue("t") + d := req.FormValue("d") + + // Pop the nonce from the session. + nonce, ok := session.Values["nonce"].(string) + if !ok || nonce == "" { + http.Error(w, "Missing nonce", http.StatusBadRequest) + return + } + delete(session.Values, "nonce") + + tkt, err := s.v.Validate(t, nonce, service, groups) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Authenticate the user. + session.Values["a"] = &authSession{ + ExpiringSession: httputil.NewExpiringSession(authSessionLifetime), + Auth: true, + Username: tkt.User, + } + if err := sessions.Save(req, w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + http.Redirect(w, req, d, http.StatusFound) +} + +func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) { + session.Options.MaxAge = -1 + if err := sessions.Save(req, w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "OK") +} + +// Redirect to the SSO server. +func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) { + // Generate a random nonce and store it in the local session. + nonce := makeUniqueNonce() + session.Values["nonce"] = nonce + if err := sessions.Save(req, w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + v := make(url.Values) + v.Set("s", service) + v.Set("d", req.URL.String()) + v.Set("n", nonce) + v.Set("g", strings.Join(groups, ",")) + loginURL := s.serverURL + "?" + v.Encode() + http.Redirect(w, req, loginURL, http.StatusFound) +} + +// Extract the URL path from the service specification. The result +// will have both a leading and a trailing slash. +func pathFromService(service string) string { + i := strings.IndexRune(service, '/') + if i < 0 { + return "" + } + return service[i:] +} + +func makeUniqueNonce() string { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + panic(err) + } + return hex.EncodeToString(b[:]) +} diff --git a/httputil/session.go b/httputil/session.go new file mode 100644 index 0000000000000000000000000000000000000000..f1847ba98e563371f0188afd191550a224ac98e5 --- /dev/null +++ b/httputil/session.go @@ -0,0 +1,35 @@ +package httputil + +import ( + "encoding/gob" + "time" +) + +// ExpiringSession is a session with server-side expiration check. +// Session data is saved in signed, encrypted cookies in the +// browser. We'd like these cookies to expire when a certain amount of +// time passes, or when the user closes the browser. We trust the +// browser for the latter, but we enforce time-based expiration on the +// server. +type ExpiringSession struct { + Expiry time.Time +} + +// NewExpiringSession returns a session that is valid for the given +// duration. +func NewExpiringSession(ttl time.Duration) *ExpiringSession { + return &ExpiringSession{ + Expiry: time.Now().Add(ttl), + } +} + +// Valid returns true if the session has not expired yet. +// It can be called with a nil receiver. +func (e *ExpiringSession) Valid() bool { + return e != nil && time.Now().Before(e.Expiry) +} + +func init() { + gob.Register(&ExpiringSession{}) +} + diff --git a/server/util_test.go b/httputil/session_test.go similarity index 87% rename from server/util_test.go rename to httputil/session_test.go index 3b8c773c91bca06f0b13bbf31284bc275df80a29..df4ca4696c4d6d4b2322353a5acf2707a320dac7 100644 --- a/server/util_test.go +++ b/httputil/session_test.go @@ -1,4 +1,4 @@ -package server +package httputil import ( "bytes" @@ -14,7 +14,7 @@ func TestExpiringSession(t *testing.T) { Data string } s := &mySession{ - ExpiringSession: newExpiringSession(60 * time.Second), + ExpiringSession: NewExpiringSession(60 * time.Second), Data: "data", } diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 0000000000000000000000000000000000000000..f7aa72f06f4529ffc8ec3c178d02546dfb7363d3 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,182 @@ +package proxy + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/gorilla/mux" + + "git.autistici.org/id/go-sso/httpsso" +) + +// Backend defines a single-host HTTP proxy to a set of upstream +// backends. +type Backend struct { + Host string `yaml:"host"` + Upstream []string `yaml:"upstream"` + ClientTLSConfig *TLSConfig `yaml:"client_tls"` + //ServerTLSConfig *TLSConfig `yaml:"server_tls"` + + AllowedGroups []string `yaml:"allowed_groups"` +} + +func (b *Backend) newHandler(ssow *httpsso.SSOWrapper) (http.Handler, error) { + // Setup upstream connections. + if len(b.Upstream) < 1 { + return nil, errors.New("no backends specified") + } + + u := &url.URL{Scheme: "http", Host: b.Host} + if b.ClientTLSConfig != nil { + u.Scheme = "https" + } + + proxy := httputil.NewSingleHostReverseProxy(u) + + var tlsConfig *tls.Config + if b.ClientTLSConfig != nil { + var err error + tlsConfig, err = b.ClientTLSConfig.toClientConfig() + if err != nil { + return nil, err + } + } + proxy.Transport = newTransport(b.Upstream, tlsConfig) + + h := ssow.Wrap(proxy, b.Host+"/", b.AllowedGroups) + return h, nil +} + +// TLSConfig defines the TLS parameters for a client connection. +type TLSConfig struct { + Cert string `yaml:"cert"` + Key string `yaml:"key"` + CA string `yaml:"ca"` +} + +func (c *TLSConfig) toClientConfig() (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + return nil, err + } + + cas, err := loadCA(c.CA) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: cas, + }, nil +} + +func loadCA(path string) (*x509.CertPool, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + cas := x509.NewCertPool() + cas.AppendCertsFromPEM(data) + return cas, nil +} + +// func buildServerTLSConfig(config *Configuration) (*tls.Config, error) { +// var certs []tls.Certificate +// for _, b := range config.Backends { +// cert, err := tls.LoadX509KeyPair(b.ServerTLSConfig.Cert, b.ServerTLSConfig.Key) +// if err != nil { +// return nil, err +// } +// certs = append(certs, cert) +// } + +// c := &tls.Config{ +// Certificates: certs, +// } + +// if config.CA != "" { +// cas, err := loadCA(config.CA) +// if err != nil { +// return nil, err +// } +// c.ClientAuth = tls.RequireAndVerifyClientCert +// c.ClientCAs = cas +// } + +// c.BuildNameToCertificate() + +// return c, nil +// } + +// Configuration for the proxy. +type Configuration struct { + SessionAuthKey []byte `yaml:"session_auth_key"` + SessionEncKey []byte `yaml:"session_enc_key"` + CA string `yaml:"ca"` + + SSOLoginServerURL string `yaml:"sso_server_url"` + SSOPublicKeyFile string `yaml:"sso_public_key_file"` + SSODomain string `yaml:"sso_domain"` + + Backends []*Backend `yaml:"backends"` +} + +// Sanity checks for the configuration. +func (c *Configuration) check() error { + switch len(c.SessionAuthKey) { + case 32, 64: + case 0: + return errors.New("session_auth_key is empty") + default: + return errors.New("session_auth_key must be a random string of 32 or 64 bytes") + } + switch len(c.SessionEncKey) { + case 16, 24, 43: + case 0: + return errors.New("session_enc_key is empty") + default: + return errors.New("session_enc_key must be a random string of 16, 24 or 32 bytes") + } + if c.SSOLoginServerURL == "" { + return errors.New("sso_server_url is empty") + } + if c.SSODomain == "" { + return errors.New("sso_domain is empty") + } + return nil +} + +// NewProxy builds a SSO-protected multi-host handler with the +// specified configuration. +func NewProxy(config *Configuration) (http.Handler, error) { + if err := config.check(); err != nil { + return nil, err + } + + pkey, err := ioutil.ReadFile(config.SSOPublicKeyFile) + if err != nil { + return nil, err + } + + w, err := httpsso.NewSSOWrapper(config.SSOLoginServerURL, pkey, config.SSODomain, config.SessionAuthKey, config.SessionEncKey) + if err != nil { + return nil, err + } + + r := mux.NewRouter() + for _, b := range config.Backends { + h, err := b.newHandler(w) + if err != nil { + return nil, fmt.Errorf("error for host %s: %v", b.Host, err) + } + r.Host(b.Host).Handler(h) + } + return r, nil +} diff --git a/proxy/transport.go b/proxy/transport.go new file mode 100644 index 0000000000000000000000000000000000000000..38da38300762a10caddbe27ce27702954454dcc5 --- /dev/null +++ b/proxy/transport.go @@ -0,0 +1,128 @@ +package proxy + +import ( + "crypto/tls" + "errors" + "log" + "math/rand" + "net" + "net/http" + "sort" + "sync" +) + +func resolveIPs(hosts []string) []string { + var resolved []string + for _, hostport := range hosts { + host, port, err := net.SplitHostPort(hostport) + if err != nil { + log.Printf("error parsing %s: %v", hostport, err) + continue + } + hostIPs, err := net.LookupIP(host) + if err != nil { + log.Printf("error resolving %s: %v", host, err) + continue + } + for _, ip := range hostIPs { + resolved = append(resolved, net.JoinHostPort(ip.String(), port)) + } + } + return resolved +} + +type balancer struct { + mx sync.Mutex + ips []string + errs []uint64 +} + +func (b *balancer) incrError(index int) { + b.mx.Lock() + b.errs[index]++ + b.mx.Unlock() +} + +func (b *balancer) dial(network, addr string) (net.Conn, error) { + ips, err := b.pickIPs() + if err != nil { + return nil, err + } + for _, s := range ips { + conn, err := net.Dial(network, s.ip) + if err == nil { + return conn, nil + } + log.Printf("error connecting to %s: %v", s.ip, err) + b.incrError(s.index) + } + return nil, errors.New("all upstream connections failed") +} + +type ipScore struct { + ip string + score int + index int +} + +type ipScoreList []ipScore + +func (l ipScoreList) Len() int { return len(l) } +func (l ipScoreList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l ipScoreList) Less(i, j int) bool { return l[i].score < l[j].score } + +func shuffleScores(scores []ipScore) { + for i, j := range rand.Perm(len(scores)) { + scores[i], scores[j] = scores[j], scores[i] + } +} + +const minErrs = 3 + +func (b *balancer) pickIPs() ([]ipScore, error) { + b.mx.Lock() + scores := make([]ipScore, len(b.ips)) + for i, ip := range b.ips { + score := 1 + if b.errs[i] > minErrs { + score *= 10 + } + scores[i] = ipScore{ip: ip, score: score, index: i} + } + b.mx.Unlock() + + sort.Sort(ipScoreList(scores)) + + // Iterate through the sorted list, shuffling groups of + // elements that have identical scores. + curScore := scores[0].score + head := 0 + for i := 1; i < len(scores); i++ { + if scores[i].score != curScore { + group := scores[head : i+1] + if len(group) > 1 { + shuffleScores(group) + } + head = i + 1 + } + } + group := scores[head:] + if len(group) > 1 { + shuffleScores(group) + } + + return scores, nil +} + +func newTransport(backends []string, tlsConf *tls.Config) http.RoundTripper { + ips := resolveIPs(backends) + b := &balancer{ + ips: ips, + errs: make([]uint64, len(ips)), + } + + return &http.Transport{ + Dial: b.dial, + TLSClientConfig: tlsConf, + } +} diff --git a/server/http.go b/server/http.go index 244b4ec87ab11817295fcfd62e6dd4521b26d2ae..daae41bece17c2dd0d0636dcc85e9a784bdfc58c 100644 --- a/server/http.go +++ b/server/http.go @@ -21,13 +21,14 @@ import ( "git.autistici.org/id/auth" authclient "git.autistici.org/id/auth/client" + "git.autistici.org/id/go-sso/httputil" "git.autistici.org/id/go-sso/server/device" ) const authSessionKey = "_auth" type authSession struct { - *ExpiringSession + *httputil.ExpiringSession // User name and other information (like group membership). Username string @@ -53,7 +54,7 @@ var defaultAuthSessionLifetime = 20 * time.Hour func newAuthSession(ttl time.Duration, username string, userinfo *auth.UserInfo) *authSession { return &authSession{ - ExpiringSession: newExpiringSession(ttl), + ExpiringSession: httputil.NewExpiringSession(ttl), Username: username, UserInfo: userinfo, } diff --git a/server/login.go b/server/login.go index 847bc2b4efe3bea8246cad15d1f57162c32ede88..f79848f762220ee78c39a00a169a285d6ce58994 100644 --- a/server/login.go +++ b/server/login.go @@ -17,13 +17,14 @@ import ( "git.autistici.org/id/auth" authclient "git.autistici.org/id/auth/client" + "git.autistici.org/id/go-sso/httputil" "git.autistici.org/id/go-sso/server/device" ) const loginSessionKey = "_login" type loginSession struct { - *ExpiringSession + *httputil.ExpiringSession State loginState @@ -46,7 +47,7 @@ var defaultLoginSessionLifetime = 300 * time.Second func newLoginSession() *loginSession { return &loginSession{ - ExpiringSession: newExpiringSession(defaultLoginSessionLifetime), + ExpiringSession: httputil.NewExpiringSession(defaultLoginSessionLifetime), State: loginStatePassword, } } diff --git a/server/util.go b/server/util.go index 179a60ff19e56b1e68951b97545c6567f5e09108..63439fd819cf81487954bc0d5cd574226f665471 100644 --- a/server/util.go +++ b/server/util.go @@ -1,35 +1,5 @@ package server -import ( - "encoding/gob" - "time" -) - -// ExpiringSession is a session with server-side expiration check. -// Session data is saved in signed, encrypted cookies in the -// browser. We'd like these cookies to expire when a certain amount of -// time passes, or when the user closes the browser. We trust the -// browser for the latter, but we enforce time-based expiration on the -// server. -type ExpiringSession struct { - Expiry time.Time -} - -func newExpiringSession(ttl time.Duration) *ExpiringSession { - return &ExpiringSession{ - Expiry: time.Now().Add(ttl), - } -} - -// Valid returns true if the session has not expired yet. -func (e *ExpiringSession) Valid() bool { - return e != nil && time.Now().Before(e.Expiry) -} - -func init() { - gob.Register(&ExpiringSession{}) -} - // Returns the intersection of two string lists (in O(N^2) time). func intersectGroups(a, b []string) []string { var out []string