Skip to content
Snippets Groups Projects
Commit 679daf44 authored by ale's avatar ale
Browse files

Close pending connections on shutdown in the UNIX socket server

Since the idle time can be long, persistent connections could have
prevented a timely server shutdown.
parent 95125bd5
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ package unix
import (
"bufio"
"container/list"
"context"
"errors"
"io"
......@@ -32,6 +33,11 @@ type SocketServer struct {
closing atomic.Value
wg sync.WaitGroup
handler Handler
// Keep track of active connections so we can shut them down
// on Close.
connMx sync.Mutex
conns list.List
}
func newServer(l net.Listener, lock *flock.Flock, h Handler) *SocketServer {
......@@ -99,7 +105,18 @@ func NewSystemdSocketServer(h Handler) (*SocketServer, error) {
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
s.closing.Store(true)
// Close the listener to stop incoming connections.
s.l.Close() // nolint
// Close all active connections (this will return an error to
// the client if the connection is not idle).
s.connMx.Lock()
for el := s.conns.Front(); el != nil; el = el.Next() {
el.Value.(net.Conn).Close() // nolint
}
s.connMx.Unlock()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock() // nolint
......@@ -120,10 +137,21 @@ func (s *SocketServer) Serve() error {
}
return err
}
s.wg.Add(1)
s.connMx.Lock()
connEl := s.conns.PushBack(conn)
s.connMx.Unlock()
go func() {
s.handler.ServeConnection(conn)
conn.Close() // nolint
if !s.isClosing() {
s.connMx.Lock()
s.conns.Remove(connEl)
s.connMx.Unlock()
}
s.wg.Done()
}()
}
......
package unix
import (
"context"
"io/ioutil"
"net/textproto"
"os"
"path/filepath"
"testing"
)
type fakeServer struct{}
func (f *fakeServer) ServeLine(_ context.Context, lw LineResponseWriter, _ []byte) error {
return lw.WriteLineCRLF([]byte("hello"))
}
func doRequests(socketPath string, n int) error {
conn, err := textproto.Dial("unix", socketPath)
if err != nil {
return err
}
defer conn.Close()
for i := 0; i < n; i++ {
if err := conn.PrintfLine("request"); err != nil {
return err
}
if _, err := conn.ReadLine(); err != nil {
return err
}
}
return nil
}
func doConcurrentRequests(socketPath string, n, conns int) error {
start := make(chan bool)
errCh := make(chan error, conns)
for i := 0; i < conns; i++ {
go func() {
<-start
errCh <- doRequests(socketPath, n)
}()
}
close(start)
var err error
for i := 0; i < conns; i++ {
if werr := <-errCh; werr != nil && err == nil {
err = werr
}
}
return err
}
func TestServer(t *testing.T) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
socketPath := filepath.Join(dir, "test.sock")
srv, err := NewUNIXSocketServer(socketPath, NewLineServer(&fakeServer{}))
if err != nil {
t.Fatalf("NewUNIXSocketServer: %v", err)
}
go srv.Serve() // nolint
defer srv.Close()
if err := doConcurrentRequests(socketPath, 1000, 10); err != nil {
t.Fatalf("request error: %v", err)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment