diff --git a/dovecot/dict.go b/dovecot/dict.go index 320b9e316d7929966f2321cac49d7d4605b64ee7..4b6a508f6a61fcb5f0e69354e198942d232994bf 100644 --- a/dovecot/dict.go +++ b/dovecot/dict.go @@ -72,6 +72,12 @@ func (p *DictProxyServer) handleHello(ctx context.Context, lw unix.LineResponseW return fmt.Errorf("unsupported protocol version %d", majorVersion) } + // In version 3, HELLO expects a response. + if majorVersion == 3 { + s := fmt.Sprintf("O%d\t0", majorVersion) + return lw.WriteLineCRLF([]byte(s)) + } + return nil } diff --git a/dovecot/dict_test.go b/dovecot/dict_test.go index e9800ae4f6344f8774c2136c4890474d16f73f94..8f03d12007ebb15c0fc2745e8470a0e4715d1fd8 100644 --- a/dovecot/dict_test.go +++ b/dovecot/dict_test.go @@ -1,7 +1,14 @@ package dovecot import ( + "context" + "log" + "net/textproto" + "os" + "path/filepath" "testing" + + "git.autistici.org/ai3/go-common/unix" ) func TestUnescape(t *testing.T) { @@ -21,3 +28,97 @@ func TestUnescape(t *testing.T) { } } } + +type testDictDatabase struct { + data map[string]interface{} +} + +func (d *testDictDatabase) Lookup(_ context.Context, key string) (interface{}, bool, error) { + value, ok := d.data[key] + return value, ok, nil +} + +func newTestDictDatabase() *testDictDatabase { + return &testDictDatabase{ + data: map[string]interface{}{ + "pass/foo": struct { + Password string `json:"password"` + }{ + Password: "bar", + }, + }, + } +} + +func newTestSocketServer(t *testing.T) (string, func()) { + t.Helper() + + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } + + socketPath := filepath.Join(dir, "socket") + + lsrv := unix.NewLineServer( + NewDictProxyServer(newTestDictDatabase())) + srv, err := unix.NewUNIXSocketServer(socketPath, lsrv) + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error, 1) + go func() { + errCh <- srv.Serve() + }() + + return socketPath, func() { + srv.Close() + os.RemoveAll(dir) + if err := <-errCh; err != nil { + log.Printf("unix socket server error: %v", err) + } + } +} + +func makeTestRequest(t *testing.T, socketPath string, version int, lookupKey string, expectedStatus byte) { + conn, err := textproto.Dial("unix", socketPath) + if err != nil { + t.Fatalf("dialing socket: %v", err) + } + defer conn.Close() + + if err := conn.PrintfLine("H%d\t0\t1\tuser\tpass", version); err != nil { + t.Fatalf("error writing HELLO: %v", err) + } + + if version > 2 { + resp, err := conn.ReadLine() + if err != nil { + t.Fatalf("error reading HELLO response: %v", err) + } + if resp[0] != 'O' { + t.Fatalf("request returned unexpected status: %s", resp) + } + } + + if err := conn.PrintfLine("L%s", lookupKey); err != nil { + t.Fatalf("error writing LOOKUP: %v", err) + } + resp, err := conn.ReadLine() + if err != nil { + t.Fatalf("error reading HELLO response: %v", err) + } + if resp[0] != expectedStatus { + t.Fatalf("request returned unexpected status: %s, expected %c", resp, expectedStatus) + } +} + +func TestLookup(t *testing.T) { + socketPath, cleanup := newTestSocketServer(t) + defer cleanup() + + makeTestRequest(t, socketPath, 2, "pass/foo", 'O') + makeTestRequest(t, socketPath, 3, "pass/foo", 'O') + makeTestRequest(t, socketPath, 3, "unknown", 'N') +}