diff --git a/cmd/iprep/server.go b/cmd/iprep/server.go index c770d2d23096c39398c9fb67cc6d18230ac23ebf..b1a6d6f90aeda6841b39eff1be59967261b94e45 100644 --- a/cmd/iprep/server.go +++ b/cmd/iprep/server.go @@ -37,6 +37,8 @@ type serverCommand struct { externalSrcDir string tlsCert string tlsKey string + maskIPv4Bits int + maskIPv6Bits int } func (c *serverCommand) Name() string { return "server" } @@ -58,6 +60,8 @@ func (c *serverCommand) SetFlags(f *flag.FlagSet) { f.StringVar(&c.tlsCert, "tls-cert", "", "TLS certificate `path` (grpc only)") f.StringVar(&c.tlsKey, "tls-key", "", "TLS private key `path` (grpc only)") + f.IntVar(&c.maskIPv4Bits, "mask-ipv4-bits", 32, "bits for masking IPv4 addrs") + f.IntVar(&c.maskIPv6Bits, "mask-ipv6-bits", 64, "bits for masking IPv6 addrs") } func (c *serverCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { @@ -119,7 +123,9 @@ func (c *serverCommand) run(ctx context.Context) error { if err != nil { return err } - srv, err := server.New(c.dbURI, c.scriptPath, srcs) + + mask := ippb.NewIPMask(c.maskIPv4Bits, c.maskIPv6Bits) + srv, err := server.New(c.dbURI, c.scriptPath, srcs, mask) if err != nil { return err } diff --git a/db/db.go b/db/db.go index 2b237fec4f1bf0913887ebf91eef61751f8b318f..2ca49dba29128403f0e962626ea0fd64399f1ddb 100644 --- a/db/db.go +++ b/db/db.go @@ -20,7 +20,7 @@ type DB interface { Close() } -func Open(path string) (DB, error) { +func Open(path string, mask ippb.IPMask) (DB, error) { u, err := url.Parse(path) if err != nil { return nil, err @@ -31,9 +31,9 @@ func Open(path string) (DB, error) { switch u.Scheme { case "", "leveldb": - return leveldb.Open(u.Path) + return leveldb.Open(u.Path, mask) case "sqlite": - return sqlite.Open(u.Path) + return sqlite.Open(u.Path, mask) default: return nil, fmt.Errorf("unsupported scheme %s", u.Scheme) } diff --git a/db/db_test.go b/db/db_test.go index 91f17c96d385c3292b8c14cac97f2e39b32971bc..d32f331539594f5860a3cd34164234a34c754367 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -79,7 +79,8 @@ func createTestDB(t testing.TB, driver string) (DB, func()) { t.Fatal(err) } - db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db")) + db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db"), + ippb.DefaultIPMask) if err != nil { t.Fatalf("Open: %v", err) } @@ -126,7 +127,7 @@ func runSanityCheck(t *testing.T, driver string) { t.Fatalf("ScanIP: %v", err) } if n := m[eventType]; n != count { - t.Fatalf("read %d events from db, expected %d", n, count) + t.Fatalf("(%s/%s): read %d events from db, expected %d", ip, eventType, n, count) } } diff --git a/db/leveldb/leveldb.go b/db/leveldb/leveldb.go index c374425c3f7895260ed150bc060e3390c88f0c3b..a551deb57ed937b7aa062334bbe8c3d9db2135f2 100644 --- a/db/leveldb/leveldb.go +++ b/db/leveldb/leveldb.go @@ -25,10 +25,11 @@ var ( const deleteBatchSize = 10000 type DB struct { - db *leveldb.DB + db *leveldb.DB + mask ippb.IPMask } -func Open(path string) (*DB, error) { +func Open(path string, mask ippb.IPMask) (*DB, error) { opts := &opt.Options{ BlockCacheCapacity: 64 * opt.MiB, WriteBuffer: 64 * opt.MiB, @@ -42,7 +43,10 @@ func Open(path string) (*DB, error) { if err != nil { return nil, err } - return &DB{db: db}, nil + return &DB{ + db: db, + mask: mask, + }, nil } func (db *DB) Close() { @@ -60,7 +64,7 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error { ev.Type = bt.Type ev.Ip = item.Ip ev.Count = item.Count - enc := marshalEvent(&ev) + enc := marshalEvent(&ev, db.mask) wb.Put(eventKeyFromID(id), enc.data()) wb.Put(indexKey(id, ipIndexPrefix, enc.binaryIP()), enc.data()) @@ -122,7 +126,7 @@ func (db *DB) WipeOldData(age time.Duration) (int64, error) { func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) { iter := db.db.NewIterator( - prefixSince(startTime, ipIndexPrefix, ipToBytes(ip)), + prefixSince(startTime, ipIndexPrefix, ipToBytes(ip, db.mask)), nil) defer iter.Release() @@ -220,8 +224,8 @@ func ipFromIPIndexKey(key []byte) net.IP { return net.IP(key[len(ipIndexPrefix) : len(ipIndexPrefix)+16]) } -func ipToBytes(s string) []byte { - return []byte(net.ParseIP(s).To16()) +func ipToBytes(s string, mask ippb.IPMask) []byte { + return []byte(mask.MaskIP(net.ParseIP(s)).To16()) } func prefixSince(start time.Time, indexPrefix, indexKey []byte) *util.Range { @@ -250,9 +254,9 @@ func prefixUntil(end time.Time, pfx []byte) *util.Range { // by the (free form) type field. type encodedEvent []byte -func marshalEvent(e *ippb.Event) encodedEvent { +func marshalEvent(e *ippb.Event, mask ippb.IPMask) encodedEvent { b := make([]byte, 16+8) - copy(b[:16], ipToBytes(e.Ip)) + copy(b[:16], ipToBytes(e.Ip, mask)) binary.BigEndian.PutUint64(b[16:24], uint64(e.Count)) btype := []byte(e.Type) return encodedEvent(append(b, btype...)) diff --git a/db/sqlite/driver.go b/db/sqlite/driver.go index df856a9842abc046951632c274c768370e186b1c..1133b4eca034f32e1aefae8087ddc05437f90156 100644 --- a/db/sqlite/driver.go +++ b/db/sqlite/driver.go @@ -9,10 +9,11 @@ import ( ) type DB struct { - db *sql.DB + db *sql.DB + mask ippb.IPMask } -func Open(path string) (*DB, error) { +func Open(path string, mask ippb.IPMask) (*DB, error) { // Auto-enable SQLite WAL. if !strings.Contains(path, "?") { path += "?_journal=WAL" @@ -23,7 +24,8 @@ func Open(path string) (*DB, error) { } return &DB{ - db: db, + db: db, + mask: mask, }, nil } @@ -49,7 +51,8 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error { for _, bt := range aggr.ByType { for _, item := range bt.ByIp { - if _, err := stmt.Exec(item.Ip, bt.Type, item.Count, now); err != nil { + ip := db.mask.MaskString(item.Ip) + if _, err := stmt.Exec(ip, bt.Type, item.Count, now); err != nil { tx.Rollback() // nolint return err } @@ -81,7 +84,7 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) { `SELECT event_type, SUM(count) AS sum FROM events WHERE ip = ? AND timestamp > ? GROUP BY event_type`, - ip, startTime) + db.mask.MaskString(ip), startTime) if err != nil { return nil, err } diff --git a/proto/iprep.go b/proto/iprep.go index f279d54e46e21ec54eb05970d5163d1a4f65841f..ac965133402f3345acdeb8c8861a38d0690b62d6 100644 --- a/proto/iprep.go +++ b/proto/iprep.go @@ -1,5 +1,7 @@ package proto +import "net" + // Map is an in-memory aggregate representation: type/ip/count. type Map map[string]map[string]int64 @@ -100,3 +102,30 @@ func (a *Aggregate) Recalc() { b := tmp.ToAggregate() a.ByType = b.ByType } + +// IPMask is a utility type to mask v4/v6 addresses retaining +// different number of bits for each protocol. +type IPMask struct { + v4Mask net.IPMask + v6Mask net.IPMask +} + +var DefaultIPMask = NewIPMask(32, 64) + +func NewIPMask(v4bits, v6bits int) IPMask { + return IPMask{ + v4Mask: net.CIDRMask(v4bits, 32), + v6Mask: net.CIDRMask(v6bits, 128), + } +} + +func (m IPMask) MaskIP(ip net.IP) net.IP { + if ip.To4() == nil { + return ip.Mask(m.v6Mask) + } + return ip.Mask(m.v4Mask) +} + +func (m IPMask) MaskString(ipstr string) string { + return m.MaskIP(net.ParseIP(ipstr)).String() +} diff --git a/proto/iprep_test.go b/proto/iprep_test.go index 2315cfe5a5d7393e9d8c3f7a66a17397b510d4de..a4b5758a2010f31ff9d5a0250834a5921f29fc13 100644 --- a/proto/iprep_test.go +++ b/proto/iprep_test.go @@ -3,6 +3,7 @@ package proto import ( "fmt" "math/rand" + "net" "testing" ) @@ -104,3 +105,32 @@ func TestAggregate_Merge(t *testing.T) { t.Fatalf("bad acount in aggregate: expected 5, got %d", count) } } + +func TestIPMask(t *testing.T) { + testdata := []struct { + ip string + v4bits, v6bits int + exp string + }{ + {"1.2.3.4", 32, 128, "1.2.3.4"}, + {"1.2.3.4", 16, 128, "1.2.0.0"}, + {"2001:a020:9f1d:c96:2efd:a1ff:fec6:f2fe", 32, 64, "2001:a020:9f1d:c96::"}, + } + + for _, td := range testdata { + mask := NewIPMask(td.v4bits, td.v6bits) + ip := net.ParseIP(td.ip) + result := mask.MaskIP(ip) + resultStr := mask.MaskString(td.ip) + if result.String() != resultStr { + t.Errorf( + "internal inconsistency, MaskIP(%s/%d/%d) -> %s, MaskString -> %s", + td.ip, td.v4bits, td.v6bits, result.String(), resultStr) + } + if resultStr != td.exp { + t.Errorf( + "MaskIP(%s/%d/%d) -> %s, expected %s", + td.ip, td.v4bits, td.v6bits, resultStr, td.exp) + } + } +} diff --git a/server/server.go b/server/server.go index cfc783e44ae4dbe85b4095cf0eeff9450fb9fe0c..8e17d4547600cda2b30c34b5ff55129eed1edb1d 100644 --- a/server/server.go +++ b/server/server.go @@ -34,13 +34,13 @@ var ( defaultScript = `score = 0` ) -func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource) (*Server, error) { +func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource, mask ippb.IPMask) (*Server, error) { scriptMgr, err := script.NewManager(scriptPath, defaultScript) if err != nil { return nil, err } - database, err := db.Open(dbPath) + database, err := db.Open(dbPath, mask) if err != nil { return nil, err }