Skip to content
Snippets Groups Projects
Commit ae875b25 authored by ale's avatar ale
Browse files

Deduplicate EndpointSet entries based on endpoint names

So we are always returning a set of unique Endpoints.
parent 90667aae
No related branches found
No related tags found
1 merge request!1v2.0
...@@ -2,8 +2,10 @@ package presence ...@@ -2,8 +2,10 @@ package presence
import ( import (
"context" "context"
"fmt"
"log" "log"
"math/rand" "math/rand"
"strings"
"sync" "sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
...@@ -15,6 +17,10 @@ import ( ...@@ -15,6 +17,10 @@ import (
// EndpointSet is a container of Endpoints that is synchronizable with // EndpointSet is a container of Endpoints that is synchronizable with
// the contents of a presence tree as created by Register(). // the contents of a presence tree as created by Register().
//
// Note that while an Endpoint might appear multiple times in the
// presence registry (before a stale lease expires, for instance), we
// deduplicate endpoints based on their Name.
type EndpointSet struct { type EndpointSet struct {
mx sync.Mutex mx sync.Mutex
...@@ -28,11 +34,30 @@ type EndpointSet struct { ...@@ -28,11 +34,30 @@ type EndpointSet struct {
epl []*pb.Endpoint epl []*pb.Endpoint
} }
func (n *EndpointSet) debugString() string {
var eps []string
for _, ep := range n.epl {
eps = append(eps, fmt.Sprintf("%s(%s)", ep.Name, ep.Addrs[0]))
}
return fmt.Sprintf("<%s>", strings.Join(eps, ","))
}
func (n *EndpointSet) rebuildIndexes() {
epn := make(map[string]*pb.Endpoint)
epl := make([]*pb.Endpoint, 0, len(n.eps))
for _, ep := range n.eps {
if _, ok := epn[ep.Name]; !ok {
epn[ep.Name] = ep
epl = append(epl, ep)
}
}
n.epn = epn
n.epl = epl
}
// Reset implements the watcher.Watchable interface. // Reset implements the watcher.Watchable interface.
func (n *EndpointSet) Reset(m map[string]string) { func (n *EndpointSet) Reset(m map[string]string) {
eps := make(map[string]*pb.Endpoint) eps := make(map[string]*pb.Endpoint)
epn := make(map[string]*pb.Endpoint)
var epl []*pb.Endpoint
for k, v := range m { for k, v := range m {
var ep pb.Endpoint var ep pb.Endpoint
if err := proto.Unmarshal([]byte(v), &ep); err != nil { if err := proto.Unmarshal([]byte(v), &ep); err != nil {
...@@ -40,13 +65,11 @@ func (n *EndpointSet) Reset(m map[string]string) { ...@@ -40,13 +65,11 @@ func (n *EndpointSet) Reset(m map[string]string) {
continue continue
} }
eps[k] = &ep eps[k] = &ep
epn[ep.Name] = &ep
epl = append(epl, &ep)
} }
n.mx.Lock() n.mx.Lock()
n.eps = eps n.eps = eps
n.epn = epn n.rebuildIndexes()
n.epl = epl log.Printf("presence state change (reset): %s", n.debugString())
n.mx.Unlock() n.mx.Unlock()
} }
...@@ -58,18 +81,12 @@ func (n *EndpointSet) Set(k, v string) { ...@@ -58,18 +81,12 @@ func (n *EndpointSet) Set(k, v string) {
return return
} }
n.mx.Lock() n.mx.Lock()
if old, ok := n.eps[k]; ok {
n.epl = withoutEndpoint(n.epl, old)
}
n.epl = append(n.epl, &ep)
if n.eps == nil { if n.eps == nil {
n.eps = make(map[string]*pb.Endpoint) n.eps = make(map[string]*pb.Endpoint)
} }
if n.epn == nil {
n.epn = make(map[string]*pb.Endpoint)
}
n.eps[k] = &ep n.eps[k] = &ep
n.epn[ep.Name] = &ep n.rebuildIndexes()
log.Printf("presence state change: %s", n.debugString())
n.mx.Unlock() n.mx.Unlock()
} }
...@@ -80,21 +97,11 @@ func (n *EndpointSet) Delete(k string) { ...@@ -80,21 +97,11 @@ func (n *EndpointSet) Delete(k string) {
if n.eps == nil { if n.eps == nil {
return return
} }
if old, ok := n.eps[k]; ok { if _, ok := n.eps[k]; ok {
delete(n.eps, k) delete(n.eps, k)
delete(n.epn, old.Name) n.rebuildIndexes()
n.epl = withoutEndpoint(n.epl, old)
} }
} log.Printf("presence state change: %s", n.debugString())
func withoutEndpoint(l []*pb.Endpoint, old *pb.Endpoint) []*pb.Endpoint {
out := make([]*pb.Endpoint, 0, len(l))
for _, ep := range l {
if ep != old {
out = append(out, ep)
}
}
return out
} }
// Endpoints returns all registered endpoints. // Endpoints returns all registered endpoints.
...@@ -163,3 +170,10 @@ func WatchEndpoints(ctx context.Context, cli *clientv3.Client, prefix string) *E ...@@ -163,3 +170,10 @@ func WatchEndpoints(ctx context.Context, cli *clientv3.Client, prefix string) *E
go watcher.Watch(ctx, cli, prefix, &endpoints) go watcher.Watch(ctx, cli, prefix, &endpoints)
return &endpoints return &endpoints
} }
func WatchEndpointsReady(ctx context.Context, cli *clientv3.Client, prefix string) (*EndpointSet, <-chan struct{}) {
var endpoints EndpointSet
w := watcher.NewReadyWatchable(&endpoints)
go watcher.Watch(ctx, cli, prefix, w)
return &endpoints, w.WaitForInit()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment