Skip to content
Snippets Groups Projects
Commit b4f12e2a authored by ale's avatar ale
Browse files

Implement triggers, add an integration test

The integration test checks the server + pull combination, with all
features including TLS ACLs.
parent 223638d2
No related branches found
No related tags found
No related merge requests found
Pipeline #53863 passed
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"git.autistici.org/ai3/tools/replds2/common" "git.autistici.org/ai3/tools/replds2/common"
) )
type ACLOp int type Op int
const ( const (
OpRead = iota OpRead = iota
...@@ -18,7 +18,7 @@ const ( ...@@ -18,7 +18,7 @@ const (
OpPeer OpPeer
) )
func (op ACLOp) String() string { func (op Op) String() string {
switch op { switch op {
case OpRead: case OpRead:
return "READ" return "READ"
...@@ -34,12 +34,12 @@ type NullManager struct{} ...@@ -34,12 +34,12 @@ type NullManager struct{}
func NewNullManager() *NullManager { return new(NullManager) } func NewNullManager() *NullManager { return new(NullManager) }
func (m *NullManager) Check(identity string, op ACLOp, path string) bool { return true } func (m *NullManager) Check(identity string, op Op, path string) bool { return true }
type Entry struct { type Entry struct {
Identity string Identity string
Path string Path string
Op ACLOp Op Op
} }
type Manager struct { type Manager struct {
...@@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager { ...@@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager {
return m return m
} }
func (m *Manager) Check(identity string, op ACLOp, path string) bool { func (m *Manager) Check(identity string, op Op, path string) bool {
e := Entry{Identity: identity, Op: op} e := Entry{Identity: identity, Op: op}
for _, p := range common.PathPrefixes(path) { for _, p := range common.PathPrefixes(path) {
e.Path = p e.Path = p
...@@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) { ...@@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) {
return return
} }
func parseOp(s string) (op ACLOp, err error) { func parseOp(s string) (op Op, err error) {
switch strings.ToLower(s) { switch strings.ToLower(s) {
case "read", "r": case "read", "r":
op = OpRead op = OpRead
......
...@@ -12,7 +12,7 @@ func TestACL(t *testing.T) { ...@@ -12,7 +12,7 @@ func TestACL(t *testing.T) {
testdata := []struct { testdata := []struct {
identity, path string identity, path string
op ACLOp op Op
expectedOk bool expectedOk bool
}{ }{
{"user", "/home/user/foo", OpWrite, true}, {"user", "/home/user/foo", OpWrite, true},
......
package integrationtest
package integrationtest
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
replds "git.autistici.org/ai3/tools/replds2"
"git.autistici.org/ai3/tools/replds2/acl"
"git.autistici.org/ai3/tools/replds2/common"
pb "git.autistici.org/ai3/tools/replds2/proto"
"git.autistici.org/ai3/tools/replds2/store/memlog"
"git.autistici.org/ai3/tools/replds2/watcher"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/types/known/timestamppb"
)
func writeSSLCert(t *testing.T, der []byte, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "CERTIFICATE",
Bytes: der,
})
}
func writeSSLKey(t *testing.T, key *rsa.PrivateKey, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
}
func createTestCA(t *testing.T, dir string) {
t.Logf("initializing test TLS credentials")
ca := &x509.Certificate{
SerialNumber: big.NewInt(2023),
Subject: pkix.Name{
Organization: []string{"Test"},
Country: []string{"XX"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
caDER, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, caDER, filepath.Join(dir, "ca.pem"))
for idx, name := range []string{"server", "client1", "client2"} {
cert := &x509.Certificate{
SerialNumber: big.NewInt(int64(idx + 1)),
Subject: pkix.Name{
CommonName: name,
},
DNSNames: []string{name},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
der, err := x509.CreateCertificate(rand.Reader, cert, ca, &priv.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, der, filepath.Join(dir, name+".pem"))
writeSSLKey(t, priv, filepath.Join(dir, name+".key"))
}
}
func loadCA(path string) (*x509.CertPool, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
if !cas.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("no certificates could be parsed in %s", path)
}
return cas, nil
}
func serverTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: ca,
}
}
func clientTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
RootCAs: ca,
}
}
func runTestServer(ctx context.Context, l net.Listener, memlogDir, sslDir string, acls replds.ACLMgr) error {
store, err := memlog.NewMemFileStore(memlogDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
cluster, err := replds.NewGRPCCluster(nil)
if err != nil {
return fmt.Errorf("NewGRPCCluster(): %w", err)
}
server := replds.NewServer(ctx, cluster, store, acls)
grpcSrv := grpc.NewServer(
grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(common.TLSAuthFunc)),
grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(common.TLSAuthFunc)),
grpc.Creds(credentials.NewTLS(serverTLSConfig(
filepath.Join(sslDir, "server.pem"),
filepath.Join(sslDir, "server.key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
pb.RegisterInternalReplServer(grpcSrv, server)
pb.RegisterRepldsServer(grpcSrv, server)
go func() {
<-ctx.Done()
grpcSrv.GracefulStop()
}()
return grpcSrv.Serve(l)
}
func runPuller(ctx context.Context, grpcURL, pullDir, triggersDir, sslDir, certName string) error {
triggers, err := watcher.LoadTriggersFromDir(triggersDir)
if err != nil {
return fmt.Errorf("LoadTriggersFromDir(): %w", err)
}
// Create a file-backed store for the local target path.
store, err := memlog.NewMemFileStore(pullDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial(): %w", err)
}
defer conn.Close()
prefix := "/certs"
w := watcher.New(conn, store, prefix, triggers)
w.Run(ctx)
return nil
}
func uploadSomeData(grpcURL, sslDir, certName string) error {
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial: %w", err)
}
defer conn.Close()
client := pb.NewRepldsClient(conn)
req := pb.StoreRequest{
Nodes: []*pb.Node{
&pb.Node{
Path: "/certs/somedata",
Data: []byte("some data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
&pb.Node{
Path: "/certs/somedata2",
Data: []byte("some more data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
},
}
_, err = client.Store(context.Background(), &req)
if err != nil {
return fmt.Errorf("client.Store: %w", err)
}
return nil
}
var errStop = errors.New("stop")
func TestIntegration_Pull(t *testing.T) {
dir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
sslDir := filepath.Join(dir, "ssl")
memlogDir := filepath.Join(dir, "memlog")
pullDir := filepath.Join(dir, "pull")
triggersDir := filepath.Join(dir, "triggers")
for _, s := range []string{sslDir, memlogDir, pullDir, triggersDir} {
os.Mkdir(s, 0700)
}
// Set up ACLs.
os.WriteFile(
filepath.Join(dir, "acls"),
[]byte(`
client1 /certs READ
client2 /certs WRITE
`),
0600,
)
// Set up a test trigger.
os.WriteFile(
filepath.Join(triggersDir, "test-trigger"),
[]byte(fmt.Sprintf(
`{"path": "/certs/somedata", "command": "touch '%s/trigger-stamp'"}`, dir)),
0700,
)
createTestCA(t, sslDir)
g, ctx := errgroup.WithContext(context.Background())
// Listener for the grpc server.
l, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatalf("listen(): %v", err)
}
grpcURL := l.Addr().String()
acls, err := acl.Load(filepath.Join(dir, "acls"))
if err != nil {
t.Fatalf("acl.Load(): %v", err)
}
g.Go(func() error {
t.Logf("starting GRPC server on %s", grpcURL)
defer t.Logf("GRPC server stopped")
return runTestServer(ctx, l, memlogDir, sslDir, acls)
})
g.Go(func() error {
t.Logf("starting puller")
defer t.Logf("puller stopped")
return runPuller(ctx, grpcURL, pullDir, triggersDir, sslDir, "client1")
})
g.Go(func() error {
t.Logf("uploading some data")
if err := uploadSomeData(grpcURL, sslDir, "client2"); err != nil {
return err
}
// Also check that client1 can't write due to ACL.
if err := uploadSomeData(grpcURL, sslDir, "client1"); err == nil {
return errors.New("ACL failure: client1 was able to write")
}
time.Sleep(2 * time.Second)
t.Logf("stopping everything")
return errStop
})
if err := g.Wait(); err != nil && !errors.Is(err, errStop) {
t.Fatal(err)
}
// Verify that data has been propagated by the puller.
if _, err := os.Stat(filepath.Join(pullDir, "certs/somedata")); err != nil {
out, _ := exec.Command("/bin/ls", "-lR", pullDir).CombinedOutput()
t.Logf("pullDir contents:\n%s", out)
t.Fatal("'somedata' does not exist in pullDir")
}
// Verify that the trigger has been called.
if _, err := os.Stat(filepath.Join(dir, "trigger-stamp")); err != nil {
t.Fatal("trigger did not run")
}
}
...@@ -28,7 +28,7 @@ type Store interface { ...@@ -28,7 +28,7 @@ type Store interface {
// ACLMgr is an ACL manager that can check access rules. // ACLMgr is an ACL manager that can check access rules.
type ACLMgr interface { type ACLMgr interface {
Check(string, acl.ACLOp, string) bool Check(string, acl.Op, string) bool
} }
// ClusterMgr holds state about the service cluster network layout. // ClusterMgr holds state about the service cluster network layout.
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
...@@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) { ...@@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) {
} }
type scriptTrigger struct { type scriptTrigger struct {
name string
Path string `json:"path"` Path string `json:"path"`
Command string `json:"command"` Command string `json:"command"`
} }
func (t *scriptTrigger) Run(nodes []*pb.Node) error { func (t *scriptTrigger) Run(nodes []*pb.Node) error {
return nil // Build an environment for the script.
changedEnvStr := "REPLDS_CHANGES="
for _, node := range nodes {
if !node.Deleted {
changedEnvStr += node.Path
}
}
env := os.Environ()
env = append(env, "REPLDS=1")
env = append(env, changedEnvStr)
// Run the command using the shell.
cmd := exec.Command("/bin/sh", "-c", t.Command)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stdout
cmd.Env = env
log.Printf("executing trigger '%s'", t.name)
return cmd.Run()
} }
func LoadTriggersFromDir(dir string) (TriggerManager, error) { func LoadTriggersFromDir(dir string) (TriggerManager, error) {
...@@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) { ...@@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) {
continue continue
} }
var trig scriptTrigger var trig scriptTrigger
trig.name = f.Name()
if err := json.Unmarshal(data, &trig); err != nil { if err := json.Unmarshal(data, &trig); err != nil {
log.Printf("invalid JSON in %s: %v", f.Name(), err) log.Printf("invalid JSON in %s: %v", f.Name(), err)
continue continue
......
...@@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) { ...@@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) {
for { for {
resp, err := stream.Recv() resp, err := stream.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
break break
} }
if err != nil { if err != nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment