diff --git a/dovecot/dict.go b/dovecot/dict.go index 50835edd28283dac6ee062b25551095e98ddb6d1..4b6a508f6a61fcb5f0e69354e198942d232994bf 100644 --- a/dovecot/dict.go +++ b/dovecot/dict.go @@ -17,7 +17,10 @@ var ( noMatchResponse = []byte{'N', '\n'} ) -const supportedDictProtocolVersion = 2 +const ( + supportedDictProtocolVersionMin = 2 + supportedDictProtocolVersionMax = 3 +) // DictDatabase is an interface to a key/value store by way of the Lookup // method. @@ -29,7 +32,7 @@ type DictDatabase interface { } // DictProxyServer exposes a Database using the Dovecot dict proxy -// protocol (see https://wiki2.dovecot.org/AuthDatabase/Dict). +// protocol (see https://doc.dovecot.org/developer_manual/design/dict_protocol/). // // It implements the unix.LineHandler interface from the // ai3/go-common/unix package. @@ -59,21 +62,35 @@ func (p *DictProxyServer) ServeLine(ctx context.Context, lw unix.LineResponseWri } func (p *DictProxyServer) handleHello(ctx context.Context, lw unix.LineResponseWriter, arg []byte) error { - fields := bytes.Split(arg, []byte{'\t'}) + fields := splitFields(arg) if len(fields) < 1 { return errors.New("could not parse HELLO") } majorVersion, _ := strconv.Atoi(string(fields[0])) - if majorVersion != supportedDictProtocolVersion { + if majorVersion < supportedDictProtocolVersionMin || majorVersion > supportedDictProtocolVersionMax { return fmt.Errorf("unsupported protocol version %d", majorVersion) } + // In version 3, HELLO expects a response. + if majorVersion == 3 { + s := fmt.Sprintf("O%d\t0", majorVersion) + return lw.WriteLineCRLF([]byte(s)) + } + return nil } func (p *DictProxyServer) handleLookup(ctx context.Context, lw unix.LineResponseWriter, arg []byte) error { - obj, ok, err := p.db.Lookup(ctx, string(arg)) + // Support protocol versions 2 and 3 by looking for the \t + // field separator, which should not appear in the key in + // version 2 anyway. + fields := splitFields(arg) + if len(fields) < 1 { + return errors.New("could not parse LOOKUP") + } + + obj, ok, err := p.db.Lookup(ctx, string(fields[0])) if err != nil { log.Printf("error: %v", err) return lw.WriteLine(failResponse) @@ -91,3 +108,41 @@ func (p *DictProxyServer) handleLookup(ctx context.Context, lw unix.LineResponse //buf.Write([]byte{'\n'}) return lw.WriteLine(buf.Bytes()) } + +var dovecotEscapeChars = map[byte]byte{ + '0': 0, + '1': 1, + 't': '\t', + 'r': '\r', + 'l': '\n', +} + +var fieldSeparator = []byte{'\t'} + +func unescapeInplace(b []byte) []byte { + var esc bool + var j int + for i := 0; i < len(b); i++ { + c := b[i] + if esc { + if escC, ok := dovecotEscapeChars[c]; ok { + c = escC + } + esc = false + } else if c == '\001' { + esc = true + continue + } + b[j] = c + j++ + } + return b[:j] +} + +func splitFields(b []byte) [][]byte { + fields := bytes.Split(b, fieldSeparator) + for i := 0; i < len(fields); i++ { + fields[i] = unescapeInplace(fields[i]) + } + return fields +} diff --git a/dovecot/dict_test.go b/dovecot/dict_test.go new file mode 100644 index 0000000000000000000000000000000000000000..61446476a9f19ada7813e0f580347ebd819d6912 --- /dev/null +++ b/dovecot/dict_test.go @@ -0,0 +1,126 @@ +package dovecot + +import ( + "context" + "io/ioutil" + "log" + "net/textproto" + "os" + "path/filepath" + "testing" + + "git.autistici.org/ai3/go-common/unix" +) + +func TestUnescape(t *testing.T) { + for _, td := range []struct { + input, exp string + }{ + {"boo", "boo"}, + {"bo\001t", "bo\t"}, + {"bo\001t\001l", "bo\t\n"}, + {"bo\001t\0011", "bo\t\001"}, + } { + out := make([]byte, len(td.input)) + copy(out, td.input) + out = unescapeInplace(out) + if string(out) != td.exp { + t.Errorf("unescape('%s') returned '%s', expected '%s'", td.input, out, td.exp) + } + } +} + +type testDictDatabase struct { + data map[string]interface{} +} + +func (d *testDictDatabase) Lookup(_ context.Context, key string) (interface{}, bool, error) { + value, ok := d.data[key] + return value, ok, nil +} + +func newTestDictDatabase() *testDictDatabase { + return &testDictDatabase{ + data: map[string]interface{}{ + "pass/foo": struct { + Password string `json:"password"` + }{ + Password: "bar", + }, + }, + } +} + +func newTestSocketServer(t *testing.T) (string, func()) { + t.Helper() + + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + socketPath := filepath.Join(dir, "socket") + + lsrv := unix.NewLineServer( + NewDictProxyServer(newTestDictDatabase())) + srv, err := unix.NewUNIXSocketServer(socketPath, lsrv) + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error, 1) + go func() { + errCh <- srv.Serve() + }() + + return socketPath, func() { + srv.Close() + os.RemoveAll(dir) + if err := <-errCh; err != nil { + log.Printf("unix socket server error: %v", err) + } + } +} + +func makeTestRequest(t *testing.T, socketPath string, version int, lookupKey, expected string) { + conn, err := textproto.Dial("unix", socketPath) + if err != nil { + t.Fatalf("dialing socket: %v", err) + } + defer conn.Close() + + if err := conn.PrintfLine("H%d\t0\t1\tuser\tpass", version); err != nil { + t.Fatalf("error writing HELLO: %v", err) + } + + if version > 2 { + resp, err := conn.ReadLine() + if err != nil { + t.Fatalf("error reading HELLO response: %v", err) + } + if resp[0] != 'O' { + t.Fatalf("request returned unexpected status: %s", resp) + } + } + + if err := conn.PrintfLine("L%s", lookupKey); err != nil { + t.Fatalf("error writing LOOKUP: %v", err) + } + resp, err := conn.ReadLine() + if err != nil { + t.Fatalf("error reading HELLO response: %v", err) + } + if resp != expected { + t.Fatalf("unexpected response: got '%s', expected '%s'", resp, expected) + } +} + +func TestLookup(t *testing.T) { + socketPath, cleanup := newTestSocketServer(t) + defer cleanup() + + makeTestRequest(t, socketPath, 2, "pass/foo", "O{\"password\":\"bar\"}") + makeTestRequest(t, socketPath, 3, "pass/foo", "O{\"password\":\"bar\"}") + makeTestRequest(t, socketPath, 2, "unknown", "N") + makeTestRequest(t, socketPath, 3, "unknown", "N") +}