From ddecd3760f89ef9ef7765bf445024c402ab4ad5a Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Tue, 19 Dec 2017 09:05:37 +0000
Subject: [PATCH] Exit gracefully on signals

---
 cmd/crawl/crawl.go | 19 +++++++++++++++++++
 crawler.go         | 38 +++++++++++++++++++++++++++++++++-----
 2 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go
index abf2b42..e7e8582 100644
--- a/cmd/crawl/crawl.go
+++ b/cmd/crawl/crawl.go
@@ -11,11 +11,13 @@ import (
 	"log"
 	"net/http"
 	"os"
+	"os/signal"
 	"runtime/pprof"
 	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
+	"syscall"
 	"time"
 
 	"git.autistici.org/ale/crawl"
@@ -224,9 +226,26 @@ func main() {
 	if err != nil {
 		log.Fatal(err)
 	}
+
+	// Set up signal handlers so we can terminate gently if possible.
+	var signaled atomic.Value
+	signaled.Store(false)
+	sigCh := make(chan os.Signal, 1)
+	go func() {
+		<-sigCh
+		log.Printf("exiting due to signal")
+		signaled.Store(true)
+		crawler.Stop()
+	}()
+	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
+
 	crawler.Run(*concurrency)
 
 	crawler.Close()
+
+	if signaled.Load().(bool) {
+		os.Exit(1)
+	}
 	if !*keepDb {
 		os.RemoveAll(*dbPath)
 	}
diff --git a/crawler.go b/crawler.go
index b3c4a7b..aef628f 100644
--- a/crawler.go
+++ b/crawler.go
@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"net/url"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/PuerkitoBio/purell"
@@ -136,6 +137,9 @@ type Crawler struct {
 	fetcher Fetcher
 	handler Handler
 
+	stopCh   chan bool
+	stopping atomic.Value
+
 	enqueueMx sync.Mutex
 }
 
@@ -169,16 +173,25 @@ func (c *Crawler) Enqueue(link Outlink, depth int) {
 	c.db.Write(wb, nil)
 }
 
+var scanInterval = 1 * time.Second
+
 // Scan the queue for URLs until there are no more.
 func (c *Crawler) process() <-chan queuePair {
-	ch := make(chan queuePair)
+	ch := make(chan queuePair, 100)
 	go func() {
-		for range time.Tick(2 * time.Second) {
-			if err := c.queue.Scan(ch); err != nil {
-				break
+		t := time.NewTicker(scanInterval)
+		defer t.Stop()
+		defer close(ch)
+		for {
+			select {
+			case <-t.C:
+				if err := c.queue.Scan(ch); err != nil {
+					return
+				}
+			case <-c.stopCh:
+				return
 			}
 		}
-		close(ch)
 	}()
 	return ch
 }
@@ -186,6 +199,13 @@ func (c *Crawler) process() <-chan queuePair {
 // Main worker loop.
 func (c *Crawler) urlHandler(queue <-chan queuePair) {
 	for p := range queue {
+		// Stop flag needs to short-circuit the queue (which
+		// is buffered), or it's going to take a while before
+		// we actually stop.
+		if c.stopping.Load().(bool) {
+			return
+		}
+
 		// Retrieve the URLInfo object from the crawl db.
 		// Ignore errors, we can work with an empty object.
 		urlkey := []byte(fmt.Sprintf("url/%s", p.URL))
@@ -254,7 +274,9 @@ func NewCrawler(path string, seeds []*url.URL, scope Scope, f Fetcher, h Handler
 		handler: h,
 		seeds:   seeds,
 		scope:   scope,
+		stopCh:  make(chan bool),
 	}
+	c.stopping.Store(false)
 
 	// Recover active tasks.
 	c.queue.Recover()
@@ -283,6 +305,12 @@ func (c *Crawler) Run(concurrency int) {
 	wg.Wait()
 }
 
+// Stop a running crawl. This will cause a running Run function to return.
+func (c *Crawler) Stop() {
+	c.stopping.Store(true)
+	close(c.stopCh)
+}
+
 // Close the database and release resources associated with the crawler state.
 func (c *Crawler) Close() {
 	c.db.Close()
-- 
GitLab