Skip to content
Snippets Groups Projects
Commit ff2091ae authored by ale's avatar ale
Browse files

Support hosting multiple domains more easily

parent a1c18acb
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ gpg --homedir $tmpdir --list-options show-only-fpr-mbox -k '*' \ ...@@ -20,7 +20,7 @@ gpg --homedir $tmpdir --list-options show-only-fpr-mbox -k '*' \
hash=$(echo $mbox | /usr/lib/gnupg/gpg-wks-client --print-wkd-hash | awk '{print $1}') hash=$(echo $mbox | /usr/lib/gnupg/gpg-wks-client --print-wkd-hash | awk '{print $1}')
domain=${mbox##*@} domain=${mbox##*@}
key=$(base64 -w0 < ${tmpdir}/openpgpkeys/${domain}/hu/${hash}) key=$(base64 -w0 < ${tmpdir}/openpgpkeys/${domain}/hu/${hash})
echo "${comma} \"${hash}\": {\"Addr\": \"${mbox}\", \"Data\": \"${key}\"}" echo "${comma} \"${hash}@${domain}\": {\"Addr\": \"${mbox}\", \"Data\": \"${key}\"}"
comma=, comma=,
done) done)
echo '}' echo '}'
......
...@@ -37,10 +37,10 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error) ...@@ -37,10 +37,10 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error)
} }
if filter == "" { if filter == "" {
filter = fmt.Sprintf("(%s=%%s)", hashLDAPAttr) filter = fmt.Sprintf("(%s=%%h@%%d)", hashLDAPAttr)
} }
if !strings.Contains(filter, "%s") { if !strings.Contains(filter, "%h") {
return nil, errors.New("filter expression does not contain literal '%s' token") return nil, errors.New("filter expression does not contain literal '%h' token")
} }
return &ldapStorage{ return &ldapStorage{
...@@ -50,8 +50,15 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error) ...@@ -50,8 +50,15 @@ func NewLDAPStorage(uri, bindDN, bindPw, baseDN, filter string) (Storage, error)
}, nil }, nil
} }
func (s *ldapStorage) Lookup(ctx context.Context, hash string) (*Key, error) { // Replace '%h' (for hash) and '%d' (for domain) tokens in the
filter := fmt.Sprintf(s.filter, ldap.EscapeFilter(hash)) // configured LDAP filter string, and return the result.
func (s *ldapStorage) filterExpr(hash, domain string) string {
// This is only safe because domain can't contain %h.
f := strings.Replace(s.filter, "%d", ldap.EscapeFilter(domain), -1)
return strings.Replace(f, "%h", ldap.EscapeFilter(hash), -1)
}
func (s *ldapStorage) Lookup(ctx context.Context, hash, domain string) (*Key, error) {
req := ldap.NewSearchRequest( req := ldap.NewSearchRequest(
s.baseDN, s.baseDN,
ldap.ScopeWholeSubtree, ldap.ScopeWholeSubtree,
...@@ -59,7 +66,7 @@ func (s *ldapStorage) Lookup(ctx context.Context, hash string) (*Key, error) { ...@@ -59,7 +66,7 @@ func (s *ldapStorage) Lookup(ctx context.Context, hash string) (*Key, error) {
0, 0,
0, 0,
false, false,
filter, s.filterExpr(hash, domain),
[]string{mailLDAPAttr, keyLDAPAttr}, []string{mailLDAPAttr, keyLDAPAttr},
nil, nil,
) )
......
...@@ -84,7 +84,7 @@ type Storage interface { ...@@ -84,7 +84,7 @@ type Storage interface {
// information. The special ErrNotFound error can be used to // information. The special ErrNotFound error can be used to
// indicate that no key was found, as opposed to a generic // indicate that no key was found, as opposed to a generic
// backend error. // backend error.
Lookup(context.Context, string) (*Key, error) Lookup(context.Context, string, string) (*Key, error)
} }
// Server for the WKD protocol. // Server for the WKD protocol.
...@@ -138,7 +138,7 @@ func (s *Server) serveDiscovery(w http.ResponseWriter, r *http.Request, request ...@@ -138,7 +138,7 @@ func (s *Server) serveDiscovery(w http.ResponseWriter, r *http.Request, request
// Go through available storages until one returns a valid key. // Go through available storages until one returns a valid key.
for _, storage := range s.storages { for _, storage := range s.storages {
key, err = storage.Lookup(r.Context(), request.Hash) key, err = storage.Lookup(r.Context(), request.Hash, request.Domain)
if err == nil { if err == nil {
break break
} }
......
...@@ -3,6 +3,7 @@ package wkd ...@@ -3,6 +3,7 @@ package wkd
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"os" "os"
) )
...@@ -30,8 +31,9 @@ func NewStaticStorage(path string) (Storage, error) { ...@@ -30,8 +31,9 @@ func NewStaticStorage(path string) (Storage, error) {
return &staticStorage{keys: keys}, nil return &staticStorage{keys: keys}, nil
} }
func (s *staticStorage) Lookup(ctx context.Context, hash string) (*Key, error) { func (s *staticStorage) Lookup(ctx context.Context, hash, domain string) (*Key, error) {
if key, ok := s.keys[hash]; ok { hashAddr := fmt.Sprintf("%s@%s", hash, domain)
if key, ok := s.keys[hashAddr]; ok {
return key, nil return key, nil
} }
return nil, ErrNotFound return nil, ErrNotFound
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment