From 9446198149d34d9d3329a89b43eacf2ecd93b343 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Mon, 19 Apr 2021 23:15:38 +0100
Subject: [PATCH] Implement IP masking in the database

Make it possible to apply an IP mask (different for v4 and v6) to all
addresses stored and retrieved from the database. This is transparent
to the API (you can still send and query individual IPs).

The default behavior becomes to aggregate away IPv6 addresses to a /64
network, while there is no change to IPv4 handling.
---
 cmd/iprep/server.go   |  8 +++++++-
 db/db.go              |  6 +++---
 db/db_test.go         |  5 +++--
 db/leveldb/leveldb.go | 22 +++++++++++++---------
 db/sqlite/driver.go   | 13 ++++++++-----
 proto/iprep.go        | 29 +++++++++++++++++++++++++++++
 proto/iprep_test.go   | 30 ++++++++++++++++++++++++++++++
 server/server.go      |  4 ++--
 8 files changed, 95 insertions(+), 22 deletions(-)

diff --git a/cmd/iprep/server.go b/cmd/iprep/server.go
index c770d2d..b1a6d6f 100644
--- a/cmd/iprep/server.go
+++ b/cmd/iprep/server.go
@@ -37,6 +37,8 @@ type serverCommand struct {
 	externalSrcDir string
 	tlsCert        string
 	tlsKey         string
+	maskIPv4Bits   int
+	maskIPv6Bits   int
 }
 
 func (c *serverCommand) Name() string     { return "server" }
@@ -58,6 +60,8 @@ func (c *serverCommand) SetFlags(f *flag.FlagSet) {
 
 	f.StringVar(&c.tlsCert, "tls-cert", "", "TLS certificate `path` (grpc only)")
 	f.StringVar(&c.tlsKey, "tls-key", "", "TLS private key `path` (grpc only)")
+	f.IntVar(&c.maskIPv4Bits, "mask-ipv4-bits", 32, "bits for masking IPv4 addrs")
+	f.IntVar(&c.maskIPv6Bits, "mask-ipv6-bits", 64, "bits for masking IPv6 addrs")
 }
 
 func (c *serverCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
@@ -119,7 +123,9 @@ func (c *serverCommand) run(ctx context.Context) error {
 	if err != nil {
 		return err
 	}
-	srv, err := server.New(c.dbURI, c.scriptPath, srcs)
+
+	mask := ippb.NewIPMask(c.maskIPv4Bits, c.maskIPv6Bits)
+	srv, err := server.New(c.dbURI, c.scriptPath, srcs, mask)
 	if err != nil {
 		return err
 	}
diff --git a/db/db.go b/db/db.go
index 2b237fe..2ca49db 100644
--- a/db/db.go
+++ b/db/db.go
@@ -20,7 +20,7 @@ type DB interface {
 	Close()
 }
 
-func Open(path string) (DB, error) {
+func Open(path string, mask ippb.IPMask) (DB, error) {
 	u, err := url.Parse(path)
 	if err != nil {
 		return nil, err
@@ -31,9 +31,9 @@ func Open(path string) (DB, error) {
 
 	switch u.Scheme {
 	case "", "leveldb":
-		return leveldb.Open(u.Path)
+		return leveldb.Open(u.Path, mask)
 	case "sqlite":
-		return sqlite.Open(u.Path)
+		return sqlite.Open(u.Path, mask)
 	default:
 		return nil, fmt.Errorf("unsupported scheme %s", u.Scheme)
 	}
diff --git a/db/db_test.go b/db/db_test.go
index 91f17c9..d32f331 100644
--- a/db/db_test.go
+++ b/db/db_test.go
@@ -79,7 +79,8 @@ func createTestDB(t testing.TB, driver string) (DB, func()) {
 		t.Fatal(err)
 	}
 
-	db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db"))
+	db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db"),
+		ippb.DefaultIPMask)
 	if err != nil {
 		t.Fatalf("Open: %v", err)
 	}
@@ -126,7 +127,7 @@ func runSanityCheck(t *testing.T, driver string) {
 			t.Fatalf("ScanIP: %v", err)
 		}
 		if n := m[eventType]; n != count {
-			t.Fatalf("read %d events from db, expected %d", n, count)
+			t.Fatalf("(%s/%s): read %d events from db, expected %d", ip, eventType, n, count)
 		}
 	}
 
diff --git a/db/leveldb/leveldb.go b/db/leveldb/leveldb.go
index c374425..a551deb 100644
--- a/db/leveldb/leveldb.go
+++ b/db/leveldb/leveldb.go
@@ -25,10 +25,11 @@ var (
 const deleteBatchSize = 10000
 
 type DB struct {
-	db *leveldb.DB
+	db   *leveldb.DB
+	mask ippb.IPMask
 }
 
-func Open(path string) (*DB, error) {
+func Open(path string, mask ippb.IPMask) (*DB, error) {
 	opts := &opt.Options{
 		BlockCacheCapacity: 64 * opt.MiB,
 		WriteBuffer:        64 * opt.MiB,
@@ -42,7 +43,10 @@ func Open(path string) (*DB, error) {
 	if err != nil {
 		return nil, err
 	}
-	return &DB{db: db}, nil
+	return &DB{
+		db:   db,
+		mask: mask,
+	}, nil
 }
 
 func (db *DB) Close() {
@@ -60,7 +64,7 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error {
 			ev.Type = bt.Type
 			ev.Ip = item.Ip
 			ev.Count = item.Count
-			enc := marshalEvent(&ev)
+			enc := marshalEvent(&ev, db.mask)
 
 			wb.Put(eventKeyFromID(id), enc.data())
 			wb.Put(indexKey(id, ipIndexPrefix, enc.binaryIP()), enc.data())
@@ -122,7 +126,7 @@ func (db *DB) WipeOldData(age time.Duration) (int64, error) {
 
 func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) {
 	iter := db.db.NewIterator(
-		prefixSince(startTime, ipIndexPrefix, ipToBytes(ip)),
+		prefixSince(startTime, ipIndexPrefix, ipToBytes(ip, db.mask)),
 		nil)
 	defer iter.Release()
 
@@ -220,8 +224,8 @@ func ipFromIPIndexKey(key []byte) net.IP {
 	return net.IP(key[len(ipIndexPrefix) : len(ipIndexPrefix)+16])
 }
 
-func ipToBytes(s string) []byte {
-	return []byte(net.ParseIP(s).To16())
+func ipToBytes(s string, mask ippb.IPMask) []byte {
+	return []byte(mask.MaskIP(net.ParseIP(s)).To16())
 }
 
 func prefixSince(start time.Time, indexPrefix, indexKey []byte) *util.Range {
@@ -250,9 +254,9 @@ func prefixUntil(end time.Time, pfx []byte) *util.Range {
 // by the (free form) type field.
 type encodedEvent []byte
 
-func marshalEvent(e *ippb.Event) encodedEvent {
+func marshalEvent(e *ippb.Event, mask ippb.IPMask) encodedEvent {
 	b := make([]byte, 16+8)
-	copy(b[:16], ipToBytes(e.Ip))
+	copy(b[:16], ipToBytes(e.Ip, mask))
 	binary.BigEndian.PutUint64(b[16:24], uint64(e.Count))
 	btype := []byte(e.Type)
 	return encodedEvent(append(b, btype...))
diff --git a/db/sqlite/driver.go b/db/sqlite/driver.go
index df856a9..1133b4e 100644
--- a/db/sqlite/driver.go
+++ b/db/sqlite/driver.go
@@ -9,10 +9,11 @@ import (
 )
 
 type DB struct {
-	db *sql.DB
+	db   *sql.DB
+	mask ippb.IPMask
 }
 
-func Open(path string) (*DB, error) {
+func Open(path string, mask ippb.IPMask) (*DB, error) {
 	// Auto-enable SQLite WAL.
 	if !strings.Contains(path, "?") {
 		path += "?_journal=WAL"
@@ -23,7 +24,8 @@ func Open(path string) (*DB, error) {
 	}
 
 	return &DB{
-		db: db,
+		db:   db,
+		mask: mask,
 	}, nil
 }
 
@@ -49,7 +51,8 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error {
 
 	for _, bt := range aggr.ByType {
 		for _, item := range bt.ByIp {
-			if _, err := stmt.Exec(item.Ip, bt.Type, item.Count, now); err != nil {
+			ip := db.mask.MaskString(item.Ip)
+			if _, err := stmt.Exec(ip, bt.Type, item.Count, now); err != nil {
 				tx.Rollback() // nolint
 				return err
 			}
@@ -81,7 +84,7 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) {
 		`SELECT event_type, SUM(count) AS sum FROM events
                  WHERE ip = ? AND timestamp > ?
                  GROUP BY event_type`,
-		ip, startTime)
+		db.mask.MaskString(ip), startTime)
 	if err != nil {
 		return nil, err
 	}
diff --git a/proto/iprep.go b/proto/iprep.go
index f279d54..ac96513 100644
--- a/proto/iprep.go
+++ b/proto/iprep.go
@@ -1,5 +1,7 @@
 package proto
 
+import "net"
+
 // Map is an in-memory aggregate representation: type/ip/count.
 type Map map[string]map[string]int64
 
@@ -100,3 +102,30 @@ func (a *Aggregate) Recalc() {
 	b := tmp.ToAggregate()
 	a.ByType = b.ByType
 }
+
+// IPMask is a utility type to mask v4/v6 addresses retaining
+// different number of bits for each protocol.
+type IPMask struct {
+	v4Mask net.IPMask
+	v6Mask net.IPMask
+}
+
+var DefaultIPMask = NewIPMask(32, 64)
+
+func NewIPMask(v4bits, v6bits int) IPMask {
+	return IPMask{
+		v4Mask: net.CIDRMask(v4bits, 32),
+		v6Mask: net.CIDRMask(v6bits, 128),
+	}
+}
+
+func (m IPMask) MaskIP(ip net.IP) net.IP {
+	if ip.To4() == nil {
+		return ip.Mask(m.v6Mask)
+	}
+	return ip.Mask(m.v4Mask)
+}
+
+func (m IPMask) MaskString(ipstr string) string {
+	return m.MaskIP(net.ParseIP(ipstr)).String()
+}
diff --git a/proto/iprep_test.go b/proto/iprep_test.go
index 2315cfe..a4b5758 100644
--- a/proto/iprep_test.go
+++ b/proto/iprep_test.go
@@ -3,6 +3,7 @@ package proto
 import (
 	"fmt"
 	"math/rand"
+	"net"
 	"testing"
 )
 
@@ -104,3 +105,32 @@ func TestAggregate_Merge(t *testing.T) {
 		t.Fatalf("bad acount in aggregate: expected 5, got %d", count)
 	}
 }
+
+func TestIPMask(t *testing.T) {
+	testdata := []struct {
+		ip             string
+		v4bits, v6bits int
+		exp            string
+	}{
+		{"1.2.3.4", 32, 128, "1.2.3.4"},
+		{"1.2.3.4", 16, 128, "1.2.0.0"},
+		{"2001:a020:9f1d:c96:2efd:a1ff:fec6:f2fe", 32, 64, "2001:a020:9f1d:c96::"},
+	}
+
+	for _, td := range testdata {
+		mask := NewIPMask(td.v4bits, td.v6bits)
+		ip := net.ParseIP(td.ip)
+		result := mask.MaskIP(ip)
+		resultStr := mask.MaskString(td.ip)
+		if result.String() != resultStr {
+			t.Errorf(
+				"internal inconsistency, MaskIP(%s/%d/%d) -> %s, MaskString -> %s",
+				td.ip, td.v4bits, td.v6bits, result.String(), resultStr)
+		}
+		if resultStr != td.exp {
+			t.Errorf(
+				"MaskIP(%s/%d/%d) -> %s, expected %s",
+				td.ip, td.v4bits, td.v6bits, resultStr, td.exp)
+		}
+	}
+}
diff --git a/server/server.go b/server/server.go
index cfc783e..8e17d45 100644
--- a/server/server.go
+++ b/server/server.go
@@ -34,13 +34,13 @@ var (
 	defaultScript = `score = 0`
 )
 
-func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource) (*Server, error) {
+func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource, mask ippb.IPMask) (*Server, error) {
 	scriptMgr, err := script.NewManager(scriptPath, defaultScript)
 	if err != nil {
 		return nil, err
 	}
 
-	database, err := db.Open(dbPath)
+	database, err := db.Open(dbPath, mask)
 	if err != nil {
 		return nil, err
 	}
-- 
GitLab