Commit 4014dc50 authored by ale's avatar ale

Merge branch 'tcp' into 'master'

Add support for TCP/SSL

See merge request !7
parents da156e68 eb4feb3c
Pipeline #8996 passed with stages
in 1 minute and 48 seconds
......@@ -21,10 +21,11 @@ backends such as Memcached for short-term storage, and
anonymized user activity data. For this reason, it is recommended to
install an auth-server on every host.
It listens for authorization requests over a UNIX socket. UNIX
permissions should be used to control access to the socket if
necessary. Clients speak a custom simple line-based attribute/value
protocol, and can send multiple requests over the same connection.
The authentication protocol is a simple line-based text protocol. The
auth-server can listen on a UNIX or TCP socket: in the first case,
filesystem permissions should be used to control access to the socket,
while in the second case there is support for SSL, with optional
checks on the provided client certificates.
## Services
......@@ -370,6 +371,9 @@ add specific users to it easily.
The daemon can run either standalone or be socket-activated by
systemd, which is what the Debian package does.
Check out the output of *auth-server --help* for documentation on how
to configure the listening sockets.
## Wire protocol
The rationale behind the wire protocol ("why not http?") is twofold:
......
......@@ -3,15 +3,19 @@ package client
import (
"context"
"net"
"net/textproto"
"strings"
"github.com/cenkalti/backoff"
"go.opencensus.io/trace"
"git.autistici.org/id/auth"
"git.autistici.org/id/auth/lineproto"
)
var DefaultSocketPath = "/run/auth/socket"
var (
DefaultSocketPath = "/run/auth/socket"
DefaultPoolSize = 3
)
type Client interface {
Authenticate(context.Context, *auth.Request) (*auth.Response, error)
......@@ -20,12 +24,20 @@ type Client interface {
type socketClient struct {
socketPath string
codec auth.Codec
pool *Pool
}
func New(socketPath string) Client {
return &socketClient{
socketPath: socketPath,
codec: auth.DefaultCodec,
pool: NewPool(func() (*lineproto.Conn, error) {
c, err := net.Dial("unix", socketPath)
if err != nil {
return nil, err
}
return lineproto.NewConn(c, ""), nil
}, DefaultPoolSize),
}
}
......@@ -47,6 +59,8 @@ func (c *socketClient) Authenticate(ctx context.Context, req *auth.Request) (*au
resp, err = c.doAuthenticate(sctx, req)
if err == nil {
return nil
} else if strings.Contains(err.Error(), "use of closed network connection") {
return err
} else if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
return netErr
}
......@@ -59,52 +73,45 @@ func (c *socketClient) Authenticate(ctx context.Context, req *auth.Request) (*au
}
func (c *socketClient) doAuthenticate(ctx context.Context, req *auth.Request) (*auth.Response, error) {
// Create the connection outside of the timed goroutine, so
// that we can call Close() on exit regardless of the reason:
// this way, when a timeout occurs or the context is canceled,
// the pending request terminates immediately.
conn, err := textproto.Dial("unix", c.socketPath)
if err != nil {
return nil, err
}
defer conn.Close()
// Make space in the channel for at least one element, or we
// will leak a goroutine whenever the authentication request
// times out.
done := make(chan error, 1)
var resp auth.Response
go func() {
defer close(done)
// Write the auth command to the connection.
if err := conn.PrintfLine("auth %s", string(c.codec.Encode(req))); err != nil {
done <- err
return
err := c.pool.WithConn(func(conn *lineproto.Conn) error {
// Make space in the channel for at least one element, or we
// will leak a goroutine whenever the authentication request
// times out.
done := make(chan error, 1)
go func() {
defer close(done)
// Write the auth command to the connection.
if err := conn.WriteLine([]byte("auth "), c.codec.Encode(req)); err != nil {
done <- err
return
}
// Read the response.
line, err := conn.ReadLine()
if err != nil {
done <- err
return
}
if err := c.codec.Decode(line, &resp); err != nil {
done <- err
return
}
done <- nil
}()
// Wait for the call to terminate, or the context to time out,
// whichever happens first.
select {
case err := <-done:
return err
case <-ctx.Done():
return ctx.Err()
}
// Read the response.
line, err := conn.ReadLineBytes()
if err != nil {
done <- err
return
}
if err := c.codec.Decode(line, &resp); err != nil {
done <- err
return
}
done <- nil
}()
// Wait for the call to terminate, or the context to time out,
// whichever happens first.
select {
case err := <-done:
return &resp, err
case <-ctx.Done():
return nil, ctx.Err()
}
})
return &resp, err
}
func responseToTraceStatus(resp *auth.Response, err error) trace.Status {
......
package client
import (
"git.autistici.org/id/auth/lineproto"
)
type PoolDialer func() (*lineproto.Conn, error)
type Pool struct {
ch chan *lineproto.Conn
dialer PoolDialer
}
func NewPool(dialer PoolDialer, size int) *Pool {
return &Pool{
ch: make(chan *lineproto.Conn, size),
dialer: dialer,
}
}
func (p *Pool) WithConn(f func(*lineproto.Conn) error) error {
// Acquire a connection.
var conn *lineproto.Conn
select {
case conn = <-p.ch:
default:
var err error
conn, err = p.dialer()
if err != nil {
return err
}
}
// Run the function and inspect its return value.
err := f(conn)
if err != nil {
conn.Close()
} else {
select {
case p.ch <- conn:
default:
conn.Close()
}
}
return err
}
package main
import (
"bytes"
"context"
"crypto/tls"
"errors"
"flag"
"log"
"net"
"os"
"os/signal"
"syscall"
common "git.autistici.org/ai3/go-common"
"git.autistici.org/id/auth"
"git.autistici.org/id/auth/client"
"git.autistici.org/id/auth/lineproto"
"github.com/coreos/go-systemd/activation"
"github.com/coreos/go-systemd/daemon"
"golang.org/x/sync/errgroup"
)
var (
socketPath = flag.String("socket", "", "`path` to the UNIX socket to listen on")
systemdSocketActivation = flag.Bool("systemd-socket", false, "use SystemD socket activation")
upstreamAddr = flag.String("upstream", "", "upstream address (host:port)")
sslCert = flag.String("ssl-cert", "", "SSL certificate `file`")
sslKey = flag.String("ssl-key", "", "SSL private key `file`")
sslCA = flag.String("ssl-ca", "", "SSL CA `file` (enables client TLS)")
)
func buildTLSConfig() (*tls.Config, error) {
if *sslCA == "" {
return nil, nil
}
cas, err := common.LoadCA(*sslCA)
if err != nil {
return nil, err
}
tlsConf := &tls.Config{
RootCAs: cas,
}
if *sslCert != "" && *sslKey != "" {
cert, err := tls.LoadX509KeyPair(*sslCert, *sslKey)
if err != nil {
return nil, err
}
tlsConf.Certificates = []tls.Certificate{cert}
}
return tlsConf, nil
}
type proxyServer struct {
codec auth.Codec
pool *client.Pool
}
var (
authCmd = []byte("auth ")
quitCmd = []byte("quit")
)
func newProxyServer(upstream string, tlsConfig *tls.Config) *proxyServer {
var dialer func() (*lineproto.Conn, error)
if tlsConfig != nil {
dialer = func() (*lineproto.Conn, error) {
c, err := tls.Dial("tcp", upstream, tlsConfig)
if err != nil {
return nil, err
}
return lineproto.NewConn(c, ""), nil
}
} else {
dialer = func() (*lineproto.Conn, error) {
c, err := net.Dial("tcp", upstream)
if err != nil {
return nil, err
}
return lineproto.NewConn(c, ""), nil
}
}
return &proxyServer{
codec: auth.DefaultCodec,
pool: client.NewPool(dialer, client.DefaultPoolSize),
}
}
func (s *proxyServer) ServeLine(ctx context.Context, lw lineproto.LineResponseWriter, line []byte) error {
if bytes.HasPrefix(line, quitCmd) {
return lineproto.ErrCloseConnection
}
if bytes.HasPrefix(line, authCmd) {
var resp []byte
err := s.pool.WithConn(func(conn *lineproto.Conn) (err error) {
if err = conn.WriteLine(line); err != nil {
return
}
resp, err = conn.ReadLine()
return
})
if err != nil {
return err
}
return lw.WriteLine(resp)
}
return errors.New("syntax error")
}
type genericServer interface {
Serve() error
Close()
}
func runServer(ctx context.Context, srv genericServer) error {
go func() {
<-ctx.Done()
if ctx.Err() == context.Canceled {
srv.Close()
}
}()
daemon.SdNotify(false, "READY=1") // nolint: errcheck
return srv.Serve()
}
func main() {
log.SetFlags(0)
flag.Parse()
syscall.Umask(007)
tlsConfig, err := buildTLSConfig()
if err != nil {
log.Fatal(err)
}
proxySrv := newProxyServer(*upstreamAddr, tlsConfig)
srv := lineproto.NewLineServer(proxySrv)
outerCtx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(outerCtx)
var servers []genericServer
if *systemdSocketActivation {
ll, err := activation.Listeners()
if err != nil {
log.Fatal(err)
}
for _, l := range ll {
servers = append(servers, lineproto.NewServer("systemd", l, srv))
}
}
if *socketPath != "" {
l, err := lineproto.NewUNIXSocketListener(*socketPath)
if err != nil {
log.Fatal(err)
}
servers = append(servers, lineproto.NewServer("unix", l, srv))
}
if len(servers) == 0 {
log.Fatal("no sockets available for listening")
}
sigCh := make(chan os.Signal, 1)
go func() {
<-sigCh
log.Printf("terminating")
cancel()
}()
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
for _, s := range servers {
func(s genericServer) {
g.Go(func() error {
return runServer(ctx, s)
})
}(s)
}
err = g.Wait()
if err != nil && err != context.Canceled {
log.Fatal(err)
}
}
package main
import (
"context"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"git.autistici.org/ai3/go-common/serverutil"
"git.autistici.org/ai3/go-common/tracing"
"git.autistici.org/ai3/go-common/unix"
"github.com/coreos/go-systemd/activation"
"github.com/coreos/go-systemd/daemon"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"git.autistici.org/id/auth/lineproto"
"git.autistici.org/id/auth/server"
)
......@@ -23,7 +29,13 @@ var (
socketPath = flag.String("socket", "/run/auth/socket", "`path` to the UNIX socket to listen on")
systemdSocketActivation = flag.Bool("systemd-socket", false, "use SystemD socket activation")
httpAddr = flag.String("http-addr", "", "if not nil, bind an HTTP server to this `addr` for Prometheus metrics")
tcpAddr = flag.String("addr", "", "listen on this TCP address")
requestTimeout = flag.Duration("timeout", 5*time.Second, "timeout for incoming requests")
sslCert = flag.String("ssl-cert", "", "SSL certificate `file`")
sslKey = flag.String("ssl-key", "", "SSL private key `file`")
sslCA = flag.String("ssl-ca", "", "SSL CA `file` (requires client certificates)")
sslACLs = flag.String("ssl-acl", "", "SSL access control lists (comma-separated list of regexps matching CN)")
)
func usage() {
......@@ -35,6 +47,52 @@ Known options:
flag.PrintDefaults()
}
type genericServer interface {
Serve() error
Close()
}
func runServer(ctx context.Context, srv genericServer) error {
go func() {
<-ctx.Done()
if ctx.Err() == context.Canceled {
srv.Close()
}
}()
daemon.SdNotify(false, "READY=1") // nolint: errcheck
return srv.Serve()
}
type metricsServer struct {
*http.Server
}
func newMetricsServer() *metricsServer {
h := http.NewServeMux()
h.Handle("/metrics", promhttp.Handler())
return &metricsServer{
Server: &http.Server{
Addr: *httpAddr,
Handler: h,
ReadTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
WriteTimeout: 10 * time.Second,
},
}
}
func (s *metricsServer) Serve() error {
err := s.Server.ListenAndServe()
if err == http.ErrServerClosed {
err = nil
}
return err
}
func (s *metricsServer) Close() {
s.Server.Close() // nolint: errcheck
}
func main() {
log.SetFlags(0)
flag.Usage = usage
......@@ -54,42 +112,86 @@ func main() {
tracing.Init()
outerCtx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(outerCtx)
if *httpAddr != "" {
h := http.NewServeMux()
h.Handle("/metrics", promhttp.Handler())
go func() {
log.Fatal(http.ListenAndServe(*httpAddr, h))
}()
g.Go(func() error {
return runServer(ctx, newMetricsServer())
})
}
srv := unix.NewLineServer(server.NewSocketServer(authSrv))
srv := lineproto.NewLineServer(server.NewSocketServer(authSrv))
srv.RequestTimeout = *requestTimeout
var sockSrv *unix.SocketServer
var servers []genericServer
// For legacy reasons, systemd-socket takes precedence over socket.
if *systemdSocketActivation {
sockSrv, err = unix.NewSystemdSocketServer(srv)
} else {
sockSrv, err = unix.NewUNIXSocketServer(*socketPath, srv)
ll, err := activation.Listeners()
if err != nil {
log.Fatal(err)
}
for _, l := range ll {
servers = append(servers, lineproto.NewServer("systemd", l, srv))
}
} else if *socketPath != "" {
l, err := lineproto.NewUNIXSocketListener(*socketPath)
if err != nil {
log.Fatal(err)
}
servers = append(servers, lineproto.NewServer("unix", l, srv))
}
if err != nil {
log.Fatalf("error: %v", err)
if *tcpAddr != "" {
var llsrv lineproto.Handler = srv
if *sslCert != "" {
var acls []string
if *sslACLs != "" {
acls = strings.Split(*sslACLs, ",")
}
var err error
llsrv, err = newTLSServer(
llsrv,
&serverutil.TLSServerConfig{
Cert: *sslCert,
Key: *sslKey,
CA: *sslCA,
},
acls,
)
if err != nil {
log.Fatal(err)
}
}
l, err := net.Listen("tcp", *tcpAddr)
if err != nil {
log.Fatal(err)
}
servers = append(servers, lineproto.NewServer("tcp", l, llsrv))
}
if len(servers) == 0 {
log.Fatal("no sockets available for listening")
}
done := make(chan struct{})
sigCh := make(chan os.Signal, 1)
go func() {
<-sigCh
log.Printf("terminating")
sockSrv.Close()
close(done)
cancel()
}()
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
log.Printf("starting")
daemon.SdNotify(false, "READY=1")
if err := sockSrv.Serve(); err != nil {
log.Fatal(err)
for _, s := range servers {
func(s genericServer) {
g.Go(func() error {
return runServer(ctx, s)
})
}(s)
}
<-done
err = g.Wait()
if err != nil && err != context.Canceled {
log.Fatal(err)
}
}
package main
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"regexp"