package node

import (
	"context"
	"errors"
	"fmt"
	"log"
	"math/rand"
	"net"
	"net/http"
	"strconv"
	"strings"
	"sync"
	"time"

	"git.autistici.org/ale/autoradio"
	"git.autistici.org/ale/autoradio/coordination/presence"
	"git.autistici.org/ale/autoradio/node/lbv2"
	pb "git.autistici.org/ale/autoradio/proto"
	"git.autistici.org/ale/autoradio/util"
	"go.etcd.io/etcd/clientv3"
)

type loadBalancer struct {
	lb *lbv2.LoadBalancer

	nodeMx sync.RWMutex
	nodes  []*nodeInfo
}

func newLoadBalancer(ctx context.Context, cli *clientv3.Client, nodeID string, statusMgr *statusManager, lbSpec string) (*loadBalancer, error) {
	lb, err := parseLoadBalancerSpec(nodeID, lbSpec)
	if err != nil {
		return nil, err
	}

	// Watch the authoritative list of peer nodes, do not return
	// until it has been bootstrapped.
	publicPeers, peersReady := presence.WatchEndpointsReady(ctx, cli, autoradio.PublicEndpointPrefix)
	<-peersReady

	l := &loadBalancer{
		lb: lb,
	}

	go util.RunCron(ctx, 500*time.Millisecond, func(_ context.Context) {
		nodes := buildNodeList(publicPeers, statusMgr)
		l.nodeMx.Lock()
		l.nodes = nodes
		l.lb.Update(nodeList(nodes))
		l.nodeMx.Unlock()
	})

	return l, nil
}

type ipPort struct {
	ip   net.IP
	port int
}

func (i ipPort) String() string {
	return net.JoinHostPort(i.ip.String(), strconv.Itoa(i.port))
}

func parseEndpointAddrs(ep *pb.Endpoint) []ipPort {
	out := make([]ipPort, 0, len(ep.Addrs))
	for _, addr := range ep.Addrs {
		host, portStr, err := net.SplitHostPort(addr)
		if err != nil {
			continue
		}
		port, err := strconv.Atoi(portStr)
		if err != nil {
			continue
		}
		ip := net.ParseIP(host)
		if ip == nil {
			continue
		}
		out = append(out, ipPort{ip, port})
	}
	return out
}

// The main purpose of the Frontend is to aggregate presence and
// runtime status information from the Icecast nodes, and to provide
// the results as a list of available node candidates to pick a target
// from.
//
// We keep a list of nodeInfo objects, combining the node status with
// the public Icecast endpoint information (so we don't have to look
// it up and parse it again). This list is periodically updated,
// rather than computed on-demand, in order to unify the update
// mechanisms of presence and status (one synchronous, the other
// asynchronous). By doing so however we introduce a propagation delay
// in topology changes equal to the update interval.
//
type nodeInfo struct {
	ep          *pb.Endpoint
	status      *pb.Status
	parsedAddrs []ipPort
}

// Utilization returns the node utilization across the specified
// dimension. Implements the lbv2.Node interface.
func (n *nodeInfo) Utilization(dimension int) lbv2.NodeUtilization {
	if n.status == nil {
		return lbv2.NodeUtilization{Utilization: 1}
	}
	nl := int(n.status.NumListeners)
	var u float64
	switch dimension {
	case utilBandwidth:
		u = float64(n.status.CurBandwidth) / float64(n.status.MaxBandwidth)
	case utilListeners:
		u = float64(n.status.NumListeners) / float64(n.status.MaxListeners)
	}
	if u < 0 {
		u = 0
	}
	if u > 1 {
		u = 1
	}
	return lbv2.NodeUtilization{
		Utilization: u,
		Requests:    nl,
	}
}

// Name returns the node name. Implements the lbv2.Node interface.
func (n *nodeInfo) Name() string {
	return n.ep.Name
}

type nodeList []*nodeInfo

func (l nodeList) Len() int            { return len(l) }
func (l nodeList) Get(i int) lbv2.Node { return l[i] }

// Periodically assemble all the runtime data we have available from
// various sources, mangling it into a form that's usable by our
// public-facing load balancing algorithms.
//
// This code expects a 1:1 mapping between nodes and frontends,
// matched by their name: the load balancing algorithm deals with the
// utilization of Icecast nodes, but we send traffic to frontend
// nodes, so we map the Icecast node to the frontend node with the
// same name.
func buildNodeList(peers *presence.EndpointSet, status *statusManager) []*nodeInfo {
	// The authoritative source for the full list of targets is
	// the presence protocol. Then we merge that with status
	// information, where available.
	tmp := make(map[string]*nodeInfo)
	for _, ep := range peers.Endpoints() {
		tmp[ep.Name] = &nodeInfo{
			ep:          ep,
			parsedAddrs: parseEndpointAddrs(ep),
		}
	}
	for _, s := range status.getStatus() {
		if lbn, ok := tmp[s.Name]; ok {
			lbn.status = s
		}
	}

	out := make([]*nodeInfo, 0, len(tmp))
	for _, lbn := range tmp {
		out = append(out, lbn)
	}
	return out
}

func (l *loadBalancer) getNodes() []*nodeInfo {
	l.nodeMx.RLock()
	defer l.nodeMx.RUnlock()
	return l.nodes
}

