From ff2091ae90c96651cda5035851a74875c3ece0ab Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Fri, 1 Apr 2022 13:51:56 +0100
Subject: [PATCH] Support hosting multiple domains more easily

---
 build-static-bundle.sh |  2 +-
 ldap.go                | 19 +++++++++++++------
 server.go              |  4 ++--
 static.go              |  6 ++++--
 4 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/build-static-bundle.sh b/build-static-bundle.sh
index 596aa79..e8d0994 100755
--- a/build-static-bundle.sh
+++ b/build-static-bundle.sh
@@ -20,7 +20,7 @@ gpg --homedir $tmpdir --list-options show-only-fpr-mbox -k '*' \
            hash=$(echo $mbox | /usr/lib/gnupg/gpg-wks-client --print-wkd-hash | awk '{print $1}')
            domain=${mbox##*@}
            key=$(base64 -w0 < ${tmpdir}/openpgpkeys/${domain}/hu/${hash})
-           echo "${comma}  \"${hash}\": {\"Addr\": \"${mbox}\", \"Data\": \"${key}\"}"
+           echo "${comma}  \"${hash}@${domain}\": {\"Addr\": \"${mbox}\", \"Data\": \"${key}\"}"
            comma=,
        done)
 echo '}'
diff --git a/ldap.go b/ldap.go
index 897d510..62adf6c 100644
--- a/ldap.go
+++ b/ldap.go
@@ -37,10 +37,10 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error)
 	}
 
 	if filter == "" {
-		filter = fmt.Sprintf("(%s=%%s)", hashLDAPAttr)
+		filter = fmt.Sprintf("(%s=%%h@%%d)", hashLDAPAttr)
 	}
-	if !strings.Contains(filter, "%s") {
-		return nil, errors.New("filter expression does not contain literal '%s' token")
+	if !strings.Contains(filter, "%h") {
+		return nil, errors.New("filter expression does not contain literal '%h' token")
 	}
 
 	return &ldapStorage{
@@ -50,8 +50,15 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error)
 	}, nil
 }
 
-func (s *ldapStorage) Lookup(ctx context.Context, hash string) (*Key, error) {
-	filter := fmt.Sprintf(s.filter, ldap.EscapeFilter(hash))
+// Replace '%h' (for hash) and '%d' (for domain) tokens in the
+// configured LDAP filter string, and return the result.
+func (s *ldapStorage) filterExpr(hash, domain string) string {
+	// This is only safe because domain can't contain %h.
+	f := strings.Replace(s.filter, "%d", ldap.EscapeFilter(domain), -1)
+	return strings.Replace(f, "%h", ldap.EscapeFilter(hash), -1)
+}
+
+func (s *ldapStorage) Lookup(ctx context.Context, hash, domain string) (*Key, error) {
 	req := ldap.NewSearchRequest(
 		s.baseDN,
 		ldap.ScopeWholeSubtree,
@@ -59,7 +66,7 @@ func (s *ldapStorage) Lookup(ctx context.Context, hash string) (*Key, error) {
 		0,
 		0,
 		false,
-		filter,
+		s.filterExpr(hash, domain),
 		[]string{mailLDAPAttr, keyLDAPAttr},
 		nil,
 	)
diff --git a/server.go b/server.go
index 137372a..a64b259 100644
--- a/server.go
+++ b/server.go
@@ -84,7 +84,7 @@ type Storage interface {
 	// information. The special ErrNotFound error can be used to
 	// indicate that no key was found, as opposed to a generic
 	// backend error.
-	Lookup(context.Context, string) (*Key, error)
+	Lookup(context.Context, string, string) (*Key, error)
 }
 
 // Server for the WKD protocol.
@@ -138,7 +138,7 @@ func (s *Server) serveDiscovery(w http.ResponseWriter, r *http.Request, request
 
 	// Go through available storages until one returns a valid key.
 	for _, storage := range s.storages {
-		key, err = storage.Lookup(r.Context(), request.Hash)
+		key, err = storage.Lookup(r.Context(), request.Hash, request.Domain)
 		if err == nil {
 			break
 		}
diff --git a/static.go b/static.go
index c3f0f2d..f6fbd2e 100644
--- a/static.go
+++ b/static.go
@@ -3,6 +3,7 @@ package wkd
 import (
 	"context"
 	"encoding/json"
+	"fmt"
 	"os"
 )
 
@@ -30,8 +31,9 @@ func NewStaticStorage(path string) (Storage, error) {
 	return &staticStorage{keys: keys}, nil
 }
 
-func (s *staticStorage) Lookup(ctx context.Context, hash string) (*Key, error) {
-	if key, ok := s.keys[hash]; ok {
+func (s *staticStorage) Lookup(ctx context.Context, hash, domain string) (*Key, error) {
+	hashAddr := fmt.Sprintf("%s@%s", hash, domain)
+	if key, ok := s.keys[hashAddr]; ok {
 		return key, nil
 	}
 	return nil, ErrNotFound
-- 
GitLab