Skip to content
Snippets Groups Projects
Commit 41657e75 authored by ale's avatar ale
Browse files

Add common LDAP connection pooling code

parent 52b3a6dd
No related branches found
No related tags found
No related merge requests found
package ldaputil
import (
"context"
"errors"
"net/url"
"gopkg.in/ldap.v2"
)
// ConnectionPool provides a goroutine-safe pool of long-lived LDAP
// connections that will reconnect on errors.
type ConnectionPool struct {
network string
addr string
bindDN string
bindPw string
c chan *ldap.Conn
}
func (p *ConnectionPool) connect() (*ldap.Conn, error) {
conn, err := ldap.Dial(p.network, p.addr)
if err != nil {
return nil, err
}
if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// Get a fresh connection from the pool.
func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
select {
case conn := <-p.c:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Release a used connection onto the pool.
func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
// We assume that if we get an ErrorNetwork, then we need to reconnect.
for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
if conn != nil {
conn.Close()
}
conn, err = p.connect()
}
p.c <- conn
}
// Close all connections. Not implemented yet.
func (p *ConnectionPool) Close() {}
// Parse a LDAP URI into network and address strings suitable for
// ldap.Dial.
func parseLDAPURI(uri string) (string, string, error) {
u, err := url.Parse(uri)
if err != nil {
return "", "", err
}
network := "tcp"
addr := "localhost:389"
switch u.Scheme {
case "ldap":
if u.Host != "" {
addr = u.Host
}
case "ldapi":
network = "unix"
addr = u.Path
default:
return "", "", errors.New("unsupported scheme")
}
return network, addr, nil
}
// NewConnectionPool creates a new pool of LDAP connections to the
// specified server, using the provided bind credentials. The pool
// will contain numConns connections.
func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) {
network, addr, err := parseLDAPURI(uri)
if err != nil {
return nil, err
}
p := &ConnectionPool{
c: make(chan *ldap.Conn, numConns),
network: network,
addr: addr,
bindDN: bindDN,
bindPw: bindPw,
}
for i := 0; i < numConns; i++ {
conn, err := p.connect()
if err != nil {
p.Close()
return nil, err
}
p.c <- conn
}
return p, nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment