From 86a36cf5da88919ee7d9ec12d1a92043b16fcc9c Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 10 Dec 2017 11:04:55 +0000
Subject: [PATCH] Add a robust Search method to the LDAP ConnectionPool

Also simplifies handling of the connection pool itself, switching to
using deferred connects, and adds a test suite for misbehaving LDAP
servers to exercise failure modes.
---
 ldap/pool.go      |  68 ++++++++++++++----------
 ldap/pool_test.go | 131 ++++++++++++++++++++++++++++++++++++++++++++++
 ldap/search.go    |  54 +++++++++++++++++++
 3 files changed, 225 insertions(+), 28 deletions(-)
 create mode 100644 ldap/pool_test.go
 create mode 100644 ldap/search.go

diff --git a/ldap/pool.go b/ldap/pool.go
index 908cc99..6d8093e 100644
--- a/ldap/pool.go
+++ b/ldap/pool.go
@@ -3,7 +3,9 @@ package ldaputil
 import (
 	"context"
 	"errors"
+	"net"
 	"net/url"
+	"time"
 
 	"gopkg.in/ldap.v2"
 )
@@ -19,13 +21,27 @@ type ConnectionPool struct {
 	c chan *ldap.Conn
 }
 
-func (p *ConnectionPool) connect() (*ldap.Conn, error) {
-	conn, err := ldap.Dial(p.network, p.addr)
+var defaultConnectTimeout = 5 * time.Second
+
+func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
+	// Dial the connection with a timeout, if the context has a
+	// deadline (as it should). If the context does not have a
+	// deadline, we set a default timeout.
+	deadline, ok := ctx.Deadline()
+	if !ok {
+		deadline = time.Now().Add(defaultConnectTimeout)
+	}
+
+	c, err := net.DialTimeout(p.network, p.addr, time.Until(deadline))
 	if err != nil {
 		return nil, err
 	}
 
-	if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
+	conn := ldap.NewConn(c, false)
+	conn.Start()
+
+	conn.SetTimeout(time.Until(deadline))
+	if _, err = conn.SimpleBind(ldap.NewSimpleBindRequest(p.bindDN, p.bindPw, nil)); err != nil {
 		conn.Close()
 		return nil, err
 	}
@@ -35,24 +51,31 @@ func (p *ConnectionPool) connect() (*ldap.Conn, error) {
 
 // Get a fresh connection from the pool.
 func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
+	// Grab a connection from the cache, or create a new one if
+	// there are no available connections.
 	select {
 	case conn := <-p.c:
 		return conn, nil
-	case <-ctx.Done():
-		return nil, ctx.Err()
+	default:
+		return p.connect(ctx)
 	}
 }
 
 // Release a used connection onto the pool.
 func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
-	// We assume that if we get an ErrorNetwork, then we need to reconnect.
-	for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
-		if conn != nil {
-			conn.Close()
-		}
-		conn, err = p.connect()
+	// Connections that failed should not be reused.
+	if err != nil {
+		conn.Close()
+		return
+	}
+
+	// Return the connection to the cache, or close it if it's
+	// full.
+	select {
+	case p.c <- conn:
+	default:
+		conn.Close()
 	}
-	p.c <- conn
 }
 
 // Close all connections. Not implemented yet.
@@ -85,29 +108,18 @@ func parseLDAPURI(uri string) (string, string, error) {
 
 // NewConnectionPool creates a new pool of LDAP connections to the
 // specified server, using the provided bind credentials. The pool
-// will contain numConns connections.
-func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) {
+// will cache at most cacheSize connections.
+func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPool, error) {
 	network, addr, err := parseLDAPURI(uri)
 	if err != nil {
 		return nil, err
 	}
 
-	p := &ConnectionPool{
-		c:       make(chan *ldap.Conn, numConns),
+	return &ConnectionPool{
+		c:       make(chan *ldap.Conn, cacheSize),
 		network: network,
 		addr:    addr,
 		bindDN:  bindDN,
 		bindPw:  bindPw,
-	}
-
-	for i := 0; i < numConns; i++ {
-		conn, err := p.connect()
-		if err != nil {
-			p.Close()
-			return nil, err
-		}
-		p.c <- conn
-	}
-
-	return p, nil
+	}, nil
 }
