Select Git revision
localauditd.go 5.82 KiB
package main
import (
"bytes"
"crypto/tls"
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path"
"strings"
"time"
"git.autistici.org/ai/audit"
)
var (
sslCa = flag.String("ssl-ca", "/etc/ai/internal_ca.pem", "SSL CA file")
sslCert = flag.String("ssl-cert", "/etc/ai/localhost_internal.pem", "SSL certificate file")
sslKey = flag.String("ssl-key", "/etc/ai/localhost_internal.key", "SSL private key file")
spoolDir = flag.String("spool-dir", "/var/spool/audit/incoming", "Path to the spool directory")
serverUrl = flag.String("server", "https://logs.m.investici.org:1717", "URL for the main audit server")
socketPath = flag.String("socket", "/var/run/audit/local", "Path to the local socket to listen on")
)
// Local store-and-forward audit message server.
type localServer struct {
spool *Spool
http *http.Client
url string
}
func dialTimeout(network, addr string) (net.Conn, error) {
return net.DialTimeout(network, addr, time.Duration(30*time.Second))
}
func newLocalServer(spoolPath string, serverUrl string, tlsConf *tls.Config) *localServer {
// Interface to the local spool.
spool := NewSpool(spoolPath)
if !spool.IsWritable() {
log.Fatal("spool path is not writable")
}
// Create an HTTP transport with a connection timeout.
timeoutTransport := &http.Transport{
Dial: dialTimeout,
TLSClientConfig: tlsConf,
}
httpClient := &http.Client{
Transport: timeoutTransport,
}
l := &localServer{
spool: spool,
http: httpClient,
url: serverUrl,
}
go l.backgroundFlusher()
return l
}
func (l *localServer) Serve(path string) error {
// Remove the UNIX socket, or we won't be able to bind successfully.
os.Remove(path)
// Bind to the specified UNIX socket.
uaddr, err := net.ResolveUnixAddr("unix", path)
if err != nil {
return err
}
s, err := net.ListenUnix("unix", uaddr)
if err != nil {
return err
}
defer s.Close()
// Accept connections and handle them.
for {
conn, err := s.Accept()
if err != nil {
log.Fatalf("Accept() error: %s", err)
}
go l.handleConnection(conn)
}
return nil
}
func (l *localServer) handleConnection(conn net.Conn) {
if err := l.handleRequest(conn); err != nil {
fmt.Fprintf(conn, "ERR %s\n", err.Error())
} else {
io.WriteString(conn, "OK\n")
}
conn.Close()
}
func (l *localServer) handleRequest(conn net.Conn) error {
data, err := ioutil.ReadAll(conn)
if err != nil {
return err
}
if err := l.forward(data); err != nil {
if err := l.store(data); err != nil {
return err
}
}
return nil
}
func (l *localServer) store(data []byte) error {
return l.spool.Add(data)
}
func (l *localServer) forward(data []byte) error {
// Attempt to forward the request to the remote audit server.
// We must be careful to detect temporary (connection errors,
// timeouts, etc) vs. permanent (validation) errors, and
// destroy the message in the latter case to avoid infinite
// retry loops for 'broken' messages.
resp, err := l.http.Post(l.url, "application/json", bytes.NewBuffer(data))
if err != nil {
return err
}
defer resp.Body.Close()
// Error code 400 means that there are problems with the
// message itself: pretend success so as not to retry the
// request later.
if resp.StatusCode != 200 && resp.StatusCode != 400 {
return fmt.Errorf("HTTP Error: %s", resp.Status)
}
return nil
}
func (l *localServer) backgroundFlusher() {
var sleep time.Duration
// Try to flush the pending entries in our local spool. If
// there are errors, retry after a shorter period of time.
for {
if err := l.spool.Flush(l.forward); err != nil {
log.Printf("flush() failed: %s", err)
sleep = time.Duration(300 * time.Second)
} else {
sleep = time.Duration(1800 * time.Second)
}
time.Sleep(sleep)
}
}
type Spool struct {
path string
}
func NewSpool(path string) *Spool {
return &Spool{path: path}
}
// Test if the spool directory is writable.
func (s *Spool) IsWritable() bool {
testpath := path.Join(s.path, ".write_test")
defer os.Remove(testpath)
if err := ioutil.WriteFile(testpath, []byte("ok"), 0600); err != nil {
return false
}
return true
}
// Add an entry to the spool.
func (s *Spool) Add(data []byte) error {
id := hex.EncodeToString(audit.NewUniqueId(time.Now()))
filename := path.Join(s.path, id)
tmpname := path.Join(s.path, fmt.Sprintf(".%s.tmp", id))
defer os.Remove(tmpname)
// Do an atomic write by creating a temporary file first, and
// then renaming it to the final destination.
if err := ioutil.WriteFile(tmpname, data, 0600); err != nil {
return err
}
return os.Rename(tmpname, filename)
}
// Try to flush entries in the spool, applying the specified function.
// If it is successful, entries are removed.
func (s *Spool) Flush(fn func([]byte) error) error {
dir, err := os.Open(s.path)
if err != nil {
return err
}
defer dir.Close()
// Scan the entries in the directory in batches. Abort if we
// reach a certain number of errors.
errs := 0
maxErrs := 5
for {
filenames, err := dir.Readdirnames(100)
if err != nil {
break
}
// For each file in the spool, read its data, call the
// flush function, and if it is successful, remove the
// file.
for _, filename := range filenames {
if strings.HasPrefix(filename, ".") {
continue
}
filepath := path.Join(s.path, filename)
data, err := ioutil.ReadFile(filepath)
if err != nil {
log.Printf("error reading %s: %s", filename, err)
continue
}
if err := fn(data); err != nil {
errs++
if errs > maxErrs {
log.Printf("aborting scan: too many errors")
return errors.New("too many errors")
}
}
os.Remove(filepath)
}
}
return nil
}
func main() {
flag.Parse()
tlsConf := audit.TLSClientAuthConfigWithCerts(*sslCa, *sslCert, *sslKey)
locald := newLocalServer(*spoolDir, *serverUrl+"/api/1/write", tlsConf)
log.Fatal(locald.Serve(*socketPath))
}