Skip to content
Snippets Groups Projects
Commit 37eb65e6 authored by ale's avatar ale
Browse files

Use net.IPs to build Endpoints

parent ad23ce5a
No related branches found
No related tags found
1 merge request!1v2.0
...@@ -2,10 +2,8 @@ package presence ...@@ -2,10 +2,8 @@ package presence
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"log" "log"
mrand "math/rand" "math/rand"
"sync" "sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
...@@ -15,12 +13,6 @@ import ( ...@@ -15,12 +13,6 @@ import (
pb "git.autistici.org/ale/autoradio/proto" pb "git.autistici.org/ale/autoradio/proto"
) )
func getUniqueToken() string {
var b [16]byte
rand.Read(b[:]) // nolint
return hex.EncodeToString(b[:])
}
// EndpointSet is a container of Endpoints that is synchronizable with // EndpointSet is a container of Endpoints that is synchronizable with
// the contents of a presence tree as created by Register(). // the contents of a presence tree as created by Register().
type EndpointSet struct { type EndpointSet struct {
...@@ -123,7 +115,7 @@ func (n *EndpointSet) RandomEndpoint() *pb.Endpoint { ...@@ -123,7 +115,7 @@ func (n *EndpointSet) RandomEndpoint() *pb.Endpoint {
if len(n.epl) == 0 { if len(n.epl) == 0 {
return nil return nil
} }
return n.epl[mrand.Intn(len(n.epl))] return n.epl[rand.Intn(len(n.epl))]
} }
// RandomEndpointExcluding returns a randomly selected Endpoint except // RandomEndpointExcluding returns a randomly selected Endpoint except
...@@ -145,13 +137,13 @@ func (n *EndpointSet) RandomEndpointExcluding(name string) *pb.Endpoint { ...@@ -145,13 +137,13 @@ func (n *EndpointSet) RandomEndpointExcluding(name string) *pb.Endpoint {
} }
// Is the excluded element not in the list? // Is the excluded element not in the list?
if found < 0 { if found < 0 {
return n.epl[mrand.Intn(l)] return n.epl[rand.Intn(l)]
} }
// Is the list empty once we exclude the item? // Is the list empty once we exclude the item?
if l == 1 { if l == 1 {
return nil return nil
} }
i := mrand.Intn(l - 1) i := rand.Intn(l - 1)
if i >= found { if i >= found {
i++ i++
} }
......
...@@ -2,6 +2,9 @@ package presence ...@@ -2,6 +2,9 @@ package presence
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"net"
"strings" "strings"
pb "git.autistici.org/ale/autoradio/proto" pb "git.autistici.org/ale/autoradio/proto"
...@@ -14,17 +17,17 @@ import ( ...@@ -14,17 +17,17 @@ import (
// endpoint registration. Addresses without ports will use the // endpoint registration. Addresses without ports will use the
// DefaultPort. // DefaultPort.
type EndpointRegistration struct { type EndpointRegistration struct {
Prefix string Prefix string
Addrs string Addrs []net.IP
DefaultPort int Port int
} }
// NewRegistration creates a new EndpointRegistration. // NewRegistration creates a new EndpointRegistration.
func NewRegistration(prefix, addrs string, port int) EndpointRegistration { func NewRegistration(prefix string, addrs []net.IP, port int) EndpointRegistration {
return EndpointRegistration{ return EndpointRegistration{
Prefix: prefix, Prefix: prefix,
Addrs: addrs, Addrs: addrs,
DefaultPort: port, Port: port,
} }
} }
...@@ -33,11 +36,8 @@ func NewRegistration(prefix, addrs string, port int) EndpointRegistration { ...@@ -33,11 +36,8 @@ func NewRegistration(prefix, addrs string, port int) EndpointRegistration {
func Register(ctx context.Context, session *concurrency.Session, name string, regs ...EndpointRegistration) ([]*pb.Endpoint, error) { func Register(ctx context.Context, session *concurrency.Session, name string, regs ...EndpointRegistration) ([]*pb.Endpoint, error) {
var endpoints []*pb.Endpoint var endpoints []*pb.Endpoint
for _, r := range regs { for _, r := range regs {
ep, err := pb.ParseEndpoint(name, r.Addrs, r.DefaultPort) ep := pb.NewEndpointWithIPAndPort(name, r.Addrs, r.Port)
if err != nil { if err := registerEndpoint(ctx, session, r.Prefix, ep); err != nil {
return nil, err
}
if err = registerEndpoint(ctx, session, r.Prefix, ep); err != nil {
return nil, err return nil, err
} }
endpoints = append(endpoints, ep) endpoints = append(endpoints, ep)
...@@ -61,3 +61,9 @@ func registerEndpoint(ctx context.Context, session *concurrency.Session, prefix ...@@ -61,3 +61,9 @@ func registerEndpoint(ctx context.Context, session *concurrency.Session, prefix
) )
return err return err
} }
func getUniqueToken() string {
var b [16]byte
rand.Read(b[:]) // nolint
return hex.EncodeToString(b[:])
}
package autoradio package autoradio
import ( import (
"errors"
"fmt"
"net" "net"
"strconv" "strconv"
"strings"
) )
// ParseEndpoint creates an Endpoint with the specified name and // NewEndpointWithIPAndPort creates an Endpoint with the specified IP
// address. If the address does not specify a port, defaultPort is // address and port.
// used. func NewEndpointWithIPAndPort(name string, ips []net.IP, port int) *Endpoint {
func ParseEndpoint(name, s string, defaultPort int) (*Endpoint, error) { ep := Endpoint{
if s == "" { Name: name,
return nil, errors.New("empty endpoint spec")
} }
addrs := strings.Split(s, ",") sport := strconv.Itoa(port)
var ep Endpoint for _, ip := range ips {
ep.Name = name ep.Addrs = append(ep.Addrs, net.JoinHostPort(ip.String(), sport))
for _, addr := range addrs {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("parsing '%s': %v", addr, err)
}
if port == "" {
port = strconv.Itoa(defaultPort)
}
if ip := net.ParseIP(host); ip == nil {
return nil, fmt.Errorf("parsing '%s': bad IP address", host)
}
ep.Addrs = append(ep.Addrs, net.JoinHostPort(host, port))
} }
return &ep, nil return &ep
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment