Skip to content
Snippets Groups Projects
Select Git revision
1 result Searching

localauditd.go

Blame
  • localauditd.go 5.82 KiB
    package main
    
    import (
    	"bytes"
    	"crypto/tls"
    	"encoding/hex"
    	"errors"
    	"flag"
    	"fmt"
    	"io"
    	"io/ioutil"
    	"log"
    	"net"
    	"net/http"
    	"os"
    	"path"
    	"strings"
    	"time"
    
    	"git.autistici.org/ai/audit"
    )
    
    var (
    	sslCa      = flag.String("ssl-ca", "/etc/ai/internal_ca.pem", "SSL CA file")
    	sslCert    = flag.String("ssl-cert", "/etc/ai/localhost_internal.pem", "SSL certificate file")
    	sslKey     = flag.String("ssl-key", "/etc/ai/localhost_internal.key", "SSL private key file")
    	spoolDir   = flag.String("spool-dir", "/var/spool/audit/incoming", "Path to the spool directory")
    	serverUrl  = flag.String("server", "https://logs.m.investici.org:1717", "URL for the main audit server")
    	socketPath = flag.String("socket", "/var/run/audit/local", "Path to the local socket to listen on")
    )
    
    // Local store-and-forward audit message server.
    type localServer struct {
    	spool *Spool
    	http  *http.Client
    	url   string
    }
    
    func dialTimeout(network, addr string) (net.Conn, error) {
    	return net.DialTimeout(network, addr, time.Duration(30*time.Second))
    }
    
    func newLocalServer(spoolPath string, serverUrl string, tlsConf *tls.Config) *localServer {
    	// Interface to the local spool.
    	spool := NewSpool(spoolPath)
    	if !spool.IsWritable() {
    		log.Fatal("spool path is not writable")
    	}
    
    	// Create an HTTP transport with a connection timeout.
    	timeoutTransport := &http.Transport{
    		Dial:            dialTimeout,
    		TLSClientConfig: tlsConf,
    	}
    	httpClient := &http.Client{
    		Transport: timeoutTransport,
    	}
    
    	l := &localServer{
    		spool: spool,
    		http:  httpClient,
    		url:   serverUrl,
    	}
    	go l.backgroundFlusher()
    	return l
    }
    
    func (l *localServer) Serve(path string) error {
    	// Remove the UNIX socket, or we won't be able to bind successfully.
    	os.Remove(path)
    
    	// Bind to the specified UNIX socket.
    	uaddr, err := net.ResolveUnixAddr("unix", path)
    	if err != nil {
    		return err
    	}
    	s, err := net.ListenUnix("unix", uaddr)
    	if err != nil {
    		return err
    	}
    	defer s.Close()
    
    	// Accept connections and handle them.
    	for {
    		conn, err := s.Accept()
    		if err != nil {
    			log.Fatalf("Accept() error: %s", err)
    		}
    		go l.handleConnection(conn)
    	}
    
    	return nil
    }
    
    func (l *localServer) handleConnection(conn net.Conn) {
    	if err := l.handleRequest(conn); err != nil {
    		fmt.Fprintf(conn, "ERR %s\n", err.Error())
    	} else {
    		io.WriteString(conn, "OK\n")
    	}
    
    	conn.Close()
    }
    
    func (l *localServer) handleRequest(conn net.Conn) error {
    	data, err := ioutil.ReadAll(conn)
    	if err != nil {
    		return err
    	}
    
    	if err := l.forward(data); err != nil {
    		if err := l.store(data); err != nil {
    			return err
    		}
    	}
    	return nil
    }
    
    func (l *localServer) store(data []byte) error {
    	return l.spool.Add(data)
    }
    
    func (l *localServer) forward(data []byte) error {
    	// Attempt to forward the request to the remote audit server.
    	// We must be careful to detect temporary (connection errors,
    	// timeouts, etc) vs. permanent (validation) errors, and
    	// destroy the message in the latter case to avoid infinite
    	// retry loops for 'broken' messages.
    	resp, err := l.http.Post(l.url, "application/json", bytes.NewBuffer(data))
    	if err != nil {
    		return err
    	}
    	defer resp.Body.Close()
    
    	// Error code 400 means that there are problems with the
    	// message itself: pretend success so as not to retry the
    	// request later.
    	if resp.StatusCode != 200 && resp.StatusCode != 400 {
    		return fmt.Errorf("HTTP Error: %s", resp.Status)
    	}
    	return nil
    }
    
    func (l *localServer) backgroundFlusher() {
    	var sleep time.Duration
    
    	// Try to flush the pending entries in our local spool. If
    	// there are errors, retry after a shorter period of time.
    	for {
    		if err := l.spool.Flush(l.forward); err != nil {
    			log.Printf("flush() failed: %s", err)
    			sleep = time.Duration(300 * time.Second)
    		} else {
    			sleep = time.Duration(1800 * time.Second)
    		}
    		time.Sleep(sleep)
    	}
    }
    
    type Spool struct {
    	path string
    }
    
    func NewSpool(path string) *Spool {
    	return &Spool{path: path}
    }
    
    // Test if the spool directory is writable.
    func (s *Spool) IsWritable() bool {
    	testpath := path.Join(s.path, ".write_test")
    	defer os.Remove(testpath)
    	if err := ioutil.WriteFile(testpath, []byte("ok"), 0600); err != nil {
    		return false
    	}
    	return true
    }
    
    // Add an entry to the spool.
    func (s *Spool) Add(data []byte) error {
    	id := hex.EncodeToString(audit.NewUniqueId(time.Now()))
    	filename := path.Join(s.path, id)
    	tmpname := path.Join(s.path, fmt.Sprintf(".%s.tmp", id))
    	defer os.Remove(tmpname)
    
    	// Do an atomic write by creating a temporary file first, and
    	// then renaming it to the final destination.
    	if err := ioutil.WriteFile(tmpname, data, 0600); err != nil {
    		return err
    	}
    	return os.Rename(tmpname, filename)
    }
    
    // Try to flush entries in the spool, applying the specified function.
    // If it is successful, entries are removed.
    func (s *Spool) Flush(fn func([]byte) error) error {
    	dir, err := os.Open(s.path)
    	if err != nil {
    		return err
    	}
    	defer dir.Close()
    
    	// Scan the entries in the directory in batches. Abort if we
    	// reach a certain number of errors.
    	errs := 0
    	maxErrs := 5
    	for {
    		filenames, err := dir.Readdirnames(100)
    		if err != nil {
    			break
    		}
    
    		// For each file in the spool, read its data, call the
    		// flush function, and if it is successful, remove the
    		// file.
    		for _, filename := range filenames {
    			if strings.HasPrefix(filename, ".") {
    				continue
    			}
    			filepath := path.Join(s.path, filename)
    			data, err := ioutil.ReadFile(filepath)
    			if err != nil {
    				log.Printf("error reading %s: %s", filename, err)
    				continue
    			}
    			if err := fn(data); err != nil {
    				errs++
    				if errs > maxErrs {
    					log.Printf("aborting scan: too many errors")
    					return errors.New("too many errors")
    				}
    			}
    			os.Remove(filepath)
    		}
    	}
    
    	return nil
    }
    
    func main() {
    	flag.Parse()
    
    	tlsConf := audit.TLSClientAuthConfigWithCerts(*sslCa, *sslCert, *sslKey)
    	locald := newLocalServer(*spoolDir, *serverUrl+"/api/1/write", tlsConf)
    	log.Fatal(locald.Serve(*socketPath))
    }