diff --git a/ldap/pool.go b/ldap/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..908cc9973fa4adcedbdd69e2a2aea921465d777c --- /dev/null +++ b/ldap/pool.go @@ -0,0 +1,113 @@ +package ldaputil + +import ( + "context" + "errors" + "net/url" + + "gopkg.in/ldap.v2" +) + +// ConnectionPool provides a goroutine-safe pool of long-lived LDAP +// connections that will reconnect on errors. +type ConnectionPool struct { + network string + addr string + bindDN string + bindPw string + + c chan *ldap.Conn +} + +func (p *ConnectionPool) connect() (*ldap.Conn, error) { + conn, err := ldap.Dial(p.network, p.addr) + if err != nil { + return nil, err + } + + if err = conn.Bind(p.bindDN, p.bindPw); err != nil { + conn.Close() + return nil, err + } + + return conn, err +} + +// Get a fresh connection from the pool. +func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) { + select { + case conn := <-p.c: + return conn, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Release a used connection onto the pool. +func (p *ConnectionPool) Release(conn *ldap.Conn, err error) { + // We assume that if we get an ErrorNetwork, then we need to reconnect. + for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) { + if conn != nil { + conn.Close() + } + conn, err = p.connect() + } + p.c <- conn +} + +// Close all connections. Not implemented yet. +func (p *ConnectionPool) Close() {} + +// Parse a LDAP URI into network and address strings suitable for +// ldap.Dial. +func parseLDAPURI(uri string) (string, string, error) { + u, err := url.Parse(uri) + if err != nil { + return "", "", err + } + + network := "tcp" + addr := "localhost:389" + switch u.Scheme { + case "ldap": + if u.Host != "" { + addr = u.Host + } + case "ldapi": + network = "unix" + addr = u.Path + default: + return "", "", errors.New("unsupported scheme") + } + + return network, addr, nil +} + +// NewConnectionPool creates a new pool of LDAP connections to the +// specified server, using the provided bind credentials. The pool +// will contain numConns connections. +func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) { + network, addr, err := parseLDAPURI(uri) + if err != nil { + return nil, err + } + + p := &ConnectionPool{ + c: make(chan *ldap.Conn, numConns), + network: network, + addr: addr, + bindDN: bindDN, + bindPw: bindPw, + } + + for i := 0; i < numConns; i++ { + conn, err := p.connect() + if err != nil { + p.Close() + return nil, err + } + p.c <- conn + } + + return p, nil +}