diff --git a/acl/acl.go b/acl/acl.go index f8628d216179516a8ef6cd492e593aa653b48928..b925b905899d2878fb8a2e7b0e1ac561ddcc1f8e 100644 --- a/acl/acl.go +++ b/acl/acl.go @@ -10,7 +10,7 @@ import ( "git.autistici.org/ai3/tools/replds2/common" ) -type ACLOp int +type Op int const ( OpRead = iota @@ -18,7 +18,7 @@ const ( OpPeer ) -func (op ACLOp) String() string { +func (op Op) String() string { switch op { case OpRead: return "READ" @@ -34,12 +34,12 @@ type NullManager struct{} func NewNullManager() *NullManager { return new(NullManager) } -func (m *NullManager) Check(identity string, op ACLOp, path string) bool { return true } +func (m *NullManager) Check(identity string, op Op, path string) bool { return true } type Entry struct { Identity string Path string - Op ACLOp + Op Op } type Manager struct { @@ -54,7 +54,7 @@ func NewManager(acls []Entry) *Manager { return m } -func (m *Manager) Check(identity string, op ACLOp, path string) bool { +func (m *Manager) Check(identity string, op Op, path string) bool { e := Entry{Identity: identity, Op: op} for _, p := range common.PathPrefixes(path) { e.Path = p @@ -141,7 +141,7 @@ func parseEntry(line string) (e Entry, err error) { return } -func parseOp(s string) (op ACLOp, err error) { +func parseOp(s string) (op Op, err error) { switch strings.ToLower(s) { case "read", "r": op = OpRead diff --git a/acl/acl_test.go b/acl/acl_test.go index 316a60cac2b5aa69da30495428845bbd163418bd..3944abc07e4d60e8155b3a85b8293a9e2fe77e39 100644 --- a/acl/acl_test.go +++ b/acl/acl_test.go @@ -12,7 +12,7 @@ func TestACL(t *testing.T) { testdata := []struct { identity, path string - op ACLOp + op Op expectedOk bool }{ {"user", "/home/user/foo", OpWrite, true}, diff --git a/integrationtest/doc.go b/integrationtest/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..a74af46ae0ec65fad4de9167ce148ff871db5bf5 --- /dev/null +++ b/integrationtest/doc.go @@ -0,0 +1 @@ +package integrationtest diff --git a/integrationtest/integration_test.go b/integrationtest/integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f2a917ea0f12d8d674e63ef3f31fd297728ab6e8 --- /dev/null +++ b/integrationtest/integration_test.go @@ -0,0 +1,366 @@ +package integrationtest + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + replds "git.autistici.org/ai3/tools/replds2" + "git.autistici.org/ai3/tools/replds2/acl" + "git.autistici.org/ai3/tools/replds2/common" + pb "git.autistici.org/ai3/tools/replds2/proto" + "git.autistici.org/ai3/tools/replds2/store/memlog" + "git.autistici.org/ai3/tools/replds2/watcher" + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func writeSSLCert(t *testing.T, der []byte, path string) { + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + pem.Encode(f, &pem.Block{ + Type: "CERTIFICATE", + Bytes: der, + }) +} + +func writeSSLKey(t *testing.T, key *rsa.PrivateKey, path string) { + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + pem.Encode(f, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) +} + +func createTestCA(t *testing.T, dir string) { + t.Logf("initializing test TLS credentials") + + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2023), + Subject: pkix.Name{ + Organization: []string{"Test"}, + Country: []string{"XX"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + caDER, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey) + if err != nil { + t.Fatal(err) + } + + writeSSLCert(t, caDER, filepath.Join(dir, "ca.pem")) + + for idx, name := range []string{"server", "client1", "client2"} { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(int64(idx + 1)), + Subject: pkix.Name{ + CommonName: name, + }, + DNSNames: []string{name}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + der, err := x509.CreateCertificate(rand.Reader, cert, ca, &priv.PublicKey, caKey) + if err != nil { + t.Fatal(err) + } + + writeSSLCert(t, der, filepath.Join(dir, name+".pem")) + writeSSLKey(t, priv, filepath.Join(dir, name+".key")) + } +} + +func loadCA(path string) (*x509.CertPool, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + cas := x509.NewCertPool() + if !cas.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("no certificates could be parsed in %s", path) + } + return cas, nil +} + +func serverTLSConfig(sslCert, sslKey, sslCA string) *tls.Config { + cert, err := tls.LoadX509KeyPair(sslCert, sslKey) + if err != nil { + panic(fmt.Sprintf("load x509: %s: %v", sslCert, err)) + } + ca, err := loadCA(sslCA) + if err != nil { + panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err)) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: ca, + } +} + +func clientTLSConfig(sslCert, sslKey, sslCA string) *tls.Config { + cert, err := tls.LoadX509KeyPair(sslCert, sslKey) + if err != nil { + panic(fmt.Sprintf("load x509: %s: %v", sslCert, err)) + } + ca, err := loadCA(sslCA) + if err != nil { + panic(fmt.Sprintf("load x509 ca: %s: %v", sslCA, err)) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + RootCAs: ca, + } +} + +func runTestServer(ctx context.Context, l net.Listener, memlogDir, sslDir string, acls replds.ACLMgr) error { + store, err := memlog.NewMemFileStore(memlogDir) + if err != nil { + return fmt.Errorf("NewMemFileStore(): %w", err) + } + defer store.Close() + + cluster, err := replds.NewGRPCCluster(nil) + if err != nil { + return fmt.Errorf("NewGRPCCluster(): %w", err) + } + + server := replds.NewServer(ctx, cluster, store, acls) + + grpcSrv := grpc.NewServer( + grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(common.TLSAuthFunc)), + grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(common.TLSAuthFunc)), + grpc.Creds(credentials.NewTLS(serverTLSConfig( + filepath.Join(sslDir, "server.pem"), + filepath.Join(sslDir, "server.key"), + filepath.Join(sslDir, "ca.pem"), + ))), + ) + pb.RegisterInternalReplServer(grpcSrv, server) + pb.RegisterRepldsServer(grpcSrv, server) + + go func() { + <-ctx.Done() + grpcSrv.GracefulStop() + }() + + return grpcSrv.Serve(l) +} + +func runPuller(ctx context.Context, grpcURL, pullDir, triggersDir, sslDir, certName string) error { + triggers, err := watcher.LoadTriggersFromDir(triggersDir) + if err != nil { + return fmt.Errorf("LoadTriggersFromDir(): %w", err) + } + + // Create a file-backed store for the local target path. + store, err := memlog.NewMemFileStore(pullDir) + if err != nil { + return fmt.Errorf("NewMemFileStore(): %w", err) + } + defer store.Close() + + conn, err := grpc.Dial( + grpcURL, + grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig( + filepath.Join(sslDir, certName+".pem"), + filepath.Join(sslDir, certName+".key"), + filepath.Join(sslDir, "ca.pem"), + ))), + ) + if err != nil { + return fmt.Errorf("grpc.Dial(): %w", err) + } + defer conn.Close() + + prefix := "/certs" + w := watcher.New(conn, store, prefix, triggers) + + w.Run(ctx) + + return nil +} + +func uploadSomeData(grpcURL, sslDir, certName string) error { + conn, err := grpc.Dial( + grpcURL, + grpc.WithTransportCredentials(credentials.NewTLS(clientTLSConfig( + filepath.Join(sslDir, certName+".pem"), + filepath.Join(sslDir, certName+".key"), + filepath.Join(sslDir, "ca.pem"), + ))), + ) + if err != nil { + return fmt.Errorf("grpc.Dial: %w", err) + } + defer conn.Close() + + client := pb.NewRepldsClient(conn) + + req := pb.StoreRequest{ + Nodes: []*pb.Node{ + &pb.Node{ + Path: "/certs/somedata", + Data: []byte("some data"), + Version: 1, + Timestamp: timestamppb.New(time.Now()), + }, + &pb.Node{ + Path: "/certs/somedata2", + Data: []byte("some more data"), + Version: 1, + Timestamp: timestamppb.New(time.Now()), + }, + }, + } + + _, err = client.Store(context.Background(), &req) + if err != nil { + return fmt.Errorf("client.Store: %w", err) + } + return nil +} + +var errStop = errors.New("stop") + +func TestIntegration_Pull(t *testing.T) { + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + sslDir := filepath.Join(dir, "ssl") + memlogDir := filepath.Join(dir, "memlog") + pullDir := filepath.Join(dir, "pull") + triggersDir := filepath.Join(dir, "triggers") + for _, s := range []string{sslDir, memlogDir, pullDir, triggersDir} { + os.Mkdir(s, 0700) + } + + // Set up ACLs. + os.WriteFile( + filepath.Join(dir, "acls"), + []byte(` +client1 /certs READ +client2 /certs WRITE +`), + 0600, + ) + + // Set up a test trigger. + os.WriteFile( + filepath.Join(triggersDir, "test-trigger"), + []byte(fmt.Sprintf( + `{"path": "/certs/somedata", "command": "touch '%s/trigger-stamp'"}`, dir)), + 0700, + ) + + createTestCA(t, sslDir) + + g, ctx := errgroup.WithContext(context.Background()) + + // Listener for the grpc server. + l, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + t.Fatalf("listen(): %v", err) + } + grpcURL := l.Addr().String() + + acls, err := acl.Load(filepath.Join(dir, "acls")) + if err != nil { + t.Fatalf("acl.Load(): %v", err) + } + + g.Go(func() error { + t.Logf("starting GRPC server on %s", grpcURL) + defer t.Logf("GRPC server stopped") + return runTestServer(ctx, l, memlogDir, sslDir, acls) + }) + + g.Go(func() error { + t.Logf("starting puller") + defer t.Logf("puller stopped") + return runPuller(ctx, grpcURL, pullDir, triggersDir, sslDir, "client1") + }) + + g.Go(func() error { + t.Logf("uploading some data") + if err := uploadSomeData(grpcURL, sslDir, "client2"); err != nil { + return err + } + + // Also check that client1 can't write due to ACL. + if err := uploadSomeData(grpcURL, sslDir, "client1"); err == nil { + return errors.New("ACL failure: client1 was able to write") + } + + time.Sleep(2 * time.Second) + t.Logf("stopping everything") + return errStop + }) + + if err := g.Wait(); err != nil && !errors.Is(err, errStop) { + t.Fatal(err) + } + + // Verify that data has been propagated by the puller. + if _, err := os.Stat(filepath.Join(pullDir, "certs/somedata")); err != nil { + out, _ := exec.Command("/bin/ls", "-lR", pullDir).CombinedOutput() + t.Logf("pullDir contents:\n%s", out) + t.Fatal("'somedata' does not exist in pullDir") + } + + // Verify that the trigger has been called. + if _, err := os.Stat(filepath.Join(dir, "trigger-stamp")); err != nil { + t.Fatal("trigger did not run") + } +} diff --git a/server.go b/server.go index 5bc912d793248e23a04848005d040f55e5e4bfaf..b096b571c14721f779b1eb19ef888cef35341474 100644 --- a/server.go +++ b/server.go @@ -28,7 +28,7 @@ type Store interface { // ACLMgr is an ACL manager that can check access rules. type ACLMgr interface { - Check(string, acl.ACLOp, string) bool + Check(string, acl.Op, string) bool } // ClusterMgr holds state about the service cluster network layout. diff --git a/watcher/triggers.go b/watcher/triggers.go index ae54875989e3fe6700b738dd24f0288b712c1aac..f100a6ca6d2c060deec7f94afcd961ca7545202e 100644 --- a/watcher/triggers.go +++ b/watcher/triggers.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "log" "os" + "os/exec" "path/filepath" "strings" @@ -37,12 +38,34 @@ func (m scriptTriggerManager) Notify(b *common.NotifyBatch) { } type scriptTrigger struct { + name string + Path string `json:"path"` Command string `json:"command"` } func (t *scriptTrigger) Run(nodes []*pb.Node) error { - return nil + // Build an environment for the script. + changedEnvStr := "REPLDS_CHANGES=" + for _, node := range nodes { + if !node.Deleted { + changedEnvStr += node.Path + } + } + + env := os.Environ() + env = append(env, "REPLDS=1") + env = append(env, changedEnvStr) + + // Run the command using the shell. + cmd := exec.Command("/bin/sh", "-c", t.Command) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stdout + cmd.Env = env + + log.Printf("executing trigger '%s'", t.name) + + return cmd.Run() } func LoadTriggersFromDir(dir string) (TriggerManager, error) { @@ -70,6 +93,7 @@ func LoadTriggersFromDir(dir string) (TriggerManager, error) { continue } var trig scriptTrigger + trig.name = f.Name() if err := json.Unmarshal(data, &trig); err != nil { log.Printf("invalid JSON in %s: %v", f.Name(), err) continue diff --git a/watcher/watcher.go b/watcher/watcher.go index 9cf5cba1fa4f88b85629eef680e216dda5c0a37e..3ca7299f7bec7e6c4914f7e130ab4a5bc4f82641 100644 --- a/watcher/watcher.go +++ b/watcher/watcher.go @@ -82,7 +82,7 @@ func (w *Watcher) Run(ctx context.Context) { for { resp, err := stream.Recv() - if errors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled { break } if err != nil {