package unix

import (
	"context"
	"io/ioutil"
	"net/textproto"
	"os"
	"path/filepath"
	"testing"
)

type fakeServer struct{}

func (f *fakeServer) ServeLine(_ context.Context, lw LineResponseWriter, _ []byte) error {
	return lw.WriteLineCRLF([]byte("hello"))
}

func doRequests(socketPath string, n int) error {
	conn, err := textproto.Dial("unix", socketPath)
	if err != nil {
		return err
	}
	defer conn.Close()

	for i := 0; i < n; i++ {
		if err := conn.PrintfLine("request"); err != nil {
			return err
		}
		if _, err := conn.ReadLine(); err != nil {
			return err
		}
	}
	return nil
}

func doConcurrentRequests(socketPath string, n, conns int) error {
	start := make(chan bool)
	errCh := make(chan error, conns)
	for i := 0; i < conns; i++ {
		go func() {
			<-start
			errCh <- doRequests(socketPath, n)
		}()
	}

	close(start)

	var err error
	for i := 0; i < conns; i++ {
		if werr := <-errCh; werr != nil && err == nil {
			err = werr
		}
	}
	return err
}

func TestServer(t *testing.T) {
	dir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(dir)

	socketPath := filepath.Join(dir, "test.sock")
	srv, err := NewUNIXSocketServer(socketPath, NewLineServer(&fakeServer{}))
	if err != nil {
		t.Fatalf("NewUNIXSocketServer: %v", err)
	}
	go srv.Serve() // nolint
	defer srv.Close()

	if err := doConcurrentRequests(socketPath, 1000, 10); err != nil {
		t.Fatalf("request error: %v", err)
	}
}