Skip to content
Snippets Groups Projects
Commit 37c649a8 authored by ale's avatar ale
Browse files

Allow setting DNS overrides using the --resolve option

parent cf35cce6
No related branches found
No related tags found
No related merge requests found
package crawl
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/cookiejar"
"time"
......@@ -9,14 +11,19 @@ import (
var defaultClientTimeout = 60 * time.Second
// DefaultClient returns a http.Client suitable for crawling: does not
// follow redirects, accepts invalid TLS certificates, sets a
// DefaultClient points at a shared http.Client suitable for crawling:
// does not follow redirects, accepts invalid TLS certificates, sets a
// reasonable timeout for requests.
var DefaultClient *http.Client
func init() {
DefaultClient = NewHTTPClient()
}
// NewHTTPClient returns an http.Client suitable for crawling.
func NewHTTPClient() *http.Client {
jar, _ := cookiejar.New(nil) // nolint
DefaultClient = &http.Client{
return &http.Client{
Timeout: defaultClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
......@@ -29,3 +36,33 @@ func init() {
Jar: jar,
}
}
// NewHTTPClientWithDNSOverride returns an http.Client suitable for
// crawling, with some additional DNS overrides.
func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client {
jar, _ := cookiejar.New(nil) // nolint
dialer := new(net.Dialer)
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if override, ok := dnsMap[host]; ok {
addr = net.JoinHostPort(override, port)
}
return dialer.DialContext(ctx, network, addr)
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // nolint
},
}
return &http.Client{
Timeout: defaultClientTimeout,
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Jar: jar,
}
}
......@@ -5,6 +5,7 @@ package main
import (
"bufio"
"bytes"
"errors"
"flag"
"fmt"
"io"
......@@ -38,12 +39,27 @@ var (
warcFileSizeMB = flag.Int("output-max-size", 100, "maximum output WARC file size (in MB) when using patterns")
cpuprofile = flag.String("cpuprofile", "", "create cpu profile")
dnsMap = dnsMapFlag(make(map[string]string))
excludes []*regexp.Regexp
httpClient *http.Client
)
func init() {
flag.Var(&excludesFlag{}, "exclude", "exclude regex URL patterns")
flag.Var(&excludesFileFlag{}, "exclude-from-file", "load exclude regex URL patterns from a file")
flag.Var(dnsMap, "resolve", "set DNS overrides (in hostname=addr format)")
stats = &crawlStats{
states: make(map[int]int),
start: time.Now(),
}
go func() {
for range time.Tick(10 * time.Second) {
stats.Dump()
}
}()
}
type excludesFlag struct{}
......@@ -82,6 +98,19 @@ func (f *excludesFileFlag) Set(s string) error {
return nil
}
type dnsMapFlag map[string]string
func (f dnsMapFlag) String() string { return "" }
func (f dnsMapFlag) Set(s string) error {
parts := strings.Split(s, "=")
if len(parts) != 2 {
return errors.New("value not in host=addr format")
}
f[parts[0]] = parts[1]
return nil
}
func extractLinks(p crawl.Publisher, u string, depth int, resp *http.Response, _ error) error {
links, err := analysis.GetLinks(resp)
if err != nil {
......@@ -217,26 +246,13 @@ func (c *crawlStats) Dump() {
var stats *crawlStats
func fetch(urlstr string) (*http.Response, error) {
resp, err := crawl.DefaultClient.Get(urlstr)
resp, err := httpClient.Get(urlstr)
if err == nil {
stats.Update(resp)
}
return resp, err
}
func init() {
stats = &crawlStats{
states: make(map[int]int),
start: time.Now(),
}
go func() {
for range time.Tick(10 * time.Second) {
stats.Dump()
}
}()
}
type byteCounter struct {
io.ReadCloser
}
......@@ -298,6 +314,8 @@ func main() {
log.Fatal(err)
}
httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap)
crawler, err := crawl.NewCrawler(
*dbPath,
seeds,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment