Skip to content
Snippets Groups Projects
Commit 3a0bd89b authored by ale's avatar ale
Browse files

Add module to run socket-based servers

parent dd6697d2
No related branches found
No related tags found
No related merge requests found
package unix
import (
"context"
"errors"
"io"
"log"
"net"
"net/textproto"
"os"
"sync"
"sync/atomic"
"time"
"github.com/coreos/go-systemd/activation"
"github.com/theckman/go-flock"
)
// Handler for UNIX socket server connections.
type Handler interface {
ServeConnection(c net.Conn)
}
// HandlerFunc is a function adapter for Handler.
type HandlerFunc func(net.Conn)
func (f HandlerFunc) ServeConnection(c net.Conn) { f(c) }
// SocketServer accepts connections on a UNIX socket, speaking the
// line-based wire protocol, and dispatches incoming requests to the
// wrapped Server.
type SocketServer struct {
l net.Listener
lock *flock.Flock
closing atomic.Value
wg sync.WaitGroup
handler Handler
}
func newServer(l net.Listener, lock *flock.Flock, h Handler) *SocketServer {
s := &SocketServer{
l: l,
lock: lock,
handler: h,
}
s.closing.Store(false)
return s
}
// NewUNIXSocketServer returns a new SocketServer listening on the given path.
func NewUNIXSocketServer(socketPath string, h Handler) (*SocketServer, error) {
// The simplest workflow is: create a new socket, remove it on
// exit. However, if the program crashes, the socket might
// stick around and prevent the next execution from starting
// successfully. We could remove it before starting, but that
// would be dangerous if another instance was listening on
// that socket. So we wrap socket access with a file lock.
lock := flock.NewFlock(socketPath + ".lock")
locked, err := lock.TryLock()
if err != nil {
return nil, err
}
if !locked {
return nil, errors.New("socket is locked by another process")
}
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
// Always try to unlink the socket before creating it.
os.Remove(socketPath)
l, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return newServer(l, lock, h), nil
}
// NewSystemdSocketServer uses systemd socket activation, receiving
// the open socket as a file descriptor on exec.
func NewSystemdSocketServer(h Handler) (*SocketServer, error) {
listeners, err := activation.Listeners(false)
if err != nil {
return nil, err
}
// Our server loop implies a single listener, so find
// the first one passed by systemd and ignore all others.
// TODO: listen on all fds.
for _, l := range listeners {
if l != nil {
return newServer(l, nil, h), nil
}
}
return nil, errors.New("no available sockets found")
}
// Close the socket listener and release all associated resources.
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
s.closing.Store(true)
s.l.Close()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock()
}
}
func (s *SocketServer) isClosing() bool {
return s.closing.Load().(bool)
}
// Serve connections.
func (s *SocketServer) Serve() error {
for {
conn, err := s.l.Accept()
if err != nil {
if s.isClosing() {
return nil
}
return err
}
s.wg.Add(1)
go func() {
s.handler.ServeConnection(conn)
conn.Close()
s.wg.Done()
}()
}
}
// LineHandler is the handler for LineServer.
type LineHandler interface {
ServeLine(context.Context, []byte) (string, error)
}
// LineServer implements a line-based text protocol. It satisfies the
// Handler interface.
type LineServer struct {
handler LineHandler
IdleTimeout time.Duration
WriteTimeout time.Duration
RequestTimeout time.Duration
}
var (
defaultIdleTimeout = 600 * time.Second
defaultWriteTimeout = 10 * time.Second
defaultRequestTimeout = 30 * time.Second
)
// NewLineServer returns a new LineServer with the given handler and
// default I/O timeouts.
func NewLineServer(h LineHandler) *LineServer {
return &LineServer{
handler: h,
IdleTimeout: defaultIdleTimeout,
WriteTimeout: defaultWriteTimeout,
RequestTimeout: defaultRequestTimeout,
}
}
func (l *LineServer) ServeConnection(nc net.Conn) {
c := textproto.NewConn(nc)
for {
nc.SetReadDeadline(time.Now().Add(l.IdleTimeout))
line, err := c.ReadLineBytes()
if err == io.EOF {
break
} else if err != nil {
log.Printf("client error: %v", err)
break
}
// Create a context for the request and call the handler with it.
ctx, cancel := context.WithTimeout(context.Background(), l.RequestTimeout)
response, err := l.handler.ServeLine(ctx, line)
cancel()
// Close the connection on error, or on empty response.
if response != "" {
nc.SetWriteDeadline(time.Now().Add(l.WriteTimeout))
c.PrintfLine(response)
}
if err != nil {
log.Printf("request error: %v", err)
break
}
if response == "" {
break
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment