package submission

import (
	"context"
	"errors"
	"sync"
	"testing"
	"time"

	ippb "git.autistici.org/ai3/tools/iprep/proto"
)

type fakeClient struct {
	mx    sync.Mutex
	calls int
	count int64
}

func (f *fakeClient) Close() {}

func (f *fakeClient) Submit(_ context.Context, evs []*ippb.Event, aggr *ippb.Aggregate) error {
	f.mx.Lock()
	defer f.mx.Unlock()

	f.calls++
	for _, ev := range evs {
		f.count += ev.Count
	}
	if aggr != nil {
		for _, bt := range aggr.ByType {
			for _, bi := range bt.ByIp {
				f.count += bi.Count
			}
		}
	}
	return nil
}

func (f *fakeClient) GetScore(_ context.Context, _ string) (*ippb.GetScoreResponse, error) {
	return nil, errors.New("not implemented")
}

func (f *fakeClient) GetAllScores(_ context.Context, _ float32) (<-chan *ippb.GetScoreResponse, error) {
	return nil, errors.New("not implemented")
}

func rateDo(timeout time.Duration, qps float64, f func()) int {
	end := time.After(timeout)
	interval := 1 * time.Second
	interval /= time.Duration(qps)
	tick := time.NewTicker(interval)
	defer tick.Stop()
	n := 0
	for {
		select {
		case <-tick.C:
			f()
			n++
		case <-end:
			return n
		}
	}
}

func runTest(t *testing.T, tag string, opts *Options, qps float64) (int, int) {
	t.Parallel()

	fc := new(fakeClient)
	if opts == nil {
		opts = new(Options)
	}
	opts.MaxDelay = 1 * time.Second
	s := newSubmitter(fc, opts)

	n := rateDo(3*time.Second, qps, func() {
		s.AddEvent(&ippb.Event{
			Ip:    "1.2.3.4",
			Type:  "detection",
			Count: 1,
		})
	})

	s.Close()

	fc.mx.Lock()
	rcvd := int(fc.count)
	calls := fc.calls
	fc.mx.Unlock()

	// Since there may be pending events in the channel buffer
	// when we close the queue, so we trust the counter of
	// received events.
	if rcvd > n {
		t.Fatalf("%s: received too many events: sent=%d, received=%d", tag, n, rcvd)
	}

	return rcvd, calls
}

func TestSubmitter_LowRate(t *testing.T) {
	// Verify that a low-rate of events gets sent through largely
	// unmolested.
	sent, calls := runTest(t, "qps=10", nil, 1)
	if sent != calls {
		t.Fatalf("wrong number of calls: %d (expected %d)", calls, sent)
	}
}

func TestSubmitter_HighRate(t *testing.T) {
	// A high qps rate gets batched into MaxDelay time chunks (set
	// MaxStored very high so it does not factor in).
	sent, calls := runTest(t, "qps=1000", &Options{
		MaxStored: 100000,
	}, 1000)
	expected := 3
	if calls != expected {
		t.Fatalf("sent=%d calls=%d, expected=%d", sent, calls, expected)
	}
}

func TestSubmitter_HighRate_Buffered(t *testing.T) {
	// Same as above but with smaller MaxStored.
	sent, calls := runTest(t, "qps=1000/buffered", &Options{
		MaxStored: 100,
	}, 1000)
	expected := 1 + sent/100
	if calls != expected {
		t.Fatalf("sent=%d calls=%d, expected=%d", sent, calls, expected)
	}
}