Commit d95488b4 authored by ale's avatar ale

Add support for TCP/SSL

parent e1eb24cc
Pipeline #8375 failed with stages
in 37 seconds
......@@ -21,10 +21,11 @@ backends such as Memcached for short-term storage, and
anonymized user activity data. For this reason, it is recommended to
install an auth-server on every host.
It listens for authorization requests over a UNIX socket. UNIX
permissions should be used to control access to the socket if
necessary. Clients speak a custom simple line-based attribute/value
protocol, and can send multiple requests over the same connection.
The authentication protocol is a simple line-based text protocol. The
auth-server can listen on a UNIX or TCP socket: in the first case,
filesystem permissions should be used to control access to the socket,
while in the second case there is support for SSL, with optional
checks on the provided client certificates.
## Services
......@@ -370,6 +371,9 @@ add specific users to it easily.
The daemon can run either standalone or be socket-activated by
systemd, which is what the Debian package does.
Check out the output of *auth-server --help* for documentation on how
to configure the listening sockets.
## Wire protocol
The rationale behind the wire protocol ("why not http?") is twofold:
......
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"git.autistici.org/ai3/go-common/serverutil"
"git.autistici.org/ai3/go-common/tracing"
"git.autistici.org/ai3/go-common/unix"
"github.com/coreos/go-systemd/daemon"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"git.autistici.org/id/auth/server"
)
......@@ -23,7 +27,13 @@ var (
socketPath = flag.String("socket", "/run/auth/socket", "`path` to the UNIX socket to listen on")
systemdSocketActivation = flag.Bool("systemd-socket", false, "use SystemD socket activation")
httpAddr = flag.String("http-addr", "", "if not nil, bind an HTTP server to this `addr` for Prometheus metrics")
tcpAddr = flag.String("addr", "", "listen on this TCP address")
requestTimeout = flag.Duration("timeout", 5*time.Second, "timeout for incoming requests")
sslCert = flag.String("ssl-cert", "", "SSL certificate `file`")
sslKey = flag.String("ssl-key", "", "SSL private key `file`")
sslCA = flag.String("ssl-ca", "", "SSL CA `file` (requires client certificates)")
sslACLs = flag.String("ssl-acl", "", "SSL access control lists (comma-separated list of regexps matching CN)")
)
func usage() {
......@@ -35,11 +45,62 @@ Known options:
flag.PrintDefaults()
}
type genericServer interface {
Serve() error
Close()
}
func runServer(ctx context.Context, name string, srv genericServer) error {
go func() {
<-ctx.Done()
if ctx.Err() == context.Canceled {
srv.Close()
}
}()
log.Printf("starting %s", name)
daemon.SdNotify(false, "READY=1") // nolint: errcheck
return srv.Serve()
}
type metricsServer struct {
*http.Server
}
func newMetricsServer() *metricsServer {
h := http.NewServeMux()
h.Handle("/metrics", promhttp.Handler())
return &metricsServer{
Server: &http.Server{
Addr: *httpAddr,
Handler: h,
ReadTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
WriteTimeout: 10 * time.Second,
},
}
}
func (s *metricsServer) Serve() error {
err := s.Server.ListenAndServe()
if err == http.ErrServerClosed {
err = nil
}
return err
}
func (s *metricsServer) Close() {
s.Server.Close() // nolint: errcheck
}
func main() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
if *socketPath == "" && *tcpAddr == "" && !*systemdSocketActivation {
log.Fatal("no listening sockets configured")
}
syscall.Umask(007)
config, err := server.LoadConfig(*configPath)
......@@ -54,42 +115,73 @@ func main() {
tracing.Init()
outerCtx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(outerCtx)
if *httpAddr != "" {
h := http.NewServeMux()
h.Handle("/metrics", promhttp.Handler())
go func() {
log.Fatal(http.ListenAndServe(*httpAddr, h))
}()
g.Go(func() error {
return runServer(ctx, "HTTP listener (metrics)", newMetricsServer())
})
}
srv := unix.NewLineServer(server.NewSocketServer(authSrv))
srv.RequestTimeout = *requestTimeout
var sockSrv *unix.SocketServer
if *systemdSocketActivation {
sockSrv, err = unix.NewSystemdSocketServer(srv)
} else {
sockSrv, err = unix.NewUNIXSocketServer(*socketPath, srv)
if *socketPath != "" || *systemdSocketActivation {
g.Go(func() error {
var sockSrv *unix.SocketServer
var err error
if *systemdSocketActivation {
sockSrv, err = unix.NewSystemdSocketServer(srv)
} else {
sockSrv, err = unix.NewUNIXSocketServer(*socketPath, srv)
}
if err != nil {
return err
}
return runServer(ctx, "unix socket listener", sockSrv)
})
}
if err != nil {
log.Fatalf("error: %v", err)
if *tcpAddr != "" {
g.Go(func() error {
var h unix.Handler = srv
if *sslCert != "" {
sslH, err := newTLSServer(
h,
&serverutil.TLSServerConfig{
Cert: *sslCert,
Key: *sslKey,
CA: *sslCA,
},
strings.Split(*sslACLs, ","),
)
if err != nil {
return err
}
h = sslH
}
tcpSrv, err := newTCPServer(*tcpAddr, h)
if err != nil {
return err
}
return runServer(ctx, "tcp listener", tcpSrv)
})
}
done := make(chan struct{})
sigCh := make(chan os.Signal, 1)
go func() {
<-sigCh
log.Printf("terminating")
sockSrv.Close()
close(done)
cancel()
}()
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
log.Printf("starting")
daemon.SdNotify(false, "READY=1")
if err := sockSrv.Serve(); err != nil {
err = g.Wait()
if err != nil && err != context.Canceled {
log.Fatal(err)
}
<-done
}
package main
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"net"
"regexp"
"git.autistici.org/ai3/go-common/serverutil"
"git.autistici.org/ai3/go-common/unix"
)
type tcpServer struct {
l net.Listener
h unix.Handler
}
func newTCPServer(addr string, h unix.Handler) (*tcpServer, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &tcpServer{l: l, h: h}, nil
}
func (s *tcpServer) Serve() error {
for {
conn, err := s.l.Accept()
if err != nil {
return err
}
go func() {
s.h.ServeConnection(conn)
conn.Close()
}()
}
}
func (s *tcpServer) Close() {
s.l.Close()
}
type tlsServer struct {
h unix.Handler
config *tls.Config
acl []*regexp.Regexp
}
func newTLSServer(h unix.Handler, config *serverutil.TLSServerConfig, acls []string) (*tlsServer, error) {
// Use serverutil.TLSServerConfig to build a tls.Config.
tlsConf, err := config.TLSConfig()
if err != nil {
return nil, err
}
s := &tlsServer{
h: h,
config: tlsConf,
}
for _, acl := range acls {
rx, err := regexp.Compile(acl)
if err != nil {
return nil, fmt.Errorf("acl error: %v", err)
}
s.acl = append(s.acl, rx)
}
return s, nil
}
func (s *tlsServer) aclCheck(certs []*x509.Certificate) error {
if len(s.acl) == 0 {
return nil
}
if len(certs) < 1 {
return errors.New("no certificate was provided")
}
for _, acl := range s.acl {
if acl.MatchString(certs[0].Subject.CommonName) {
return nil
}
}
return errors.New("access denied")
}
func (s *tlsServer) ServeConnection(conn net.Conn) {
sconn := tls.Server(conn, s.config)
defer sconn.Close()
if err := sconn.Handshake(); err != nil {
log.Printf("tls handshake error: %v", err)
return
}
if err := s.aclCheck(sconn.ConnectionState().PeerCertificates); err != nil {
log.Printf("acl error: %v", err)
return
}
s.h.ServeConnection(sconn)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment