Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/replds2
1 result
Show changes
Commits on Source (3)
......@@ -10,7 +10,7 @@ import (
"git.autistici.org/ai3/tools/replds2/common"
)
type ACLOp int
type Op int
const (
OpRead = iota
......@@ -18,7 +18,7 @@ const (
OpPeer
)
func (op ACLOp) String() string {
func (op Op) String() string {
switch op {
case OpRead:
return "READ"
......@@ -34,12 +34,12 @@ type NullManager struct{}
func NewNullManager() *NullManager { return new(NullManager) }
func (m *NullManager) Check(identity string, op ACLOp, path string) bool { return true }
func (m *NullManager) Check(identity string, op Op, path string) bool { return true }
type Entry struct {
Identity string
Path string
Op ACLOp
Op Op
}
type Manager struct {
......@@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager {
return m
}
func (m *Manager) Check(identity string, op ACLOp, path string) bool {
func (m *Manager) Check(identity string, op Op, path string) bool {
e := Entry{Identity: identity, Op: op}
for _, p := range common.PathPrefixes(path) {
e.Path = p
......@@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) {
return
}
func parseOp(s string) (op ACLOp, err error) {
func parseOp(s string) (op Op, err error) {
switch strings.ToLower(s) {
case "read", "r":
op = OpRead
......
......@@ -12,7 +12,7 @@ func TestACL(t *testing.T) {
testdata := []struct {
identity, path string
op ACLOp
op Op
expectedOk bool
}{
{"user", "/home/user/foo", OpWrite, true},
......
......@@ -19,12 +19,10 @@ import (
)
type storeCommand struct {
sslCert string
sslKey string
sslCA string
storePath string
triggersPath string
serverAddr string
sslCert string
sslKey string
sslCA string
serverAddr string
}
func init() {
......
module git.autistici.org/ai3/tools/replds2
go 1.15
go 1.19
require (
github.com/DataDog/zstd v1.5.3-0.20220606203749-fd035e54e312
......@@ -11,9 +11,20 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/prometheus/client_golang v1.11.0
github.com/prometheus/common v0.31.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
golang.org/x/sync v0.1.0
google.golang.org/grpc v1.53.0
google.golang.org/protobuf v1.29.1
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.31.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
golang.org/x/net v0.5.0 // indirect
golang.org/x/sys v0.4.0 // indirect
golang.org/x/text v0.6.0 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
)
This diff is collapsed.
package integrationtest
package integrationtest
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"path/filepath"
"testing"
"time"
replds "git.autistici.org/ai3/tools/replds2"
"git.autistici.org/ai3/tools/replds2/acl"
"git.autistici.org/ai3/tools/replds2/common"
pb "git.autistici.org/ai3/tools/replds2/proto"
"git.autistici.org/ai3/tools/replds2/store/memlog"
"git.autistici.org/ai3/tools/replds2/watcher"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/types/known/timestamppb"
)
func writeSSLCert(t *testing.T, der []byte, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "CERTIFICATE",
Bytes: der,
})
}
func writeSSLKey(t *testing.T, key *rsa.PrivateKey, path string) {
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
pem.Encode(f, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
}
func createTestCA(t *testing.T, dir string) {
t.Logf("initializing test TLS credentials")
ca := &x509.Certificate{
SerialNumber: big.NewInt(2023),
Subject: pkix.Name{
Organization: []string{"Test"},
Country: []string{"XX"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
caDER, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, caDER, filepath.Join(dir, "ca.pem"))
for idx, name := range []string{"server", "client1", "client2"} {
cert := &x509.Certificate{
SerialNumber: big.NewInt(int64(idx + 1)),
Subject: pkix.Name{
CommonName: name,
},
DNSNames: []string{name},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
der, err := x509.CreateCertificate(rand.Reader, cert, ca, &priv.PublicKey, caKey)
if err != nil {
t.Fatal(err)
}
writeSSLCert(t, der, filepath.Join(dir, name+".pem"))
writeSSLKey(t, priv, filepath.Join(dir, name+".key"))
}
}
func loadCA(path string) (*x509.CertPool, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
if !cas.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("no certificates could be parsed in %s", path)
}
return cas, nil
}
func serverTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: ca,
}
}
func clientTLSConfig(sslCert, sslKey, sslCA string) *tls.Config {
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
if err != nil {
panic(fmt.Sprintf("load x509: %s: %v", sslCert, err))
}
ca, err := loadCA(sslCA)
if err != nil {
panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err))
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
RootCAs: ca,
}
}
func runTestServer(ctx context.Context, l net.Listener, memlogDir, sslDir string, acls replds.ACLMgr) error {
store, err := memlog.NewMemFileStore(memlogDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
cluster, err := replds.NewGRPCCluster(nil)
if err != nil {
return fmt.Errorf("NewGRPCCluster(): %w", err)
}
server := replds.NewServer(ctx, cluster, store, acls)
grpcSrv := grpc.NewServer(
grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(common.TLSAuthFunc)),
grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(common.TLSAuthFunc)),
grpc.Creds(credentials.NewTLS(serverTLSConfig(
filepath.Join(sslDir, "server.pem"),
filepath.Join(sslDir, "server.key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
pb.RegisterInternalReplServer(grpcSrv, server)
pb.RegisterRepldsServer(grpcSrv, server)
go func() {
<-ctx.Done()
grpcSrv.GracefulStop()
}()
return grpcSrv.Serve(l)
}
func runPuller(ctx context.Context, grpcURL, pullDir, triggersDir, sslDir, certName string) error {
triggers, err := watcher.LoadTriggersFromDir(triggersDir)
if err != nil {
return fmt.Errorf("LoadTriggersFromDir(): %w", err)
}
// Create a file-backed store for the local target path.
store, err := memlog.NewMemFileStore(pullDir)
if err != nil {
return fmt.Errorf("NewMemFileStore(): %w", err)
}
defer store.Close()
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial(): %w", err)
}
defer conn.Close()
prefix := "/certs"
w := watcher.New(conn, store, prefix, triggers)
w.Run(ctx)
return nil
}
func uploadSomeData(grpcURL, sslDir, certName string) error {
conn, err := grpc.Dial(
grpcURL,
grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig(
filepath.Join(sslDir, certName+".pem"),
filepath.Join(sslDir, certName+".key"),
filepath.Join(sslDir, "ca.pem"),
))),
)
if err != nil {
return fmt.Errorf("grpc.Dial: %w", err)
}
defer conn.Close()
client := pb.NewRepldsClient(conn)
req := pb.StoreRequest{
Nodes: []*pb.Node{
&pb.Node{
Path: "/certs/somedata",
Data: []byte("some data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
&pb.Node{
Path: "/certs/somedata2",
Data: []byte("some more data"),
Version: 1,
Timestamp: timestamppb.New(time.Now()),
},
},
}
_, err = client.Store(context.Background(), &req)
if err != nil {
return fmt.Errorf("client.Store: %w", err)
}
return nil
}
var errStop = errors.New("stop")
func TestIntegration_Pull(t *testing.T) {
dir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
sslDir := filepath.Join(dir, "ssl")
memlogDir := filepath.Join(dir, "memlog")
pullDir := filepath.Join(dir, "pull")
triggersDir := filepath.Join(dir, "triggers")
for _, s := range []string{sslDir, memlogDir, pullDir, triggersDir} {
os.Mkdir(s, 0700)
}
// Set up ACLs.
os.WriteFile(
filepath.Join(dir, "acls"),
[]byte(`
client1 /certs READ
client2 /certs WRITE
`),
0600,
)
// Set up a test trigger.
os.WriteFile(
filepath.Join(triggersDir, "test-trigger"),
[]byte(fmt.Sprintf(
`{"path": "/certs/somedata", "command": "touch '%s/trigger-stamp'"}`, dir)),
0700,
)
createTestCA(t, sslDir)
g, ctx := errgroup.WithContext(context.Background())
// Listener for the grpc server.
l, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatalf("listen(): %v", err)
}
grpcURL := l.Addr().String()
acls, err := acl.Load(filepath.Join(dir, "acls"))
if err != nil {
t.Fatalf("acl.Load(): %v", err)
}
g.Go(func() error {
t.Logf("starting GRPC server on %s", grpcURL)
defer t.Logf("GRPC server stopped")
return runTestServer(ctx, l, memlogDir, sslDir, acls)
})
g.Go(func() error {
t.Logf("starting puller")
defer t.Logf("puller stopped")
return runPuller(ctx, grpcURL, pullDir, triggersDir, sslDir, "client1")
})
g.Go(func() error {
t.Logf("uploading some data")
if err := uploadSomeData(grpcURL, sslDir, "client2"); err != nil {
return err
}
// Also check that client1 can't write due to ACL.
if err := uploadSomeData(grpcURL, sslDir, "client1"); err == nil {
return errors.New("ACL failure: client1 was able to write")
}
time.Sleep(2 * time.Second)
t.Logf("stopping everything")
return errStop
})
if err := g.Wait(); err != nil && !errors.Is(err, errStop) {
t.Fatal(err)
}
// Verify that data has been propagated by the puller.
if _, err := os.Stat(filepath.Join(pullDir, "certs/somedata")); err != nil {
out, _ := exec.Command("/bin/ls", "-lR", pullDir).CombinedOutput()
t.Logf("pullDir contents:\n%s", out)
t.Fatal("'somedata' does not exist in pullDir")
}
// Verify that the trigger has been called.
if _, err := os.Stat(filepath.Join(dir, "trigger-stamp")); err != nil {
t.Fatal("trigger did not run")
}
}
......@@ -28,7 +28,7 @@ type Store interface {
// ACLMgr is an ACL manager that can check access rules.
type ACLMgr interface {
Check(string, acl.ACLOp, string) bool
Check(string, acl.Op, string) bool
}
// ClusterMgr holds state about the service cluster network layout.
......
......@@ -52,7 +52,7 @@ func TestLog(t *testing.T) {
for i := 0; i < numEntries; i++ {
node := &pb.Node{
Path: fmt.Sprintf("/path/to/%08d", (i % maxSz)),
Version: int64(time.Now().Unix()),
Version: time.Now().Unix(),
Data: []byte("just some random data, blah blah, blah blah!"),
}
if err := l.Write(node); err != nil {
......@@ -95,7 +95,7 @@ func TestLog(t *testing.T) {
log.Printf(">>> writing new data, closing and re-opening log...")
err = l.Write(&pb.Node{
Path: testPath,
Version: int64(time.Now().Unix()),
Version: time.Now().Unix(),
Data: []byte("new data"),
})
if err != nil {
......
......@@ -5,6 +5,7 @@ import (
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
......@@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) {
}
type scriptTrigger struct {
name string
Path string `json:"path"`
Command string `json:"command"`
}
func (t *scriptTrigger) Run(nodes []*pb.Node) error {
return nil
// Build an environment for the script.
changedEnvStr := "REPLDS_CHANGES="
for _, node := range nodes {
if !node.Deleted {
changedEnvStr += node.Path
}
}
env := os.Environ()
env = append(env, "REPLDS=1")
env = append(env, changedEnvStr)
// Run the command using the shell.
cmd := exec.Command("/bin/sh", "-c", t.Command)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stdout
cmd.Env = env
log.Printf("executing trigger '%s'", t.name)
return cmd.Run()
}
func LoadTriggersFromDir(dir string) (TriggerManager, error) {
......@@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) {
continue
}
var trig scriptTrigger
trig.name = f.Name()
if err := json.Unmarshal(data, &trig); err != nil {
log.Printf("invalid JSON in %s: %v", f.Name(), err)
continue
......
......@@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) {
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
break
}
if err != nil {
......