diff --git a/actions.go b/actions.go index 3dfc65e7827381e89f51eb315c5d98b25385f721..9bb16400ca5eb18fce33eb39207d5fa2abfb98f9 100644 --- a/actions.go +++ b/actions.go @@ -143,7 +143,7 @@ type ChangeUserPasswordRequest struct { // Validate the request. func (r *ChangeUserPasswordRequest) Validate(ctx context.Context, s *AccountService) error { - return s.passwordValidator(ctx, r.Password) + return s.fieldValidators.password(ctx, r.Password) } // ChangeUserPassword updates a user's password. It will also take @@ -169,7 +169,7 @@ type PasswordRecoveryRequest struct { // Validate the request. func (r *PasswordRecoveryRequest) Validate(ctx context.Context, s *AccountService) error { - return s.passwordValidator(ctx, r.Password) + return s.fieldValidators.password(ctx, r.Password) } // RecoverPassword lets users reset their password by providing @@ -198,7 +198,7 @@ type ResetPasswordRequest struct { // Validate the request. func (r *ResetPasswordRequest) Validate(ctx context.Context, s *AccountService) error { - return s.passwordValidator(ctx, r.Password) + return s.fieldValidators.password(ctx, r.Password) } // ResetPassword is an admin operation to forcefully reset the @@ -231,7 +231,7 @@ type SetPasswordRecoveryHintRequest struct { // Validate the request. func (r *SetPasswordRecoveryHintRequest) Validate(ctx context.Context, s *AccountService) error { - return s.passwordValidator(ctx, r.Response) + return s.fieldValidators.password(ctx, r.Response) } // SetPasswordRecoveryHint lets users set the password recovery hint @@ -583,7 +583,7 @@ func (r *AddEmailAliasRequest) Validate(ctx context.Context, s *AccountService) if r.ResourceID.Type() != ResourceTypeEmail { return errors.New("this operation only works on email resources") } - return s.emailValidator(ctx, r.Addr) + return s.fieldValidators.email(ctx, r.Addr) } const maxEmailAliases = 5 diff --git a/actions_test.go b/actions_test.go index a3338455e0ab06a7ded8c5fb1d68a5884233680d..5adf9f08aeadaac43c4bdae61a0fc10c5d5cabec 100644 --- a/actions_test.go +++ b/actions_test.go @@ -32,6 +32,7 @@ func (b *fakeBackend) GetResource(_ context.Context, resourceID ResourceID) (*Re } func (b *fakeBackend) UpdateResource(_ context.Context, r *Resource) error { + b.resources[r.ID.User()][r.ID.String()] = r return nil } @@ -80,6 +81,15 @@ func (b *fakeBackend) DeleteUserTOTPSecret(_ context.Context, user *User) error } func (b *fakeBackend) HasAnyResource(_ context.Context, rsrcs []FindResourceRequest) (bool, error) { + for _, fr := range rsrcs { + for _, ur := range b.resources { + for _, r := range ur { + if r.ID.Type() == fr.Type && r.ID.Name() == fr.Name { + return true, nil + } + } + } + } return false, nil } @@ -310,7 +320,7 @@ func TestService_AddEmailAlias(t *testing.T) { func TestService_Create(t *testing.T) { svc, tx := testService("") - _, err := svc.CreateResources(context.Background(), tx, &CreateResourcesRequest{ + req := &CreateResourcesRequest{ Resources: []*Resource{ &Resource{ ID: NewResourceID(ResourceTypeDomain, "testuser", "example2.com"), @@ -325,8 +335,17 @@ func TestService_Create(t *testing.T) { }, }, }, - }) + } + + // The request should succeed the first time around. + _, err := svc.CreateResources(context.Background(), tx, req) if err != nil { t.Fatal("CreateResources", err) } + + // The object already exists, so the same request should fail now. + _, err = svc.CreateResources(context.Background(), tx, req) + if err == nil { + t.Fatal("creating a duplicate resource did not fail") + } } diff --git a/config.go b/config.go index 0576517418ae460d7d32e241c43aa4162bb301ec..1cf2edf4f2f3989d937df71807dda507cb7c7b30 100644 --- a/config.go +++ b/config.go @@ -31,7 +31,7 @@ func (c *Config) domainBackend() domainBackend { return b } -func (c *Config) validationConfig(be Backend) (*validationConfig, error) { +func (c *Config) validationContext(be Backend) (*validationContext, error) { fu, err := newStringSetFromFileOrList(c.ForbiddenUsernames, c.ForbiddenUsernamesFile) if err != nil { return nil, err @@ -40,7 +40,7 @@ func (c *Config) validationConfig(be Backend) (*validationConfig, error) { if err != nil { return nil, err } - return &validationConfig{ + return &validationContext{ forbiddenUsernames: fu, forbiddenPasswords: fp, minPasswordLength: 6, diff --git a/service.go b/service.go index 8650ef8489d16258248830d22ed5864d93f969c4..c0098fcff64de5e894cfb863722ca1aebaae0e0b 100644 --- a/service.go +++ b/service.go @@ -70,9 +70,7 @@ type AccountService struct { audit auditLogger - passwordValidator ValidatorFunc - emailValidator ValidatorFunc - listValidator ValidatorFunc + fieldValidators *fieldValidators resourceValidator *resourceValidator } @@ -95,14 +93,12 @@ func newAccountServiceWithSSO(backend Backend, config *Config, ssoValidator sso. audit: &syslogAuditLogger{}, } - validationConfig, err := config.validationConfig(backend) + vc, err := config.validationContext(backend) if err != nil { return nil, err } - s.passwordValidator = validPassword(validationConfig) - s.emailValidator = validHostedEmail(validationConfig) - s.listValidator = validHostedMailingList(validationConfig) - s.resourceValidator = newResourceValidator(validationConfig) + s.fieldValidators = newFieldValidators(vc) + s.resourceValidator = newResourceValidator(vc) return s, nil } diff --git a/validators.go b/validators.go index efaf87cb2fe7155854dd6c6030cb339711d68d44..9e3bbfd374dc5138e39ef3807789bbb0efca1912 100644 --- a/validators.go +++ b/validators.go @@ -14,11 +14,13 @@ import ( // A domainBackend manages the list of domains users are allowed to request services on. type domainBackend interface { - GetAvailableDomains(context.Context, string) []string - IsAvailableDomain(context.Context, string, string) bool + GetAllowedDomains(context.Context, string) []string + IsAllowedDomain(context.Context, string, string) bool } -type validationConfig struct { +// The validationContext contains all configuration and backends that +// the various validation functions will need. +type validationContext struct { forbiddenUsernames stringSet forbiddenPasswords stringSet minPasswordLength int @@ -62,11 +64,11 @@ type staticDomainBackend struct { sets map[string]stringSet } -func (d *staticDomainBackend) GetAvailableDomains(_ context.Context, kind string) []string { +func (d *staticDomainBackend) GetAllowedDomains(_ context.Context, kind string) []string { return d.sets[kind].List() } -func (d *staticDomainBackend) IsAvailableDomain(_ context.Context, kind, domain string) bool { +func (d *staticDomainBackend) IsAllowedDomain(_ context.Context, kind, domain string) bool { return d.sets[kind].Contains(domain) } @@ -198,7 +200,7 @@ func relatedEmails(ctx context.Context, be domainBackend, addr string) []FindRes // Mailing lists must have unique names regardless of the domain, so we // add potential conflicts for mailing lists with the same name over all // list-enabled domains. - for _, d := range be.GetAvailableDomains(ctx, ResourceTypeMailingList) { + for _, d := range be.GetAllowedDomains(ctx, ResourceTypeMailingList) { rel = append(rel, FindResourceRequest{ Type: ResourceTypeMailingList, Name: fmt.Sprintf("%s@%s", user, d), @@ -207,59 +209,58 @@ func relatedEmails(ctx context.Context, be domainBackend, addr string) []FindRes return rel } -func splitSubsite(value string) (string, string) { - parts := strings.SplitN(value, "/", 2) - return parts[0], parts[1] -} - -func isSubsite(value string) bool { - return strings.Contains(value, "/") +func relatedWebsites(ctx context.Context, be domainBackend, value string) []FindResourceRequest { + // Ignore the parent domain (websites share a global namespace). + return []FindResourceRequest{ + { + Type: ResourceTypeWebsite, + Name: value, + }, + } } -func relatedWebsites(ctx context.Context, be domainBackend, value string) []FindResourceRequest { - var resourceIDs []FindResourceRequest - if isSubsite(value) { - _, path := splitSubsite(value) - for _, d := range be.GetAvailableDomains(ctx, ResourceTypeWebsite) { - resourceIDs = append(resourceIDs, FindResourceRequest{ - Type: ResourceTypeWebsite, - Name: fmt.Sprintf("%s/%s", d, path), - }) - } - } else { - resourceIDs = append(resourceIDs, FindResourceRequest{ +func relatedDomains(ctx context.Context, be domainBackend, value string) []FindResourceRequest { + return []FindResourceRequest{ + { Type: ResourceTypeDomain, Name: value, - }) + }, } - return resourceIDs } -func isAvailableEmailHostingDomain(config *validationConfig) ValidatorFunc { +func (v *validationContext) isAllowedDomain(rtype string) ValidatorFunc { return func(ctx context.Context, value string) error { - if !config.domains.IsAvailableDomain(ctx, ResourceTypeEmail, value) { + if !v.domains.IsAllowedDomain(ctx, rtype, value) { return errors.New("unavailable domain") } return nil } } -func isAvailableMailingListDomain(config *validationConfig) ValidatorFunc { +func (v *validationContext) isAvailableEmailAddr() ValidatorFunc { return func(ctx context.Context, value string) error { - if !config.domains.IsAvailableDomain(ctx, ResourceTypeMailingList, value) { - return errors.New("unavailable domain") + rel := relatedEmails(ctx, v.domains, value) + + // Run the presence check in a new transaction. Unavailability + // of the server results in a validation error (fail close). + tx, err := v.backend.NewTransaction() + if err != nil { + return err + } + if ok, _ := tx.HasAnyResource(ctx, rel); ok { + return errors.New("address unavailable") } return nil } } -func isAvailableEmailAddr(config *validationConfig) ValidatorFunc { +func (v *validationContext) isAvailableDomain() ValidatorFunc { return func(ctx context.Context, value string) error { - rel := relatedEmails(ctx, config.domains, value) + rel := relatedDomains(ctx, v.domains, value) // Run the presence check in a new transaction. Unavailability // of the server results in a validation error (fail close). - tx, err := config.backend.NewTransaction() + tx, err := v.backend.NewTransaction() if err != nil { return err } @@ -270,13 +271,13 @@ func isAvailableEmailAddr(config *validationConfig) ValidatorFunc { } } -func isAvailableDomain(config *validationConfig) ValidatorFunc { +func (v *validationContext) isAvailableWebsite() ValidatorFunc { return func(ctx context.Context, value string) error { - rel := relatedWebsites(ctx, config.domains, value) + rel := relatedWebsites(ctx, v.domains, value) // Run the presence check in a new transaction. Unavailability // of the server results in a validation error (fail close). - tx, err := config.backend.NewTransaction() + tx, err := v.backend.NewTransaction() if err != nil { return err } @@ -287,39 +288,31 @@ func isAvailableDomain(config *validationConfig) ValidatorFunc { } } -func validHostedEmail(config *validationConfig) ValidatorFunc { +func (v *validationContext) validHostedEmail() ValidatorFunc { return allOf( validateUsernameAndDomain( - allOf(matchUsernameRx(), minLength(4), maxLength(64), notInSet(config.forbiddenUsernames)), - allOf(isAvailableEmailHostingDomain(config)), + allOf(matchUsernameRx(), minLength(4), maxLength(64), notInSet(v.forbiddenUsernames)), + allOf(v.isAllowedDomain(ResourceTypeEmail)), ), - isAvailableEmailAddr(config), + v.isAvailableEmailAddr(), ) } -func validHostedMailingList(config *validationConfig) ValidatorFunc { +func (v *validationContext) validHostedMailingList() ValidatorFunc { return allOf( validateUsernameAndDomain( - allOf(matchUsernameRx(), minLength(4), maxLength(64), notInSet(config.forbiddenUsernames)), - allOf(isAvailableMailingListDomain(config)), + allOf(matchUsernameRx(), minLength(4), maxLength(64), notInSet(v.forbiddenUsernames)), + allOf(v.isAllowedDomain(ResourceTypeMailingList)), ), - isAvailableEmailAddr(config), + v.isAvailableEmailAddr(), ) } -func validHostedDomain(config *validationConfig) ValidatorFunc { +func (v *validationContext) validPassword() ValidatorFunc { return allOf( - minLength(6), - validDomainName, - isAvailableDomain(config), - ) -} - -func validPassword(config *validationConfig) ValidatorFunc { - return allOf( - minLength(config.minPasswordLength), - maxLength(config.maxPasswordLength), - notInSet(config.forbiddenPasswords), + minLength(v.minPasswordLength), + maxLength(v.maxPasswordLength), + notInSet(v.forbiddenPasswords), ) } @@ -338,16 +331,16 @@ func allOf(funcs ...ValidatorFunc) ValidatorFunc { // various fields in a Resource, depending on its type. type ResourceValidatorFunc func(ctx context.Context, r *Resource) error -func validEmailResource(config *validationConfig) ResourceValidatorFunc { - emailValidator := validHostedEmail(config) +func (v *validationContext) validEmailResource() ResourceValidatorFunc { + emailValidator := v.validHostedEmail() return func(ctx context.Context, r *Resource) error { return emailValidator(ctx, r.ID.Name()) } } -func validListResource(config *validationConfig) ResourceValidatorFunc { - listValidator := validHostedMailingList(config) +func (v *validationContext) validListResource() ResourceValidatorFunc { + listValidator := v.validHostedMailingList() return func(ctx context.Context, r *Resource) error { if err := listValidator(ctx, r.ID.Name()); err != nil { @@ -360,28 +353,63 @@ func validListResource(config *validationConfig) ResourceValidatorFunc { } } -func validDomain(config *validationConfig) ResourceValidatorFunc { - domainValidator := validHostedDomain(config) +func (v *validationContext) validDomainResource() ResourceValidatorFunc { + domainValidator := allOf( + minLength(6), + validDomainName, + v.isAvailableDomain(), + ) return func(ctx context.Context, r *Resource) error { return domainValidator(ctx, r.ID.Name()) } } +func (v *validationContext) validWebsiteResource() ResourceValidatorFunc { + nameValidator := allOf( + minLength(6), + matchSitenameRx(), + v.isAvailableWebsite(), + ) + parentValidator := v.isAllowedDomain(ResourceTypeWebsite) + + return func(ctx context.Context, r *Resource) error { + if err := nameValidator(ctx, r.ID.Name()); err != nil { + return err + } + return parentValidator(ctx, r.Website.ParentDomain) + } +} + +// Validator for arbitrary resource types. type resourceValidator struct { - v map[string]ResourceValidatorFunc + rvs map[string]ResourceValidatorFunc } -func newResourceValidator(config *validationConfig) *resourceValidator { +func newResourceValidator(v *validationContext) *resourceValidator { return &resourceValidator{ - v: map[string]ResourceValidatorFunc{ - ResourceTypeEmail: validEmailResource(config), - ResourceTypeMailingList: validListResource(config), - ResourceTypeDomain: validDomain(config), + rvs: map[string]ResourceValidatorFunc{ + ResourceTypeEmail: v.validEmailResource(), + ResourceTypeMailingList: v.validListResource(), + ResourceTypeDomain: v.validDomainResource(), + ResourceTypeWebsite: v.validWebsiteResource(), }, } } func (v *resourceValidator) validateResource(ctx context.Context, r *Resource) error { - return v.v[r.ID.Type()](ctx, r) + return v.rvs[r.ID.Type()](ctx, r) +} + +// Common validators for specific field types. +type fieldValidators struct { + password ValidatorFunc + email ValidatorFunc +} + +func newFieldValidators(v *validationContext) *fieldValidators { + return &fieldValidators{ + password: v.validPassword(), + email: v.validHostedEmail(), + } } diff --git a/validators_test.go b/validators_test.go index 833e3cda3b445f7ee3587d8fe4260176e27b9d92..95845bb80ae02d3146d41393f09372acb36546ff 100644 --- a/validators_test.go +++ b/validators_test.go @@ -97,8 +97,8 @@ func (f *fakeCheckTX) HasAnyResource(_ context.Context, reqs []FindResourceReque return false, nil } -func newTestValidationConfig(entries ...string) *validationConfig { - return &validationConfig{ +func newTestValidationConfig(entries ...string) *validationContext { + return &validationContext{ forbiddenUsernames: newStringSetFromList(entries), } } @@ -128,7 +128,7 @@ func TestValidator_HostedEmail(t *testing.T) { {Type: ResourceTypeEmail, Name: "existing@example.com"}, }) vc.domains = newFakeDomainBackend("example.com") - runValidationTest(t, validHostedEmail(vc), td) + runValidationTest(t, vc.validHostedEmail(), td) } func TestValidator_HostedMailingList(t *testing.T) { @@ -146,5 +146,5 @@ func TestValidator_HostedMailingList(t *testing.T) { {Type: ResourceTypeMailingList, Name: "existing@domain2.com"}, }) vc.domains = newFakeDomainBackend("domain1.com", "domain2.com") - runValidationTest(t, validHostedMailingList(vc), td) + runValidationTest(t, vc.validHostedMailingList(), td) }