tx.go 5.71 KB
Newer Older
1 2 3 4
package backend

import (
	"context"
ale's avatar
ale committed
5
	"log"
6 7 8 9 10
	"strings"

	"gopkg.in/ldap.v2"
)

11 12 13 14 15 16 17 18 19
// Generic interface to LDAP - allows us to stub out the LDAP client while
// testing.
type ldapConn interface {
	Search(context.Context, *ldap.SearchRequest) (*ldap.SearchResult, error)
	Add(context.Context, *ldap.AddRequest) error
	Modify(context.Context, *ldap.ModifyRequest) error
	Close()
}

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
type ldapAttr struct {
	dn, attr string
	values   []string
}

// An LDAP "transaction" is really just a buffer of attribute changes,
// that are all executed at once at Commit() time.
//
// Unfortunately, in order to issue LDAP Modify requests properly
// (which have separate Add and Replace options, for one), we need to
// keep state about the observed data, so we cache the results of all
// Search operations and compare those with the new data at commit
// time. If you attempt to modify an object that you haven't Searched
// for previously in the same transaction, you're most likely to get a
// LDAP error in return. Which is fine because all our workflows are
// read/modify/update ones anyway.
//
// Since ordering of Modify requests is important in LDAP, this object
// will preserve the ordering of DNs and attributes when calling
// Commit().
//
type ldapTX struct {
	conn ldapConn

	cache   map[string][]string
45
	newDNs  map[string]struct{} // nolint (it's plural DN, not DNS)
46 47 48 49 50
	changes []ldapAttr
}

func newLDAPTX(conn ldapConn) *ldapTX {
	return &ldapTX{
51 52 53
		conn:   conn,
		cache:  make(map[string][]string),
		newDNs: make(map[string]struct{}),
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
	}
}

func cacheKey(dn, attr string) string {
	return strings.Join([]string{dn, attr}, ";")
}

// Search wrapper that fills the cache.
func (tx *ldapTX) search(ctx context.Context, req *ldap.SearchRequest) (*ldap.SearchResult, error) {
	res, err := tx.conn.Search(ctx, req)
	if err != nil {
		return nil, err
	}

	for _, entry := range res.Entries {
		for _, attr := range entry.Attributes {
			tx.cache[cacheKey(entry.DN, attr.Name)] = attr.Values
		}
	}

	return res, nil
}

77 78 79 80 81 82
// Announce the intention to create a new object. To be called before
// setAttr() on the new DN.
func (tx *ldapTX) create(dn string) {
	tx.newDNs[dn] = struct{}{}
}

83 84 85
// setAttr modifies a single attribute of an object. To delete an
// attribute, pass an empty list of values.
func (tx *ldapTX) setAttr(dn, attr string, values ...string) {
ale's avatar
ale committed
86 87 88
	if dn == "" {
		panic("empty dn in setAttr!")
	}
89 90 91 92 93 94 95
	tx.changes = append(tx.changes, ldapAttr{dn: dn, attr: attr, values: values})
}

// Commit the transaction, sending all changes to the LDAP server.
func (tx *ldapTX) Commit(ctx context.Context) error {
	// Iterate through the changes, and generate ModifyRequest
	// objects grouped by DN (while preserving the order of DNs).
96
	adds, mods, dns := tx.aggregateChanges(ctx)
97

98 99
	// Now issue all Modify or Add requests, one by one, in the
	// same order as we have seen them. Abort on the first error.
100
	for _, dn := range dns {
101 102 103 104 105 106 107 108 109 110 111 112 113 114
		var err error
		if ar, ok := adds[dn]; ok {
			if isEmptyAddRequest(ar) {
				continue
			}
			log.Printf("issuing AddRequest: %+v", ar)
			err = tx.conn.Add(ctx, ar)
		} else {
			mr := mods[dn]
			if isEmptyModifyRequest(mr) {
				continue
			}
			log.Printf("issuing ModifyRequest: %+v", mr)
			err = tx.conn.Modify(ctx, mr)
115
		}
116
		if err != nil {
117 118 119
			return err
		}
	}
120 121 122 123 124

	// Cleanup
	tx.changes = nil
	tx.newDNs = make(map[string]struct{})

125 126 127
	return nil
}

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
// Helper for Commit that aggregates changes into add and modify lists.
func (tx *ldapTX) aggregateChanges(ctx context.Context) (map[string]*ldap.AddRequest, map[string]*ldap.ModifyRequest, []string) {
	var dns []string
	mods := make(map[string]*ldap.ModifyRequest)
	adds := make(map[string]*ldap.AddRequest)
	for _, c := range tx.changes {
		if _, isNew := tx.newDNs[c.dn]; isNew {
			ar, ok := adds[c.dn]
			if !ok {
				ar = ldap.NewAddRequest(c.dn)
				adds[c.dn] = ar
				dns = append(dns, c.dn)
			}
			if len(c.values) > 0 {
				ar.Attribute(c.attr, c.values)
			}
		} else {
			mr, ok := mods[c.dn]
			if !ok {
				mr = ldap.NewModifyRequest(c.dn)
				mods[c.dn] = mr
				dns = append(dns, c.dn)
			}
			tx.updateModifyRequest(ctx, mr, c)
		}
	}
	return adds, mods, dns
}

157
func (tx *ldapTX) updateModifyRequest(ctx context.Context, mr *ldap.ModifyRequest, attr ldapAttr) {
158
	old, ok := tx.cache[cacheKey(attr.dn, attr.attr)]
159 160 161 162 163 164 165 166 167 168 169 170 171

	// Pessimistic approach: if we haven't seen this attribute
	// before, try to fetch it from LDAP so we know if we need to
	// perform an Add or a Replace.
	if !ok {
		log.Printf("tx: pessimistic fallback for %s %s", attr.dn, attr.attr)
		oldFromLDAP := tx.readAttributeValues(ctx, attr.dn, attr.attr)
		if len(oldFromLDAP) > 0 {
			ok = true
			old = oldFromLDAP
		}
	}

172 173 174 175
	switch {
	case ok && !stringListEquals(old, attr.values):
		mr.Replace(attr.attr, attr.values)
	case ok && attr.values == nil:
176
		mr.Delete(attr.attr, old)
177 178 179 180 181
	case !ok && len(attr.values) > 0:
		mr.Add(attr.attr, attr.values)
	}
}

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
func (tx *ldapTX) readAttributeValues(ctx context.Context, dn, attr string) []string {
	result, err := tx.search(ctx, ldap.NewSearchRequest(
		dn,
		ldap.ScopeBaseObject,
		ldap.NeverDerefAliases,
		0,
		0,
		false,
		"(objectClass=*)",
		[]string{attr},
		nil,
	))
	if err == nil && len(result.Entries) > 0 {
		return result.Entries[0].GetAttributeValues(attr)
	}
	return nil
}

200 201 202 203 204 205
func isEmptyModifyRequest(mr *ldap.ModifyRequest) bool {
	return (len(mr.AddAttributes) == 0 &&
		len(mr.DeleteAttributes) == 0 &&
		len(mr.ReplaceAttributes) == 0)
}

206 207 208 209
func isEmptyAddRequest(ar *ldap.AddRequest) bool {
	return len(ar.Attributes) == 0
}

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
// Unordered list comparison.
func stringListEquals(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}
	tmp := make(map[string]struct{})
	for _, aa := range a {
		tmp[aa] = struct{}{}
	}
	for _, bb := range b {
		if _, ok := tmp[bb]; !ok {
			return false
		}
	}
	return true
}