// Variant of the base http/httputil ReverseProxy suitable for
// low-latency, long-term connections.

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// HTTP reverse proxy handler

package node

import (
	"io"
	"log"
	"net"
	"net/http"
	"net/url"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/prometheus/client_golang/prometheus"
)

func copyHeader(dst, src http.Header) {
	for k, vv := range src {
		for _, v := range vv {
			dst.Add(k, v)
		}
	}
}

// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
	"Connection",
	"Keep-Alive",
	"Proxy-Authenticate",
	"Proxy-Authorization",
	"Te", // canonicalized version of "TE"
	"Trailers",
	"Transfer-Encoding",
	"Upgrade",
}

type wrappedWriter interface {
	WrappedWriter() http.ResponseWriter
}

// Proxy a request to the desired backend. Due to the way the Icecast
// protocol works, this just dumps the initial (rewritten) HTTP/1.0
// request, and then switches to a full bi-directional TCP proxy.  The
// outbound request is built from the target host, path, and eventual
// query string parameters and headers passed on from the original
// request. The additional streamName parameter is used for
// instrumentation.
func doIcecastProxy(rw http.ResponseWriter, req *http.Request, target *url.URL, streamName string) {
	outreq := new(http.Request)
	*outreq = *req // includes shallow copies of maps, but okay

	// Make a HTTP/1.0 connection to the backend.
	outreq.URL.Scheme = target.Scheme
	outreq.URL.Host = target.Host
	//outreq.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
	outreq.URL.Path = target.Path
	outreq.Proto = "HTTP/1.0"
	outreq.ProtoMajor = 1
	outreq.ProtoMinor = 0
	outreq.Close = true

	// Remove hop-by-hop headers to the backend. This is modifying
	// the same underlying map from req (shallow copied above) so
	// we only copy it if necessary.
	copiedHeaders := false
	for _, h := range hopHeaders {
		if outreq.Header.Get(h) != "" {
			if !copiedHeaders {
				outreq.Header = make(http.Header)
				copyHeader(outreq.Header, req.Header)
				copiedHeaders = true
			}
			outreq.Header.Del(h)
		}
	}

	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
		// If we aren't the first proxy retain prior
		// X-Forwarded-For information as a comma+space
		// separated list and fold multiple headers into one.
		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
			clientIP = strings.Join(prior, ", ") + ", " + clientIP
		}
		outreq.Header.Set("X-Forwarded-For", clientIP)
	}

	// Create the upstream connection and write the HTTP request
	// to it.
	upstream, err := net.Dial("tcp", outreq.URL.Host)
	if err != nil {
		log.Printf("http: proxy dial error: %v", err)
		rw.WriteHeader(http.StatusInternalServerError)
		proxyConnectErrs.WithLabelValues(streamName, target.Host).Inc()
		return
	}
	defer upstream.Close()
	if err := outreq.Write(upstream); err != nil {
		log.Printf("http: proxy request write error: %v", err)
		rw.WriteHeader(http.StatusInternalServerError)
		proxyConnectErrs.WithLabelValues(streamName, target.Host).Inc()
		return
	}

	// Hijack the request connection. We might need to unroll the
	// layers of nested WrappedWriters, until we find a
	// ResponseWriter that also implements the http.Hijacker
	// interface.
	var conn net.Conn
	for {
		if h, ok := rw.(http.Hijacker); ok {
			var err error
			conn, _, err = h.Hijack()
			if err != nil {
				log.Printf("http: proxy hijack error: %v", err)
				rw.WriteHeader(http.StatusInternalServerError)
				return
			}
			break
		} else if w, ok := rw.(wrappedWriter); ok {
			rw = w.WrappedWriter()
		} else {
			break
		}
	}
	if conn == nil {
		log.Printf("http: proxy error: could not find hijackable connection")
		rw.WriteHeader(http.StatusInternalServerError)
		return
	}
	defer conn.Close()
	if err := conn.SetDeadline(time.Time{}); err != nil {
		log.Printf("http: proxy setdeadline error: %v", err)
	}

	// Run two-way proxying.
	handleProxy(conn.(*net.TCPConn), upstream.(*net.TCPConn), streamName)
}

