Commit 284580fc authored by ale's avatar ale

Refactor the line protocol package

Stop using net/textproto which is way more complex than what we need,
instead just use bufio directly.

Create a new 'lineproto' package with Servers and LineServers. While
this initially duplicates most of the ai3/go-common/unix API, the
duplication will allow us to deprecate the other interface once this
one is complete.
parent c977307b
Pipeline #8377 passed with stages
in 1 minute and 42 seconds
......@@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
......@@ -14,11 +15,12 @@ import (
"git.autistici.org/ai3/go-common/serverutil"
"git.autistici.org/ai3/go-common/tracing"
"git.autistici.org/ai3/go-common/unix"
"github.com/coreos/go-systemd/activation"
"github.com/coreos/go-systemd/daemon"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"git.autistici.org/id/auth/lineproto"
"git.autistici.org/id/auth/server"
)
......@@ -50,14 +52,13 @@ type genericServer interface {
Close()
}
func runServer(ctx context.Context, name string, srv genericServer) error {
func runServer(ctx context.Context, srv genericServer) error {
go func() {
<-ctx.Done()
if ctx.Err() == context.Canceled {
srv.Close()
}
}()
log.Printf("starting %s", name)
daemon.SdNotify(false, "READY=1") // nolint: errcheck
return srv.Serve()
}
......@@ -97,10 +98,6 @@ func main() {
flag.Usage = usage
flag.Parse()
if *socketPath == "" && *tcpAddr == "" && !*systemdSocketActivation {
log.Fatal("no listening sockets configured")
}
syscall.Umask(007)
config, err := server.LoadConfig(*configPath)
......@@ -120,56 +117,57 @@ func main() {
if *httpAddr != "" {
g.Go(func() error {
return runServer(ctx, "HTTP listener (metrics)", newMetricsServer())
return runServer(ctx, newMetricsServer())
})
}
srv := unix.NewLineServer(server.NewSocketServer(authSrv))
srv := lineproto.NewLineServer(server.NewSocketServer(authSrv))
srv.RequestTimeout = *requestTimeout
if *socketPath != "" || *systemdSocketActivation {
g.Go(func() error {
var sockSrv *unix.SocketServer
var servers []genericServer
if *socketPath != "" {
l, err := lineproto.NewUNIXSocketListener(*socketPath)
if err != nil {
log.Fatal(err)
}
servers = append(servers, lineproto.NewServer("unix", l, srv))
}
if *systemdSocketActivation {
ll, err := activation.Listeners()
if err != nil {
log.Fatal(err)
}
for _, l := range ll {
servers = append(servers, lineproto.NewServer("systemd", l, srv))
}
}
if *tcpAddr != "" {
var llsrv lineproto.Handler = srv
if *sslCert != "" {
var err error
if *systemdSocketActivation {
sockSrv, err = unix.NewSystemdSocketServer(srv)
} else {
sockSrv, err = unix.NewUNIXSocketServer(*socketPath, srv)
}
llsrv, err = newTLSServer(
llsrv,
&serverutil.TLSServerConfig{
Cert: *sslCert,
Key: *sslKey,
CA: *sslCA,
},
strings.Split(*sslACLs, ","),
)
if err != nil {
return err
log.Fatal(err)
}
}
return runServer(ctx, "unix socket listener", sockSrv)
})
l, err := net.Listen("tcp", *tcpAddr)
if err != nil {
log.Fatal(err)
}
servers = append(servers, lineproto.NewServer("tcp", l, llsrv))
}
if *tcpAddr != "" {
g.Go(func() error {
var h unix.Handler = srv
if *sslCert != "" {
sslH, err := newTLSServer(
h,
&serverutil.TLSServerConfig{
Cert: *sslCert,
Key: *sslKey,
CA: *sslCA,
},
strings.Split(*sslACLs, ","),
)
if err != nil {
return err
}
h = sslH
}
tcpSrv, err := newTCPServer(*tcpAddr, h)
if err != nil {
return err
}
return runServer(ctx, "tcp listener", tcpSrv)
})
if len(servers) == 0 {
log.Fatal("no sockets available for listening")
}
sigCh := make(chan os.Signal, 1)
......@@ -180,6 +178,14 @@ func main() {
}()
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
for _, s := range servers {
func(s genericServer) {
g.Go(func() error {
return runServer(ctx, s)
})
}(s)
}
err = g.Wait()
if err != nil && err != context.Canceled {
log.Fatal(err)
......
......@@ -6,51 +6,20 @@ import (
"errors"
"fmt"
"log"
"net"
"regexp"
"git.autistici.org/ai3/go-common/serverutil"
"git.autistici.org/ai3/go-common/unix"
"git.autistici.org/id/auth/lineproto"
)
type tcpServer struct {
l net.Listener
h unix.Handler
}
func newTCPServer(addr string, h unix.Handler) (*tcpServer, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &tcpServer{l: l, h: h}, nil
}
func (s *tcpServer) Serve() error {
for {
conn, err := s.l.Accept()
if err != nil {
return err
}
go func() {
s.h.ServeConnection(conn)
conn.Close()
}()
}
}
func (s *tcpServer) Close() {
s.l.Close()
}
type tlsServer struct {
h unix.Handler
h lineproto.Handler
config *tls.Config
acl []*regexp.Regexp
}
func newTLSServer(h unix.Handler, config *serverutil.TLSServerConfig, acls []string) (*tlsServer, error) {
func newTLSServer(h lineproto.Handler, config *serverutil.TLSServerConfig, acls []string) (*tlsServer, error) {
// Use serverutil.TLSServerConfig to build a tls.Config.
tlsConf, err := config.TLSConfig()
if err != nil {
......@@ -87,7 +56,7 @@ func (s *tlsServer) aclCheck(certs []*x509.Certificate) error {
return errors.New("access denied")
}
func (s *tlsServer) ServeConnection(conn net.Conn) {
func (s *tlsServer) ServeConnection(conn *lineproto.Conn) {
sconn := tls.Server(conn, s.config)
defer sconn.Close()
......@@ -101,5 +70,5 @@ func (s *tlsServer) ServeConnection(conn net.Conn) {
return
}
s.h.ServeConnection(sconn)
s.h.ServeConnection(lineproto.NewConn(sconn, conn.ServerName))
}
package lineproto
import (
"bufio"
"net"
)
type Reader struct {
r *bufio.Reader
}
func (r *Reader) ReadLine() ([]byte, error) {
var line []byte
for {
l, more, err := r.r.ReadLine()
if err != nil {
return nil, err
}
// Avoid the copy if the first call produced a full line.
if line == nil && !more {
return l, nil
}
line = append(line, l...)
if !more {
break
}
}
return line, nil
}
type Writer struct {
w *bufio.Writer
}
var crlf = []byte("\r\n")
func (w *Writer) WriteLine(args ...[]byte) error {
for _, arg := range args {
_, err := w.w.Write(arg)
if err != nil {
return err
}
}
_, err := w.w.Write(crlf)
if err != nil {
return err
}
return w.w.Flush()
}
type Conn struct {
net.Conn
*Reader
*Writer
ServerName string
}
func NewConn(c net.Conn, name string) *Conn {
return &Conn{
Conn: c,
Reader: &Reader{r: bufio.NewReader(c)},
Writer: &Writer{w: bufio.NewWriter(c)},
ServerName: name,
}
}
package unix
package lineproto
import (
"bufio"
"container/list"
"context"
"errors"
"io"
"log"
"net"
"net/textproto"
"os"
"sync"
"sync/atomic"
"time"
"github.com/coreos/go-systemd/activation"
"github.com/prometheus/client_golang/prometheus"
"github.com/theckman/go-flock"
)
// Handler for UNIX socket server connections.
type Handler interface {
ServeConnection(c net.Conn)
}
// SocketServer accepts connections on a UNIX socket, speaking the
// line-based wire protocol, and dispatches incoming requests to the
// wrapped Server.
type SocketServer struct {
l net.Listener
lock *flock.Flock
closing atomic.Value
wg sync.WaitGroup
handler Handler
// Keep track of active connections so we can shut them down
// on Close.
connMx sync.Mutex
conns list.List
}
func newServer(l net.Listener, lock *flock.Flock, h Handler) *SocketServer {
s := &SocketServer{
l: l,
lock: lock,
handler: h,
}
s.closing.Store(false)
return s
}
// NewUNIXSocketServer returns a new SocketServer listening on the given path.
func NewUNIXSocketServer(socketPath string, h Handler) (*SocketServer, error) {
// The simplest workflow is: create a new socket, remove it on
// exit. However, if the program crashes, the socket might
// stick around and prevent the next execution from starting
// successfully. We could remove it before starting, but that
// would be dangerous if another instance was listening on
// that socket. So we wrap socket access with a file lock.
lock := flock.New(socketPath + ".lock")
locked, err := lock.TryLock()
if err != nil {
return nil, err
}
if !locked {
return nil, errors.New("socket is locked by another process")
}
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
// Always try to unlink the socket before creating it.
os.Remove(socketPath) // nolint
l, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return newServer(l, lock, h), nil
}
// NewSystemdSocketServer uses systemd socket activation, receiving
// the open socket as a file descriptor on exec.
func NewSystemdSocketServer(h Handler) (*SocketServer, error) {
listeners, err := activation.Listeners()
if err != nil {
return nil, err
}
// Our server loop implies a single listener, so find
// the first one passed by systemd and ignore all others.
// TODO: listen on all fds.
for _, l := range listeners {
if l != nil {
return newServer(l, nil, h), nil
}
}
return nil, errors.New("no available sockets found")
}
// Close the socket listener and release all associated resources.
// Waits for active connections to terminate before returning.
func (s *SocketServer) Close() {
s.closing.Store(true)
// Close the listener to stop incoming connections.
s.l.Close() // nolint
// Close all active connections (this will return an error to
// the client if the connection is not idle).
s.connMx.Lock()
for el := s.conns.Front(); el != nil; el = el.Next() {
el.Value.(net.Conn).Close() // nolint
}
s.connMx.Unlock()
s.wg.Wait()
if s.lock != nil {
s.lock.Unlock() // nolint
}
}
func (s *SocketServer) isClosing() bool {
return s.closing.Load().(bool)
}
// Serve connections.
func (s *SocketServer) Serve() error {
for {
conn, err := s.l.Accept()
if err != nil {
if s.isClosing() {
return nil
}
return err
}
s.wg.Add(1)
s.connMx.Lock()
connEl := s.conns.PushBack(conn)
s.connMx.Unlock()
go func() {
s.handler.ServeConnection(conn)
conn.Close() // nolint
if !s.isClosing() {
s.connMx.Lock()
s.conns.Remove(connEl)
s.connMx.Unlock()
}
s.wg.Done()
}()
}
}
// LineHandler is the handler for LineServer.
type LineHandler interface {
ServeLine(context.Context, LineResponseWriter, []byte) error
......@@ -169,12 +22,9 @@ var ErrCloseConnection = errors.New("close")
// LineResponseWriter writes a single-line response to the underlying
// connection.
type LineResponseWriter interface {
// WriteLine writes a response (which must include the
// line terminator).
WriteLine([]byte) error
// WriteLineCRLF writes a response and adds a line terminator.
WriteLineCRLF([]byte) error
// WriteLine writes a response as a single line (the line
// terminator is added by the function).
WriteLine(...[]byte) error
}
// LineServer implements a line-based text protocol. It satisfies the
......@@ -204,39 +54,14 @@ func NewLineServer(h LineHandler) *LineServer {
}
}
var crlf = []byte{'\r', '\n'}
type lrWriter struct {
*bufio.Writer
}
func (w *lrWriter) WriteLine(data []byte) error {
if _, err := w.Writer.Write(data); err != nil {
return err
}
return w.Writer.Flush()
}
func (w *lrWriter) WriteLineCRLF(data []byte) error {
if _, err := w.Writer.Write(data); err != nil {
return err
}
if _, err := w.Writer.Write(crlf); err != nil {
return err
}
return w.Writer.Flush()
}
// ServeConnection handles a new connection. It will accept multiple
// requests on the same connection (or not, depending on the client
// preference).
func (l *LineServer) ServeConnection(nc net.Conn) {
totalConnections.Inc()
c := textproto.NewConn(nc)
rw := &lrWriter{bufio.NewWriter(nc)}
func (l *LineServer) ServeConnection(c *Conn) {
totalConnections.WithLabelValues(c.ServerName).Inc()
for {
nc.SetReadDeadline(time.Now().Add(l.IdleTimeout)) // nolint
line, err := c.ReadLineBytes()
c.Conn.SetReadDeadline(time.Now().Add(l.IdleTimeout)) // nolint
line, err := c.ReadLine()
if err == io.EOF {
break
} else if err != nil {
......@@ -249,52 +74,51 @@ func (l *LineServer) ServeConnection(nc net.Conn) {
// connection to allow the full RequestTimeout time to
// generate the response.
start := time.Now()
nc.SetWriteDeadline(start.Add(l.RequestTimeout + l.WriteTimeout)) // nolint
c.Conn.SetWriteDeadline(start.Add(l.RequestTimeout + l.WriteTimeout)) // nolint
ctx, cancel := context.WithTimeout(context.Background(), l.RequestTimeout)
err = l.handler.ServeLine(ctx, rw, line)
err = l.handler.ServeLine(ctx, c, line)
elapsedMs := time.Since(start).Nanoseconds() / 1000000
requestLatencyHist.Observe(float64(elapsedMs))
requestLatencyHist.WithLabelValues(c.ServerName).
Observe(float64(elapsedMs))
cancel()
// Close the connection on error, or on empty response.
if err != nil {
totalRequests.With(prometheus.Labels{
"status": "error",
}).Inc()
totalRequests.WithLabelValues(c.ServerName, "error").Inc()
if err != ErrCloseConnection {
log.Printf("request error: %v", err)
}
break
}
totalRequests.With(prometheus.Labels{
"status": "ok",
}).Inc()
totalRequests.WithLabelValues(c.ServerName, "ok").Inc()
}
}
// Instrumentation metrics.
var (
totalConnections = prometheus.NewCounter(
totalConnections = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "unix_connections_total",
Name: "lineproto_connections_total",
Help: "Total number of connections.",
},
[]string{"listener"},
)
totalRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "unix_requests_total",
Name: "lineproto_requests_total",
Help: "Total number of requests.",
},
[]string{"status"},
[]string{"listener", "status"},
)
// Histogram buckets are tuned for the low-milliseconds range
// (the largest bucket sits at ~1s).
requestLatencyHist = prometheus.NewHistogram(
requestLatencyHist = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "unix_requests_latency_ms",
Name: "lineproto_requests_latency_ms",
Help: "Latency of requests.",
Buckets: prometheus.ExponentialBuckets(5, 1.4142, 16),
},
[]string{"listener"},
)
)
......
package lineproto
import (
"container/list"
"net"
"sync"
"sync/atomic"
)
type Handler interface {
ServeConnection(c *Conn)
}
type Server struct {
Name string
l net.Listener
h Handler
// Keep track of active connections so we can shut them down
// on Close.
closing atomic.Value
wg sync.WaitGroup
connMx sync.Mutex
conns list.List
}
func NewServer(name string, l net.Listener, h Handler) *Server {
s := &Server{
Name: name,
l: l,
h: h,
}
s.closing.Store(false)