Skip to content
Snippets Groups Projects
Commit fd18f2d8 authored by ale's avatar ale
Browse files

Do not close connections on protocol errors; add instrumentation

parent 29b9513a
Branches
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ import (
"time"
"github.com/cenkalti/backoff"
"github.com/prometheus/client_golang/prometheus"
"go.opencensus.io/trace"
"gopkg.in/ldap.v3"
)
......@@ -46,6 +47,8 @@ type ConnectionPool struct {
var defaultConnectTimeout = 5 * time.Second
func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
connectionsCounter.Inc()
// Dial the connection with a timeout, if the context has a
// deadline (as it should). If the context does not have a
// deadline, we set a default timeout.
......@@ -56,6 +59,7 @@ func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
c, err := net.DialTimeout(p.network, p.addr, time.Until(deadline))
if err != nil {
connectionErrors.Inc()
return nil, err
}
......@@ -65,12 +69,13 @@ func (p *ConnectionPool) connect(ctx context.Context) (*ldap.Conn, error) {
if p.bindDN != "" {
conn.SetTimeout(time.Until(deadline))
if _, err = conn.SimpleBind(ldap.NewSimpleBindRequest(p.bindDN, p.bindPw, nil)); err != nil {
connectionErrors.Inc()
conn.Close()
return nil, err
}
}
return conn, err
return conn, nil
}
// Get a fresh connection from the pool.
......@@ -88,7 +93,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
// Release a used connection onto the pool.
func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
// Connections that failed should not be reused.
if err != nil {
if err != nil && !isProtocolError(err) {
conn.Close()
return
}
......@@ -182,6 +187,10 @@ func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []tra
// Tracing: set the final status.
span.SetStatus(errorToTraceStatus(rerr))
requestsCounter.WithLabelValues(name).Inc()
if err != nil {
requestErrors.WithLabelValues(name).Inc()
}
return rerr
}
......@@ -242,6 +251,17 @@ func isTemporaryLDAPError(err error) bool {
}
}
// Return true if the error is protocol-level, i.e. we have not left
// the LDAP connection in a problematic state. This relies on the
// explicit numeric values of the ResultCode attribute in ldap.Error.
func isProtocolError(err error) bool {
if ldapErr, ok := err.(*ldap.Error); ok {
// All protocol-level errors have values < 200.
return ldapErr.ResultCode < ldap.ErrorNetwork
}
return false
}
func errorToTraceStatus(err error) trace.Status {
switch err {
case nil:
......@@ -254,3 +274,22 @@ func errorToTraceStatus(err error) trace.Status {
return trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()}
}
}
var (
connectionsCounter = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ldap_connections_total",
Help: "Counter of new LDAP connections.",
})
connectionErrors = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ldap_connection_errors_total",
Help: "Counter of LDAP connection errors.",
})
requestsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ldap_requests_total",
Help: "Counter of LDAP requests.",
}, []string{"method"})
requestErrors = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ldap_errors_total",
Help: "Counter of LDAP requests.",
}, []string{"method"})
)
......@@ -71,13 +71,13 @@ func newConnFailDelayServer(t testing.TB) *tcpServer {
func newFakeBindOnlyLDAPServer(t testing.TB) *tcpServer {
return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {
var b [1024]byte
c.Read(b[:])
c.Read(b[:]) // nolint: errcheck
resp := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
resp.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 1, "MessageID"))
resp.AppendChild(ber.NewSequence("Description"))
c.Write(resp.Bytes())
c.Write(resp.Bytes()) // nolint: errcheck
}))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment