Skip to content
Snippets Groups Projects
Commit 30d27a52 authored by ale's avatar ale
Browse files

Update ldap pkg to use the opentelemetry API

parent 01b14cb7
Branches
No related tags found
No related merge requests found
......@@ -7,10 +7,13 @@ import (
"net/url"
"time"
"git.autistici.org/ai3/go-common/tracing"
"github.com/cenkalti/backoff/v4"
"github.com/go-ldap/ldap/v3"
"github.com/prometheus/client_golang/prometheus"
"go.opencensus.io/trace"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// Parameters that define the exponential backoff algorithm used.
......@@ -153,18 +156,20 @@ func NewConnectionPool(uri, bindDN, bindPw string, cacheSize int) (*ConnectionPo
}, nil
}
func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []trace.Attribute, fn func(*ldap.Conn) error) error {
func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []attribute.KeyValue, fn func(*ldap.Conn) error) error {
// Tracing: initialize a new client span.
sctx, span := trace.StartSpan(ctx, name,
trace.WithSpanKind(trace.SpanKindClient))
defer span.End()
var span trace.Span
if tracing.Enabled {
ctx, span = tracing.Tracer.Start(ctx, name)
defer span.End()
if len(attrs) > 0 {
span.AddAttributes(attrs...)
if len(attrs) > 0 {
span.SetAttributes(attrs...)
}
}
rerr := backoff.Retry(func() error {
conn, err := p.Get(sctx)
conn, err := p.Get(ctx)
if err != nil {
// Here conn is nil, so we don't need to Release it.
if isTemporaryLDAPError(err) {
......@@ -173,7 +178,7 @@ func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []tra
return backoff.Permanent(err)
}
if deadline, ok := sctx.Deadline(); ok {
if deadline, ok := ctx.Deadline(); ok {
conn.SetTimeout(time.Until(deadline))
}
......@@ -186,7 +191,9 @@ func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []tra
}, backoff.WithContext(newExponentialBackOff(), ctx))
// Tracing: set the final status.
span.SetStatus(errorToTraceStatus(rerr))
if span != nil {
span.SetStatus(errorToTraceStatus(rerr))
}
requestsCounter.WithLabelValues(name).Inc()
if rerr != nil {
requestErrors.WithLabelValues(name).Inc()
......@@ -199,10 +206,10 @@ func (p *ConnectionPool) doRequest(ctx context.Context, name string, attrs []tra
// on temporary errors.
func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
var result *ldap.SearchResult
err := p.doRequest(ctx, "ldap.Search", []trace.Attribute{
trace.StringAttribute("ldap.base", searchRequest.BaseDN),
trace.StringAttribute("ldap.filter", searchRequest.Filter),
trace.Int64Attribute("ldap.scope", int64(searchRequest.Scope)),
err := p.doRequest(ctx, "ldap.Search", []attribute.KeyValue{
attribute.String("ldap.base", searchRequest.BaseDN),
attribute.String("ldap.filter", searchRequest.Filter),
attribute.Int64("ldap.scope", int64(searchRequest.Scope)),
}, func(conn *ldap.Conn) (cerr error) {
result, cerr = conn.Search(searchRequest)
return
......@@ -212,8 +219,8 @@ func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchR
// Modify issues a ModifyRequest to the LDAP server.
func (p *ConnectionPool) Modify(ctx context.Context, modifyRequest *ldap.ModifyRequest) error {
return p.doRequest(ctx, "ldap.Modify", []trace.Attribute{
trace.StringAttribute("ldap.dn", modifyRequest.DN),
return p.doRequest(ctx, "ldap.Modify", []attribute.KeyValue{
attribute.String("ldap.dn", modifyRequest.DN),
}, func(conn *ldap.Conn) error {
return conn.Modify(modifyRequest)
})
......@@ -221,8 +228,8 @@ func (p *ConnectionPool) Modify(ctx context.Context, modifyRequest *ldap.ModifyR
// Add issues an AddRequest to the LDAP server.
func (p *ConnectionPool) Add(ctx context.Context, addRequest *ldap.AddRequest) error {
return p.doRequest(ctx, "ldap.Add", []trace.Attribute{
trace.StringAttribute("ldap.dn", addRequest.DN),
return p.doRequest(ctx, "ldap.Add", []attribute.KeyValue{
attribute.String("ldap.dn", addRequest.DN),
}, func(conn *ldap.Conn) error {
return conn.Add(addRequest)
})
......@@ -262,16 +269,16 @@ func isProtocolError(err error) bool {
return false
}
func errorToTraceStatus(err error) trace.Status {
func errorToTraceStatus(err error) (codes.Code, string) {
switch err {
case nil:
return trace.Status{Code: trace.StatusCodeOK, Message: "OK"}
return codes.Ok, "OK"
case context.Canceled:
return trace.Status{Code: trace.StatusCodeCancelled, Message: "CANCELED"}
return codes.Error, "CANCELED"
case context.DeadlineExceeded:
return trace.Status{Code: trace.StatusCodeDeadlineExceeded, Message: "DEADLINE_EXCEEDED"}
return codes.Error, "DEADLINE_EXCEEDED"
default:
return trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()}
return codes.Error, err.Error()
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment