diff --git a/ldap/pool.go b/ldap/pool.go index c77d06177cb1b13ed8c7b71bc3facf07a015adc4..560d639b7091334e83d290b5c0ac1ff56b1261ee 100644 --- a/ldap/pool.go +++ b/ldap/pool.go @@ -7,6 +7,8 @@ import ( "net/url" "time" + "git.autistici.org/ai3/go-common/clientutil" + "github.com/cenkalti/backoff" "gopkg.in/ldap.v2" ) @@ -125,3 +127,77 @@ func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPo bindPw: bindPw, }, nil } + +func (p *ConnectionPool) doRequest(ctx context.Context, fn func(*ldap.Conn) error) error { + return clientutil.Retry(func() error { + conn, err := p.Get(ctx) + if err != nil { + // Here conn is nil, so we don't need to Release it. + if isTemporaryLDAPError(err) { + return clientutil.TempError(err) + } + return err + } + + if deadline, ok := ctx.Deadline(); ok { + conn.SetTimeout(time.Until(deadline)) + } + + err = fn(conn) + if err != nil && isTemporaryLDAPError(err) { + p.Release(conn, err) + return clientutil.TempError(err) + } + p.Release(conn, err) + return err + }, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx)) +} + +// Search performs the given search request. It will retry the request +// on temporary errors. +func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + var result *ldap.SearchResult + err := p.doRequest(ctx, func(conn *ldap.Conn) error { + var err error + result, err = conn.Search(searchRequest) + return err + }) + return result, err +} + +// Modify issues a ModifyRequest to the LDAP server. +func (p *ConnectionPool) Modify(ctx context.Context, modifyRequest *ldap.ModifyRequest) error { + return p.doRequest(ctx, func(conn *ldap.Conn) error { + return conn.Modify(modifyRequest) + }) +} + +// Add issues an AddRequest to the LDAP server. +func (p *ConnectionPool) Add(ctx context.Context, addRequest *ldap.AddRequest) error { + return p.doRequest(ctx, func(conn *ldap.Conn) error { + return conn.Add(addRequest) + }) +} + +// Interface matched by net.Error. +type hasTemporary interface { + Temporary() bool +} + +// Treat network errors as temporary. Other errors are permanent by +// default. +func isTemporaryLDAPError(err error) bool { + switch v := err.(type) { + case *ldap.Error: + switch v.ResultCode { + case ldap.ErrorNetwork: + return true + default: + return false + } + case hasTemporary: + return v.Temporary() + default: + return false + } +} diff --git a/ldap/search.go b/ldap/search.go deleted file mode 100644 index db29ba092ea0cb1e1be3db3a753c75c7e1188f3f..0000000000000000000000000000000000000000 --- a/ldap/search.go +++ /dev/null @@ -1,63 +0,0 @@ -package ldaputil - -import ( - "context" - "time" - - "github.com/cenkalti/backoff" - "gopkg.in/ldap.v2" - - "git.autistici.org/ai3/go-common/clientutil" -) - -// Interface matched by net.Error. -type hasTemporary interface { - Temporary() bool -} - -// Treat network errors as temporary. Other errors are permanent by -// default. -func isTemporaryLDAPError(err error) bool { - switch v := err.(type) { - case *ldap.Error: - switch v.ResultCode { - case ldap.ErrorNetwork: - return true - default: - return false - } - case hasTemporary: - return v.Temporary() - default: - return false - } -} - -// Search performs the given search request. It will retry the request -// on temporary errors. -func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { - var result *ldap.SearchResult - err := clientutil.Retry(func() error { - conn, err := p.Get(ctx) - if err != nil { - // Here conn is nil, so we don't need to Release it. - if isTemporaryLDAPError(err) { - return clientutil.TempError(err) - } - return err - } - - if deadline, ok := ctx.Deadline(); ok { - conn.SetTimeout(time.Until(deadline)) - } - - result, err = conn.Search(searchRequest) - if err != nil && isTemporaryLDAPError(err) { - p.Release(conn, err) - return clientutil.TempError(err) - } - p.Release(conn, err) - return err - }, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx)) - return result, err -}