func (l *loadBalancer) chooseNode(ctx lbv2.RequestContext) *nodeInfo {
	result, err := l.lb.Choose(ctx)
	if err != nil {
		log.Printf("lbv2 failure: %v", err)
		return nil
	}
	return result.(*nodeInfo)
}

// Known dimensions for utilization.
const (
	utilBandwidth = iota
	utilListeners
)

type autoradioNodeFilterFunc func(lbv2.RequestContext, *nodeInfo) bool

func (f autoradioNodeFilterFunc) Score(ctx lbv2.RequestContext, n lbv2.Node) float64 {
	if f(ctx, n.(*nodeInfo)) {
		return 1
	}
	return 0
}

func icecastActiveFilter(ctx lbv2.RequestContext, n *nodeInfo) bool {
	return n.status != nil && n.status.IcecastOk
}

// NodeFilter that disables backends where Icecast is not running.
// Note that "disabled" means that the associated score is set to
// zero: if all nodes are disabled, the load balancing algorithm will
// still pick one (in fact we explicitly do not use any
// lbv2.ActiveNodesFilter).
func newIcecastActiveFilter() lbv2.NodeFilter {
	return lbv2.NodeScorerFilter(autoradioNodeFilterFunc(icecastActiveFilter))
}

func getIPProtos(ips []ipPort) (bool, bool) {
	hasV4 := false
	hasV6 := false
	for _, ipp := range ips {
		if ipp.ip.To4() == nil {
			hasV6 = true
		} else {
			hasV4 = true
		}
	}
	return hasV4, hasV6
}

func ipProtocolFilter(ctx lbv2.RequestContext, n *nodeInfo) bool {
	if ctx == nil {
		return true
	}
	addr := ctx.RemoteAddr()
	if addr == nil {
		return true
	}
	remoteV6 := addr.To4() == nil
	hasV4, hasV6 := getIPProtos(n.parsedAddrs)
	return (remoteV6 && hasV6) || (!remoteV6 && hasV4)
}

// NodeFilter that selects those backends having at least an IP
// address matching the request protocol (IPv4/IPv6).
func newIPProtocolFilter() lbv2.NodeFilter {
	return lbv2.NodeScorerFilter(autoradioNodeFilterFunc(ipProtocolFilter))
}

// Parse a string that specifies how to build a LoadBalancer. The
// string should consist of a list of comma-separated tokens, each
// identifying a specific filter or policy.
//
// Some filters will always be included in the resulting LoadBalancer
// and do not need to be specified explicitly (icecastActiveFilter and
// ipProtocolFilter).
func parseLoadBalancerSpec(nodeID, specstr string) (*lbv2.LoadBalancer, error) {
	lb := lbv2.New(nodeID)
	lb.AddFilter(newIcecastActiveFilter())
	lb.AddFilter(newIPProtocolFilter())

	var policy lbv2.Policy
	for _, spec := range strings.Split(specstr, ",") {
		switch spec {
		case "bandwidth_available":
			lb.AddFilter(lbv2.NewCapacityAvailableFilter(lb.GetPredictor(utilBandwidth)))
		case "listeners_available":
			lb.AddFilter(lbv2.NewCapacityAvailableFilter(lb.GetPredictor(utilListeners)))
		case "bandwidth_score":
			lb.AddFilter(lbv2.NewCapacityAvailableScorer(lb.GetPredictor(utilBandwidth)))
		case "listeners_score":
			lb.AddFilter(lbv2.NewCapacityAvailableScorer(lb.GetPredictor(utilListeners)))
		case "random":
			policy = lbv2.RandomPolicy
		case "weighted":
			policy = lbv2.WeightedPolicy
		case "best":
			policy = lbv2.HighestScorePolicy
		default:
			return nil, fmt.Errorf("unknown lb filter spec \"%s\"", spec)
		}
	}

	if policy == nil {
		return nil, errors.New("no lb policy specified")
	}
	lb.SetPolicy(policy)

	return lb, nil
}

// Wrap a http.Request into a lbv2.RequestContext.
type httpRequestContext struct {
	req *http.Request
}

func (r *httpRequestContext) RemoteAddr() net.IP {
	host, _, _ := net.SplitHostPort(r.req.RemoteAddr)
	return net.ParseIP(host)
}

// Filter a list of IP addresses by protocol.
func filterIPByProto(ips []ipPort, v6 bool) []ipPort {
	var candidates []ipPort
	for _, ipp := range ips {
		isIPv6 := (ipp.ip.To4() == nil)
		if (isIPv6 && v6) || (!isIPv6 && !v6) {
			candidates = append(candidates, ipp)
		}
	}
	return candidates
}

// Pick a random IP for the specified proto.
func randomIPByProto(ips []ipPort, v6 bool) ipPort {
	candidates := filterIPByProto(ips, v6)
	if len(candidates) > 0 {
		return candidates[rand.Intn(len(candidates))]
	}
	return ipPort{}
}

// Select a random IP address from ips, with an IP protocol that
// matches remoteAddr.
func randomIPWithMatchingProtocol(ips []ipPort, remoteAddr string) string {
	var isV6 bool
	if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
		addr := net.ParseIP(host)
		isV6 = (addr != nil && addr.To4() == nil)
	}
	ipp := randomIPByProto(ips, isV6)
	if ipp.ip == nil {
		return ""
	}
	return ipp.String()
}