// Copy data between two network connections. On recent Go versions
// (>1.11), this is quite fast as io.CopyBuffer uses the splice()
// system call internally (in exchange we lose the ability to figure
// out which connection is the source of the error).
func copyStream(tag string, out, in *net.TCPConn, promCounter prometheus.Counter, cntr *uint64) {
	buf := getBuf()
	defer releaseBuf(buf)

	// We used to do this in order to support half-closed connections.
	//defer in.CloseRead()   //nolint
	//defer out.CloseWrite() //nolint

	// Instead we do this and shut down the entire connection on error.
	// We end up calling Close() twice but that's not a huge problem.
	defer in.Close()  //nolint
	defer out.Close() //nolint

	for {
		n, err := io.CopyBuffer(out, in, buf)
		promCounter.Add(float64(n))
		if cntr != nil {
			atomic.AddUint64(cntr, uint64(n))
		}
		if err != nil {
			if !isCloseError(err) {
				log.Printf("http: proxy error (%s): %v", tag, err)
			}
			return
		}
	}
}

// This is a bad implementation (see https://github.com/golang/go/issues/4373
// for some notes on why it is a layering violation), and we could replace it
// with an atomic 'closing' flag.
func isCloseError(err error) bool {
	return strings.Contains(err.Error(), "use of closed network connection")
}

// Simple two-way TCP proxy that copies data in both directions and
// can shutdown each direction of the connection independently.
func handleProxy(conn *net.TCPConn, upstream *net.TCPConn, streamName string) {
	var wg sync.WaitGroup
	wg.Add(2)
	streamListeners.WithLabelValues(streamName).Inc()

	// Instrument both directions of the stream, but let the
	// bandwidth estimator count only the bytes sent to the user.
	go func() {
		copyStream("upstream -> client", conn, upstream, streamSentBytes.WithLabelValues(streamName), &bwBytesSent)
		wg.Done()
	}()
	go func() {
		copyStream("client -> upstream", upstream, conn, streamRcvdBytes.WithLabelValues(streamName), nil)
		wg.Done()
	}()

	wg.Wait()
	streamListeners.WithLabelValues(streamName).Dec()
}

// Implementation of a simple buffer cache, to minimize large
// allocations at runtime.
const (
	bufSize     = 8192
	bufPoolSize = 512
)

var bufPool chan []byte

func init() {
	bufPool = make(chan []byte, bufPoolSize)
	for i := 0; i < bufPoolSize; i++ {
		bufPool <- make([]byte, bufSize)
	}
}

func getBuf() (b []byte) {
	select {
	case b = <-bufPool:
	default:
		b = make([]byte, bufSize)
	}
	return
}

func releaseBuf(b []byte) {
	select {
	case bufPool <- b:
	default:
	}
}

// Simple bandwidth meter that keeps track of the current
// (approximate) rate of bytes sent through the proxy.
var (
	bwBytesSent     uint64
	bwLastBytesSent uint64
	bwLastTS        time.Time

	bwMx      sync.Mutex
	bwCurrent float64
)

func init() {
	tick := time.NewTicker(10 * time.Second)
	go func() {
		for t := range tick.C {
			bytesSent := atomic.LoadUint64(&bwBytesSent)
			bw := float64(bytesSent-bwLastBytesSent) / t.Sub(bwLastTS).Seconds()
			bwLastBytesSent = bytesSent
			bwLastTS = t

			bwMx.Lock()
			bwCurrent = bw
			bwMx.Unlock()
		}
	}()
}

// Returns current usage (through the proxy) in bytes per second.
func getCurrentBandwidthUsage() float64 {
	bwMx.Lock()
	defer bwMx.Unlock()
	return bwCurrent
}