From f20eaef0ac3be1a169a10c926b72aa2ac9505a61 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Thu, 1 Aug 2019 13:37:10 +0100
Subject: [PATCH] Add SSO group ACLs to the SAML-bridge provider configuration

One can now restrict providers to specific SSO groups.
---
 saml/saml.go | 77 ++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 62 insertions(+), 15 deletions(-)

diff --git a/saml/saml.go b/saml/saml.go
index a75dffd..42090d9 100644
--- a/saml/saml.go
+++ b/saml/saml.go
@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"net/http"
 	"net/url"
 	"os"
@@ -24,6 +25,13 @@ import (
 	"git.autistici.org/id/go-sso/httpsso"
 )
 
+type serviceProvider struct {
+	Descriptor string   `yaml:"descriptor"`
+	SSOGroups  []string `yaml:"sso_groups"`
+
+	parsed *saml.EntityDescriptor
+}
+
 type Config struct {
 	BaseURL string `yaml:"base_url"`
 
@@ -41,8 +49,8 @@ type Config struct {
 	SSODomain         string `yaml:"sso_domain"`
 
 	// Service provider config.
-	ServiceProviders       []string `yaml:"service_providers"`
-	parsedServiceProviders map[string]*saml.EntityDescriptor
+	ServiceProviders   []*serviceProvider `yaml:"service_providers"`
+	serviceProviderMap map[string]*serviceProvider
 }
 
 // Sanity checks for the configuration.
@@ -71,9 +79,9 @@ func (c *Config) check() error {
 }
 
 func (c *Config) loadServiceProviders() error {
-	c.parsedServiceProviders = make(map[string]*saml.EntityDescriptor)
-	for _, path := range c.ServiceProviders {
-		data, err := ioutil.ReadFile(path)
+	c.serviceProviderMap = make(map[string]*serviceProvider)
+	for _, sp := range c.ServiceProviders {
+		data, err := ioutil.ReadFile(sp.Descriptor)
 		if err != nil {
 			return err
 		}
@@ -81,17 +89,26 @@ func (c *Config) loadServiceProviders() error {
 		if err := xml.Unmarshal(data, &ent); err != nil {
 			return err
 		}
-		c.parsedServiceProviders[ent.EntityID] = &ent
+		sp.parsed = &ent
+		c.serviceProviderMap[ent.EntityID] = sp
 	}
 	return nil
 }
 
 func (c *Config) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) {
-	srv, ok := c.parsedServiceProviders[serviceProviderID]
+	sp, ok := c.serviceProviderMap[serviceProviderID]
 	if !ok {
 		return nil, os.ErrNotExist
 	}
-	return srv, nil
+	return sp.parsed, nil
+}
+
+func (c *Config) GetSSOGroups(serviceProviderID string) []string {
+	sp, ok := c.serviceProviderMap[serviceProviderID]
+	if !ok {
+		return nil
+	}
+	return sp.SSOGroups
 }
 
 // Read users from a YAML-encoded file, in a format surprisingly
@@ -106,11 +123,12 @@ type userInfo struct {
 }
 
 type userFileBackend struct {
-	users map[string]userInfo
+	config *Config
+	users  map[string]userInfo
 }
 
-func newUserFileBackend(path string) (*userFileBackend, error) {
-	data, err := ioutil.ReadFile(path)
+func newUserFileBackend(config *Config) (*userFileBackend, error) {
+	data, err := ioutil.ReadFile(config.UsersFile)
 	if err != nil {
 		return nil, err
 	}
@@ -122,22 +140,51 @@ func newUserFileBackend(path string) (*userFileBackend, error) {
 	for _, u := range userList {
 		users[u.Name] = u
 	}
-	return &userFileBackend{users}, nil
+	return &userFileBackend{
+		config: config,
+		users:  users,
+	}, nil
+}
+
+// Nice little O(N^2) algorithm right there...
+func matchGroups(user, exp []string) bool {
+	if exp == nil {
+		return true
+	}
+	for _, ug := range user {
+		for _, eg := range exp {
+			if ug == eg {
+				return true
+			}
+		}
+	}
+	return false
 }
 
 func (b *userFileBackend) GetSession(w http.ResponseWriter, r *http.Request, req *saml.IdpAuthnRequest) *saml.Session {
-	// The request should have the X-Authenticated-User header.
-	username := r.Header.Get("X-Authenticated-User")
+	// Check for authentication by verifying the SSO username. We
+	// also need to be able to retrieve user information from the
+	// backend, to match SSO. Group membership, if enabled in our
+	// configuration, is also verified at this stage.
+	username := httpsso.Username(r)
 	if username == "" {
 		http.Error(w, "No user found", http.StatusInternalServerError)
 		return nil
 	}
+
+	if !matchGroups(httpsso.Groups(r), b.config.GetSSOGroups(req.ServiceProviderMetadata.ID)) {
+		http.Error(w, "Forbidden (bad group)", http.StatusForbidden)
+		return nil
+	}
+
 	user, ok := b.users[username]
 	if !ok {
 		http.Error(w, "User not found", http.StatusInternalServerError)
 		return nil
 	}
 
+	log.Printf("successfully authenticated session for username=%s, provider=%s", username, req.ServiceProviderMetadata.ID)
+
 	return &saml.Session{
 		ID:             base64.StdEncoding.EncodeToString(randomBytes(32)),
 		CreateTime:     saml.TimeNow(),
@@ -186,7 +233,7 @@ func NewSAMLIDP(config *Config) (http.Handler, error) {
 		svc += "/"
 	}
 
-	users, err := newUserFileBackend(config.UsersFile)
+	users, err := newUserFileBackend(config)
 	if err != nil {
 		return nil, err
 	}
-- 
GitLab