diff --git a/ldap/pool_test.go b/ldap/pool_test.go
new file mode 100644
index 0000000..a0b2d59
--- /dev/null
+++ b/ldap/pool_test.go
@@ -0,0 +1,131 @@
+package ldaputil
+
+import (
+	"context"
+	"log"
+	"net"
+	"testing"
+	"time"
+
+	"gopkg.in/asn1-ber.v1"
+	"gopkg.in/ldap.v2"
+)
+
+type tcpHandler interface {
+	Handle(net.Conn)
+}
+
+type tcpHandlerFunc func(net.Conn)
+
+func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) }
+
+// Base TCP server type (to build fake LDAP servers).
+type tcpServer struct {
+	l       net.Listener
+	handler tcpHandler
+}
+
+func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer {
+	l, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatal("Listen():", err)
+	}
+	log.Printf("started new tcp server on %s", l.Addr().String())
+	s := &tcpServer{l: l, handler: handler}
+	go s.serve()
+	return s
+}
+
+func (s *tcpServer) serve() {
+	for {
+		conn, err := s.l.Accept()
+		if err != nil {
+			return
+		}
+		go func(c net.Conn) {
+			s.handler.Handle(c)
+			c.Close()
+		}(conn)
+	}
+}
+
+func (s *tcpServer) Addr() string {
+	return s.l.Addr().String()
+}
+
+func (s *tcpServer) Close() {
+	s.l.Close()
+}
+
+// A test server that will close all incoming connections right away.
+func newConnFailServer(t testing.TB) *tcpServer {
+	return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {}))
+}
+
+// A test server that will close all connections after a 1s delay.
+func newConnFailDelayServer(t testing.TB) *tcpServer {
+	return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) }))
+}
+
+// A fake LDAP server that will read a request and return a protocol error.
+func newFakeBindOnlyLDAPServer(t testing.TB) *tcpServer {
+	return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {
+		var b [1024]byte
+		c.Read(b[:])
+
+		resp := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+		resp.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 1, "MessageID"))
+		resp.AppendChild(ber.NewSequence("Description"))
+
+		c.Write(resp.Bytes())
+	}))
+}
+
+func TestConnectionPool_ConnFail(t *testing.T) {
+	runSearchQueries(t, newConnFailServer(t))
+}
+
+func TestConnectionPool_ConnFailDelay(t *testing.T) {
+	runSearchQueries(t, newConnFailDelayServer(t))
+}
+
+func TestConnectionPool_PortClosed(t *testing.T) {
+	srv := newConnFailServer(t)
+	srv.Close()
+	runSearchQueries(t, srv)
+}
+
+func TestConnectionPool_BindOnly(t *testing.T) {
+	runSearchQueries(t, newFakeBindOnlyLDAPServer(t))
+}
+
+func runSearchQueries(t testing.TB, srv *tcpServer) {
+	defer srv.Close()
+	ldapURI := "ldap://" + srv.Addr()
+
+	p, err := NewConnectionPool(ldapURI, "user", "password", 10)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer p.Close()
+
+	for i := 0; i < 5; i++ {
+		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+		_, err := p.Search(ctx, ldap.NewSearchRequest(
+			"o=Anarchy",
+			ldap.ScopeWholeSubtree,
+			ldap.NeverDerefAliases,
+			0,
+			0,
+			false,
+			"(objectClass=*)",
+			[]string{"dn"},
+			nil,
+		))
+		cancel()
+		log.Printf("%d: %v", i, err)
+		if err == nil {
+			t.Error("weird, no error on Search")
+		}
+	}
+}
diff --git a/ldap/search.go b/ldap/search.go
new file mode 100644
index 0000000..872f6fe
--- /dev/null
+++ b/ldap/search.go
@@ -0,0 +1,54 @@
+package ldaputil
+
+import (
+	"context"
+	"time"
+
+	"github.com/cenkalti/backoff"
+	"gopkg.in/ldap.v2"
+
+	"git.autistici.org/ai3/go-common/clientutil"
+)
+
+// Treat all errors as potential network-level issues, except for a
+// whitelist of LDAP protocol level errors that we know are benign.
+func isTemporaryLDAPError(err error) bool {
+	ldapErr, ok := err.(*ldap.Error)
+	if !ok {
+		return true
+	}
+	switch ldapErr.ResultCode {
+	case ldap.ErrorNetwork:
+		return true
+	default:
+		return false
+	}
+}
+
+// Search performs the given search request. It will retry the request
+// on temporary errors.
+func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
+	var result *ldap.SearchResult
+	err := clientutil.Retry(func() error {
+		conn, err := p.Get(ctx)
+		if err != nil {
+			if isTemporaryLDAPError(err) {
+				return clientutil.TempError(err)
+			}
+			return err
+		}
+
+		if deadline, ok := ctx.Deadline(); ok {
+			conn.SetTimeout(time.Until(deadline))
+		}
+
+		result, err = conn.Search(searchRequest)
+		if err != nil && isTemporaryLDAPError(err) {
+			p.Release(conn, nil)
+			return clientutil.TempError(err)
+		}
+		p.Release(conn, err)
+		return err
+	}, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx))
+	return result, err
+}
-- 
GitLab