diff --git a/db/db.go b/db/db.go index fc972521c08937f396fb8a49bbd8a4bf91cec666..9ca9c5d56300af37a213109c8f2548ad928c29cc 100644 --- a/db/db.go +++ b/db/db.go @@ -13,8 +13,8 @@ import ( type DB interface { AddAggregate(*ippb.Aggregate) error WipeOldData(time.Duration) error - ScanIP(time.Time, string) (ippb.Map, error) - ScanType(time.Time, string) (ippb.Map, error) + ScanIP(time.Time, string) (map[string]int64, error) + ScanType(time.Time, string) (map[string]int64, error) Close() } diff --git a/db/db_test.go b/db/db_test.go index fd6e3cb444c42dc999dc50bb0a5964c73dd2f47d..ad92837e527dc26650bca02271eddd748742543c 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -29,8 +29,13 @@ func randomIPs(n int) []string { return out } +var eventTypes = []string{ + "detection1", "detection2", "detection3", "detection4", "detection5", + "detection6", "detection7", "detection8", "detection9", "detection10", +} + func randomType() string { - return "detection" + return eventTypes[rand.Intn(len(eventTypes))] } func randomEvent() *ippb.Event { @@ -80,7 +85,7 @@ func runRWTest(t *testing.T, driver string) { if err != nil { t.Fatalf("ScanIP(%s): %v", refIP, err) } - if count := m[refType][refIP]; count < 1 { + if count := m[refType]; count < 1 { t.Fatalf("read %d events from db, expected > 0", count) } @@ -190,7 +195,7 @@ func runReadBenchmark(b *testing.B, driver string, eventsPerIP int, threadCounts for _, ip := range refIPs { a := new(ippb.Aggregate) for i := 0; i < eventsPerIP; i++ { - a.AddEvent(&ippb.Event{Type: "test", Ip: ip, Count: 1}) + a.AddEvent(&ippb.Event{Type: randomType(), Ip: ip, Count: 1}) } db.AddAggregate(a) } @@ -217,7 +222,7 @@ func runReadBenchmark(b *testing.B, driver string, eventsPerIP int, threadCounts if len(m) == 0 { b.Fatalf("ScanIP(%d): returned empty result", i) } - if len(m["test"]) < 1 { + if len(m) < 1 { b.Fatalf("ScanIP(%d): returned bad results: %v", i, m) } } diff --git a/db/leveldb/leveldb.go b/db/leveldb/leveldb.go index 8a38fd5806230c9110c8d7adb162a51d77ce941b..fe68f8143a423f9c0f1566b524ec85d731f05f5d 100644 --- a/db/leveldb/leveldb.go +++ b/db/leveldb/leveldb.go @@ -73,24 +73,6 @@ func (db *DB) AddAggregate(aggr *ippb.Aggregate) error { return db.db.Write(wb, nil) } -func (db *DB) scanIndex(startTime time.Time, indexPrefix, indexKey []byte) (ippb.Map, error) { - iter := db.db.NewIterator(prefixSince(startTime, indexPrefix, indexKey), &opt.ReadOptions{ - DontFillCache: true, - }) - m := make(ippb.Map) - - for iter.Next() { - var err error - var e ippb.Event - if err = proto.Unmarshal(iter.Value(), &e); err != nil { - continue - } - m.Incr(e.Type, e.Ip, e.Count) - } - - return m, iter.Error() -} - func (db *DB) delete(endTime time.Time) error { iter := db.db.NewIterator(prefixUntil(endTime, eventPrefix), &opt.ReadOptions{ DontFillCache: true, @@ -116,12 +98,36 @@ func (db *DB) WipeOldData(age time.Duration) error { return db.db.CompactRange(util.Range{}) } -func (db *DB) ScanIP(startTime time.Time, ip string) (ippb.Map, error) { - return db.scanIndex(startTime, ipIndexPrefix, ipToBytes(ip)) +func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) { + iter := db.db.NewIterator(prefixSince(startTime, ipIndexPrefix, ipToBytes(ip)), &opt.ReadOptions{ + DontFillCache: true, + }) + m := make(map[string]int64) + for iter.Next() { + var err error + var e ippb.Event + if err = proto.Unmarshal(iter.Value(), &e); err != nil { + continue + } + m[e.Type] += e.Count + } + return m, iter.Error() } -func (db *DB) ScanType(startTime time.Time, t string) (ippb.Map, error) { - return db.scanIndex(startTime, typeIndexPrefix, []byte(t)) +func (db *DB) ScanType(startTime time.Time, t string) (map[string]int64, error) { + iter := db.db.NewIterator(prefixSince(startTime, typeIndexPrefix, []byte(t)), &opt.ReadOptions{ + DontFillCache: true, + }) + m := make(map[string]int64) + for iter.Next() { + var err error + var e ippb.Event + if err = proto.Unmarshal(iter.Value(), &e); err != nil { + continue + } + m[e.Ip] += e.Count + } + return m, iter.Error() } func (db *DB) Compact() error { diff --git a/db/sqlite/driver.go b/db/sqlite/driver.go index aacaa48ebb48c94b5a2e85ee04e16caf2751a523..72f74f83a80ecc935f7019c926eefb6f2d8143bc 100644 --- a/db/sqlite/driver.go +++ b/db/sqlite/driver.go @@ -30,7 +30,7 @@ type DB struct { } func Open(path string) (*DB, error) { - db, err := sqlOpen(path) + db, err := sqlOpen(path + "?_journal=WAL") if err != nil { return nil, err } @@ -78,7 +78,7 @@ DELETE FROM events WHERE timestamp < ? return err } -func (db *DB) ScanIP(startTime time.Time, ip string) (ippb.Map, error) { +func (db *DB) ScanIP(startTime time.Time, ip string) (map[string]int64, error) { tx, err := db.db.Begin() if err != nil { return nil, err @@ -91,19 +91,19 @@ func (db *DB) ScanIP(startTime time.Time, ip string) (ippb.Map, error) { } defer rows.Close() - m := make(ippb.Map) + m := make(map[string]int64) for rows.Next() { var evType string var count int64 if err := rows.Scan(&evType, &count); err != nil { return nil, err } - m.Incr(evType, ip, count) + m[evType] = count } return m, nil } -func (db *DB) ScanType(startTime time.Time, t string) (ippb.Map, error) { +func (db *DB) ScanType(startTime time.Time, t string) (map[string]int64, error) { return nil, nil } diff --git a/script/script.go b/script/script.go index 72e8076e15fb976c115f7a364a2d13b259a21c64..8bfe9d3d1968fb4294e80bd86b37481158c2514f 100644 --- a/script/script.go +++ b/script/script.go @@ -8,8 +8,6 @@ import ( "github.com/d5/tengo/script" "github.com/d5/tengo/stdlib" - - ippb "git.autistici.org/ai3/tools/iprep/proto" ) type Script struct { @@ -48,23 +46,13 @@ func NewScript(src []byte) (*Script, error) { return &Script{compiled: c}, nil } -func buildIPMap(ip string, m ippb.Map) map[string]interface{} { - out := make(map[string]interface{}) - for t, bt := range m { - if n, ok := bt[ip]; ok { - out[t] = n - } - } - return out -} - -func (script *Script) RunIP(ctx context.Context, ip string, m ippb.Map, intervalSecs float64, ext map[string]interface{}) (float64, error) { +func (script *Script) RunIP(ctx context.Context, ip string, counts map[string]int64, intervalSecs float64, ext map[string]interface{}) (float64, error) { c := script.compiled.Clone() c.Set("ip", ip) c.Set("score", 0.0) c.Set("interval", intervalSecs) c.Set("ext", ext) - if err := c.Set("counts", buildIPMap(ip, m)); err != nil { + if err := c.Set("counts", intMap(counts)); err != nil { return 0, err } if err := c.RunContext(ctx); err != nil { diff --git a/script/script_test.go b/script/script_test.go index 0ddc89e243a59a7b6b63db33cc21cba104bcc428..49946f6beec0e6644c7efe995ca2f63ca3b006ff 100644 --- a/script/script_test.go +++ b/script/script_test.go @@ -3,8 +3,6 @@ package script import ( "context" "testing" - - ippb "git.autistici.org/ai3/tools/iprep/proto" ) func TestScript(t *testing.T) { @@ -17,10 +15,10 @@ score = counts["test"] / 2 + counts["test2"] if err != nil { t.Fatalf("NewScript: %v", err) } - m := make(ippb.Map) - m.Incr("test", "1.2.3.4", 10) - m.Incr("test2", "1.2.3.4", 2) - m.Incr("test2", "2.3.4.5", 3) + m := map[string]int64{ + "test": 10, + "test2": 2, + } score, err := s.RunIP(context.Background(), "1.2.3.4", m, 3600, nil) if err != nil { diff --git a/script/types.go b/script/types.go new file mode 100644 index 0000000000000000000000000000000000000000..4e9d7a01b8cc02ba20098b6d29fa5741b2ac642e --- /dev/null +++ b/script/types.go @@ -0,0 +1,107 @@ +package script + +import ( + "github.com/d5/tengo/compiler/token" + "github.com/d5/tengo/objects" +) + +// Tengo object type that wraps a read-only map[string]int64, without +// requiring us to go through a map[string]interface{}. +type intMap map[string]int64 + +func (m intMap) String() string { + return "<intMap>" +} + +func (m intMap) TypeName() string { + return "int-map" +} + +func (m intMap) Copy() objects.Object { + return m +} + +func (m intMap) IsFalsy() bool { + return len(m) == 0 +} + +func (m intMap) Equals(o objects.Object) bool { + return false +} + +func (m intMap) BinaryOp(op token.Token, rhs objects.Object) (objects.Object, error) { + return nil, objects.ErrInvalidOperator +} + +func (m intMap) IndexGet(index objects.Object) (objects.Object, error) { + indexStr, ok := index.(*objects.String) + if !ok { + return nil, objects.ErrInvalidIndexType + } + value, ok := m[indexStr.Value] + if !ok { + return objects.UndefinedValue, nil + } + return &objects.Int{Value: value}, nil +} + +func (m intMap) Iterate() objects.Iterator { + return newIntMapIterator(m) +} + +// We need to make a copy of the map to iterate on it with this API. +// Note that in the iterator, 'idx' maps to the *next* item. +type intMapIterator struct { + arr []keyIntPair + idx int +} + +type keyIntPair struct { + key string + value int64 +} + +func newIntMapIterator(m intMap) *intMapIterator { + arr := make([]keyIntPair, 0, len(m)) + for k, v := range m { + arr = append(arr, keyIntPair{key: k, value: v}) + } + return &intMapIterator{arr: arr} +} + +func (i *intMapIterator) String() string { + return "<intMapIterator>" +} + +func (i *intMapIterator) TypeName() string { + return "int-map-iterator" +} + +func (i *intMapIterator) Copy() objects.Object { + return i +} + +func (i *intMapIterator) IsFalsy() bool { + return len(i.arr) == 0 +} + +func (i *intMapIterator) Equals(o objects.Object) bool { + return false +} + +func (i *intMapIterator) BinaryOp(op token.Token, rhs objects.Object) (objects.Object, error) { + return nil, objects.ErrInvalidOperator +} + +func (i *intMapIterator) Next() bool { + i.idx++ + return i.idx <= len(i.arr) +} + +func (i *intMapIterator) Key() objects.Object { + return &objects.String{Value: i.arr[i.idx-1].key} +} + +func (i *intMapIterator) Value() objects.Object { + return &objects.Int{Value: i.arr[i.idx-1].value} +} diff --git a/server/server.go b/server/server.go index cf2a91985f9a01ae3acb14c7858bf818facac81f..9edf106a08619e824192f7698afd185a914184e2 100644 --- a/server/server.go +++ b/server/server.go @@ -83,12 +83,12 @@ func (s *Server) Submit(ctx context.Context, req *ippb.SubmitRequest) (*empty.Em } func (s *Server) GetScore(ctx context.Context, req *ippb.GetScoreRequest) (*ippb.GetScoreResponse, error) { - m, err := s.db.ScanIP(time.Now().Add(s.horizon), req.Ip) + counts, err := s.db.ScanIP(time.Now().Add(s.horizon), req.Ip) if err != nil { return nil, status.Errorf(codes.Unavailable, "%v", err) } - score, err := s.Script().RunIP(ctx, req.Ip, m, s.horizon.Seconds(), nil) + score, err := s.Script().RunIP(ctx, req.Ip, counts, s.horizon.Seconds(), nil) if err != nil { return nil, status.Errorf(codes.Internal, "%v", err) }