From 37c649a8b693ba65a59eab5de3c01bf212f791ad Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 23 Aug 2020 16:53:40 +0100
Subject: [PATCH] Allow setting DNS overrides using the --resolve option

---
 client.go          | 43 ++++++++++++++++++++++++++++++++++++++++---
 cmd/crawl/crawl.go | 46 ++++++++++++++++++++++++++++++++--------------
 2 files changed, 72 insertions(+), 17 deletions(-)

diff --git a/client.go b/client.go
index 45736f5..c028e42 100644
--- a/client.go
+++ b/client.go
@@ -1,7 +1,9 @@
 package crawl
 
 import (
+	"context"
 	"crypto/tls"
+	"net"
 	"net/http"
 	"net/http/cookiejar"
 	"time"
@@ -9,14 +11,19 @@ import (
 
 var defaultClientTimeout = 60 * time.Second
 
-// DefaultClient returns a http.Client suitable for crawling: does not
-// follow redirects, accepts invalid TLS certificates, sets a
+// DefaultClient points at a shared http.Client suitable for crawling:
+// does not follow redirects, accepts invalid TLS certificates, sets a
 // reasonable timeout for requests.
 var DefaultClient *http.Client
 
 func init() {
+	DefaultClient = NewHTTPClient()
+}
+
+// NewHTTPClient returns an http.Client suitable for crawling.
+func NewHTTPClient() *http.Client {
 	jar, _ := cookiejar.New(nil) // nolint
-	DefaultClient = &http.Client{
+	return &http.Client{
 		Timeout: defaultClientTimeout,
 		Transport: &http.Transport{
 			TLSClientConfig: &tls.Config{
@@ -29,3 +36,33 @@ func init() {
 		Jar: jar,
 	}
 }
+
+// NewHTTPClientWithDNSOverride returns an http.Client suitable for
+// crawling, with some additional DNS overrides.
+func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client {
+	jar, _ := cookiejar.New(nil) // nolint
+	dialer := new(net.Dialer)
+	transport := &http.Transport{
+		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+			host, port, err := net.SplitHostPort(addr)
+			if err != nil {
+				return nil, err
+			}
+			if override, ok := dnsMap[host]; ok {
+				addr = net.JoinHostPort(override, port)
+			}
+			return dialer.DialContext(ctx, network, addr)
+		},
+		TLSClientConfig: &tls.Config{
+			InsecureSkipVerify: true, // nolint
+		},
+	}
+	return &http.Client{
+		Timeout:   defaultClientTimeout,
+		Transport: transport,
+		CheckRedirect: func(req *http.Request, via []*http.Request) error {
+			return http.ErrUseLastResponse
+		},
+		Jar: jar,
+	}
+}
diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go
index 93506ac..b54b999 100644
--- a/cmd/crawl/crawl.go
+++ b/cmd/crawl/crawl.go
@@ -5,6 +5,7 @@ package main
 import (
 	"bufio"
 	"bytes"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -38,12 +39,27 @@ var (
 	warcFileSizeMB = flag.Int("output-max-size", 100, "maximum output WARC file size (in MB) when using patterns")
 	cpuprofile     = flag.String("cpuprofile", "", "create cpu profile")
 
+	dnsMap   = dnsMapFlag(make(map[string]string))
 	excludes []*regexp.Regexp
+
+	httpClient *http.Client
 )
 
 func init() {
 	flag.Var(&excludesFlag{}, "exclude", "exclude regex URL patterns")
 	flag.Var(&excludesFileFlag{}, "exclude-from-file", "load exclude regex URL patterns from a file")
+	flag.Var(dnsMap, "resolve", "set DNS overrides (in hostname=addr format)")
+
+	stats = &crawlStats{
+		states: make(map[int]int),
+		start:  time.Now(),
+	}
+
+	go func() {
+		for range time.Tick(10 * time.Second) {
+			stats.Dump()
+		}
+	}()
 }
 
 type excludesFlag struct{}
@@ -82,6 +98,19 @@ func (f *excludesFileFlag) Set(s string) error {
 	return nil
 }
 
+type dnsMapFlag map[string]string
+
+func (f dnsMapFlag) String() string { return "" }
+
+func (f dnsMapFlag) Set(s string) error {
+	parts := strings.Split(s, "=")
+	if len(parts) != 2 {
+		return errors.New("value not in host=addr format")
+	}
+	f[parts[0]] = parts[1]
+	return nil
+}
+
 func extractLinks(p crawl.Publisher, u string, depth int, resp *http.Response, _ error) error {
 	links, err := analysis.GetLinks(resp)
 	if err != nil {
@@ -217,26 +246,13 @@ func (c *crawlStats) Dump() {
 var stats *crawlStats
 
 func fetch(urlstr string) (*http.Response, error) {
-	resp, err := crawl.DefaultClient.Get(urlstr)
+	resp, err := httpClient.Get(urlstr)
 	if err == nil {
 		stats.Update(resp)
 	}
 	return resp, err
 }
 
-func init() {
-	stats = &crawlStats{
-		states: make(map[int]int),
-		start:  time.Now(),
-	}
-
-	go func() {
-		for range time.Tick(10 * time.Second) {
-			stats.Dump()
-		}
-	}()
-}
-
 type byteCounter struct {
 	io.ReadCloser
 }
@@ -298,6 +314,8 @@ func main() {
 		log.Fatal(err)
 	}
 
+	httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap)
+
 	crawler, err := crawl.NewCrawler(
 		*dbPath,
 		seeds,
-- 
GitLab