package client import ( "context" "net" "strings" "github.com/cenkalti/backoff/v4" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" "git.autistici.org/id/auth" "git.autistici.org/id/auth/lineproto" ) var ( DefaultSocketPath = "/run/auth/socket" ) type Client interface { Authenticate(context.Context, *auth.Request) (*auth.Response, error) } type socketClient struct { socketPath string codec auth.Codec } func New(socketPath string) Client { return &socketClient{ socketPath: socketPath, codec: auth.DefaultCodec, } } func (c *socketClient) Authenticate(ctx context.Context, req *auth.Request) (*auth.Response, error) { // Create a tracing span for the authentication request. ctx, span := otel.GetTracerProvider().Tracer("auth-client").Start( ctx, "auth-server.Authenticate", trace.WithSpanKind(trace.SpanKindClient), trace.WithAttributes( attribute.String("auth.service", req.Service), attribute.String("auth.username", req.Username), )) defer span.End() // Retry the request, with backoff, if we get a temporary // network error. var resp *auth.Response err := backoff.Retry(func() error { var err error resp, err = c.doAuthenticate(ctx, req) if err == nil { return nil } else if strings.Contains(err.Error(), "use of closed network connection") { return err } else if netErr, ok := err.(net.Error); ok && netErr.Temporary() { return netErr } return backoff.Permanent(err) }, backoff.WithContext(backoff.NewExponentialBackOff(), ctx)) setSpanStatus(span, resp, err) return resp, err } func (c *socketClient) doAuthenticate(ctx context.Context, req *auth.Request) (*auth.Response, error) { nc, err := net.Dial("unix", c.socketPath) if err != nil { return nil, err } conn := lineproto.NewConn(nc, "unix") defer conn.Close() // Make space in the channel for at least one element, or we // will leak a goroutine whenever the authentication request // times out. done := make(chan error, 1) var resp auth.Response go func() { defer close(done) // Write the auth command to the connection. if err := conn.WriteLine([]byte("auth "), c.codec.Encode(req)); err != nil { done <- err return } // Read the response. line, err := conn.ReadLine() if err != nil { done <- err return } if err := c.codec.Decode(line, &resp); err != nil { done <- err return } done <- nil }() // Wait for the call to terminate, or the context to time out, // whichever happens first. select { case err = <-done: case <-ctx.Done(): err = ctx.Err() } return &resp, err } func setSpanStatus(span trace.Span, resp *auth.Response, err error) { switch err { case nil: switch resp.Status { case auth.StatusOK: span.SetStatus(codes.Ok, "OK") case auth.StatusInsufficientCredentials: span.SetStatus(codes.Error, "Insufficient Credentials") default: span.SetStatus(codes.Error, "Authentication Failure") } return case context.Canceled: span.SetStatus(codes.Error, "CANCELED") case context.DeadlineExceeded: span.SetStatus(codes.Error, "DEADLINE_EXCEEDED") default: span.SetStatus(codes.Error, err.Error()) } span.RecordError(err) }