From 11cab73f6d97bcc0e631aea6af2946a20e6eee8c Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Tue, 26 Jun 2018 18:55:07 +0100
Subject: [PATCH] Minor refactoring of LDAP query templates

Rename queryConfig to queryTemplate to better reflect its purpose, and
drop all cruft that had to do with config deserialization.
---
 backend/model.go      | 47 ++++++++++++++++++++-----------------
 backend/model_test.go | 24 +++++++++++++++++--
 backend/resources.go  | 54 +++++++++++++++++++++----------------------
 backend/util.go       | 54 +++++++++++++++----------------------------
 4 files changed, 93 insertions(+), 86 deletions(-)

diff --git a/backend/model.go b/backend/model.go
index 5f110243..e2db1b4c 100644
--- a/backend/model.go
+++ b/backend/model.go
@@ -31,8 +31,8 @@ const (
 type backend struct {
 	conn                ldapConn
 	baseDN              string
-	userQuery           *queryConfig
-	userResourceQueries []*queryConfig
+	userQuery           *queryTemplate
+	userResourceQueries []*queryTemplate
 	resources           *resourceRegistry
 }
 
@@ -74,22 +74,23 @@ func newLDAPBackendWithConn(conn ldapConn, base string) (*backend, error) {
 	return &backend{
 		conn:   conn,
 		baseDN: base,
-		userQuery: mustCompileQueryConfig(&queryConfig{
-			Base:  "uid=${user},ou=People," + base,
-			Scope: "base",
-		}),
-		userResourceQueries: []*queryConfig{
+		userQuery: &queryTemplate{
+			Base:   joinDN("uid=${user}", "ou=People", base),
+			Filter: "(objectClass=*)",
+			Scope:  ldap.ScopeBaseObject,
+		},
+		userResourceQueries: []*queryTemplate{
 			// Find all resources that are children of the main uid object.
-			mustCompileQueryConfig(&queryConfig{
-				Base:  "uid=${user},ou=People," + base,
-				Scope: "sub",
-			}),
+			&queryTemplate{
+				Base:  joinDN("uid=${user}", "ou=People", base),
+				Scope: ldap.ScopeWholeSubtree,
+			},
 			// Find mailing lists, which are nested under a different root.
-			mustCompileQueryConfig(&queryConfig{
-				Base:   "ou=Lists," + base,
+			&queryTemplate{
+				Base:   joinDN("ou=Lists", base),
 				Filter: "(&(objectClass=mailingList)(listOwner=${user}))",
-				Scope:  "one",
-			}),
+				Scope:  ldap.ScopeSingleLevel,
+			},
 		},
 		resources: rsrc,
 	}, nil
@@ -155,13 +156,16 @@ func (tx *backendTX) CreateUser(ctx context.Context, user *accountserver.User) e
 func (tx *backendTX) GetUser(ctx context.Context, username string) (*accountserver.User, error) {
 	// First of all, find the main user object, and just that one.
 	vars := map[string]string{"user": username}
-	result, err := tx.search(ctx, tx.backend.userQuery.searchRequest(vars, nil))
+	result, err := tx.search(ctx, tx.backend.userQuery.query(vars))
 	if err != nil {
 		if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
 			return nil, nil
 		}
 		return nil, err
 	}
+	if len(result.Entries) == 0 {
+		return nil, nil
+	}
 
 	user, err := newUser(result.Entries[0])
 	if err != nil {
@@ -172,8 +176,8 @@ func (tx *backendTX) GetUser(ctx context.Context, username string) (*accountserv
 	// object we just created.
 	// TODO: parallelize.
 	// TODO: add support for non-LDAP resource queries.
-	for _, query := range tx.backend.userResourceQueries {
-		result, err = tx.search(ctx, query.searchRequest(vars, nil))
+	for _, tpl := range tx.backend.userResourceQueries {
+		result, err = tx.search(ctx, tpl.query(vars))
 		if err != nil {
 			continue
 		}
@@ -300,16 +304,17 @@ func (tx *backendTX) SetResourcePassword(ctx context.Context, r *accountserver.R
 }
 
 func (tx *backendTX) hasResource(ctx context.Context, resourceType, resourceName string) (bool, error) {
-	query, err := tx.backend.resources.SearchQuery(resourceType)
+	tpl, err := tx.backend.resources.SearchQuery(resourceType)
 	if err != nil {
 		return false, err
 	}
 
 	// Make a quick LDAP search that only fetches the DN attribute.
-	result, err := tx.search(ctx, query.searchRequest(map[string]string{
+	tpl.Attrs = []string{"dn"}
+	result, err := tx.search(ctx, tpl.query(map[string]string{
 		"resource": resourceName,
 		"type":     resourceType,
-	}, []string{"dn"}))
+	}))
 	if err != nil {
 		if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
 			return false, nil
diff --git a/backend/model_test.go b/backend/model_test.go
index c99c8f62..151b6872 100644
--- a/backend/model_test.go
+++ b/backend/model_test.go
@@ -25,7 +25,7 @@ func startServerAndGetUser2(t testing.TB) (func(), accountserver.Backend, *accou
 	return startServerAndGetUserWithName(t, testUser2)
 }
 
-func startServerAndGetUserWithName(t testing.TB, username string) (func(), accountserver.Backend, *accountserver.User) {
+func startServer(t testing.TB) (func(), accountserver.Backend) {
 	stop := ldaptest.StartServer(t, &ldaptest.Config{
 		Dir:  "../ldaptest",
 		Port: testLDAPPort,
@@ -42,6 +42,12 @@ func startServerAndGetUserWithName(t testing.TB, username string) (func(), accou
 		t.Fatal("NewLDAPBackend", err)
 	}
 
+	return stop, b
+}
+
+func startServerAndGetUserWithName(t testing.TB, username string) (func(), accountserver.Backend, *accountserver.User) {
+	stop, b := startServer(t)
+
 	tx, _ := b.NewTransaction()
 	user, err := tx.GetUser(context.Background(), username)
 	if err != nil {
@@ -54,6 +60,20 @@ func startServerAndGetUserWithName(t testing.TB, username string) (func(), accou
 	return stop, b, user
 }
 
+func TestModel_GetUser_NotFound(t *testing.T) {
+	stop, b := startServer(t)
+	defer stop()
+
+	tx, _ := b.NewTransaction()
+	user, err := tx.GetUser(context.Background(), "wrong_user")
+	if err != nil {
+		t.Fatalf("GetUser(wrong_user) should have returned no error, got: %v", err)
+	}
+	if user != nil {
+		t.Fatal("GetUser(wrong_user) returned non-nil user")
+	}
+}
+
 func TestModel_GetUser(t *testing.T) {
 	stop, _, user := startServerAndGetUser(t)
 	defer stop()
@@ -217,7 +237,7 @@ func TestModel_HasAnyResource(t *testing.T) {
 		t.Fatal("HasAnyResource", err)
 	}
 	if !ok {
-		t.Fatal("could not find test resource")
+		t.Fatal("could not find test email resource")
 	}
 
 	// Request that should fail (bad resource type).
diff --git a/backend/resources.go b/backend/resources.go
index 3da6cea0..161246e7 100644
--- a/backend/resources.go
+++ b/backend/resources.go
@@ -16,7 +16,7 @@ type resourceHandler interface {
 	GetDN(accountserver.ResourceID) (string, error)
 	ToLDAP(*accountserver.Resource) []ldap.PartialAttribute
 	FromLDAP(*ldap.Entry) (*accountserver.Resource, error)
-	SearchQuery() *queryConfig
+	SearchQuery() *queryTemplate
 }
 
 // Registry for demultiplexing resource handling. Has a similar
@@ -101,9 +101,9 @@ func (reg *resourceRegistry) FromLDAPWithType(rsrcType string, entry *ldap.Entry
 	return
 }
 
-func (reg *resourceRegistry) SearchQuery(rsrcType string) (c *queryConfig, err error) {
+func (reg *resourceRegistry) SearchQuery(rsrcType string) (q *queryTemplate, err error) {
 	err = reg.dispatch(rsrcType, func(h resourceHandler) error {
-		c = h.SearchQuery()
+		q = h.SearchQuery()
 		return nil
 	})
 	return
@@ -177,12 +177,12 @@ func (h *emailResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap.Parti
 	}
 }
 
-func (h *emailResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *emailResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=People", h.baseDN),
 		Filter: "(&(objectClass=virtualMailUser)(mail=${resource}))",
-		Scope:  "sub",
-	})
+		Scope:  ldap.ScopeWholeSubtree,
+	}
 }
 
 // Mailing list resource.
@@ -222,12 +222,12 @@ func (h *mailingListResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap
 	}
 }
 
-func (h *mailingListResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *mailingListResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=Lists", h.baseDN),
 		Filter: "(&(objectClass=mailingList)(listName=${resource}))",
-		Scope:  "one",
-	})
+		Scope:  ldap.ScopeSingleLevel,
+	}
 }
 
 // Website (subsite) resource.
@@ -289,12 +289,12 @@ func (h *websiteResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap.Par
 	}
 }
 
-func (h *websiteResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *websiteResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=People", h.baseDN),
 		Filter: "(&(objectClass=subSite)(alias=${resource}))",
-		Scope:  "sub",
-	})
+		Scope:  ldap.ScopeWholeSubtree,
+	}
 }
 
 // Domain (virtual host) resource.
@@ -350,12 +350,12 @@ func (h *domainResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap.Part
 	}
 }
 
-func (h *domainResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *domainResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=People", h.baseDN),
 		Filter: "(&(objectClass=virtualHost)(cn=${resource}))",
-		Scope:  "sub",
-	})
+		Scope:  ldap.ScopeWholeSubtree,
+	}
 }
 
 // WebDAV (a.k.a. "ftp account") resource.
@@ -406,12 +406,12 @@ func (h *webdavResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap.Part
 	}
 }
 
-func (h *webdavResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *webdavResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=People", h.baseDN),
 		Filter: "(&(objectClass=ftpAccount)(ftpname=${resource}))",
-		Scope:  "sub",
-	})
+		Scope:  ldap.ScopeWholeSubtree,
+	}
 }
 
 // Databases are special: in LDAP, they encode their relation with a
@@ -514,12 +514,12 @@ func (h *databaseResourceHandler) ToLDAP(rsrc *accountserver.Resource) []ldap.Pa
 	}
 }
 
-func (h *databaseResourceHandler) SearchQuery() *queryConfig {
-	return mustCompileQueryConfig(&queryConfig{
+func (h *databaseResourceHandler) SearchQuery() *queryTemplate {
+	return &queryTemplate{
 		Base:   joinDN("ou=People", h.baseDN),
 		Filter: "(&(objectClass=dbMysql)(dbname=${resource}))",
-		Scope:  "sub",
-	})
+		Scope:  ldap.ScopeWholeSubtree,
+	}
 }
 
 func joinDN(parts ...string) string {
diff --git a/backend/util.go b/backend/util.go
index b7037cc4..b5098117 100644
--- a/backend/util.go
+++ b/backend/util.go
@@ -1,67 +1,47 @@
 package backend
 
 import (
-	"errors"
 	"os"
 
-	ldaputil "git.autistici.org/ai3/go-common/ldap"
 	"gopkg.in/ldap.v2"
 )
 
-// queryConfig holds the parameters for a single LDAP query.
-type queryConfig struct {
-	Base        string
-	Filter      string
-	Scope       string
-	parsedScope int
+// queryTemplate is the template for a single parametrized LDAP query.
+type queryTemplate struct {
+	Base   string
+	Filter string
+	Scope  int
+	Attrs  []string
 }
 
-func (q *queryConfig) validate() error {
-	if q.Base == "" {
-		return errors.New("empty search base")
+func (q *queryTemplate) query(vars map[string]string) *ldap.SearchRequest {
+	filter := q.Filter
+	if filter == "" {
+		filter = "(objectClass=*)"
 	}
-	// An empty filter is equivalent to objectClass=*.
-	if q.Filter == "" {
-		q.Filter = "(objectClass=*)"
-	}
-	q.parsedScope = ldap.ScopeWholeSubtree
-	if q.Scope != "" {
-		s, err := ldaputil.ParseScope(q.Scope)
-		if err != nil {
-			return err
-		}
-		q.parsedScope = s
-	}
-	return nil
-}
 
-func (q *queryConfig) searchRequest(vars map[string]string, attrs []string) *ldap.SearchRequest {
 	return ldap.NewSearchRequest(
 		replaceVars(q.Base, vars),
-		q.parsedScope,
+		q.Scope,
 		ldap.NeverDerefAliases,
 		0,
 		0,
 		false,
-		replaceVars(q.Filter, vars),
-		attrs,
+		replaceVars(filter, vars),
+		q.Attrs,
 		nil,
 	)
 }
 
-func mustCompileQueryConfig(q *queryConfig) *queryConfig {
-	if err := q.validate(); err != nil {
-		panic(err)
-	}
-	return q
-}
-
 func replaceVars(s string, vars map[string]string) string {
 	return os.Expand(s, func(k string) string {
 		return ldap.EscapeFilter(vars[k])
 	})
 }
 
+// LDAP string to boolean value. There is no schema defined for this,
+// we just match a bunch of plausible text truth values ("yes", "on",
+// "true", etc).
 func s2b(s string) bool {
 	switch s {
 	case "yes", "y", "on", "enabled", "true":
@@ -71,6 +51,7 @@ func s2b(s string) bool {
 	}
 }
 
+// Bool to LDAP value. Encoded as yes/no.
 func b2s(b bool) string {
 	if b {
 		return "yes"
@@ -87,6 +68,7 @@ func s2l(s string) []string {
 	return []string{s}
 }
 
+// Returns true if a LDAP object has the specified objectClass.
 func isObjectClass(entry *ldap.Entry, class string) bool {
 	classes := entry.GetAttributeValues("objectClass")
 	for _, c := range classes {
-- 
GitLab