Commit c856639d authored by ale's avatar ale
Browse files

Update ldap.v2 and ai3/go-common deps

parent 6fb20c24
......@@ -5,7 +5,6 @@ import (
"errors"
"io/ioutil"
"strings"
"time"
ldaputil "git.autistici.org/ai3/go-common/ldap"
"gopkg.in/ldap.v2"
......@@ -229,18 +228,7 @@ func (b *ldapBackend) GetUser(ctx context.Context, spec *BackendSpec, name strin
return nil, false
}
conn, err := b.pool.Get(ctx)
if err != nil {
return nil, false
}
// Try to turn the context deadline into a LDAP connection timeout...
if deadline, ok := ctx.Deadline(); ok {
conn.SetTimeout(time.Until(deadline))
}
result, err := conn.Search(serviceConfig.searchRequest(name))
b.pool.Release(conn, err)
result, err := b.pool.Search(ctx, serviceConfig.searchRequest(name))
if err != nil {
return nil, false
}
......
......@@ -2,7 +2,6 @@ package clientutil
import (
"errors"
"net"
"net/http"
"time"
......@@ -18,14 +17,37 @@ func NewExponentialBackOff() *backoff.ExponentialBackOff {
return b
}
// Retry operation op until it succeeds according to the backoff policy b.
// A temporary (retriable) error is something that has a Temporary method.
type tempError interface {
Temporary() bool
}
type tempErrorWrapper struct {
error
}
func (t tempErrorWrapper) Temporary() bool { return true }
// TempError makes a temporary (retriable) error out of a normal error.
func TempError(err error) error {
return tempErrorWrapper{err}
}
// Retry operation op until it succeeds according to the backoff
// policy b.
//
// Note that this function reverses the error semantics of
// backoff.Operation: all errors are permanent unless explicitly
// marked as temporary (i.e. they have a Temporary() method that
// returns true). This is to better align with the errors returned by
// the net package.
func Retry(op backoff.Operation, b backoff.BackOff) error {
innerOp := func() error {
err := op()
if err == nil {
return err
}
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
if tmpErr, ok := err.(tempError); ok && tmpErr.Temporary() {
return err
}
return backoff.Permanent(err)
......@@ -33,7 +55,7 @@ func Retry(op backoff.Operation, b backoff.BackOff) error {
return backoff.Retry(innerOp, b)
}
var errHTTPBackOff = errors.New("temporary http error")
var errHTTPBackOff = TempError(errors.New("temporary http error"))
func isStatusTemporary(code int) bool {
switch code {
......@@ -46,7 +68,8 @@ func isStatusTemporary(code int) bool {
// RetryHTTPDo retries an HTTP request until it succeeds, according to
// the backoff policy b. It will retry on temporary network errors and
// upon receiving specific temporary HTTP errors.
// upon receiving specific temporary HTTP errors. It will use the
// context associated with the HTTP request object.
func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) {
var resp *http.Response
op := func() error {
......@@ -64,6 +87,6 @@ func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*ht
return err
}
err := Retry(op, b)
err := Retry(op, backoff.WithContext(b, req.Context()))
return resp, err
}
......@@ -3,7 +3,9 @@ package ldaputil
import (
"context"
"errors"
"net"
"net/url"
"time"
"gopkg.in/ldap.v2"
)
......@@ -19,13 +21,27 @@ type ConnectionPool struct {
c chan *ldap.Conn
}
func (p *ConnectionPool) connect() (*ldap.Conn, error) {
conn, err := ldap.Dial(p.network, p.addr)
var defaultConnectTimeout = 5 * time.Second
func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
// Dial the connection with a timeout, if the context has a
// deadline (as it should). If the context does not have a
// deadline, we set a default timeout.
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultConnectTimeout)
}
c, err := net.DialTimeout(p.network, p.addr, time.Until(deadline))
if err != nil {
return nil, err
}
if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
conn := ldap.NewConn(c, false)
conn.Start()
conn.SetTimeout(time.Until(deadline))
if _, err = conn.SimpleBind(ldap.NewSimpleBindRequest(p.bindDN, p.bindPw, nil)); err != nil {
conn.Close()
return nil, err
}
......@@ -35,24 +51,31 @@ func (p *ConnectionPool) connect() (*ldap.Conn, error) {
// Get a fresh connection from the pool.
func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
// Grab a connection from the cache, or create a new one if
// there are no available connections.
select {
case conn := <-p.c:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
return p.connect(ctx)
}
}
// Release a used connection onto the pool.
func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
// We assume that if we get an ErrorNetwork, then we need to reconnect.
for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
if conn != nil {
conn.Close()
}
conn, err = p.connect()
// Connections that failed should not be reused.
if err != nil {
conn.Close()
return
}
// Return the connection to the cache, or close it if it's
// full.
select {
case p.c <- conn:
default:
conn.Close()
}
p.c <- conn
}
// Close all connections. Not implemented yet.
......@@ -85,29 +108,18 @@ func parseLDAPURI(uri string) (string, string, error) {
// NewConnectionPool creates a new pool of LDAP connections to the
// specified server, using the provided bind credentials. The pool
// will contain numConns connections.
func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) {
// will cache at most cacheSize connections.
func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPool, error) {
network, addr, err := parseLDAPURI(uri)
if err != nil {
return nil, err
}
p := &ConnectionPool{
c: make(chan *ldap.Conn, numConns),
return &ConnectionPool{
c: make(chan *ldap.Conn, cacheSize),
network: network,
addr: addr,
bindDN: bindDN,
bindPw: bindPw,
}
for i := 0; i < numConns; i++ {
conn, err := p.connect()
if err != nil {
p.Close()
return nil, err
}
p.c <- conn
}
return p, nil
}, nil
}
package ldaputil
import (
"context"
"time"
"github.com/cenkalti/backoff"
"gopkg.in/ldap.v2"
"git.autistici.org/ai3/go-common/clientutil"
)
// Treat all errors as potential network-level issues, except for a
// whitelist of LDAP protocol level errors that we know are benign.
func isTemporaryLDAPError(err error) bool {
ldapErr, ok := err.(*ldap.Error)
if !ok {
return true
}
switch ldapErr.ResultCode {
case ldap.ErrorNetwork:
return true
default:
return false
}
}
// Search performs the given search request. It will retry the request
// on temporary errors.
func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
var result *ldap.SearchResult
err := clientutil.Retry(func() error {
conn, err := p.Get(ctx)
if err != nil {
if isTemporaryLDAPError(err) {
return clientutil.TempError(err)
}
return err
}
if deadline, ok := ctx.Deadline(); ok {
conn.SetTimeout(time.Until(deadline))
}
result, err = conn.Search(searchRequest)
if err != nil && isTemporaryLDAPError(err) {
p.Release(conn, nil)
return clientutil.TempError(err)
}
p.Release(conn, err)
return err
}, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx))
return result, err
}
.PHONY: default install build test quicktest fmt vet lint
GO_VERSION := $(shell go version | cut -d' ' -f3 | cut -d. -f2)
# Only use the `-race` flag on newer versions of Go
IS_OLD_GO := $(shell test $(GO_VERSION) -le 2 && echo true)
ifeq ($(IS_OLD_GO),true)
RACE_FLAG :=
else
RACE_FLAG := -race -cpu 1,2,4
endif
default: fmt vet lint build quicktest
install:
......@@ -9,7 +19,7 @@ build:
go build -v ./...
test:
go test -v -cover ./...
go test -v $(RACE_FLAG) -cover ./...
quicktest:
go test ./...
......
// +build go1.4
package ldap
import (
"sync/atomic"
)
// For compilers that support it, we just use the underlying sync/atomic.Value
// type.
type atomicValue struct {
atomic.Value
}
// +build !go1.4
package ldap
import (
"sync"
)
// This is a helper type that emulates the use of the "sync/atomic.Value"
// struct that's available in Go 1.4 and up.
type atomicValue struct {
value interface{}
lock sync.RWMutex
}
func (av *atomicValue) Store(val interface{}) {
av.lock.Lock()
av.value = val
av.lock.Unlock()
}
func (av *atomicValue) Load() interface{} {
av.lock.RLock()
ret := av.value
av.lock.RUnlock()
return ret
}
......@@ -11,6 +11,7 @@ import (
"log"
"net"
"sync"
"sync/atomic"
"time"
"gopkg.in/asn1-ber.v1"
......@@ -82,20 +83,18 @@ const (
type Conn struct {
conn net.Conn
isTLS bool
isClosing bool
closeErr error
closing uint32
closeErr atomicValue
isStartingTLS bool
Debug debugging
chanConfirm chan bool
chanConfirm chan struct{}
messageContexts map[int64]*messageContext
chanMessage chan *messagePacket
chanMessageID chan int64
wgSender sync.WaitGroup
wgClose sync.WaitGroup
once sync.Once
outstandingRequests uint
messageMutex sync.Mutex
requestTimeout time.Duration
requestTimeout int64
}
var _ Client = &Conn{}
......@@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
conn: conn,
chanConfirm: make(chan bool),
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
messageContexts: map[int64]*messageContext{},
......@@ -158,12 +157,22 @@ func (l *Conn) Start() {
l.wgClose.Add(1)
}
// isClosing returns whether or not we're currently closing.
func (l *Conn) isClosing() bool {
return atomic.LoadUint32(&l.closing) == 1
}
// setClosing sets the closing value to true
func (l *Conn) setClosing() bool {
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
}
// Close closes the connection.
func (l *Conn) Close() {
l.once.Do(func() {
l.isClosing = true
l.wgSender.Wait()
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
......@@ -171,27 +180,25 @@ func (l *Conn) Close() {
l.Debug.Printf("Closing network connection")
if err := l.conn.Close(); err != nil {
log.Print(err)
log.Println(err)
}
l.wgClose.Done()
})
}
l.wgClose.Wait()
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
l.requestTimeout = timeout
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
}
// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if l.chanMessageID != nil {
if messageID, ok := <-l.chanMessageID; ok {
return messageID
}
if messageID, ok := <-l.chanMessageID; ok {
return messageID
}
return 0
}
......@@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
}
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
if l.isClosing {
if l.isClosing() {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
l.messageMutex.Lock()
......@@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
func (l *Conn) finishMessage(msgCtx *messageContext) {
close(msgCtx.done)
if l.isClosing {
if l.isClosing() {
return
}
......@@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
}
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
if l.isClosing {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.isClosing() {
return false
}
l.wgSender.Add(1)
l.chanMessage <- message
l.wgSender.Done()
return true
}
......@@ -333,15 +340,14 @@ func (l *Conn) processMessages() {
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.isClosing && l.closeErr != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr})
if l.isClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
delete(l.messageContexts, messageID)
}
close(l.chanMessageID)
l.chanConfirm <- true
close(l.chanConfirm)
}()
......@@ -350,11 +356,7 @@ func (l *Conn) processMessages() {
select {
case l.chanMessageID <- messageID:
messageID++
case message, ok := <-l.chanMessage:
if !ok {
l.Debug.Printf("Shutting down - message channel is closed")
return
}
case message := <-l.chanMessage:
switch message.Op {
case MessageQuit:
l.Debug.Printf("Shutting down - quit message received")
......@@ -377,14 +379,15 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context
// Add timeout if defined
if l.requestTimeout > 0 {
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
if requestTimeout > 0 {
go func() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
}
}()
time.Sleep(l.requestTimeout)
time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
......@@ -397,7 +400,7 @@ func (l *Conn) processMessages() {
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
ber.PrintPacket(message.Packet)
}
case MessageTimeout:
......@@ -439,8 +442,8 @@ func (l *Conn) reader() {
packet, err := ber.ReadPacket(l.conn)
if err != nil {
// A read error is expected here if we are closing the connection...
if !l.isClosing {
l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
if !l.isClosing() {
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
l.Debug.Printf("reader error: %s", err.Error())
}
return
......
......@@ -6,7 +6,7 @@ import (
"gopkg.in/asn1-ber.v1"
)
// debbuging type
// debugging type
// - has a Printf method to write the debug output
type debugging bool
......
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// File contains DN parsing functionallity
// File contains DN parsing functionality
//
// https://tools.ietf.org/html/rfc4514
//
......@@ -52,7 +52,7 @@ import (
"fmt"
"strings"
ber "gopkg.in/asn1-ber.v1"
"gopkg.in/asn1-ber.v1"
)
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
......@@ -143,6 +143,9 @@ func ParseDN(str string) (*DN, error) {
}
} else if char == ',' || char == '+' {
// We're done with this RDN or value, push it
if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
......
......@@ -97,6 +97,13 @@ var LDAPResultCodeMap = map[uint8]string{
LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited",
LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs",
LDAPResultOther: "Other",
ErrorNetwork: "Network Error",
ErrorFilterCompile: "Filter Compile Error",
ErrorFilterDecompile: "Filter Decompile Error",
ErrorDebugging: "Debugging Error",
ErrorUnexpectedMessage: "Unexpected Message",
ErrorUnexpectedResponse: "Unexpected Response",
}
func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
......
......@@ -82,7 +82,10 @@ func CompileFilter(filter string) (*ber.Packet, error) {
if err != nil {
return nil, err
}
if pos != len(filter) {
switch {
case pos > len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
case pos < len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
}
return packet, nil
......
......@@ -9,7 +9,7 @@ import (
"io/ioutil"
"os"
ber "gopkg.in/asn1-ber.v1"
"gopkg.in/asn1-ber.v1"
)