diff --git a/ldap/pool.go b/ldap/pool.go index 908cc9973fa4adcedbdd69e2a2aea921465d777c..6d8093e93dccd5d333133633ab8958092a171355 100644 --- a/ldap/pool.go +++ b/ldap/pool.go @@ -3,7 +3,9 @@ package ldaputil import ( "context" "errors" + "net" "net/url" + "time" "gopkg.in/ldap.v2" ) @@ -19,13 +21,27 @@ type ConnectionPool struct { c chan *ldap.Conn } -func (p *ConnectionPool) connect() (*ldap.Conn, error) { - conn, err := ldap.Dial(p.network, p.addr) +var defaultConnectTimeout = 5 * time.Second + +func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) { + // Dial the connection with a timeout, if the context has a + // deadline (as it should). If the context does not have a + // deadline, we set a default timeout. + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(defaultConnectTimeout) + } + + c, err := net.DialTimeout(p.network, p.addr, time.Until(deadline)) if err != nil { return nil, err } - if err = conn.Bind(p.bindDN, p.bindPw); err != nil { + conn := ldap.NewConn(c, false) + conn.Start() + + conn.SetTimeout(time.Until(deadline)) + if _, err = conn.SimpleBind(ldap.NewSimpleBindRequest(p.bindDN, p.bindPw, nil)); err != nil { conn.Close() return nil, err } @@ -35,24 +51,31 @@ func (p *ConnectionPool) connect() (*ldap.Conn, error) { // Get a fresh connection from the pool. func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) { + // Grab a connection from the cache, or create a new one if + // there are no available connections. select { case conn := <-p.c: return conn, nil - case <-ctx.Done(): - return nil, ctx.Err() + default: + return p.connect(ctx) } } // Release a used connection onto the pool. func (p *ConnectionPool) Release(conn *ldap.Conn, err error) { - // We assume that if we get an ErrorNetwork, then we need to reconnect. - for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) { - if conn != nil { - conn.Close() - } - conn, err = p.connect() + // Connections that failed should not be reused. + if err != nil { + conn.Close() + return + } + + // Return the connection to the cache, or close it if it's + // full. + select { + case p.c <- conn: + default: + conn.Close() } - p.c <- conn } // Close all connections. Not implemented yet. @@ -85,29 +108,18 @@ func parseLDAPURI(uri string) (string, string, error) { // NewConnectionPool creates a new pool of LDAP connections to the // specified server, using the provided bind credentials. The pool -// will contain numConns connections. -func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) { +// will cache at most cacheSize connections. +func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPool, error) { network, addr, err := parseLDAPURI(uri) if err != nil { return nil, err } - p := &ConnectionPool{ - c: make(chan *ldap.Conn, numConns), + return &ConnectionPool{ + c: make(chan *ldap.Conn, cacheSize), network: network, addr: addr, bindDN: bindDN, bindPw: bindPw, - } - - for i := 0; i < numConns; i++ { - conn, err := p.connect() - if err != nil { - p.Close() - return nil, err - } - p.c <- conn - } - - return p, nil + }, nil } diff --git a/ldap/pool_test.go b/ldap/pool_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a0b2d59de527e1e0ecce41392297084ce17116e5 --- /dev/null +++ b/ldap/pool_test.go @@ -0,0 +1,131 @@ +package ldaputil + +import ( + "context" + "log" + "net" + "testing" + "time" + + "gopkg.in/asn1-ber.v1" + "gopkg.in/ldap.v2" +) + +type tcpHandler interface { + Handle(net.Conn) +} + +type tcpHandlerFunc func(net.Conn) + +func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) } + +// Base TCP server type (to build fake LDAP servers). +type tcpServer struct { + l net.Listener + handler tcpHandler +} + +func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Listen():", err) + } + log.Printf("started new tcp server on %s", l.Addr().String()) + s := &tcpServer{l: l, handler: handler} + go s.serve() + return s +} + +func (s *tcpServer) serve() { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + go func(c net.Conn) { + s.handler.Handle(c) + c.Close() + }(conn) + } +} + +func (s *tcpServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tcpServer) Close() { + s.l.Close() +} + +// A test server that will close all incoming connections right away. +func newConnFailServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {})) +} + +// A test server that will close all connections after a 1s delay. +func newConnFailDelayServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) })) +} + +// A fake LDAP server that will read a request and return a protocol error. +func newFakeBindOnlyLDAPServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { + var b [1024]byte + c.Read(b[:]) + + resp := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + resp.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 1, "MessageID")) + resp.AppendChild(ber.NewSequence("Description")) + + c.Write(resp.Bytes()) + })) +} + +func TestConnectionPool_ConnFail(t *testing.T) { + runSearchQueries(t, newConnFailServer(t)) +} + +func TestConnectionPool_ConnFailDelay(t *testing.T) { + runSearchQueries(t, newConnFailDelayServer(t)) +} + +func TestConnectionPool_PortClosed(t *testing.T) { + srv := newConnFailServer(t) + srv.Close() + runSearchQueries(t, srv) +} + +func TestConnectionPool_BindOnly(t *testing.T) { + runSearchQueries(t, newFakeBindOnlyLDAPServer(t)) +} + +func runSearchQueries(t testing.TB, srv *tcpServer) { + defer srv.Close() + ldapURI := "ldap://" + srv.Addr() + + p, err := NewConnectionPool(ldapURI, "user", "password", 10) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, err := p.Search(ctx, ldap.NewSearchRequest( + "o=Anarchy", + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, + 0, + false, + "(objectClass=*)", + []string{"dn"}, + nil, + )) + cancel() + log.Printf("%d: %v", i, err) + if err == nil { + t.Error("weird, no error on Search") + } + } +} diff --git a/ldap/search.go b/ldap/search.go new file mode 100644 index 0000000000000000000000000000000000000000..872f6fec3fdb2a8516f7e3da55dddeb972b358f5 --- /dev/null +++ b/ldap/search.go @@ -0,0 +1,54 @@ +package ldaputil + +import ( + "context" + "time" + + "github.com/cenkalti/backoff" + "gopkg.in/ldap.v2" + + "git.autistici.org/ai3/go-common/clientutil" +) + +// Treat all errors as potential network-level issues, except for a +// whitelist of LDAP protocol level errors that we know are benign. +func isTemporaryLDAPError(err error) bool { + ldapErr, ok := err.(*ldap.Error) + if !ok { + return true + } + switch ldapErr.ResultCode { + case ldap.ErrorNetwork: + return true + default: + return false + } +} + +// Search performs the given search request. It will retry the request +// on temporary errors. +func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + var result *ldap.SearchResult + err := clientutil.Retry(func() error { + conn, err := p.Get(ctx) + if err != nil { + if isTemporaryLDAPError(err) { + return clientutil.TempError(err) + } + return err + } + + if deadline, ok := ctx.Deadline(); ok { + conn.SetTimeout(time.Until(deadline)) + } + + result, err = conn.Search(searchRequest) + if err != nil && isTemporaryLDAPError(err) { + p.Release(conn, nil) + return clientutil.TempError(err) + } + p.Release(conn, err) + return err + }, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx)) + return result, err +}