Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/replds2
1 result
Show changes
Commits on Source (3)
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"git.autistici.org/ai3/tools/replds2/common" "git.autistici.org/ai3/tools/replds2/common"
) )
type ACLOp int type Op int
const ( const (
OpRead = iota OpRead = iota
...@@ -18,7 +18,7 @@ const ( ...@@ -18,7 +18,7 @@ const (
OpPeer OpPeer
) )
func (op ACLOp) String() string { func (op Op) String() string {
switch op { switch op {
case OpRead: case OpRead:
return "READ" return "READ"
...@@ -34,12 +34,12 @@ type NullManager struct{} ...@@ -34,12 +34,12 @@ type NullManager struct{}
func NewNullManager() *NullManager { return new(NullManager) } func NewNullManager() *NullManager { return new(NullManager) }
func (m *NullManager) Check(identity string, op ACLOp, path string) bool { return true } func (m *NullManager) Check(identity string, op Op, path string) bool { return true }
type Entry struct { type Entry struct {
Identity string Identity string
Path string Path string
Op ACLOp Op Op
} }
type Manager struct { type Manager struct {
...@@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager { ...@@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager {
return m return m
} }
func (m *Manager) Check(identity string, op ACLOp, path string) bool { func (m *Manager) Check(identity string, op Op, path string) bool {
e := Entry{Identity: identity, Op: op} e := Entry{Identity: identity, Op: op}
for _, p := range common.PathPrefixes(path) { for _, p := range common.PathPrefixes(path) {
e.Path = p e.Path = p
...@@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) { ...@@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) {
return return
} }
func parseOp(s string) (op ACLOp, err error) { func parseOp(s string) (op Op, err error) {
switch strings.ToLower(s) { switch strings.ToLower(s) {
case "read", "r": case "read", "r":
op = OpRead op = OpRead
......
...@@ -12,7 +12,7 @@ func TestACL(t *testing.T) { ...@@ -12,7 +12,7 @@ func TestACL(t *testing.T) {
testdata := []struct { testdata := []struct {
identity, path string identity, path string
op ACLOp op Op
expectedOk bool expectedOk bool
}{ }{
{"user", "/home/user/foo", OpWrite, true}, {"user", "/home/user/foo", OpWrite, true},
......
...@@ -19,12 +19,10 @@ import ( ...@@ -19,12 +19,10 @@ import (
) )
type storeCommand struct { type storeCommand struct {
sslCert string sslCert string
sslKey string sslKey string
sslCA string sslCA string
storePath string serverAddr string
triggersPath string
serverAddr string
} }
func init() { func init() {
......
module git.autistici.org/ai3/tools/replds2 module git.autistici.org/ai3/tools/replds2
go 1.15 go 1.19
require ( require (
github.com/DataDog/zstd v1.5.3-0.20220606203749-fd035e54e312 github.com/DataDog/zstd v1.5.3-0.20220606203749-fd035e54e312
...@@ -11,9 +11,20 @@ require ( ...@@ -11,9 +11,20 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/prometheus/common v0.31.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
google.golang.org/grpc v1.53.0 google.golang.org/grpc v1.53.0
google.golang.org/protobuf v1.29.1 google.golang.org/protobuf v1.29.1
) )
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.31.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
golang.org/x/net v0.5.0 // indirect
golang.org/x/sys v0.4.0 // indirect
golang.org/x/text v0.6.0 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
)
This diff is collapsed.
package integrationtest
package integrationtest
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
replds "git.autistici.org/ai3/tools/replds2"
"git.autistici.org/ai3/tools/replds2/acl"
"git.autistici.org/ai3/tools/replds2/common"
pb "git.autistici.org/ai3/tools/replds2/proto"
"git.autistici.org/ai3/tools/replds2/store/memlog"
"git.autistici.org/ai3/tools/replds2/watcher"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/types/known/timestamppb"
)
func writeSSLCert(t *testing.T, der []byte, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "CERTIFICATE",
Bytes: der,
})
}
func writeSSLKey(t *testing.T, key *rsa.PrivateKey, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
}
func createTestCA(t *testing.T, dir string) {
t.Logf("initializing test TLS credentials")
ca := &x509.Certificate{
SerialNumber: big.NewInt(2023),
Subject: pkix.Name{
Organization: []string{"Test"},
Country: []string{"XX"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
caDER, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, caDER, filepath.Join(dir, "ca.pem"))
for idx, name := range []string{"server", "client1", "client2"} {
cert := &x509.Certificate{
SerialNumber: big.NewInt(int64(idx + 1)),
Subject: pkix.Name{
CommonName: name,
},
DNSNames: []string{name},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
der, err := x509.CreateCertificate(rand.Reader, cert, ca, &priv.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, der, filepath.Join(dir, name+".pem"))
writeSSLKey(t, priv, filepath.Join(dir, name+".key"))
}
}
func loadCA(path string) (*x509.CertPool, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
if !cas.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("no certificates could be parsed in %s", path)
}
return cas, nil
}
func serverTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: ca,
}
}
func clientTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
RootCAs: ca,
}
}
func runTestServer(ctx context.Context, l net.Listener, memlogDir, sslDir string, acls replds.ACLMgr) error {
store, err := memlog.NewMemFileStore(memlogDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
cluster, err := replds.NewGRPCCluster(nil)
if err != nil {
return fmt.Errorf("NewGRPCCluster(): %w", err)
}
server := replds.NewServer(ctx, cluster, store, acls)
grpcSrv := grpc.NewServer(
grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(common.TLSAuthFunc)),
grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(common.TLSAuthFunc)),
grpc.Creds(credentials.NewTLS(serverTLSConfig(
filepath.Join(sslDir, "server.pem"),
filepath.Join(sslDir, "server.key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
pb.RegisterInternalReplServer(grpcSrv, server)
pb.RegisterRepldsServer(grpcSrv, server)
go func() {
<-ctx.Done()
grpcSrv.GracefulStop()
}()
return grpcSrv.Serve(l)
}
func runPuller(ctx context.Context, grpcURL, pullDir, triggersDir, sslDir, certName string) error {
triggers, err := watcher.LoadTriggersFromDir(triggersDir)
if err != nil {
return fmt.Errorf("LoadTriggersFromDir(): %w", err)
}
// Create a file-backed store for the local target path.
store, err := memlog.NewMemFileStore(pullDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial(): %w", err)
}
defer conn.Close()
prefix := "/certs"
w := watcher.New(conn, store, prefix, triggers)
w.Run(ctx)
return nil
}
func uploadSomeData(grpcURL, sslDir, certName string) error {
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial: %w", err)
}
defer conn.Close()
client := pb.NewRepldsClient(conn)
req := pb.StoreRequest{
Nodes: []*pb.Node{
&pb.Node{
Path: "/certs/somedata",
Data: []byte("some data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
&pb.Node{
Path: "/certs/somedata2",
Data: []byte("some more data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
},
}
_, err = client.Store(context.Background(), &req)
if err != nil {
return fmt.Errorf("client.Store: %w", err)
}
return nil
}
var errStop = errors.New("stop")
func TestIntegration_Pull(t *testing.T) {
dir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
sslDir := filepath.Join(dir, "ssl")
memlogDir := filepath.Join(dir, "memlog")
pullDir := filepath.Join(dir, "pull")
triggersDir := filepath.Join(dir, "triggers")
for _, s := range []string{sslDir, memlogDir, pullDir, triggersDir} {
os.Mkdir(s, 0700)
}
// Set up ACLs.
os.WriteFile(
filepath.Join(dir, "acls"),
[]byte(`
client1 /certs READ
client2 /certs WRITE
`),
0600,
)
// Set up a test trigger.
os.WriteFile(
filepath.Join(triggersDir, "test-trigger"),
[]byte(fmt.Sprintf(
`{"path": "/certs/somedata", "command": "touch '%s/trigger-stamp'"}`, dir)),
0700,
)
createTestCA(t, sslDir)
g, ctx := errgroup.WithContext(context.Background())
// Listener for the grpc server.
l, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatalf("listen(): %v", err)
}
grpcURL := l.Addr().String()
acls, err := acl.Load(filepath.Join(dir, "acls"))
if err != nil {
t.Fatalf("acl.Load(): %v", err)
}
g.Go(func() error {
t.Logf("starting GRPC server on %s", grpcURL)
defer t.Logf("GRPC server stopped")
return runTestServer(ctx, l, memlogDir, sslDir, acls)
})
g.Go(func() error {
t.Logf("starting puller")
defer t.Logf("puller stopped")
return runPuller(ctx, grpcURL, pullDir, triggersDir, sslDir, "client1")
})
g.Go(func() error {
t.Logf("uploading some data")
if err := uploadSomeData(grpcURL, sslDir, "client2"); err != nil {
return err
}
// Also check that client1 can't write due to ACL.
if err := uploadSomeData(grpcURL, sslDir, "client1"); err == nil {
return errors.New("ACL failure: client1 was able to write")
}
time.Sleep(2 * time.Second)
t.Logf("stopping everything")
return errStop
})
if err := g.Wait(); err != nil && !errors.Is(err, errStop) {
t.Fatal(err)
}
// Verify that data has been propagated by the puller.
if _, err := os.Stat(filepath.Join(pullDir, "certs/somedata")); err != nil {
out, _ := exec.Command("/bin/ls", "-lR", pullDir).CombinedOutput()
t.Logf("pullDir contents:\n%s", out)
t.Fatal("'somedata' does not exist in pullDir")
}
// Verify that the trigger has been called.
if _, err := os.Stat(filepath.Join(dir, "trigger-stamp")); err != nil {
t.Fatal("trigger did not run")
}
}
...@@ -28,7 +28,7 @@ type Store interface { ...@@ -28,7 +28,7 @@ type Store interface {
// ACLMgr is an ACL manager that can check access rules. // ACLMgr is an ACL manager that can check access rules.
type ACLMgr interface { type ACLMgr interface {
Check(string, acl.ACLOp, string) bool Check(string, acl.Op, string) bool
} }
// ClusterMgr holds state about the service cluster network layout. // ClusterMgr holds state about the service cluster network layout.
......
...@@ -52,7 +52,7 @@ func TestLog(t *testing.T) { ...@@ -52,7 +52,7 @@ func TestLog(t *testing.T) {
for i := 0; i < numEntries; i++ { for i := 0; i < numEntries; i++ {
node := &pb.Node{ node := &pb.Node{
Path: fmt.Sprintf("/path/to/%08d", (i % maxSz)), Path: fmt.Sprintf("/path/to/%08d", (i % maxSz)),
Version: int64(time.Now().Unix()), Version: time.Now().Unix(),
Data: []byte("just some random data, blah blah, blah blah!"), Data: []byte("just some random data, blah blah, blah blah!"),
} }
if err := l.Write(node); err != nil { if err := l.Write(node); err != nil {
...@@ -95,7 +95,7 @@ func TestLog(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestLog(t *testing.T) {
log.Printf(">>> writing new data, closing and re-opening log...") log.Printf(">>> writing new data, closing and re-opening log...")
err = l.Write(&pb.Node{ err = l.Write(&pb.Node{
Path: testPath, Path: testPath,
Version: int64(time.Now().Unix()), Version: time.Now().Unix(),
Data: []byte("new data"), Data: []byte("new data"),
}) })
if err != nil { if err != nil {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
...@@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) { ...@@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) {
} }
type scriptTrigger struct { type scriptTrigger struct {
name string
Path string `json:"path"` Path string `json:"path"`
Command string `json:"command"` Command string `json:"command"`
} }
func (t *scriptTrigger) Run(nodes []*pb.Node) error { func (t *scriptTrigger) Run(nodes []*pb.Node) error {
return nil // Build an environment for the script.
changedEnvStr := "REPLDS_CHANGES="
for _, node := range nodes {
if !node.Deleted {
changedEnvStr += node.Path
}
}
env := os.Environ()
env = append(env, "REPLDS=1")
env = append(env, changedEnvStr)
// Run the command using the shell.
cmd := exec.Command("/bin/sh", "-c", t.Command)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stdout
cmd.Env = env
log.Printf("executing trigger '%s'", t.name)
return cmd.Run()
} }
func LoadTriggersFromDir(dir string) (TriggerManager, error) { func LoadTriggersFromDir(dir string) (TriggerManager, error) {
...@@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) { ...@@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) {
continue continue
} }
var trig scriptTrigger var trig scriptTrigger
trig.name = f.Name()
if err := json.Unmarshal(data, &trig); err != nil { if err := json.Unmarshal(data, &trig); err != nil {
log.Printf("invalid JSON in %s: %v", f.Name(), err) log.Printf("invalid JSON in %s: %v", f.Name(), err)
continue continue
......
...@@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) { ...@@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) {
for { for {
resp, err := stream.Recv() resp, err := stream.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
break break
} }
if err != nil { if err != nil {
......