Commit 7bc37e1c authored by ale's avatar ale
Browse files

Use ai3/go-common for the UNIX socket server code

parent dcf0b649
......@@ -9,9 +9,11 @@ import (
"os/signal"
"syscall"
"git.autistici.org/id/auth/server"
"git.autistici.org/ai3/go-common/unix"
"github.com/coreos/go-systemd/daemon"
"github.com/prometheus/client_golang/prometheus/promhttp"
"git.autistici.org/id/auth/server"
)
var (
......@@ -55,11 +57,13 @@ func main() {
}()
}
var sockSrv *server.SocketServer
srv := unix.NewLineServer(server.NewSocketServer(authSrv))
var sockSrv *unix.SocketServer
if *systemdSocketActivation {
sockSrv, err = server.NewSystemdSocketServer(authSrv)
sockSrv, err = unix.NewSystemdSocketServer(srv)
} else {
sockSrv, err = server.NewSocketServer(*socketPath, authSrv)
sockSrv, err = unix.NewUNIXSocketServer(*socketPath, srv)
}
if err != nil {
log.Fatalf("error: %v", err)
......@@ -77,7 +81,9 @@ func main() {
log.Printf("starting")
daemon.SdNotify(false, "READY=1")
sockSrv.Serve()
if err := sockSrv.Serve(); err != nil {
log.Fatal(err)
}
<-done
}
......@@ -5,182 +5,52 @@ import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/textproto"
"os"
"sync"
"time"
"git.autistici.org/id/auth"
"github.com/coreos/go-systemd/activation"
flock "github.com/theckman/go-flock"
)
// SocketServer accepts connections on a UNIX socket, speaking the
// line-based wire protocol, and dispatches incoming requests to the
// wrapped Server.
type SocketServer struct {
l net.Listener
lock *flock.Flock
closingMx sync.Mutex
closing bool
wg sync.WaitGroup
codec auth.Codec
auth *Server
requestTimeout time.Duration
codec auth.Codec
auth *Server
}
// NewSocketServer returns a new SocketServer listening on the given path.
func NewSocketServer(socketPath string, authServer *Server) (*SocketServer, error) {
// The simplest workflow is: create a new socket, remove it on
// exit. However, if the program crashes, the socket might
// stick around and prevent the next execution from starting
// successfully. We could remove it before starting, but that
// would be dangerous if another instance was listening on
// that socket. So we wrap socket access with a file lock.
lock := flock.NewFlock(socketPath + ".lock")
locked, err := lock.TryLock()
if err != nil {
return nil, err
}
if !locked {
return nil, errors.New("socket is locked by another process")
}
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
// Always try to unlink the socket before creating it.
os.Remove(socketPath)
l, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
func NewSocketServer(authServer *Server) *SocketServer {
return &SocketServer{
l: l,
lock: lock,
auth: authServer,
codec: auth.DefaultCodec,
requestTimeout: 3 * time.Second,
}, nil
}
// NewSystemdSocketServer uses systemd socket activation, receiving
// the open socket as a file descriptor on exec.
func NewSystemdSocketServer(authServer *Server) (*SocketServer, error) {
listeners, err := activation.Listeners(false)
if err != nil {
return nil, err
}
// Our server loop implies a single listener, so find
// the first one passed by systemd and ignore all others.
// TODO: listen on all fds.
for _, l := range listeners {
if l != nil {
return &SocketServer{
l: l,
auth: authServer,
codec: auth.DefaultCodec,
requestTimeout: 3 * time.Second,
}, nil
}
auth: authServer,
codec: auth.DefaultCodec,
}
return nil, errors.New("no available sockets found")
}
// Close the socket listener and release all associated resources.
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
// The lock is there to ensure cross-goroutine synchronization.
s.closingMx.Lock()
s.closing = true
s.closingMx.Unlock()
s.l.Close()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock()
func (s *SocketServer) ServeLine(ctx context.Context, line []byte) (string, error) {
// Parse the incoming command. The only two known
// commands are 'auth' for an authentication request,
// and 'quit' to terminate the connection (closing the
// connection also works). We shut down the connection
// on any protocol error.
parts := bytes.SplitN(line, []byte{' '}, 2)
nargs := len(parts)
cmd := string(parts[0])
switch {
case nargs == 1 && cmd == "quit":
return "", nil
case nargs == 2 && cmd == "auth":
return s.handleAuth(ctx, parts[1])
default:
return "", errors.New("syntax error")
}
}
func (s *SocketServer) isClosing() bool {
s.closingMx.Lock()
defer s.closingMx.Unlock()
return s.closing
}
// Serve connections.
func (s *SocketServer) Serve() {
for {
conn, err := s.l.Accept()
if err != nil {
if !s.isClosing() {
log.Printf("Accept() error: %v", err)
}
return
}
s.wg.Add(1)
go func() {
s.handle(conn)
s.wg.Done()
}()
}
}
func (s *SocketServer) handle(nc net.Conn) {
defer nc.Close()
c := textproto.NewConn(nc)
for {
line, err := c.ReadLineBytes()
if err == io.EOF {
break
} else if err != nil {
log.Printf("client connection error: %v", err)
break
}
// Parse the incoming command. The only two known
// commands are 'auth' for an authentication request,
// and 'quit' to terminate the connection (closing the
// connection also works). We shut down the connection
// on any protocol error.
parts := bytes.SplitN(line, []byte{' '}, 2)
nargs := len(parts)
cmd := string(parts[0])
switch {
case nargs == 1 && cmd == "quit":
c.PrintfLine("bye")
return
case nargs == 2 && cmd == "auth":
if err := s.handleAuth(c, parts[1]); err != nil {
log.Printf("error in auth: %v", err)
return
}
default:
log.Printf("syntax error")
return
}
}
}
func (s *SocketServer) handleAuth(c *textproto.Conn, line []byte) error {
func (s *SocketServer) handleAuth(ctx context.Context, arg []byte) (string, error) {
var req auth.Request
if err := s.codec.Decode(line, &req); err != nil {
return fmt.Errorf("decoding error: %v", err)
if err := s.codec.Decode(arg, &req); err != nil {
return "", fmt.Errorf("decoding error: %v", err)
}
// Set a timeout for the request.
ctx, cancel := context.WithTimeout(context.Background(), s.requestTimeout)
defer cancel()
resp := s.auth.Authenticate(ctx, &req)
return c.PrintfLine(string(s.codec.Encode(resp)))
return string(s.codec.Encode(resp)), nil
}
......@@ -9,6 +9,7 @@ import (
"testing"
"time"
"git.autistici.org/ai3/go-common/unix"
"git.autistici.org/id/auth"
"git.autistici.org/id/auth/client"
)
......@@ -20,7 +21,7 @@ func TestAuthServer_UNIX(t *testing.T) {
})
defer s.Close()
ss, err := NewSocketServer(".socket", s.srv)
ss, err := unix.NewUNIXSocketServer(".socket", unix.NewLineServer(NewSocketServer(s.srv)))
if err != nil {
t.Fatal(err)
}
......@@ -44,7 +45,7 @@ func TestAuthServer_UNIX_ReuseSocket(t *testing.T) {
})
defer s.Close()
ss, err := NewSocketServer(".socket", s.srv)
ss, err := unix.NewUNIXSocketServer(".socket", unix.NewLineServer(NewSocketServer(s.srv)))
if err != nil {
t.Fatal(err)
}
......@@ -67,7 +68,7 @@ func runMany(t testing.TB, concurrency, count int, f func(string) error) {
})
defer s.Close()
ss, err := NewSocketServer(".socket", s.srv)
ss, err := unix.NewUNIXSocketServer(".socket", unix.NewLineServer(NewSocketServer(s.srv)))
if err != nil {
t.Fatal(err)
}
......
package unix
import (
"context"
"errors"
"io"
"log"
"net"
"net/textproto"
"os"
"sync"
"sync/atomic"
"time"
"github.com/coreos/go-systemd/activation"
"github.com/theckman/go-flock"
)
// Handler for UNIX socket server connections.
type Handler interface {
ServeConnection(c net.Conn)
}
// HandlerFunc is a function adapter for Handler.
type HandlerFunc func(net.Conn)
func (f HandlerFunc) ServeConnection(c net.Conn) { f(c) }
// SocketServer accepts connections on a UNIX socket, speaking the
// line-based wire protocol, and dispatches incoming requests to the
// wrapped Server.
type SocketServer struct {
l net.Listener
lock *flock.Flock
closing atomic.Value
wg sync.WaitGroup
handler Handler
}
func newServer(l net.Listener, lock *flock.Flock, h Handler) *SocketServer {
s := &SocketServer{
l: l,
lock: lock,
handler: h,
}
s.closing.Store(false)
return s
}
// NewUNIXSocketServer returns a new SocketServer listening on the given path.
func NewUNIXSocketServer(socketPath string, h Handler) (*SocketServer, error) {
// The simplest workflow is: create a new socket, remove it on
// exit. However, if the program crashes, the socket might
// stick around and prevent the next execution from starting
// successfully. We could remove it before starting, but that
// would be dangerous if another instance was listening on
// that socket. So we wrap socket access with a file lock.
lock := flock.NewFlock(socketPath + ".lock")
locked, err := lock.TryLock()
if err != nil {
return nil, err
}
if !locked {
return nil, errors.New("socket is locked by another process")
}
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
// Always try to unlink the socket before creating it.
os.Remove(socketPath)
l, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return newServer(l, lock, h), nil
}
// NewSystemdSocketServer uses systemd socket activation, receiving
// the open socket as a file descriptor on exec.
func NewSystemdSocketServer(h Handler) (*SocketServer, error) {
listeners, err := activation.Listeners(false)
if err != nil {
return nil, err
}
// Our server loop implies a single listener, so find
// the first one passed by systemd and ignore all others.
// TODO: listen on all fds.
for _, l := range listeners {
if l != nil {
return newServer(l, nil, h), nil
}
}
return nil, errors.New("no available sockets found")
}
// Close the socket listener and release all associated resources.
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
s.closing.Store(true)
s.l.Close()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock()
}
}
func (s *SocketServer) isClosing() bool {
return s.closing.Load().(bool)
}
// Serve connections.
func (s *SocketServer) Serve() error {
for {
conn, err := s.l.Accept()
if err != nil {
if s.isClosing() {
return nil
}
return err
}
s.wg.Add(1)
go func() {
s.handler.ServeConnection(conn)
conn.Close()
s.wg.Done()
}()
}
}
// LineHandler is the handler for LineServer.
type LineHandler interface {
ServeLine(context.Context, []byte) (string, error)
}
// LineServer implements a line-based text protocol. It satisfies the
// Handler interface.
type LineServer struct {
handler LineHandler
IdleTimeout time.Duration
WriteTimeout time.Duration
RequestTimeout time.Duration
}
var (
defaultIdleTimeout = 600 * time.Second
defaultWriteTimeout = 10 * time.Second
defaultRequestTimeout = 30 * time.Second
)
// NewLineServer returns a new LineServer with the given handler and
// default I/O timeouts.
func NewLineServer(h LineHandler) *LineServer {
return &LineServer{
handler: h,
IdleTimeout: defaultIdleTimeout,
WriteTimeout: defaultWriteTimeout,
RequestTimeout: defaultRequestTimeout,
}
}
func (l *LineServer) ServeConnection(nc net.Conn) {
c := textproto.NewConn(nc)
for {
nc.SetReadDeadline(time.Now().Add(l.IdleTimeout))
line, err := c.ReadLineBytes()
if err == io.EOF {
break
} else if err != nil {
log.Printf("client error: %v", err)
break
}
// Create a context for the request and call the handler with it.
ctx, cancel := context.WithTimeout(context.Background(), l.RequestTimeout)
response, err := l.handler.ServeLine(ctx, line)
cancel()
// Close the connection on error, or on empty response.
if response != "" {
nc.SetWriteDeadline(time.Now().Add(l.WriteTimeout))
c.PrintfLine(response)
}
if err != nil {
log.Printf("request error: %v", err)
break
}
if response == "" {
break
}
}
}
......@@ -5,8 +5,8 @@
{
"checksumSHA1": "raJx5BjBbVQG0ylGSjPpi+JvqjU=",
"path": "git.autistici.org/ai3/go-common",
"revision": "86a36cf5da88919ee7d9ec12d1a92043b16fcc9c",
"revisionTime": "2017-12-10T11:04:55Z"
"revision": "3a0bd89b95cb0c323a1e067f085f72467063ed31",
"revisionTime": "2017-12-16T11:26:05Z"
},
{
"checksumSHA1": "jFlhSIit/5+VAIUu1cc7EVVlw0M=",
......@@ -20,6 +20,12 @@
"revision": "86a36cf5da88919ee7d9ec12d1a92043b16fcc9c",
"revisionTime": "2017-12-10T11:04:55Z"
},
{
"checksumSHA1": "hi5IyuRelE5KfRjd1ZOZMe0U5h8=",
"path": "git.autistici.org/ai3/go-common/unix",
"revision": "3a0bd89b95cb0c323a1e067f085f72467063ed31",
"revisionTime": "2017-12-16T11:26:05Z"
},
{
"checksumSHA1": "7Kbb9vTjqcQhhxtSGpmp9rk6PUk=",
"path": "git.autistici.org/id/usermetadb",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment