From 835b1627737a4cc94d797450b2f5cec3666f15e9 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Fri, 7 May 2021 14:11:57 +0100
Subject: [PATCH] Make LDAP "transactions" self-consistent

The data written will be available for reading even before calling
Commit on the transaction. This allows constructs such as loops that
repeatedly look up and modify the same attribute to work, which fixes
the behavior of deleteAllApplicationSpecificPasswords() which
previously would only succeed in deleting one of them.

Fixes issue ai3/prod#233.
---
 backend/ldap/tx.go                   | 40 ++++++++++++++++++++++------
 integrationtest/account_mgmt_test.go | 18 +++++++++++--
 integrationtest/testdata/test3.ldif  |  2 ++
 3 files changed, 50 insertions(+), 10 deletions(-)

diff --git a/backend/ldap/tx.go b/backend/ldap/tx.go
index 22945750..8399ee25 100644
--- a/backend/ldap/tx.go
+++ b/backend/ldap/tx.go
@@ -41,15 +41,28 @@ type ldapAttr struct {
 type ldapTX struct {
 	conn ldapConn
 
-	cache   map[string][]string
-	newDNs  map[string]struct{} // nolint (it's plural DN, not DNS)
+	// Read cache, containing data from the db. Only used to
+	// figure out if we need to issue a ModifyRequest or an
+	// AddRequest at commit time.
+	rcache map[string][]string
+
+	// Write cache, used to store modified attributes. Allows the
+	// transaction to be self-consistent: you can read the data
+	// that you've just written even before calling Commit().
+	wcache map[string][]string
+
+	// List of the new DNs that need to be created.
+	// nolint: it's plural DN, not DNS.
+	newDNs map[string]struct{}
+
 	changes []*ldapAttr
 }
 
 func newLDAPTX(conn ldapConn) *ldapTX {
 	return &ldapTX{
 		conn:   conn,
-		cache:  make(map[string][]string),
+		rcache: make(map[string][]string),
+		wcache: make(map[string][]string),
 		newDNs: make(map[string]struct{}),
 	}
 }
@@ -67,7 +80,7 @@ func (tx *ldapTX) search(ctx context.Context, req *ldap.SearchRequest) (*ldap.Se
 
 	for _, entry := range res.Entries {
 		for _, attr := range entry.Attributes {
-			tx.cache[cacheKey(entry.DN, attr.Name)] = attr.Values
+			tx.rcache[cacheKey(entry.DN, attr.Name)] = attr.Values
 		}
 	}
 
@@ -86,6 +99,10 @@ func (tx *ldapTX) setAttr(dn, attr string, values ...string) {
 	if dn == "" {
 		panic("empty dn in setAttr!")
 	}
+
+	// Set the value in the transaction write cache.
+	tx.wcache[cacheKey(dn, attr)] = values
+
 	// Reuse previous change, if any. Prevents value duplication
 	// in the same ModifyRequest.
 	found := false
@@ -133,6 +150,7 @@ func (tx *ldapTX) Commit(ctx context.Context) error {
 	// Cleanup
 	tx.changes = nil
 	tx.newDNs = make(map[string]struct{})
+	tx.wcache = make(map[string][]string)
 
 	return nil
 }
@@ -167,14 +185,13 @@ func (tx *ldapTX) aggregateChanges(ctx context.Context) (map[string]*ldap.AddReq
 }
 
 func (tx *ldapTX) updateModifyRequest(ctx context.Context, mr *ldap.ModifyRequest, attr *ldapAttr) {
-	old, ok := tx.cache[cacheKey(attr.dn, attr.attr)]
-
 	// Pessimistic approach: if we haven't seen this attribute
 	// before, try to fetch it from LDAP so we know if we need to
 	// perform an Add or a Replace.
+	old, ok := tx.rcache[cacheKey(attr.dn, attr.attr)]
 	if !ok {
 		log.Printf("tx: pessimistic fallback for %s %s", attr.dn, attr.attr)
-		oldFromLDAP := tx.readAttributeValues(ctx, attr.dn, attr.attr)
+		oldFromLDAP := tx.readAttributeValuesNoCache(ctx, attr.dn, attr.attr)
 		if len(oldFromLDAP) > 0 {
 			ok = true
 			old = oldFromLDAP
@@ -191,7 +208,7 @@ func (tx *ldapTX) updateModifyRequest(ctx context.Context, mr *ldap.ModifyReques
 	}
 }
 
-func (tx *ldapTX) readAttributeValues(ctx context.Context, dn, attr string) []string {
+func (tx *ldapTX) readAttributeValuesNoCache(ctx context.Context, dn, attr string) []string {
 	result, err := tx.search(ctx, ldap.NewSearchRequest(
 		dn,
 		ldap.ScopeBaseObject,
@@ -209,6 +226,13 @@ func (tx *ldapTX) readAttributeValues(ctx context.Context, dn, attr string) []st
 	return nil
 }
 
+func (tx *ldapTX) readAttributeValues(ctx context.Context, dn, attr string) []string {
+	if values, ok := tx.wcache[cacheKey(dn, attr)]; ok {
+		return values
+	}
+	return tx.readAttributeValuesNoCache(ctx, dn, attr)
+}
+
 func isEmptyModifyRequest(mr *ldap.ModifyRequest) bool {
 	return len(mr.Changes) == 0
 }
diff --git a/integrationtest/account_mgmt_test.go b/integrationtest/account_mgmt_test.go
index 0a395995..e1195030 100644
--- a/integrationtest/account_mgmt_test.go
+++ b/integrationtest/account_mgmt_test.go
@@ -159,6 +159,13 @@ func TestIntegration_AccountRecovery_WithEncryptionKeysAndCache(t *testing.T) {
 	}
 }
 
+func TestIntegration_AccountRecovery_ClearsAppSpecificPasswords(t *testing.T) {
+	user := runAccountRecoveryTest(t, "tre@investici.org", false, false)
+	if len(user.AppSpecificPasswords) > 0 {
+		t.Fatal("app-specific passwords were not cleared after account recovery")
+	}
+}
+
 func runAccountRecoveryTest(t *testing.T, username string, enableCache, enableOpportunisticEncryption bool) *as.RawUser {
 	cfg := as.Config{
 		EnableOpportunisticEncryption: enableOpportunisticEncryption,
@@ -247,7 +254,14 @@ func TestIntegration_AppSpecificPassword(t *testing.T) {
 	if err != nil {
 		t.Fatalf("GetUser error: %v", err)
 	}
-	if len(user.AppSpecificPasswords) == 0 {
-		t.Errorf("no ASPs were retrieved: %+v", user)
+	found := false
+	for _, asp := range user.AppSpecificPasswords {
+		if asp.Service == "service" {
+			found = true
+			break
+		}
+	}
+	if !found {
+		t.Errorf("could not find the ASPs that was just created: %+v", user)
 	}
 }
diff --git a/integrationtest/testdata/test3.ldif b/integrationtest/testdata/test3.ldif
index af2ad5a9..1dcd0c7f 100644
--- a/integrationtest/testdata/test3.ldif
+++ b/integrationtest/testdata/test3.ldif
@@ -20,6 +20,8 @@ uid: tre@investici.org
 uidNumber: 256799
 userPassword:: JGEyJDQkMzI3NjgkMSQwZDgyMzU1YjQ0Mzg0M2NmZDY4MjU1MzE4ZTVjYTdiZSRmNTQ0ODkxOTFiNWZlYzk2MDRlNWQ2ODZjMDQxZjJkNTFmOTgxOGY4ZTFmM2E4MDYzY2U3ZTEwMTE3OTc2OGI0
 totpSecret: ABCDEF
+appSpecificPassword: id1:email:encryptedpassword:comment
+appSpecificPassword: id2:jabber:encryptedpassword:comment
 status: active
 host: host2
 
-- 
GitLab