Commit ac2aa256 authored by ale's avatar ale
Browse files

Implement a transaction-like interface for the backend

This should make it easier to implement a SQL backend in the future if
necessary, even though LDAP knows no such thing as transactions.

As a result of a better low-level interface, reducing the boilerplate
LDAP code, the business logic in model.go should be quite more
readable.
parent b299bbd7
...@@ -14,6 +14,16 @@ import ( ...@@ -14,6 +14,16 @@ import (
// Backend user database interface. // Backend user database interface.
// //
// We are using a transactional interface even if the actual backend
// (LDAP) does not support atomic transactions, just so it is easy to
// add more backends in the future (like SQL).
type Backend interface {
NewTransaction() (TX, error)
}
// TX represents a single transaction with the backend and offers a
// high-level data management abstraction.
//
// All methods share similar semantics: Get methods will return nil if // All methods share similar semantics: Get methods will return nil if
// the requested object is not found, and only return an error in case // the requested object is not found, and only return an error in case
// of trouble reaching the backend itself. // of trouble reaching the backend itself.
...@@ -26,7 +36,9 @@ import ( ...@@ -26,7 +36,9 @@ import (
// We might add more sophisticated resource query methods later, as // We might add more sophisticated resource query methods later, as
// admin-level functionality. // admin-level functionality.
// //
type Backend interface { type TX interface {
Commit(context.Context) error
GetUser(context.Context, string) (*User, error) GetUser(context.Context, string) (*User, error)
GetResource(context.Context, string, string) (*Resource, error) GetResource(context.Context, string, string) (*Resource, error)
UpdateResource(context.Context, string, *Resource) error UpdateResource(context.Context, string, *Resource) error
...@@ -78,8 +90,8 @@ func newAccountServiceWithSSO(backend Backend, config *Config, ssoValidator sso. ...@@ -78,8 +90,8 @@ func newAccountServiceWithSSO(backend Backend, config *Config, ssoValidator sso.
validationConfig := config.validationConfig() validationConfig := config.validationConfig()
domainBackend := config.domainBackend() domainBackend := config.domainBackend()
s.dataValidators = map[string]ValidatorFunc{ s.dataValidators = map[string]ValidatorFunc{
ResourceTypeEmail: validHostedEmail(validationConfig, domainBackend, backend), ResourceTypeEmail: validHostedEmail(validationConfig, domainBackend, nil),
ResourceTypeMailingList: validHostedMailingList(validationConfig, domainBackend, backend), ResourceTypeMailingList: validHostedMailingList(validationConfig, domainBackend, nil),
} }
return s return s
...@@ -100,7 +112,7 @@ var ( ...@@ -100,7 +112,7 @@ var (
ErrResourceNotFound = errors.New("resource not found") ErrResourceNotFound = errors.New("resource not found")
) )
func (s *AccountService) authorizeAdmin(ctx context.Context, username, ssoToken string) (*User, error) { func (s *AccountService) authorizeAdmin(ctx context.Context, tx TX, username, ssoToken string) (*User, error) {
// Validate the SSO ticket. // Validate the SSO ticket.
tkt, err := s.validator.Validate(ssoToken, "", s.ssoService, s.ssoGroups) tkt, err := s.validator.Validate(ssoToken, "", s.ssoService, s.ssoGroups)
if err != nil { if err != nil {
...@@ -113,7 +125,7 @@ func (s *AccountService) authorizeAdmin(ctx context.Context, username, ssoToken ...@@ -113,7 +125,7 @@ func (s *AccountService) authorizeAdmin(ctx context.Context, username, ssoToken
return nil, newAuthError(ErrUnauthorized) return nil, newAuthError(ErrUnauthorized)
} }
user, err := s.backend.GetUser(ctx, username) user, err := tx.GetUser(ctx, username)
if err != nil { if err != nil {
return nil, newBackendError(err) return nil, newBackendError(err)
} }
...@@ -123,7 +135,7 @@ func (s *AccountService) authorizeAdmin(ctx context.Context, username, ssoToken ...@@ -123,7 +135,7 @@ func (s *AccountService) authorizeAdmin(ctx context.Context, username, ssoToken
return user, nil return user, nil
} }
func (s *AccountService) authorizeUser(ctx context.Context, username, ssoToken string) (*User, error) { func (s *AccountService) authorizeUser(ctx context.Context, tx TX, username, ssoToken string) (*User, error) {
// First, check that the username matches the SSO ticket // First, check that the username matches the SSO ticket
// username (or that the SSO ticket has admin permissions). // username (or that the SSO ticket has admin permissions).
tkt, err := s.validator.Validate(ssoToken, "", s.ssoService, s.ssoGroups) tkt, err := s.validator.Validate(ssoToken, "", s.ssoService, s.ssoGroups)
...@@ -137,7 +149,7 @@ func (s *AccountService) authorizeUser(ctx context.Context, username, ssoToken s ...@@ -137,7 +149,7 @@ func (s *AccountService) authorizeUser(ctx context.Context, username, ssoToken s
return nil, newAuthError(ErrUnauthorized) return nil, newAuthError(ErrUnauthorized)
} }
user, err := s.backend.GetUser(ctx, username) user, err := tx.GetUser(ctx, username)
if err != nil { if err != nil {
return nil, newBackendError(err) return nil, newBackendError(err)
} }
...@@ -150,9 +162,9 @@ func (s *AccountService) authorizeUser(ctx context.Context, username, ssoToken s ...@@ -150,9 +162,9 @@ func (s *AccountService) authorizeUser(ctx context.Context, username, ssoToken s
// Extended version of authorizeUser that also directly checks the // Extended version of authorizeUser that also directly checks the
// user password. Used for account-privileged operations related to // user password. Used for account-privileged operations related to
// credential manipulation. // credential manipulation.
func (s *AccountService) authorizeUserWithPassword(ctx context.Context, username, ssoToken, password string) (*User, error) { func (s *AccountService) authorizeUserWithPassword(ctx context.Context, tx TX, username, ssoToken, password string) (*User, error) {
// TODO: call out to the auth-server? // TODO: call out to the auth-server?
return s.authorizeUser(ctx, username, ssoToken) return s.authorizeUser(ctx, tx, username, ssoToken)
} }
// RequestBase contains parameters shared by all request types. // RequestBase contains parameters shared by all request types.
...@@ -173,14 +185,14 @@ type GetUserRequest struct { ...@@ -173,14 +185,14 @@ type GetUserRequest struct {
} }
// GetUser returns public information about a user. // GetUser returns public information about a user.
func (s *AccountService) GetUser(ctx context.Context, req *GetUserRequest) (*User, error) { func (s *AccountService) GetUser(ctx context.Context, tx TX, req *GetUserRequest) (*User, error) {
return s.authorizeUser(ctx, req.Username, req.SSO) return s.authorizeUser(ctx, tx, req.Username, req.SSO)
} }
// setResourceStatus sets the status of a single resource (shared // setResourceStatus sets the status of a single resource (shared
// logic between enable / disable resource methods). // logic between enable / disable resource methods).
func (s *AccountService) setResourceStatus(ctx context.Context, username, resourceID, status string) error { func (s *AccountService) setResourceStatus(ctx context.Context, tx TX, username, resourceID, status string) error {
r, err := s.backend.GetResource(ctx, username, resourceID) r, err := tx.GetResource(ctx, username, resourceID)
if err != nil { if err != nil {
return newBackendError(err) return newBackendError(err)
} }
...@@ -188,7 +200,7 @@ func (s *AccountService) setResourceStatus(ctx context.Context, username, resour ...@@ -188,7 +200,7 @@ func (s *AccountService) setResourceStatus(ctx context.Context, username, resour
return ErrResourceNotFound return ErrResourceNotFound
} }
r.Status = status r.Status = status
if err := s.backend.UpdateResource(ctx, username, r); err != nil { if err := tx.UpdateResource(ctx, username, r); err != nil {
return newBackendError(err) return newBackendError(err)
} }
return nil return nil
...@@ -200,11 +212,11 @@ type DisableResourceRequest struct { ...@@ -200,11 +212,11 @@ type DisableResourceRequest struct {
} }
// DisableResource disables a resource belonging to the user. // DisableResource disables a resource belonging to the user.
func (s *AccountService) DisableResource(ctx context.Context, req *DisableResourceRequest) error { func (s *AccountService) DisableResource(ctx context.Context, tx TX, req *DisableResourceRequest) error {
if _, err := s.authorizeUser(ctx, req.Username, req.SSO); err != nil { if _, err := s.authorizeUser(ctx, tx, req.Username, req.SSO); err != nil {
return err return err
} }
return s.setResourceStatus(ctx, req.Username, req.ResourceID, ResourceStatusInactive) return s.setResourceStatus(ctx, tx, req.Username, req.ResourceID, ResourceStatusInactive)
} }
type EnableResourceRequest struct { type EnableResourceRequest struct {
...@@ -213,11 +225,11 @@ type EnableResourceRequest struct { ...@@ -213,11 +225,11 @@ type EnableResourceRequest struct {
} }
// EnableResource enables a resource belonging to the user. // EnableResource enables a resource belonging to the user.
func (s *AccountService) EnableResource(ctx context.Context, req *EnableResourceRequest) error { func (s *AccountService) EnableResource(ctx context.Context, tx TX, req *EnableResourceRequest) error {
if _, err := s.authorizeUser(ctx, req.Username, req.SSO); err != nil { if _, err := s.authorizeUser(ctx, tx, req.Username, req.SSO); err != nil {
return err return err
} }
return s.setResourceStatus(ctx, req.Username, req.ResourceID, ResourceStatusActive) return s.setResourceStatus(ctx, tx, req.Username, req.ResourceID, ResourceStatusActive)
} }
type ChangeUserPasswordRequest struct { type ChangeUserPasswordRequest struct {
...@@ -234,8 +246,8 @@ func (r *ChangeUserPasswordRequest) Validate() error { ...@@ -234,8 +246,8 @@ func (r *ChangeUserPasswordRequest) Validate() error {
// ChangeUserPassword updates a user's password. It will also take // ChangeUserPassword updates a user's password. It will also take
// care of re-encrypting the user encryption key, if present. // care of re-encrypting the user encryption key, if present.
func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUserPasswordRequest) error { func (s *AccountService) ChangeUserPassword(ctx context.Context, tx TX, req *ChangeUserPasswordRequest) error {
user, err := s.authorizeUserWithPassword(ctx, req.Username, req.SSO, req.CurPassword) user, err := s.authorizeUserWithPassword(ctx, tx, req.Username, req.SSO, req.CurPassword)
if err != nil { if err != nil {
return err return err
} }
...@@ -246,9 +258,9 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser ...@@ -246,9 +258,9 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser
// If the user does not yet have an encryption key, generate one now. // If the user does not yet have an encryption key, generate one now.
if !user.HasEncryptionKeys { if !user.HasEncryptionKeys {
err = s.initializeUserEncryptionKeys(ctx, user, req.CurPassword) err = s.initializeUserEncryptionKeys(ctx, tx, user, req.CurPassword)
} else { } else {
err = s.updateUserEncryptionKeys(ctx, user, req.CurPassword, req.Password, UserEncryptionKeyMainID) err = s.updateUserEncryptionKeys(ctx, tx, user, req.CurPassword, req.Password, UserEncryptionKeyMainID)
} }
if err != nil { if err != nil {
return err return err
...@@ -256,11 +268,11 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser ...@@ -256,11 +268,11 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser
// Set the encrypted password attribute on the user and email resources. // Set the encrypted password attribute on the user and email resources.
encPass := pwhash.Encrypt(req.Password) encPass := pwhash.Encrypt(req.Password)
if err := s.backend.SetUserPassword(ctx, user, encPass); err != nil { if err := tx.SetUserPassword(ctx, user, encPass); err != nil {
return newBackendError(err) return newBackendError(err)
} }
for _, r := range user.GetResourcesByType(ResourceTypeEmail) { for _, r := range user.GetResourcesByType(ResourceTypeEmail) {
if err := s.backend.SetResourcePassword(ctx, user.Name, r, encPass); err != nil { if err := tx.SetResourcePassword(ctx, user.Name, r, encPass); err != nil {
return newBackendError(err) return newBackendError(err)
} }
} }
...@@ -271,7 +283,7 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser ...@@ -271,7 +283,7 @@ func (s *AccountService) ChangeUserPassword(ctx context.Context, req *ChangeUser
// Initialize the user encryption key list, by creating a new "main" key // Initialize the user encryption key list, by creating a new "main" key
// encrypted with the given password (which must be the primary password for the // encrypted with the given password (which must be the primary password for the
// user). // user).
func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, user *User, curPassword string) error { func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, tx TX, user *User, curPassword string) error {
// Generate a new key pair. // Generate a new key pair.
pub, priv, err := userenckey.GenerateKey() pub, priv, err := userenckey.GenerateKey()
if err != nil { if err != nil {
...@@ -291,10 +303,10 @@ func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, user ...@@ -291,10 +303,10 @@ func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, user
} }
// Update the backend database. // Update the backend database.
if err := s.backend.SetUserEncryptionKeys(ctx, user, keys); err != nil { if err := tx.SetUserEncryptionKeys(ctx, user, keys); err != nil {
return newBackendError(err) return newBackendError(err)
} }
if err := s.backend.SetUserEncryptionPublicKey(ctx, user, pub); err != nil { if err := tx.SetUserEncryptionPublicKey(ctx, user, pub); err != nil {
return newBackendError(err) return newBackendError(err)
} }
...@@ -304,8 +316,8 @@ func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, user ...@@ -304,8 +316,8 @@ func (s *AccountService) initializeUserEncryptionKeys(ctx context.Context, user
// Re-encrypt the specified user encryption key with newPassword. For this // Re-encrypt the specified user encryption key with newPassword. For this
// operation to succeed, we must be able to decrypt one of the keys (not // operation to succeed, we must be able to decrypt one of the keys (not
// necessarily the same one) with curPassword. // necessarily the same one) with curPassword.
func (s *AccountService) updateUserEncryptionKeys(ctx context.Context, user *User, curPassword, newPassword, keyID string) error { func (s *AccountService) updateUserEncryptionKeys(ctx context.Context, tx TX, user *User, curPassword, newPassword, keyID string) error {
keys, err := s.backend.GetUserEncryptionKeys(ctx, user) keys, err := tx.GetUserEncryptionKeys(ctx, user)
if err != nil { if err != nil {
return newBackendError(err) return newBackendError(err)
} }
...@@ -313,7 +325,7 @@ func (s *AccountService) updateUserEncryptionKeys(ctx context.Context, user *Use ...@@ -313,7 +325,7 @@ func (s *AccountService) updateUserEncryptionKeys(ctx context.Context, user *Use
if err != nil { if err != nil {
return newRequestError(err) return newRequestError(err)
} }
if err := s.backend.SetUserEncryptionKeys(ctx, user, keys); err != nil { if err := tx.SetUserEncryptionKeys(ctx, user, keys); err != nil {
return newBackendError(err) return newBackendError(err)
} }
return nil return nil
...@@ -369,8 +381,8 @@ type CreateApplicationSpecificPasswordResponse struct { ...@@ -369,8 +381,8 @@ type CreateApplicationSpecificPasswordResponse struct {
// CreateApplicationSpecificPassword will generate a new // CreateApplicationSpecificPassword will generate a new
// application-specific password for the given service. // application-specific password for the given service.
func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context, req *CreateApplicationSpecificPasswordRequest) (*CreateApplicationSpecificPasswordResponse, error) { func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context, tx TX, req *CreateApplicationSpecificPasswordRequest) (*CreateApplicationSpecificPasswordResponse, error) {
user, err := s.authorizeUserWithPassword(ctx, req.Username, req.SSO, req.CurPassword) user, err := s.authorizeUserWithPassword(ctx, tx, req.Username, req.SSO, req.CurPassword)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -394,7 +406,7 @@ func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context, ...@@ -394,7 +406,7 @@ func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context,
} }
password := randomAppSpecificPassword() password := randomAppSpecificPassword()
encPass := pwhash.Encrypt(password) encPass := pwhash.Encrypt(password)
if err := s.backend.SetApplicationSpecificPassword(ctx, user, asp, encPass); err != nil { if err := tx.SetApplicationSpecificPassword(ctx, user, asp, encPass); err != nil {
return nil, newBackendError(err) return nil, newBackendError(err)
} }
...@@ -403,7 +415,7 @@ func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context, ...@@ -403,7 +415,7 @@ func (s *AccountService) CreateApplicationSpecificPassword(ctx context.Context,
// have an 'asp_' prefix, followed by the ASP ID. // have an 'asp_' prefix, followed by the ASP ID.
if user.HasEncryptionKeys { if user.HasEncryptionKeys {
keyID := "asp_" + asp.ID keyID := "asp_" + asp.ID
if err := s.updateUserEncryptionKeys(ctx, user, req.CurPassword, password, keyID); err != nil { if err := s.updateUserEncryptionKeys(ctx, tx, user, req.CurPassword, password, keyID); err != nil {
return nil, err return nil, err
} }
} }
...@@ -420,19 +432,19 @@ type DeleteApplicationSpecificPasswordRequest struct { ...@@ -420,19 +432,19 @@ type DeleteApplicationSpecificPasswordRequest struct {
// DeleteApplicationSpecificPassword destroys an application-specific // DeleteApplicationSpecificPassword destroys an application-specific
// password, identified by its unique ID. // password, identified by its unique ID.
func (s *AccountService) DeleteApplicationSpecificPassword(ctx context.Context, req *DeleteApplicationSpecificPasswordRequest) error { func (s *AccountService) DeleteApplicationSpecificPassword(ctx context.Context, tx TX, req *DeleteApplicationSpecificPasswordRequest) error {
user, err := s.authorizeUser(ctx, req.Username, req.SSO) user, err := s.authorizeUser(ctx, tx, req.Username, req.SSO)
if err != nil { if err != nil {
return err return err
} }
if err = s.backend.DeleteApplicationSpecificPassword(ctx, user, req.AspID); err != nil { if err = tx.DeleteApplicationSpecificPassword(ctx, user, req.AspID); err != nil {
return err return err
} }
// Delete the user encryption key associated with this // Delete the user encryption key associated with this
// password (we're going to find it via its ID). // password (we're going to find it via its ID).
keys, err := s.backend.GetUserEncryptionKeys(ctx, user) keys, err := tx.GetUserEncryptionKeys(ctx, user)
if err != nil { if err != nil {
return err return err
} }
...@@ -446,7 +458,7 @@ func (s *AccountService) DeleteApplicationSpecificPassword(ctx context.Context, ...@@ -446,7 +458,7 @@ func (s *AccountService) DeleteApplicationSpecificPassword(ctx context.Context,
newKeys = append(newKeys, k) newKeys = append(newKeys, k)
} }
} }
return s.backend.SetUserEncryptionKeys(ctx, user, newKeys) return tx.SetUserEncryptionKeys(ctx, user, newKeys)
} }
type ChangeResourcePasswordRequest struct { type ChangeResourcePasswordRequest struct {
...@@ -465,8 +477,8 @@ func (r *ChangeResourcePasswordRequest) Validate() error { ...@@ -465,8 +477,8 @@ func (r *ChangeResourcePasswordRequest) Validate() error {
// ChangeResourcePassword modifies the password associated with a // ChangeResourcePassword modifies the password associated with a
// specific resource. Resources that do not support this method should // specific resource. Resources that do not support this method should
// return an error from the backend. // return an error from the backend.
func (s *AccountService) ChangeResourcePassword(ctx context.Context, req *ChangeResourcePasswordRequest) error { func (s *AccountService) ChangeResourcePassword(ctx context.Context, tx TX, req *ChangeResourcePasswordRequest) error {
_, err := s.authorizeUser(ctx, req.Username, req.SSO) _, err := s.authorizeUser(ctx, tx, req.Username, req.SSO)
if err != nil { if err != nil {
return err return err
} }
...@@ -475,7 +487,7 @@ func (s *AccountService) ChangeResourcePassword(ctx context.Context, req *Change ...@@ -475,7 +487,7 @@ func (s *AccountService) ChangeResourcePassword(ctx context.Context, req *Change
return newRequestError(err) return newRequestError(err)
} }
r, err := s.backend.GetResource(ctx, req.Username, req.ResourceID) r, err := tx.GetResource(ctx, req.Username, req.ResourceID)
if err != nil { if err != nil {
return newBackendError(err) return newBackendError(err)
} }
...@@ -484,7 +496,7 @@ func (s *AccountService) ChangeResourcePassword(ctx context.Context, req *Change ...@@ -484,7 +496,7 @@ func (s *AccountService) ChangeResourcePassword(ctx context.Context, req *Change
} }
encPass := pwhash.Encrypt(req.Password) encPass := pwhash.Encrypt(req.Password)
if err := s.backend.SetResourcePassword(ctx, req.Username, r, encPass); err != nil { if err := tx.SetResourcePassword(ctx, req.Username, r, encPass); err != nil {
return newBackendError(err) return newBackendError(err)
} }
return nil return nil
...@@ -504,14 +516,14 @@ type MoveResourceResponse struct { ...@@ -504,14 +516,14 @@ type MoveResourceResponse struct {
// between shards. Resources that are part of a group are moved all at // between shards. Resources that are part of a group are moved all at
// once regardless of which individual ResourceID is provided as long // once regardless of which individual ResourceID is provided as long
// as it belongs to the group. // as it belongs to the group.
func (s *AccountService) MoveResource(ctx context.Context, req *MoveResourceRequest) (*MoveResourceResponse, error) { func (s *AccountService) MoveResource(ctx context.Context, tx TX, req *MoveResourceRequest) (*MoveResourceResponse, error) {
user, err := s.authorizeAdmin(ctx, req.Username, req.SSO) user, err := s.authorizeAdmin(ctx, tx, req.Username, req.SSO)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Collect all related resources, as they should all be moved at once. // Collect all related resources, as they should all be moved at once.
r, err := s.backend.GetResource(ctx, req.Username, req.ResourceID) r, err := tx.GetResource(ctx, req.Username, req.ResourceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -525,10 +537,10 @@ func (s *AccountService) MoveResource(ctx context.Context, req *MoveResourceRequ ...@@ -525,10 +537,10 @@ func (s *AccountService) MoveResource(ctx context.Context, req *MoveResourceRequ
var resp MoveResourceResponse var resp MoveResourceResponse
for _, r := range resources { for _, r := range resources {
r.Shard = req.Shard r.Shard = req.Shard
if err := s.backend.UpdateResource(ctx, req.Username, r); err != nil { if err := tx.UpdateResource(ctx, req.Username, r); err != nil {
return nil, err return nil, err
} }
resp.MovedIDs = append(resp.MovedIDs, r.ID) resp.MovedIDs = append(resp.MovedIDs, r.ID.String())
} }
return &resp, nil return &resp, nil
...@@ -556,8 +568,8 @@ type EnableOTPResponse struct { ...@@ -556,8 +568,8 @@ type EnableOTPResponse struct {
// (useful for UX that confirms that the user is able to login first), // (useful for UX that confirms that the user is able to login first),
// or it can let the server generate a new secret by passing an empty // or it can let the server generate a new secret by passing an empty
// totp_secret. // totp_secret.
func (s *AccountService) EnableOTP(ctx context.Context, req *EnableOTPRequest) (*EnableOTPResponse, error) { func (s *AccountService) EnableOTP(ctx context.Context, tx TX, req *EnableOTPRequest) (*EnableOTPResponse, error) {
user, err := s.authorizeUser(ctx, req.Username, req.SSO) user, err := s.authorizeUser(ctx, tx, req.Username, req.SSO)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -573,7 +585,7 @@ func (s *AccountService) EnableOTP(ctx context.Context, req *EnableOTPRequest) ( ...@@ -573,7 +585,7 @@ func (s *AccountService) EnableOTP(ctx context.Context, req *EnableOTPRequest) (
return nil, err return nil, err
} }
} }
if err := s.backend.SetUserTOTPSecret(ctx, user, req.TOTPSecret); err != nil { if err := tx.SetUserTOTPSecret(ctx, user, req.TOTPSecret); err != nil {
return nil, newBackendError(err) return nil, newBackendError(err)
} }
...@@ -587,14 +599,14 @@ type DisableOTPRequest struct { ...@@ -587,14 +599,14 @@ type DisableOTPRequest struct {
} }
// DisableOTP disables two-factor authentication for a user. // DisableOTP disables two-factor authentication for a user.
func (s *AccountService) DisableOTP(ctx context.Context, req *DisableOTPRequest) error { func (s *AccountService) DisableOTP(ctx context.Context, tx TX, req *DisableOTPRequest) error {
user, err := s.authorizeUser(ctx, req.Username, req.SSO) user, err := s.authorizeUser(ctx, tx, req.Username, req.SSO)
if err != nil { if err != nil {
return err return err
} }
// Delete the TOTP secret (if present). // Delete the TOTP secret (if present).
if err := s.backend.DeleteUserTOTPSecret(ctx, user); err != nil { if err := tx.DeleteUserTOTPSecret(ctx, user); err != nil {
return newBackendError(err) return newBackendError(err)
} }
return nil return nil
......
...@@ -22,13 +22,13 @@ type ldapConn interface { ...@@ -22,13 +22,13 @@ type ldapConn interface {
Close() Close()
} }
// LDAPBackend is the interface to an LDAP-backed user database. // backend is the interface to an LDAP-backed user database.
// //
// We keep a set of LDAP queries for each resource type, each having a // We keep a set of LDAP queries for each resource type, each having a
// "resource" query to return a specific resource belonging to a user, // "resource" query to return a specific resource belonging to a user,
// and a "presence" query that checks for existence of a resource for // and a "presence" query that checks for existence of a resource for
// all users. // all users.
type LDAPBackend struct { type backend struct {
conn ldapConn conn ldapConn
userQuery *queryConfig userQuery *queryConfig
userResourceQueries []*queryConfig userResourceQueries []*queryConfig
...@@ -36,18 +36,33 @@ type LDAPBackend struct { ...@@ -36,18 +36,33 @@ type LDAPBackend struct {
presenceQueries map[string]*queryConfig presenceQueries map[string]*queryConfig
} }
type backendTX struct {
*ldapTX
backend *backend
}
const ldapPoolSize = 20 const ldapPoolSize = 20
func (b *backend) NewTransaction() (accountserver.TX, error) {
return &backendTX{
ldapTX: newLDAPTX(b.conn),
backend: b,
}, nil
}
// NewLDAPBackend initializes an LDAPBackend object with the given LDAP // NewLDAPBackend initializes an LDAPBackend object with the given LDAP
// connection pool. // connection pool.
func NewLDAPBackend(uri, bindDN, bindPw, base string) (*LDAPBackend, error) { func NewLDAPBackend(uri, bindDN, bindPw, base string) (accountserver.Backend, error) {
pool, err := ldaputil.NewConnectionPool(uri, bindDN, bindPw, ldapPoolSize) pool, err := ldaputil.NewConnectionPool(uri, bindDN, bindPw, ldapPoolSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newLDAPBackendWithConn(pool, base)
}
return &LDAPBackend{ func newLDAPBackendWithConn(conn ldapConn, base string) (*backend, error) {
conn: pool, return &backend{
conn: conn,
userQuery: mustCompileQueryConfig(&queryConfig{ userQuery: mustCompileQueryConfig(&queryConfig{
Base: "uid=${user},ou=People," + base, Base: "uid=${user},ou=People," + base,
Scope: "base", Scope: "base",
...@@ -195,7 +210,7 @@ func b2s(b bool) string { ...@@ -195,7 +210,7 @@ func b2s(b bool) string {
func newResourceFromLDAP(entry *ldap.Entry, resourceType, nameAttr string) *accountserver.Resource { func newResourceFromLDAP(entry *ldap.Entry, resourceType, nameAttr string) *accountserver.Resource {