From ae875b252b4ad83e1d210481a97885c8de29ed56 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sat, 13 Apr 2019 12:02:49 +0100
Subject: [PATCH] Deduplicate EndpointSet entries based on endpoint names

So we are always returning a set of unique Endpoints.
---
 coordination/presence/presence.go | 68 +++++++++++++++++++------------
 1 file changed, 41 insertions(+), 27 deletions(-)

diff --git a/coordination/presence/presence.go b/coordination/presence/presence.go
index 3f5e9e51..5ebc1773 100644
--- a/coordination/presence/presence.go
+++ b/coordination/presence/presence.go
@@ -2,8 +2,10 @@ package presence
 
 import (
 	"context"
+	"fmt"
 	"log"
 	"math/rand"
+	"strings"
 	"sync"
 
 	"github.com/golang/protobuf/proto"
@@ -15,6 +17,10 @@ import (
 
 // EndpointSet is a container of Endpoints that is synchronizable with
 // the contents of a presence tree as created by Register().
+//
+// Note that while an Endpoint might appear multiple times in the
+// presence registry (before a stale lease expires, for instance), we
+// deduplicate endpoints based on their Name.
 type EndpointSet struct {
 	mx sync.Mutex
 
@@ -28,11 +34,30 @@ type EndpointSet struct {
 	epl []*pb.Endpoint
 }
 
+func (n *EndpointSet) debugString() string {
+	var eps []string
+	for _, ep := range n.epl {
+		eps = append(eps, fmt.Sprintf("%s(%s)", ep.Name, ep.Addrs[0]))
+	}
+	return fmt.Sprintf("<%s>", strings.Join(eps, ","))
+}
+
+func (n *EndpointSet) rebuildIndexes() {
+	epn := make(map[string]*pb.Endpoint)
+	epl := make([]*pb.Endpoint, 0, len(n.eps))
+	for _, ep := range n.eps {
+		if _, ok := epn[ep.Name]; !ok {
+			epn[ep.Name] = ep
+			epl = append(epl, ep)
+		}
+	}
+	n.epn = epn
+	n.epl = epl
+}
+
 // Reset implements the watcher.Watchable interface.
 func (n *EndpointSet) Reset(m map[string]string) {
 	eps := make(map[string]*pb.Endpoint)
-	epn := make(map[string]*pb.Endpoint)
-	var epl []*pb.Endpoint
 	for k, v := range m {
 		var ep pb.Endpoint
 		if err := proto.Unmarshal([]byte(v), &ep); err != nil {
@@ -40,13 +65,11 @@ func (n *EndpointSet) Reset(m map[string]string) {
 			continue
 		}
 		eps[k] = &ep
-		epn[ep.Name] = &ep
-		epl = append(epl, &ep)
 	}
 	n.mx.Lock()
 	n.eps = eps
-	n.epn = epn
-	n.epl = epl
+	n.rebuildIndexes()
+	log.Printf("presence state change (reset): %s", n.debugString())
 	n.mx.Unlock()
 }
 
@@ -58,18 +81,12 @@ func (n *EndpointSet) Set(k, v string) {
 		return
 	}
 	n.mx.Lock()
-	if old, ok := n.eps[k]; ok {
-		n.epl = withoutEndpoint(n.epl, old)
-	}
-	n.epl = append(n.epl, &ep)
 	if n.eps == nil {
 		n.eps = make(map[string]*pb.Endpoint)
 	}
-	if n.epn == nil {
-		n.epn = make(map[string]*pb.Endpoint)
-	}
 	n.eps[k] = &ep
-	n.epn[ep.Name] = &ep
+	n.rebuildIndexes()
+	log.Printf("presence state change: %s", n.debugString())
 	n.mx.Unlock()
 }
 
@@ -80,21 +97,11 @@ func (n *EndpointSet) Delete(k string) {
 	if n.eps == nil {
 		return
 	}
-	if old, ok := n.eps[k]; ok {
+	if _, ok := n.eps[k]; ok {
 		delete(n.eps, k)
-		delete(n.epn, old.Name)
-		n.epl = withoutEndpoint(n.epl, old)
+		n.rebuildIndexes()
 	}
-}
-
-func withoutEndpoint(l []*pb.Endpoint, old *pb.Endpoint) []*pb.Endpoint {
-	out := make([]*pb.Endpoint, 0, len(l))
-	for _, ep := range l {
-		if ep != old {
-			out = append(out, ep)
-		}
-	}
-	return out
+	log.Printf("presence state change: %s", n.debugString())
 }
 
 // Endpoints returns all registered endpoints.
@@ -163,3 +170,10 @@ func WatchEndpoints(ctx context.Context, cli *clientv3.Client, prefix string) *E
 	go watcher.Watch(ctx, cli, prefix, &endpoints)
 	return &endpoints
 }
+
+func WatchEndpointsReady(ctx context.Context, cli *clientv3.Client, prefix string) (*EndpointSet, <-chan struct{}) {
+	var endpoints EndpointSet
+	w := watcher.NewReadyWatchable(&endpoints)
+	go watcher.Watch(ctx, cli, prefix, w)
+	return &endpoints, w.WaitForInit()
+}
-- 
GitLab