Commit ee1a3d8e authored by ale's avatar ale

Improve error checking

Detect write errors (both on the database and to the WARC output) and
abort with an error message.

Also fix a bunch of harmless lint warnings.
parent b3d41948
Pipeline #1178 passed with stage
in 14 seconds
...@@ -39,39 +39,11 @@ type rawOutlink struct { ...@@ -39,39 +39,11 @@ type rawOutlink struct {
// GetLinks returns all the links found in a document. Currently only // GetLinks returns all the links found in a document. Currently only
// parses HTML pages and CSS stylesheets. // parses HTML pages and CSS stylesheets.
func GetLinks(resp *http.Response) ([]crawl.Outlink, error) { func GetLinks(resp *http.Response) ([]crawl.Outlink, error) {
var outlinks []rawOutlink
ctype := resp.Header.Get("Content-Type")
if strings.HasPrefix(ctype, "text/html") {
// Use goquery to extract links from the parsed HTML
// contents (query patterns are described in the
// linkMatches table).
doc, err := goquery.NewDocumentFromResponse(resp)
if err != nil {
return nil, err
}
for _, lm := range linkMatches {
doc.Find(fmt.Sprintf("%s[%s]", lm.tag, lm.attr)).Each(func(i int, s *goquery.Selection) {
val, _ := s.Attr(lm.attr)
outlinks = append(outlinks, rawOutlink{URL: val, Tag: lm.linkTag})
})
}
} else if strings.HasPrefix(ctype, "text/css") {
// Use a simple (and actually quite bad) regular
// expression to extract "url()" links from CSS.
if data, err := ioutil.ReadAll(resp.Body); err == nil {
for _, val := range urlcssRx.FindAllStringSubmatch(string(data), -1) {
outlinks = append(outlinks, rawOutlink{URL: val[1], Tag: crawl.TagRelated})
}
}
}
// Parse outbound links relative to the request URI, and // Parse outbound links relative to the request URI, and
// return unique results. // return unique results.
var result []crawl.Outlink var result []crawl.Outlink
links := make(map[string]crawl.Outlink) links := make(map[string]crawl.Outlink)
for _, l := range outlinks { for _, l := range extractLinks(resp) {
// Skip data: URLs altogether. // Skip data: URLs altogether.
if strings.HasPrefix(l.URL, "data:") { if strings.HasPrefix(l.URL, "data:") {
continue continue
...@@ -88,3 +60,46 @@ func GetLinks(resp *http.Response) ([]crawl.Outlink, error) { ...@@ -88,3 +60,46 @@ func GetLinks(resp *http.Response) ([]crawl.Outlink, error) {
} }
return result, nil return result, nil
} }
func extractLinks(resp *http.Response) []rawOutlink {
ctype := resp.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ctype, "text/html"):
return extractLinksFromHTML(resp)
case strings.HasPrefix(ctype, "text/css"):
return extractLinksFromCSS(resp)
default:
return nil
}
}
func extractLinksFromHTML(resp *http.Response) []rawOutlink {
var outlinks []rawOutlink
// Use goquery to extract links from the parsed HTML
// contents (query patterns are described in the
// linkMatches table).
doc, err := goquery.NewDocumentFromReader(resp.Body)
if err != nil {
return nil
}
for _, lm := range linkMatches {
doc.Find(fmt.Sprintf("%s[%s]", lm.tag, lm.attr)).Each(func(i int, s *goquery.Selection) {
val, _ := s.Attr(lm.attr)
outlinks = append(outlinks, rawOutlink{URL: val, Tag: lm.linkTag})
})
}
return outlinks
}
func extractLinksFromCSS(resp *http.Response) []rawOutlink {
// Use a simple (and actually quite bad) regular
// expression to extract "url()" links from CSS.
var outlinks []rawOutlink
if data, err := ioutil.ReadAll(resp.Body); err == nil {
for _, val := range urlcssRx.FindAllStringSubmatch(string(data), -1) {
outlinks = append(outlinks, rawOutlink{URL: val[1], Tag: crawl.TagRelated})
}
}
return outlinks
}
...@@ -9,18 +9,18 @@ import ( ...@@ -9,18 +9,18 @@ import (
var defaultClientTimeout = 60 * time.Second var defaultClientTimeout = 60 * time.Second
var DefaultClient *http.Client
// DefaultClient returns a http.Client suitable for crawling: does not // DefaultClient returns a http.Client suitable for crawling: does not
// follow redirects, accepts invalid TLS certificates, sets a // follow redirects, accepts invalid TLS certificates, sets a
// reasonable timeout for requests. // reasonable timeout for requests.
var DefaultClient *http.Client
func init() { func init() {
jar, _ := cookiejar.New(nil) jar, _ := cookiejar.New(nil) // nolint
DefaultClient = &http.Client{ DefaultClient = &http.Client{
Timeout: defaultClientTimeout, Timeout: defaultClientTimeout,
Transport: &http.Transport{ Transport: &http.Transport{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true, // nolint
}, },
}, },
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
......
...@@ -37,30 +37,24 @@ var ( ...@@ -37,30 +37,24 @@ var (
cpuprofile = flag.String("cpuprofile", "", "create cpu profile") cpuprofile = flag.String("cpuprofile", "", "create cpu profile")
) )
func extractLinks(c *crawl.Crawler, u string, depth int, resp *http.Response, err error) error { func extractLinks(c *crawl.Crawler, u string, depth int, resp *http.Response, _ error) error {
links, err := analysis.GetLinks(resp) links, err := analysis.GetLinks(resp)
if err != nil { if err != nil {
return err return err
} }
for _, link := range links { for _, link := range links {
c.Enqueue(link, depth+1) if err := c.Enqueue(link, depth+1); err != nil {
return err
}
} }
return nil return nil
} }
type fakeCloser struct {
io.Reader
}
func (f *fakeCloser) Close() error {
return nil
}
func hdr2str(h http.Header) []byte { func hdr2str(h http.Header) []byte {
var b bytes.Buffer var b bytes.Buffer
h.Write(&b) h.Write(&b) // nolint
return b.Bytes() return b.Bytes()
} }
...@@ -69,43 +63,58 @@ type warcSaveHandler struct { ...@@ -69,43 +63,58 @@ type warcSaveHandler struct {
warcInfoID string warcInfoID string
} }
func (h *warcSaveHandler) writeWARCRecord(typ, uri string, data []byte) error {
hdr := warc.NewHeader()
hdr.Set("WARC-Type", typ)
hdr.Set("WARC-Target-URI", uri)
hdr.Set("WARC-Warcinfo-ID", h.warcInfoID)
hdr.Set("Content-Length", strconv.Itoa(len(data)))
w, err := h.warc.NewRecord(hdr)
if err != nil {
return err
}
if _, err := w.Write(data); err != nil {
return err
}
return w.Close()
}
func (h *warcSaveHandler) Handle(c *crawl.Crawler, u string, depth int, resp *http.Response, err error) error { func (h *warcSaveHandler) Handle(c *crawl.Crawler, u string, depth int, resp *http.Response, err error) error {
if err != nil {
return err
}
// Read the response body (so we can save it to the WARC
// output) and replace it with a buffer.
data, derr := ioutil.ReadAll(resp.Body) data, derr := ioutil.ReadAll(resp.Body)
if derr != nil { if derr != nil {
return err return derr
} }
resp.Body = &fakeCloser{bytes.NewReader(data)} resp.Body = ioutil.NopCloser(bytes.NewReader(data))
// Dump the request. // Dump the request to the WARC output.
var b bytes.Buffer var b bytes.Buffer
resp.Request.Write(&b) if werr := resp.Request.Write(&b); werr != nil {
hdr := warc.NewHeader() return werr
hdr.Set("WARC-Type", "request") }
hdr.Set("WARC-Target-URI", resp.Request.URL.String()) if werr := h.writeWARCRecord("request", resp.Request.URL.String(), b.Bytes()); werr != nil {
hdr.Set("WARC-Warcinfo-ID", h.warcInfoID) return werr
hdr.Set("Content-Length", strconv.Itoa(b.Len())) }
w := h.warc.NewRecord(hdr)
w.Write(b.Bytes())
w.Close()
// Dump the response. // Dump the response.
statusLine := fmt.Sprintf("HTTP/1.1 %s", resp.Status) statusLine := fmt.Sprintf("HTTP/1.1 %s", resp.Status)
respPayload := bytes.Join([][]byte{ respPayload := bytes.Join([][]byte{
[]byte(statusLine), hdr2str(resp.Header), data}, []byte(statusLine), hdr2str(resp.Header), data},
[]byte{'\r', '\n'}) []byte{'\r', '\n'})
hdr = warc.NewHeader() if werr := h.writeWARCRecord("response", resp.Request.URL.String(), respPayload); werr != nil {
hdr.Set("WARC-Type", "response") return werr
hdr.Set("WARC-Target-URI", resp.Request.URL.String()) }
hdr.Set("WARC-Warcinfo-ID", h.warcInfoID)
hdr.Set("Content-Length", strconv.Itoa(len(respPayload)))
w = h.warc.NewRecord(hdr)
w.Write(respPayload)
w.Close()
return extractLinks(c, u, depth, resp, err) return extractLinks(c, u, depth, resp, err)
} }
func newWarcSaveHandler(w *warc.Writer) crawl.Handler { func newWarcSaveHandler(w *warc.Writer) (crawl.Handler, error) {
info := strings.Join([]string{ info := strings.Join([]string{
"Software: crawl/1.0\r\n", "Software: crawl/1.0\r\n",
"Format: WARC File Format 1.0\r\n", "Format: WARC File Format 1.0\r\n",
...@@ -116,13 +125,18 @@ func newWarcSaveHandler(w *warc.Writer) crawl.Handler { ...@@ -116,13 +125,18 @@ func newWarcSaveHandler(w *warc.Writer) crawl.Handler {
hdr.Set("WARC-Type", "warcinfo") hdr.Set("WARC-Type", "warcinfo")
hdr.Set("WARC-Warcinfo-ID", hdr.Get("WARC-Record-ID")) hdr.Set("WARC-Warcinfo-ID", hdr.Get("WARC-Record-ID"))
hdr.Set("Content-Length", strconv.Itoa(len(info))) hdr.Set("Content-Length", strconv.Itoa(len(info)))
hdrw := w.NewRecord(hdr) hdrw, err := w.NewRecord(hdr)
io.WriteString(hdrw, info) if err != nil {
hdrw.Close() return nil, err
}
if _, err := io.WriteString(hdrw, info); err != nil {
return nil, err
}
hdrw.Close() // nolint
return &warcSaveHandler{ return &warcSaveHandler{
warc: w, warc: w,
warcInfoID: hdr.Get("WARC-Record-ID"), warcInfoID: hdr.Get("WARC-Record-ID"),
} }, nil
} }
type crawlStats struct { type crawlStats struct {
...@@ -149,7 +163,7 @@ func (c *crawlStats) Dump() { ...@@ -149,7 +163,7 @@ func (c *crawlStats) Dump() {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
rate := float64(c.bytes) / time.Since(c.start).Seconds() / 1000 rate := float64(c.bytes) / time.Since(c.start).Seconds() / 1000
fmt.Fprintf(os.Stderr, "stats: downloaded %d bytes (%.4g KB/s), status: %v\n", c.bytes, rate, c.states) fmt.Fprintf(os.Stderr, "stats: downloaded %d bytes (%.4g KB/s), status: %v\n", c.bytes, rate, c.states) // nolint
} }
var stats *crawlStats var stats *crawlStats
...@@ -201,11 +215,6 @@ func main() { ...@@ -201,11 +215,6 @@ func main() {
defer pprof.StopCPUProfile() defer pprof.StopCPUProfile()
} }
outf, err := os.Create(*outputFile)
if err != nil {
log.Fatal(err)
}
seeds := crawl.MustParseURLs(flag.Args()) seeds := crawl.MustParseURLs(flag.Args())
scope := crawl.AND( scope := crawl.AND(
crawl.NewSchemeScope(strings.Split(*validSchemes, ",")), crawl.NewSchemeScope(strings.Split(*validSchemes, ",")),
...@@ -217,10 +226,17 @@ func main() { ...@@ -217,10 +226,17 @@ func main() {
scope = crawl.OR(scope, crawl.NewIncludeRelatedScope()) scope = crawl.OR(scope, crawl.NewIncludeRelatedScope())
} }
outf, err := os.Create(*outputFile)
if err != nil {
log.Fatal(err)
}
w := warc.NewWriter(outf) w := warc.NewWriter(outf)
defer w.Close() defer w.Close() // nolint
saver := newWarcSaveHandler(w) saver, err := newWarcSaveHandler(w)
if err != nil {
log.Fatal(err)
}
crawler, err := crawl.NewCrawler(*dbPath, seeds, scope, crawl.FetcherFunc(fetch), crawl.NewRedirectHandler(saver)) crawler, err := crawl.NewCrawler(*dbPath, seeds, scope, crawl.FetcherFunc(fetch), crawl.NewRedirectHandler(saver))
if err != nil { if err != nil {
...@@ -240,13 +256,12 @@ func main() { ...@@ -240,13 +256,12 @@ func main() {
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
crawler.Run(*concurrency) crawler.Run(*concurrency)
crawler.Close() crawler.Close()
if signaled.Load().(bool) { if signaled.Load().(bool) {
os.Exit(1) os.Exit(1)
} }
if !*keepDb { if !*keepDb {
os.RemoveAll(*dbPath) os.RemoveAll(*dbPath) // nolint
} }
} }
...@@ -15,20 +15,25 @@ import ( ...@@ -15,20 +15,25 @@ import (
) )
var ( var (
dbPath = flag.String("state", "crawldb", "crawl state database path")
concurrency = flag.Int("c", 10, "concurrent workers") concurrency = flag.Int("c", 10, "concurrent workers")
depth = flag.Int("depth", 10, "maximum link depth") depth = flag.Int("depth", 10, "maximum link depth")
validSchemes = flag.String("schemes", "http,https", "comma-separated list of allowed protocols") validSchemes = flag.String("schemes", "http,https", "comma-separated list of allowed protocols")
) )
func extractLinks(c *crawl.Crawler, u string, depth int, resp *http.Response, err error) error { func extractLinks(c *crawl.Crawler, u string, depth int, resp *http.Response, err error) error {
if err != nil {
return err
}
links, err := analysis.GetLinks(resp) links, err := analysis.GetLinks(resp)
if err != nil { if err != nil {
return err return err
} }
for _, link := range links { for _, link := range links {
c.Enqueue(link, depth+1) if err := c.Enqueue(link, depth+1); err != nil {
return err
}
} }
return nil return nil
...@@ -49,4 +54,5 @@ func main() { ...@@ -49,4 +54,5 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
crawler.Run(*concurrency) crawler.Run(*concurrency)
crawler.Close()
} }
...@@ -144,10 +144,10 @@ type Crawler struct { ...@@ -144,10 +144,10 @@ type Crawler struct {
} }
// Enqueue a (possibly new) URL for processing. // Enqueue a (possibly new) URL for processing.
func (c *Crawler) Enqueue(link Outlink, depth int) { func (c *Crawler) Enqueue(link Outlink, depth int) error {
// See if it's in scope. // See if it's in scope.
if !c.scope.Check(link, depth) { if !c.scope.Check(link, depth) {
return return nil
} }
// Normalize the URL. // Normalize the URL.
...@@ -161,16 +161,20 @@ func (c *Crawler) Enqueue(link Outlink, depth int) { ...@@ -161,16 +161,20 @@ func (c *Crawler) Enqueue(link Outlink, depth int) {
var info URLInfo var info URLInfo
ukey := []byte(fmt.Sprintf("url/%s", urlStr)) ukey := []byte(fmt.Sprintf("url/%s", urlStr))
if err := c.db.GetObj(ukey, &info); err == nil { if err := c.db.GetObj(ukey, &info); err == nil {
return return nil
} }
// Store the URL in the queue, and store an empty URLInfo to // Store the URL in the queue, and store an empty URLInfo to
// make sure that subsequent calls to Enqueue with the same // make sure that subsequent calls to Enqueue with the same
// URL will fail. // URL will fail.
wb := new(leveldb.Batch) wb := new(leveldb.Batch)
c.queue.Add(wb, urlStr, depth, time.Now()) if err := c.queue.Add(wb, urlStr, depth, time.Now()); err != nil {
c.db.PutObjBatch(wb, ukey, &info) return err
c.db.Write(wb, nil) }
if err := c.db.PutObjBatch(wb, ukey, &info); err != nil {
return err
}
return c.db.Write(wb, nil)
} }
var scanInterval = 1 * time.Second var scanInterval = 1 * time.Second
...@@ -210,7 +214,7 @@ func (c *Crawler) urlHandler(queue <-chan queuePair) { ...@@ -210,7 +214,7 @@ func (c *Crawler) urlHandler(queue <-chan queuePair) {
// Ignore errors, we can work with an empty object. // Ignore errors, we can work with an empty object.
urlkey := []byte(fmt.Sprintf("url/%s", p.URL)) urlkey := []byte(fmt.Sprintf("url/%s", p.URL))
var info URLInfo var info URLInfo
c.db.GetObj(urlkey, &info) c.db.GetObj(urlkey, &info) // nolint
info.CrawledAt = time.Now() info.CrawledAt = time.Now()
info.URL = p.URL info.URL = p.URL
...@@ -230,18 +234,17 @@ func (c *Crawler) urlHandler(queue <-chan queuePair) { ...@@ -230,18 +234,17 @@ func (c *Crawler) urlHandler(queue <-chan queuePair) {
wb := new(leveldb.Batch) wb := new(leveldb.Batch)
if httpErr == nil { if httpErr == nil {
respBody.Close() respBody.Close() // nolint
// Remove the URL from the queue if the fetcher was successful. // Remove the URL from the queue if the fetcher was successful.
c.queue.Release(wb, p) c.queue.Release(wb, p)
} else { } else {
log.Printf("error retrieving %s: %v", p.URL, httpErr) log.Printf("error retrieving %s: %v", p.URL, httpErr)
c.queue.Retry(wb, p, 300*time.Second) Must(c.queue.Retry(wb, p, 300*time.Second))
} }
c.db.PutObjBatch(wb, urlkey, &info) Must(c.db.PutObjBatch(wb, urlkey, &info))
Must(c.db.Write(wb, nil))
c.db.Write(wb, nil)
} }
} }
...@@ -279,7 +282,9 @@ func NewCrawler(path string, seeds []*url.URL, scope Scope, f Fetcher, h Handler ...@@ -279,7 +282,9 @@ func NewCrawler(path string, seeds []*url.URL, scope Scope, f Fetcher, h Handler
c.stopping.Store(false) c.stopping.Store(false)
// Recover active tasks. // Recover active tasks.
c.queue.Recover() if err := c.queue.Recover(); err != nil {
return nil, err
}
return c, nil return c, nil
} }
...@@ -289,7 +294,7 @@ func NewCrawler(path string, seeds []*url.URL, scope Scope, f Fetcher, h Handler ...@@ -289,7 +294,7 @@ func NewCrawler(path string, seeds []*url.URL, scope Scope, f Fetcher, h Handler
func (c *Crawler) Run(concurrency int) { func (c *Crawler) Run(concurrency int) {
// Load initial seeds into the queue. // Load initial seeds into the queue.
for _, u := range c.seeds { for _, u := range c.seeds {
c.Enqueue(Outlink{URL: u, Tag: TagPrimary}, 0) Must(c.Enqueue(Outlink{URL: u, Tag: TagPrimary}, 0))
} }
// Start some runners and wait until they're done. // Start some runners and wait until they're done.
...@@ -313,7 +318,7 @@ func (c *Crawler) Stop() { ...@@ -313,7 +318,7 @@ func (c *Crawler) Stop() {
// Close the database and release resources associated with the crawler state. // Close the database and release resources associated with the crawler state.
func (c *Crawler) Close() { func (c *Crawler) Close() {
c.db.Close() c.db.Close() // nolint
} }
type redirectHandler struct { type redirectHandler struct {
...@@ -330,11 +335,11 @@ func (wrap *redirectHandler) Handle(c *Crawler, u string, depth int, resp *http. ...@@ -330,11 +335,11 @@ func (wrap *redirectHandler) Handle(c *Crawler, u string, depth int, resp *http.
} else if resp.StatusCode > 300 && resp.StatusCode < 400 { } else if resp.StatusCode > 300 && resp.StatusCode < 400 {
location := resp.Header.Get("Location") location := resp.Header.Get("Location")
if location != "" { if location != "" {
locationURL, err := resp.Request.URL.Parse(location) locationURL, uerr := resp.Request.URL.Parse(location)
if err != nil { if uerr != nil {
log.Printf("error parsing Location header: %v", err) log.Printf("error parsing Location header: %v", uerr)
} else { } else {
c.Enqueue(Outlink{URL: locationURL, Tag: TagPrimary}, depth+1) Must(c.Enqueue(Outlink{URL: locationURL, Tag: TagPrimary}, depth+1))
} }
} }
} else { } else {
...@@ -348,3 +353,11 @@ func (wrap *redirectHandler) Handle(c *Crawler, u string, depth int, resp *http. ...@@ -348,3 +353,11 @@ func (wrap *redirectHandler) Handle(c *Crawler, u string, depth int, resp *http.
func NewRedirectHandler(wrap Handler) Handler { func NewRedirectHandler(wrap Handler) Handler {
return &redirectHandler{wrap} return &redirectHandler{wrap}
} }
// Must will abort the program with a message when we encounter an
// error that we can't recover from.
func Must(err error) {
if err != nil {
log.Fatalf("fatal error: %v", err)
}
}
...@@ -20,8 +20,7 @@ var ( ...@@ -20,8 +20,7 @@ var (
queuePrefix = []byte("queue") queuePrefix = []byte("queue")
activePrefix = []byte("queue_active") activePrefix = []byte("queue_active")
queueKeySep = []byte{'/'} queueKeySep = []byte{'/'}
queueKeySepP1 = []byte{'/' + 1}
) )
type queuePair struct { type queuePair struct {
...@@ -45,7 +44,9 @@ func (q *queue) Scan(ch chan<- queuePair) error { ...@@ -45,7 +44,9 @@ func (q *queue) Scan(ch chan<- queuePair) error {
continue continue
} }
p.key = iter.Key() p.key = iter.Key()
q.acquire(p) if err := q.acquire(p); err != nil {
return err
}
ch <- p ch <- p
n++ n++
} }
...@@ -57,19 +58,24 @@ func (q *queue) Scan(ch chan<- queuePair) error { ...@@ -57,19 +58,24 @@ func (q *queue) Scan(ch chan<- queuePair) error {
} }
// Add an item to the pending work queue. // Add an item to the pending work queue.
func (q *queue) Add(wb *leveldb.Batch, urlStr string, depth int, when time.Time) { func (q *queue) Add(wb *leveldb.Batch, urlStr string, depth int, when time.Time) error {
t := uint64(when.UnixNano()) t := uint64(when.UnixNano())
qkey := bytes.Join([][]byte{queuePrefix, encodeUint64(t), encodeUint64(uint64(rand.Int63()))}, queueKeySep) qkey := bytes.Join([][]byte{queuePrefix, encodeUint64(t), encodeUint64(uint64(rand.Int63()))}, queueKeySep)
q.db.PutObjBatch(wb, qkey, &queuePair{URL: urlStr, Depth: depth}) return q.db.PutObjBatch(wb, qkey, &queuePair{URL: urlStr, Depth: depth})
} }
func (q *queue) acquire(qp queuePair) { func (q *queue) acquire(qp queuePair) error {
wb := new(leveldb.Batch) wb := new(leveldb.Batch)
q.db.PutObjBatch(wb, activeQueueKey(qp.key), qp) if err := q.db.PutObjBatch(wb, activeQueueKey(qp.key), qp); err != nil {
return err
}
wb.Delete(qp.key) wb.Delete(qp.key)
q.db.Write(wb, nil) if err := q.db.Write(wb, nil); err != nil {
return err
}
atomic.AddInt32(&q.numActive, 1) atomic.AddInt32(&q.numActive, 1)
return nil
} }
// Release an item from the queue. Processing for this item is done. // Release an item from the queue. Processing for this item is done.
...@@ -79,16 +85,19 @@ func (q *queue) Release(wb *leveldb.Batch, qp queuePair) { ...@@ -79,16 +85,19 @@ func (q *queue) Release(wb *leveldb.Batch, qp queuePair) {
} }
// Retry processing this item at a later time. // Retry processing this item at a later time.
func (q *queue) Retry(wb *leveldb.Batch, qp queuePair, delay time.Duration) { func (q *queue) Retry(wb *leveldb.Batch, qp queuePair, delay time.Duration) error {
wb.Delete(activeQueueKey(qp.key)) wb.Delete(activeQueueKey(qp.key))
q.Add(wb, qp.URL, qp.Depth, time.Now().Add(delay)) if err := q.Add(wb, qp.URL, qp.Depth, time.Now().Add(delay)); err != nil {
return err
}
atomic.AddInt32(&q.numActive, -1) atomic.AddInt32(&q.numActive, -1)
return nil
} }
// Recover moves all active tasks to the pending queue. To be // Recover moves all active tasks to the pending queue. To be
// called at startup to recover tasks that were active when the // called at startup to recover tasks that were active when the
// previous run terminated. // previous run terminated.
func (q *queue) Recover() { func (q *queue) Recover() error {
wb := new(leveldb.Batch) wb := new(leveldb.Batch)
prefix := bytes.Join([][]byte{activePrefix, []byte{}}, queueKeySep) prefix := bytes.Join([][]byte{activePrefix, []byte{}}, queueKeySep)
...@@ -100,11 +109,13 @@ func (q *queue) Recover() { ...@@ -100,11 +109,13 @@ func (q *queue) Recover() {
continue continue
} }
p.key = iter.Key()[len(activePrefix)+1:] p.key = iter.Key()[len(activePrefix)+1:]
q.db.PutObjBatch(wb, p.key, &p) if err := q.db.PutObjBatch(wb, p.key, &p); err != nil {
return err
}
wb.Delete(iter.Key()) wb.Delete(iter.Key())
} }
q.db.Write(wb, nil) return q.db.Write(wb, nil)
} }
func encodeUint64(n uint64) []byte { func encodeUint64(n uint64) []byte {
......
...@@ -47,12 +47,17 @@ func (h Header) Get(key string) string { ...@@ -47,12 +47,17 @@ func (h Header) Get(key string) string {
} }