Commit 679daf44 authored by ale's avatar ale

Close pending connections on shutdown in the UNIX socket server

Since the idle time can be long, persistent connections could have
prevented a timely server shutdown.
parent 95125bd5
......@@ -2,6 +2,7 @@ package unix
import (
"bufio"
"container/list"
"context"
"errors"
"io"
......@@ -32,6 +33,11 @@ type SocketServer struct {
closing atomic.Value
wg sync.WaitGroup
handler Handler
// Keep track of active connections so we can shut them down
// on Close.
connMx sync.Mutex
conns list.List
}
func newServer(l net.Listener, lock *flock.Flock, h Handler) *SocketServer {
......@@ -99,7 +105,18 @@ func NewSystemdSocketServer(h Handler) (*SocketServer, error) {
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
s.closing.Store(true)
// Close the listener to stop incoming connections.
s.l.Close() // nolint
// Close all active connections (this will return an error to
// the client if the connection is not idle).
s.connMx.Lock()
for el := s.conns.Front(); el != nil; el = el.Next() {
el.Value.(net.Conn).Close() // nolint
}
s.connMx.Unlock()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock() // nolint
......@@ -120,10 +137,21 @@ func (s *SocketServer) Serve() error {
}
return err
}
s.wg.Add(1)
s.connMx.Lock()
connEl := s.conns.PushBack(conn)
s.connMx.Unlock()
go func() {
s.handler.ServeConnection(conn)
conn.Close() // nolint
if !s.isClosing() {
s.connMx.Lock()
s.conns.Remove(connEl)
s.connMx.Unlock()
}
s.wg.Done()
}()
}
......
package unix
import (
"context"
"io/ioutil"
"net/textproto"
"os"
"path/filepath"
"testing"
)
type fakeServer struct{}
func (f *fakeServer) ServeLine(_ context.Context, lw LineResponseWriter, _ []byte) error {
return lw.WriteLineCRLF([]byte("hello"))
}
func doRequests(socketPath string, n int) error {
conn, err := textproto.Dial("unix", socketPath)
if err != nil {
return err
}
defer conn.Close()
for i := 0; i < n; i++ {
if err := conn.PrintfLine("request"); err != nil {
return err
}
if _, err := conn.ReadLine(); err != nil {
return err
}
}
return nil
}
func doConcurrentRequests(socketPath string, n, conns int) error {
start := make(chan bool)
errCh := make(chan error, conns)
for i := 0; i < conns; i++ {
go func() {
<-start
errCh <- doRequests(socketPath, n)
}()
}
close(start)
var err error
for i := 0; i < conns; i++ {
if werr := <-errCh; werr != nil && err == nil {
err = werr
}
}
return err
}
func TestServer(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
socketPath := filepath.Join(dir, "test.sock")
srv, err := NewUNIXSocketServer(socketPath, NewLineServer(&fakeServer{}))
if err != nil {
t.Fatalf("NewUNIXSocketServer: %v", err)
}
go srv.Serve() // nolint
defer srv.Close()
if err := doConcurrentRequests(socketPath, 1000, 10); err != nil {
t.Fatalf("request error: %v", err)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment