package client

import (
	"context"
	"net"
	"strings"

	"github.com/cenkalti/backoff/v4"
	"go.opentelemetry.io/otel"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/codes"
	"go.opentelemetry.io/otel/trace"

	"git.autistici.org/id/auth"
	"git.autistici.org/id/auth/lineproto"
)

var (
	DefaultSocketPath = "/run/auth/socket"
)

type Client interface {
	Authenticate(context.Context, *auth.Request) (*auth.Response, error)
}

type socketClient struct {
	socketPath string
	codec      auth.Codec
}

func New(socketPath string) Client {
	return &socketClient{
		socketPath: socketPath,
		codec:      auth.DefaultCodec,
	}
}

func (c *socketClient) Authenticate(ctx context.Context, req *auth.Request) (*auth.Response, error) {
	// Create a tracing span for the authentication request.
	ctx, span := otel.GetTracerProvider().Tracer("auth-client").Start(
		ctx, "auth-server.Authenticate", trace.WithSpanKind(trace.SpanKindClient), trace.WithAttributes(
			attribute.String("auth.service", req.Service),
			attribute.String("auth.username", req.Username),
		))
	defer span.End()

	// Retry the request, with backoff, if we get a temporary
	// network error.
	var resp *auth.Response
	err := backoff.Retry(func() error {
		var err error
		resp, err = c.doAuthenticate(ctx, req)
		if err == nil {
			return nil
		} else if strings.Contains(err.Error(), "use of closed network connection") {
			return err
		} else if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
			return netErr
		}
		return backoff.Permanent(err)
	}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))

	setSpanStatus(span, resp, err)

	return resp, err
}

func (c *socketClient) doAuthenticate(ctx context.Context, req *auth.Request) (*auth.Response, error) {
	nc, err := net.Dial("unix", c.socketPath)
	if err != nil {
		return nil, err
	}
	conn := lineproto.NewConn(nc, "unix")
	defer conn.Close()

	// Make space in the channel for at least one element, or we
	// will leak a goroutine whenever the authentication request
	// times out.
	done := make(chan error, 1)

	var resp auth.Response
	go func() {
		defer close(done)

		// Write the auth command to the connection.
		if err := conn.WriteLine([]byte("auth "), c.codec.Encode(req)); err != nil {
			done <- err
			return
		}

		// Read the response.
		line, err := conn.ReadLine()
		if err != nil {
			done <- err
			return
		}
		if err := c.codec.Decode(line, &resp); err != nil {
			done <- err
			return
		}
		done <- nil
	}()

	// Wait for the call to terminate, or the context to time out,
	// whichever happens first.
	select {
	case err = <-done:
	case <-ctx.Done():
		err = ctx.Err()
	}

	return &resp, err
}

func setSpanStatus(span trace.Span, resp *auth.Response, err error) {
	switch err {
	case nil:
		switch resp.Status {
		case auth.StatusOK:
			span.SetStatus(codes.Ok, "OK")
		case auth.StatusInsufficientCredentials:
			span.SetStatus(codes.Error, "Insufficient Credentials")
		default:
			span.SetStatus(codes.Error, "Authentication Failure")
		}
		return
	case context.Canceled:
		span.SetStatus(codes.Error, "CANCELED")
	case context.DeadlineExceeded:
		span.SetStatus(codes.Error, "DEADLINE_EXCEEDED")
	default:
		span.SetStatus(codes.Error, err.Error())
	}

	span.RecordError(err)
}