package node

import (
	"context"
	"math/rand"
	"net"
	"strings"
	"time"

	"github.com/miekg/dns"
)

const (
	zoneTTL   = 21600
	nsTTL     = 3600
	recordTTL = 5

	maxRecords = 5
)

func newDNSHandler(n *Node, origin string, nameservers []string) dns.Handler {
	if !strings.HasSuffix(origin, ".") {
		origin += "."
	}

	dnsz := newDNSZone(n.lb, origin, nameservers)
	mux := dns.NewServeMux()
	mux.Handle(origin, dnsz)
	return mux
}

// Serve DNS records for our zone.
type dnsZone struct {
	lb             *loadBalancer
	soa            *dns.SOA
	nameservers    []string
	origin         string
	originNumParts int
}

func newDNSZone(lb *loadBalancer, origin string, nameservers []string) *dnsZone {
	// Create a SOA record for the zone. Some entries will be bogus.
	soa := &dns.SOA{
		Hdr: dns.RR_Header{
			Name:   origin,
			Rrtype: dns.TypeSOA,
			Class:  dns.ClassINET,
			Ttl:    zoneTTL,
		},
		Ns:      "ns1." + origin,
		Mbox:    "hostmaster." + origin,
		Serial:  uint32(time.Now().Unix()),
		Refresh: 43200,
		Retry:   3600,
		Expire:  uint32(zoneTTL),
		Minttl:  uint32(zoneTTL),
	}

	return &dnsZone{
		lb:             lb,
		soa:            soa,
		nameservers:    nameservers,
		origin:         origin,
		originNumParts: len(strings.Split(origin, ".")),
	}
}

func (d *dnsZone) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
	m := new(dns.Msg)
	ednsFromRequest(req, m)

	// Only consider the first question.
	q := req.Question[0]
	name := d.getQuestionName(q)

	switch {
	case name == "" && q.Qtype == dns.TypeSOA:
		m.Answer = append(m.Answer, d.soa)

	case name == "" && q.Qtype == dns.TypeNS:
		for _, ns := range d.nameservers {
			m.Answer = append(m.Answer, &dns.NS{
				Hdr: dns.RR_Header{
					Name:   d.withOrigin(name),
					Rrtype: dns.TypeNS,
					Class:  dns.ClassINET,
					Ttl:    nsTTL,
				},
				Ns: ns,
			})
		}

	case q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA:
		if name != "" && name != "stream" {
			goto nxDomain
		}
		for _, ip := range d.getNodeIPs(q) {
			var rec dns.RR
			if q.Qtype == dns.TypeAAAA {
				rec = d.newAAAA(name, ip)
			} else {
				rec = d.newA(name, ip)
			}
			m.Answer = append(m.Answer, rec)
		}

	case q.Qtype == dns.TypeANY:
		goto servFail

	default:
		goto nxDomain
	}

	m.SetReply(req)
	m.MsgHdr.Authoritative = true
	w.WriteMsg(m) //nolint
	return

servFail:
	m.SetRcode(req, dns.RcodeServerFailure)
	w.WriteMsg(m) //nolint
	return

nxDomain:
	m.SetRcode(req, dns.RcodeNameError)
	w.WriteMsg(m) //nolint
}

func (d *dnsZone) getNodeIPs(q dns.Question) []net.IP {
	// Pick all known endpoint IPs, filtering those that match the
	// protocol in the DNS request.
	var ips []net.IP
	for _, ns := range d.lb.getNodes() {
		for _, ipp := range filterIPByProto(ns.parsedAddrs, (q.Qtype == dns.TypeAAAA)) {
			ips = append(ips, ipp.ip)
		}
	}

	// Shuffle the IP list in-place.
	rand.Shuffle(len(ips), func(i, j int) {
		ips[i], ips[j] = ips[j], ips[i]
	})

	// Trim it to a maximum of maxRecords.
	if len(ips) > maxRecords {
		ips = ips[:maxRecords]
	}
	return ips
}

// Strip the origin from the query.
func (d *dnsZone) getQuestionName(q dns.Question) string {
	lx := dns.SplitDomainName(q.Name)
	ql := lx[0 : len(lx)-d.originNumParts]
	return strings.ToLower(strings.Join(ql, "."))
}

// Add the origin to a query.
func (d *dnsZone) withOrigin(name string) string {
	if name == "" {
		return d.origin
	}
	return name + "." + d.origin
}

// Create an A resource record.
func (d *dnsZone) newA(name string, ip net.IP) dns.RR {
	return &dns.A{
		Hdr: dns.RR_Header{
			Name:   d.withOrigin(name),
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
			Ttl:    recordTTL,
		},
		A: ip,
	}
}

// Create an AAAA resource record.
func (d *dnsZone) newAAAA(name string, ip net.IP) dns.RR {
	return &dns.AAAA{
		Hdr: dns.RR_Header{
			Name:   d.withOrigin(name),
			Rrtype: dns.TypeAAAA,
			Class:  dns.ClassINET,
			Ttl:    recordTTL,
		},
		AAAA: ip,
	}
}

// Create skeleton edns opt RR from the query and add it to the
// message m.
func ednsFromRequest(req, m *dns.Msg) {
	for _, r := range req.Extra {
		if r.Header().Rrtype == dns.TypeOPT {
			m.SetEdns0(4096, r.(*dns.OPT).Do())
			return
		}
	}
}

// Wrapper to make the dns.Server match the genericServer interface.
type dnsServer struct {
	*dns.Server
	name string
}

func newDNSServer(name, addr, proto string, h dns.Handler) *dnsServer {
	return &dnsServer{
		Server: &dns.Server{
			Addr:    addr,
			Net:     proto,
			Handler: h,
		},
		name: name,
	}
}

func (s *dnsServer) Name() string { return s.name }

func (s *dnsServer) Serve() error {
	return s.Server.ListenAndServe()
}

func (s *dnsServer) Stop() {
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	s.Server.ShutdownContext(ctx) //nolint
	cancel()
}