Skip to content
Snippets Groups Projects
Commit 86a36cf5 authored by ale's avatar ale
Browse files

Add a robust Search method to the LDAP ConnectionPool

Also simplifies handling of the connection pool itself, switching to
using deferred connects, and adds a test suite for misbehaving LDAP
servers to exercise failure modes.
parent 1fd8c2e2
Branches
No related tags found
No related merge requests found
......@@ -3,7 +3,9 @@ package ldaputil
import (
"context"
"errors"
"net"
"net/url"
"time"
"gopkg.in/ldap.v2"
)
......@@ -19,13 +21,27 @@ type ConnectionPool struct {
c chan *ldap.Conn
}
func (p *ConnectionPool) connect() (*ldap.Conn, error) {
conn, err := ldap.Dial(p.network, p.addr)
var defaultConnectTimeout = 5 * time.Second
func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
// Dial the connection with a timeout, if the context has a
// deadline (as it should). If the context does not have a
// deadline, we set a default timeout.
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultConnectTimeout)
}
c, err := net.DialTimeout(p.network, p.addr, time.Until(deadline))
if err != nil {
return nil, err
}
if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
conn := ldap.NewConn(c, false)
conn.Start()
conn.SetTimeout(time.Until(deadline))
if _, err = conn.SimpleBind(ldap.NewSimpleBindRequest(p.bindDN, p.bindPw, nil)); err != nil {
conn.Close()
return nil, err
}
......@@ -35,24 +51,31 @@ func (p *ConnectionPool) connect() (*ldap.Conn, error) {
// Get a fresh connection from the pool.
func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
// Grab a connection from the cache, or create a new one if
// there are no available connections.
select {
case conn := <-p.c:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
return p.connect(ctx)
}
}
// Release a used connection onto the pool.
func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
// We assume that if we get an ErrorNetwork, then we need to reconnect.
for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
if conn != nil {
conn.Close()
}
conn, err = p.connect()
// Connections that failed should not be reused.
if err != nil {
conn.Close()
return
}
// Return the connection to the cache, or close it if it's
// full.
select {
case p.c <- conn:
default:
conn.Close()
}
p.c <- conn
}
// Close all connections. Not implemented yet.
......@@ -85,29 +108,18 @@ func parseLDAPURI(uri string) (string, string, error) {
// NewConnectionPool creates a new pool of LDAP connections to the
// specified server, using the provided bind credentials. The pool
// will contain numConns connections.
func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) {
// will cache at most cacheSize connections.
func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPool, error) {
network, addr, err := parseLDAPURI(uri)
if err != nil {
return nil, err
}
p := &ConnectionPool{
c: make(chan *ldap.Conn, numConns),
return &ConnectionPool{
c: make(chan *ldap.Conn, cacheSize),
network: network,
addr: addr,
bindDN: bindDN,
bindPw: bindPw,
}
for i := 0; i < numConns; i++ {
conn, err := p.connect()
if err != nil {
p.Close()
return nil, err
}
p.c <- conn
}
return p, nil
}, nil
}
package ldaputil
import (
"context"
"log"
"net"
"testing"
"time"
"gopkg.in/asn1-ber.v1"
"gopkg.in/ldap.v2"
)
type tcpHandler interface {
Handle(net.Conn)
}
type tcpHandlerFunc func(net.Conn)
func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) }
// Base TCP server type (to build fake LDAP servers).
type tcpServer struct {
l net.Listener
handler tcpHandler
}
func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal("Listen():", err)
}
log.Printf("started new tcp server on %s", l.Addr().String())
s := &tcpServer{l: l, handler: handler}
go s.serve()
return s
}
func (s *tcpServer) serve() {
for {
conn, err := s.l.Accept()
if err != nil {
return
}
go func(c net.Conn) {
s.handler.Handle(c)
c.Close()
}(conn)
}
}
func (s *tcpServer) Addr() string {
return s.l.Addr().String()
}
func (s *tcpServer) Close() {
s.l.Close()
}
// A test server that will close all incoming connections right away.
func newConnFailServer(t testing.TB) *tcpServer {
return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {}))
}
// A test server that will close all connections after a 1s delay.
func newConnFailDelayServer(t testing.TB) *tcpServer {
return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) }))
}
// A fake LDAP server that will read a request and return a protocol error.
func newFakeBindOnlyLDAPServer(t testing.TB) *tcpServer {
return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {
var b [1024]byte
c.Read(b[:])
resp := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
resp.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 1, "MessageID"))
resp.AppendChild(ber.NewSequence("Description"))
c.Write(resp.Bytes())
}))
}
func TestConnectionPool_ConnFail(t *testing.T) {
runSearchQueries(t, newConnFailServer(t))
}
func TestConnectionPool_ConnFailDelay(t *testing.T) {
runSearchQueries(t, newConnFailDelayServer(t))
}
func TestConnectionPool_PortClosed(t *testing.T) {
srv := newConnFailServer(t)
srv.Close()
runSearchQueries(t, srv)
}
func TestConnectionPool_BindOnly(t *testing.T) {
runSearchQueries(t, newFakeBindOnlyLDAPServer(t))
}
func runSearchQueries(t testing.TB, srv *tcpServer) {
defer srv.Close()
ldapURI := "ldap://" + srv.Addr()
p, err := NewConnectionPool(ldapURI, "user", "password", 10)
if err != nil {
t.Fatal(err)
}
defer p.Close()
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_, err := p.Search(ctx, ldap.NewSearchRequest(
"o=Anarchy",
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
0,
0,
false,
"(objectClass=*)",
[]string{"dn"},
nil,
))
cancel()
log.Printf("%d: %v", i, err)
if err == nil {
t.Error("weird, no error on Search")
}
}
}
package ldaputil
import (
"context"
"time"
"github.com/cenkalti/backoff"
"gopkg.in/ldap.v2"
"git.autistici.org/ai3/go-common/clientutil"
)
// Treat all errors as potential network-level issues, except for a
// whitelist of LDAP protocol level errors that we know are benign.
func isTemporaryLDAPError(err error) bool {
ldapErr, ok := err.(*ldap.Error)
if !ok {
return true
}
switch ldapErr.ResultCode {
case ldap.ErrorNetwork:
return true
default:
return false
}
}
// Search performs the given search request. It will retry the request
// on temporary errors.
func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
var result *ldap.SearchResult
err := clientutil.Retry(func() error {
conn, err := p.Get(ctx)
if err != nil {
if isTemporaryLDAPError(err) {
return clientutil.TempError(err)
}
return err
}
if deadline, ok := ctx.Deadline(); ok {
conn.SetTimeout(time.Until(deadline))
}
result, err = conn.Search(searchRequest)
if err != nil && isTemporaryLDAPError(err) {
p.Release(conn, nil)
return clientutil.TempError(err)
}
p.Release(conn, err)
return err
}, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx))
return result, err
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment