Skip to content
Snippets Groups Projects
Commit 94461981 authored by ale's avatar ale
Browse files

Implement IP masking in the database

Make it possible to apply an IP mask (different for v4 and v6) to all
addresses stored and retrieved from the database. This is transparent
to the API (you can still send and query individual IPs).

The default behavior becomes to aggregate away IPv6 addresses to a /64
network, while there is no change to IPv4 handling.
parent e81d3c17
Branches
No related tags found
No related merge requests found
Pipeline #15070 passed
......@@ -37,6 +37,8 @@ type serverCommand struct {
externalSrcDir string
tlsCert string
tlsKey string
maskIPv4Bits int
maskIPv6Bits int
}
func (c *serverCommand) Name() string { return "server" }
......@@ -58,6 +60,8 @@ func (c *serverCommand) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.tlsCert, "tls-cert", "", "TLS certificate `path` (grpc only)")
f.StringVar(&c.tlsKey, "tls-key", "", "TLS private key `path` (grpc only)")
f.IntVar(&c.maskIPv4Bits, "mask-ipv4-bits", 32, "bits for masking IPv4 addrs")
f.IntVar(&c.maskIPv6Bits, "mask-ipv6-bits", 64, "bits for masking IPv6 addrs")
}
func (c *serverCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
......@@ -119,7 +123,9 @@ func (c *serverCommand) run(ctx context.Context) error {
if err != nil {
return err
}
srv, err := server.New(c.dbURI, c.scriptPath, srcs)
mask := ippb.NewIPMask(c.maskIPv4Bits, c.maskIPv6Bits)
srv, err := server.New(c.dbURI, c.scriptPath, srcs, mask)
if err != nil {
return err
}
......
......@@ -20,7 +20,7 @@ type DB interface {
Close()
}
func Open(path string) (DB, error) {
func Open(path string, mask ippb.IPMask) (DB, error) {
u, err := url.Parse(path)
if err != nil {
return nil, err
......@@ -31,9 +31,9 @@ func Open(path string) (DB, error) {
switch u.Scheme {
case "", "leveldb":
return leveldb.Open(u.Path)
return leveldb.Open(u.Path, mask)
case "sqlite":
return sqlite.Open(u.Path)
return sqlite.Open(u.Path, mask)
default:
return nil, fmt.Errorf("unsupported scheme %s", u.Scheme)
}
......
......@@ -79,7 +79,8 @@ func createTestDB(t testing.TB, driver string) (DB, func()) {
t.Fatal(err)
}
db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db"))
db, err := Open(fmt.Sprintf("%s://%s", driver, dir+"/test.db"),
ippb.DefaultIPMask)
if err != nil {
t.Fatalf("Open: %v", err)
}
......@@ -126,7 +127,7 @@ func runSanityCheck(t *testing.T, driver string) {
t.Fatalf("ScanIP: %v", err)
}
if n := m[eventType]; n != count {
t.Fatalf("read %d events from db, expected %d", n, count)
t.Fatalf("(%s/%s): read %d events from db, expected %d", ip, eventType, n, count)
}
}
......
......@@ -26,9 +26,10 @@ const deleteBatchSize = 10000
type DB struct {
db *leveldb.DB
mask ippb.IPMask
}
func Open(path string) (*DB, error) {
func Open(path string, mask ippb.IPMask) (*DB, error) {
opts := &opt.Options{
BlockCacheCapacity: 64 * opt.MiB,
WriteBuffer: 64 * opt.MiB,
......@@ -42,7 +43,10 @@ func Open(path string) (*DB, error) {
if err != nil {
return nil, err
}
return &DB{db: db}, nil
return &DB{
db: db,
mask: mask,
}, nil
}
func (db *DB) Close() {
......@@ -60,7 +64,7 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error {
ev.Type = bt.Type
ev.Ip = item.Ip
ev.Count = item.Count
enc := marshalEvent(&ev)
enc := marshalEvent(&ev, db.mask)
wb.Put(eventKeyFromID(id), enc.data())
wb.Put(indexKey(id, ipIndexPrefix, enc.binaryIP()), enc.data())
......@@ -122,7 +126,7 @@ func (db *DB) WipeOldData(age time.Duration) (int64, error) {
func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) {
iter := db.db.NewIterator(
prefixSince(startTime, ipIndexPrefix, ipToBytes(ip)),
prefixSince(startTime, ipIndexPrefix, ipToBytes(ip, db.mask)),
nil)
defer iter.Release()
......@@ -220,8 +224,8 @@ func ipFromIPIndexKey(key []byte) net.IP {
return net.IP(key[len(ipIndexPrefix) : len(ipIndexPrefix)+16])
}
func ipToBytes(s string) []byte {
return []byte(net.ParseIP(s).To16())
func ipToBytes(s string, mask ippb.IPMask) []byte {
return []byte(mask.MaskIP(net.ParseIP(s)).To16())
}
func prefixSince(start time.Time, indexPrefix, indexKey []byte) *util.Range {
......@@ -250,9 +254,9 @@ func prefixUntil(end time.Time, pfx []byte) *util.Range {
// by the (free form) type field.
type encodedEvent []byte
func marshalEvent(e *ippb.Event) encodedEvent {
func marshalEvent(e *ippb.Event, mask ippb.IPMask) encodedEvent {
b := make([]byte, 16+8)
copy(b[:16], ipToBytes(e.Ip))
copy(b[:16], ipToBytes(e.Ip, mask))
binary.BigEndian.PutUint64(b[16:24], uint64(e.Count))
btype := []byte(e.Type)
return encodedEvent(append(b, btype...))
......
......@@ -10,9 +10,10 @@ import (
type DB struct {
db *sql.DB
mask ippb.IPMask
}
func Open(path string) (*DB, error) {
func Open(path string, mask ippb.IPMask) (*DB, error) {
// Auto-enable SQLite WAL.
if !strings.Contains(path, "?") {
path += "?_journal=WAL"
......@@ -24,6 +25,7 @@ func Open(path string) (*DB, error) {
return &DB{
db: db,
mask: mask,
}, nil
}
......@@ -49,7 +51,8 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error {
for _, bt := range aggr.ByType {
for _, item := range bt.ByIp {
if _, err := stmt.Exec(item.Ip, bt.Type, item.Count, now); err != nil {
ip := db.mask.MaskString(item.Ip)
if _, err := stmt.Exec(ip, bt.Type, item.Count, now); err != nil {
tx.Rollback() // nolint
return err
}
......@@ -81,7 +84,7 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) {
`SELECT event_type, SUM(count) AS sum FROM events
WHERE ip = ? AND timestamp > ?
GROUP BY event_type`,
ip, startTime)
db.mask.MaskString(ip), startTime)
if err != nil {
return nil, err
}
......
package proto
import "net"
// Map is an in-memory aggregate representation: type/ip/count.
type Map map[string]map[string]int64
......@@ -100,3 +102,30 @@ func (a *Aggregate) Recalc() {
b := tmp.ToAggregate()
a.ByType = b.ByType
}
// IPMask is a utility type to mask v4/v6 addresses retaining
// different number of bits for each protocol.
type IPMask struct {
v4Mask net.IPMask
v6Mask net.IPMask
}
var DefaultIPMask = NewIPMask(32, 64)
func NewIPMask(v4bits, v6bits int) IPMask {
return IPMask{
v4Mask: net.CIDRMask(v4bits, 32),
v6Mask: net.CIDRMask(v6bits, 128),
}
}
func (m IPMask) MaskIP(ip net.IP) net.IP {
if ip.To4() == nil {
return ip.Mask(m.v6Mask)
}
return ip.Mask(m.v4Mask)
}
func (m IPMask) MaskString(ipstr string) string {
return m.MaskIP(net.ParseIP(ipstr)).String()
}
......@@ -3,6 +3,7 @@ package proto
import (
"fmt"
"math/rand"
"net"
"testing"
)
......@@ -104,3 +105,32 @@ func TestAggregate_Merge(t *testing.T) {
t.Fatalf("bad acount in aggregate: expected 5, got %d", count)
}
}
func TestIPMask(t *testing.T) {
testdata := []struct {
ip string
v4bits, v6bits int
exp string
}{
{"1.2.3.4", 32, 128, "1.2.3.4"},
{"1.2.3.4", 16, 128, "1.2.0.0"},
{"2001:a020:9f1d:c96:2efd:a1ff:fec6:f2fe", 32, 64, "2001:a020:9f1d:c96::"},
}
for _, td := range testdata {
mask := NewIPMask(td.v4bits, td.v6bits)
ip := net.ParseIP(td.ip)
result := mask.MaskIP(ip)
resultStr := mask.MaskString(td.ip)
if result.String() != resultStr {
t.Errorf(
"internal inconsistency, MaskIP(%s/%d/%d) -> %s, MaskString -> %s",
td.ip, td.v4bits, td.v6bits, result.String(), resultStr)
}
if resultStr != td.exp {
t.Errorf(
"MaskIP(%s/%d/%d) -> %s, expected %s",
td.ip, td.v4bits, td.v6bits, resultStr, td.exp)
}
}
}
......@@ -34,13 +34,13 @@ var (
defaultScript = `score = 0`
)
func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource) (*Server, error) {
func New(dbPath, scriptPath string, extSrcs map[string]ext.ExternalSource, mask ippb.IPMask) (*Server, error) {
scriptMgr, err := script.NewManager(scriptPath, defaultScript)
if err != nil {
return nil, err
}
database, err := db.Open(dbPath)
database, err := db.Open(dbPath, mask)
if err != nil {
return nil, err
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment