Commit 041e8fab authored by ale's avatar ale

Refactor validators to take a RequestContext

In preparation of supporting validators that take into account the
identity of the requestor (as per issue #11).
parent b0fd19b7
Pipeline #6350 passed with stages
in 1 minute and 54 seconds
......@@ -84,7 +84,7 @@ func (r *CreateResourcesRequest) Validate(rctx *RequestContext) error {
}
// Validate the resource.
if err := rctx.resourceValidator.validateResource(rctx.Context, rsrc, tplUser, true); err != nil {
if err := rctx.resourceValidator.validateResource(rctx, rsrc, tplUser, true); err != nil {
log.Printf("validation error while creating resource %s: %v", rsrc.String(), err)
return err
}
......@@ -177,12 +177,12 @@ func (r *CreateUserRequest) Validate(rctx *RequestContext) error {
// Validate the user *and* all resources. The request must contain at
// least one email resource with the same name as the user.
for _, rsrc := range r.User.Resources {
if err := rctx.resourceValidator.validateResource(rctx.Context, rsrc, r.User, true); err != nil {
if err := rctx.resourceValidator.validateResource(rctx, rsrc, r.User, true); err != nil {
log.Printf("validation error while creating resource %+v: %v", rsrc, err)
return err
}
}
if err := rctx.userValidator(rctx.Context, r.User, true); err != nil {
if err := rctx.userValidator(rctx, r.User, true); err != nil {
log.Printf("validation error while creating user %+v: %v", r.User, err)
return err
}
......
......@@ -135,7 +135,7 @@ func (r *AdminUpdateResourceRequest) Validate(rctx *RequestContext) error {
}
// Validate the resource.
if err := rctx.resourceValidator.validateResource(rctx.Context, r.Resource, tplUser, false); err != nil {
if err := rctx.resourceValidator.validateResource(rctx, r.Resource, tplUser, false); err != nil {
return err
}
......@@ -210,7 +210,7 @@ func (r *CheckResourceAvailabilityRequest) Serve(rctx *RequestContext) (interfac
}
var resp CheckResourceAvailabilityResponse
if err := check(rctx.Context, r.Name); err == nil {
if err := check(rctx, r.Name); err == nil {
resp.Available = true
}
return &resp, nil
......@@ -324,7 +324,7 @@ func (r *AddEmailAliasRequest) Validate(rctx *RequestContext) error {
return errors.New("too many aliases")
}
if err := rctx.fieldValidators.newEmail(rctx.Context, r.Addr); err != nil {
if err := rctx.fieldValidators.newEmail(rctx, r.Addr); err != nil {
return newValidationError(nil, "addr", err.Error())
}
return nil
......
......@@ -72,7 +72,7 @@ func (r *ChangeUserPasswordRequest) Sanitize() {
// Validate the request.
func (r *ChangeUserPasswordRequest) Validate(rctx *RequestContext) error {
if err := rctx.fieldValidators.password(rctx.Context, r.Password); err != nil {
if err := rctx.fieldValidators.password(rctx, r.Password); err != nil {
return newValidationError(nil, "password", err.Error())
}
if r.Password == r.CurPassword {
......@@ -125,7 +125,7 @@ func (r *AccountRecoveryRequest) Sanitize() {
func (r *AccountRecoveryRequest) Validate(rctx *RequestContext) error {
// Only validate the password if attempting recovery.
if r.RecoveryPassword != "" {
if err := rctx.fieldValidators.password(rctx.Context, r.Password); err != nil {
if err := rctx.fieldValidators.password(rctx, r.Password); err != nil {
return newValidationError(nil, "password", err.Error())
}
}
......@@ -248,7 +248,7 @@ func (r *SetAccountRecoveryHintRequest) Validate(rctx *RequestContext) error {
if r.Hint == "" {
err = newValidationError(err, "recovery_hint", "mandatory field")
}
if verr := rctx.fieldValidators.password(rctx.Context, r.Response); verr != nil {
if verr := rctx.fieldValidators.password(rctx, r.Response); verr != nil {
err = newValidationError(err, "recovery_response", verr.Error())
}
return err.orNil()
......
......@@ -18,7 +18,7 @@ import (
// A domainBackend manages the list of domains users are allowed to request services on.
type domainBackend interface {
GetAllowedDomains(context.Context, string) []string
IsAllowedDomain(context.Context, string, string) bool
IsAllowedDomain(*RequestContext, string, string) bool
}
// A shardBackend can return information about available / allowed service shards.
......@@ -147,7 +147,7 @@ func (d *staticDomainBackend) GetAllowedDomains(_ context.Context, kind string)
return d.sets[kind].List()
}
func (d *staticDomainBackend) IsAllowedDomain(_ context.Context, kind, domain string) bool {
func (d *staticDomainBackend) IsAllowedDomain(_ *RequestContext, kind, domain string) bool {
return d.sets[kind].Contains(domain)
}
......@@ -192,10 +192,10 @@ func loadStringSetFromFile(path string) (stringSet, error) {
// ValidatorFunc is the generic interface for unstructured data field
// (string) validators.
type ValidatorFunc func(context.Context, string) error
type ValidatorFunc func(*RequestContext, string) error
func notInSet(set stringSet) ValidatorFunc {
return func(_ context.Context, value string) error {
return func(_ *RequestContext, value string) error {
if set.Contains(value) {
return errors.New("invalid value (blacklisted)")
}
......@@ -204,7 +204,7 @@ func notInSet(set stringSet) ValidatorFunc {
}
func minLength(minLen int) ValidatorFunc {
return func(_ context.Context, value string) error {
return func(_ *RequestContext, value string) error {
if len(value) < minLen {
return fmt.Errorf("value must be at least %d characters", minLen)
}
......@@ -213,7 +213,7 @@ func minLength(minLen int) ValidatorFunc {
}
func maxLength(maxLen int) ValidatorFunc {
return func(_ context.Context, value string) error {
return func(_ *RequestContext, value string) error {
if len(value) > maxLen {
return fmt.Errorf("value must be at most %d characters", maxLen)
}
......@@ -222,7 +222,7 @@ func maxLength(maxLen int) ValidatorFunc {
}
func matchRegexp(rx *regexp.Regexp, errmsg string) ValidatorFunc {
return func(_ context.Context, value string) error {
return func(_ *RequestContext, value string) error {
if !rx.MatchString(value) {
return errors.New(errmsg)
}
......@@ -249,15 +249,15 @@ func matchIdentifierRx() ValidatorFunc {
}
func validateUsernameAndDomain(validateUsername, validateDomain ValidatorFunc) ValidatorFunc {
return func(ctx context.Context, value string) error {
return func(rctx *RequestContext, value string) error {
parts := strings.SplitN(value, "@", 2)
if len(parts) != 2 {
return errors.New("malformed email address")
}
if err := validateUsername(ctx, parts[0]); err != nil {
if err := validateUsername(rctx, parts[0]); err != nil {
return err
}
return validateDomain(ctx, parts[1])
return validateDomain(rctx, parts[1])
}
}
......@@ -270,7 +270,7 @@ func isRegistered(domain string) bool {
return true
}
func validDomainName(_ context.Context, value string) error {
func validDomainName(_ *RequestContext, value string) error {
if !domainRx.MatchString(value) {
return errors.New("invalid domain name")
}
......@@ -340,8 +340,8 @@ func relatedDomains(ctx context.Context, be domainBackend, value string) []FindR
}
func (v *validationContext) isAllowedDomain(rtype string) ValidatorFunc {
return func(ctx context.Context, value string) error {
if !v.domains.IsAllowedDomain(ctx, rtype, value) {
return func(rctx *RequestContext, value string) error {
if !v.domains.IsAllowedDomain(rctx, rtype, value) {
return errors.New("unavailable domain")
}
return nil
......@@ -349,8 +349,8 @@ func (v *validationContext) isAllowedDomain(rtype string) ValidatorFunc {
}
func (v *validationContext) isAvailableEmailAddr() ValidatorFunc {
return func(ctx context.Context, value string) error {
rel := relatedEmails(ctx, v.domains, value)
return func(rctx *RequestContext, value string) error {
rel := relatedEmails(rctx.Context, v.domains, value)
// Run the presence check in a new transaction. Unavailability
// of the server results in a validation error (fail close).
......@@ -359,7 +359,7 @@ func (v *validationContext) isAvailableEmailAddr() ValidatorFunc {
return err
}
// Errors will cause to consider the address unavailable.
if ok, _ := tx.HasAnyResource(ctx, rel); ok { // nolint
if ok, _ := tx.HasAnyResource(rctx.Context, rel); ok { // nolint
return errors.New("address unavailable")
}
return nil
......@@ -367,8 +367,8 @@ func (v *validationContext) isAvailableEmailAddr() ValidatorFunc {
}
func (v *validationContext) isAvailableDomain() ValidatorFunc {
return func(ctx context.Context, value string) error {
rel := relatedDomains(ctx, v.domains, value)
return func(rctx *RequestContext, value string) error {
rel := relatedDomains(rctx.Context, v.domains, value)
// Run the presence check in a new transaction. Unavailability
// of the server results in a validation error (fail close).
......@@ -377,7 +377,7 @@ func (v *validationContext) isAvailableDomain() ValidatorFunc {
return err
}
// Errors will cause to consider the resource unavailable.
if ok, _ := tx.HasAnyResource(ctx, rel); ok { // nolint
if ok, _ := tx.HasAnyResource(rctx.Context, rel); ok { // nolint
return errors.New("address unavailable")
}
return nil
......@@ -385,8 +385,8 @@ func (v *validationContext) isAvailableDomain() ValidatorFunc {
}
func (v *validationContext) isAvailableWebsite() ValidatorFunc {
return func(ctx context.Context, value string) error {
rel := relatedWebsites(ctx, v.domains, value)
return func(rctx *RequestContext, value string) error {
rel := relatedWebsites(rctx.Context, v.domains, value)
// Run the presence check in a new transaction. Unavailability
// of the server results in a validation error (fail close).
......@@ -395,7 +395,7 @@ func (v *validationContext) isAvailableWebsite() ValidatorFunc {
return err
}
// Errors will cause to consider the resource unavailable.
if ok, _ := tx.HasAnyResource(ctx, rel); ok { // nolint
if ok, _ := tx.HasAnyResource(rctx.Context, rel); ok { // nolint
return errors.New("address unavailable")
}
return nil
......@@ -403,7 +403,7 @@ func (v *validationContext) isAvailableWebsite() ValidatorFunc {
}
func (v *validationContext) isAvailableDAV() ValidatorFunc {
return func(ctx context.Context, value string) error {
return func(rctx *RequestContext, value string) error {
rel := []FindResourceRequest{
{
Type: ResourceTypeDAV,
......@@ -418,7 +418,7 @@ func (v *validationContext) isAvailableDAV() ValidatorFunc {
return err
}
// Errors will cause to consider the resource unavailable.
if ok, _ := tx.HasAnyResource(ctx, rel); ok { // nolint
if ok, _ := tx.HasAnyResource(rctx.Context, rel); ok { // nolint
return errors.New("name unavailable")
}
return nil
......@@ -426,7 +426,7 @@ func (v *validationContext) isAvailableDAV() ValidatorFunc {
}
func (v *validationContext) isAvailableDatabase() ValidatorFunc {
return func(ctx context.Context, value string) error {
return func(rctx *RequestContext, value string) error {
rel := []FindResourceRequest{
{
Type: ResourceTypeDatabase,
......@@ -441,7 +441,7 @@ func (v *validationContext) isAvailableDatabase() ValidatorFunc {
return err
}
// Errors will cause to consider the resource unavailable.
if ok, _ := tx.HasAnyResource(ctx, rel); ok { // nolint
if ok, _ := tx.HasAnyResource(rctx.Context, rel); ok { // nolint
return errors.New("name unavailable")
}
return nil
......@@ -488,9 +488,9 @@ func (v *validationContext) validPassword() ValidatorFunc {
}
func allOf(funcs ...ValidatorFunc) ValidatorFunc {
return func(ctx context.Context, value string) error {
return func(rctx *RequestContext, value string) error {
for _, f := range funcs {
if err := f(ctx, value); err != nil {
if err := f(rctx, value); err != nil {
return err
}
}
......@@ -500,9 +500,9 @@ func allOf(funcs ...ValidatorFunc) ValidatorFunc {
// ResourceValidatorFunc is a composite type validator that checks
// various fields in a Resource, depending on its type.
type ResourceValidatorFunc func(context.Context, *Resource, *User, bool) error
type ResourceValidatorFunc func(*RequestContext, *Resource, *User, bool) error
func (v *validationContext) validateResource(_ context.Context, r *Resource, user *User) error {
func (v *validationContext) validateResource(_ *RequestContext, r *Resource, user *User) error {
// Validate the status enum.
switch r.Status {
case ResourceStatusActive, ResourceStatusInactive, ResourceStatusReadonly, ResourceStatusArchived:
......@@ -524,19 +524,19 @@ func (v *validationContext) validateResource(_ context.Context, r *Resource, use
return nil
}
func (v *validationContext) validateShardedResource(ctx context.Context, r *Resource, user *User) error {
if err := v.validateResource(ctx, r, user); err != nil {
func (v *validationContext) validateShardedResource(rctx *RequestContext, r *Resource, user *User) error {
if err := v.validateResource(rctx, r, user); err != nil {
return err
}
if r.Shard == "" {
return errors.New("empty shard")
}
if !v.shards.IsAllowedShard(ctx, r.Type, r.Shard) {
if !v.shards.IsAllowedShard(rctx.Context, r.Type, r.Shard) {
return fmt.Errorf(
"invalid shard %s for resource type %s (allowed: %v)",
r.Shard,
r.Type,
v.shards.GetAllowedShards(ctx, r.Type),
v.shards.GetAllowedShards(rctx.Context, r.Type),
)
}
if r.OriginalShard == "" {
......@@ -549,8 +549,8 @@ func (v *validationContext) validEmailResource() ResourceValidatorFunc {
emailValidator := v.validHostedEmail()
newEmailValidator := v.validHostedNewEmail()
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
......@@ -565,9 +565,9 @@ func (v *validationContext) validEmailResource() ResourceValidatorFunc {
var err error
if isNew {
err = newEmailValidator(ctx, r.Name)
err = newEmailValidator(rctx, r.Name)
} else {
err = emailValidator(ctx, r.Name)
err = emailValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -584,8 +584,8 @@ func (v *validationContext) validListResource() ResourceValidatorFunc {
listValidator := v.validHostedMailingList()
newListValidator := allOf(listValidator, v.isAvailableEmailAddr())
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
if r.List == nil {
......@@ -594,9 +594,9 @@ func (v *validationContext) validListResource() ResourceValidatorFunc {
var err error
if isNew {
err = newListValidator(ctx, r.Name)
err = newListValidator(rctx, r.Name)
} else {
err = listValidator(ctx, r.Name)
err = listValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -613,8 +613,8 @@ func (v *validationContext) validNewsletterResource() ResourceValidatorFunc {
listValidator := v.validHostedMailingList()
newListValidator := allOf(listValidator, v.isAvailableEmailAddr())
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
if r.Newsletter == nil {
......@@ -623,9 +623,9 @@ func (v *validationContext) validNewsletterResource() ResourceValidatorFunc {
var err error
if isNew {
err = newListValidator(ctx, r.Name)
err = newListValidator(rctx, r.Name)
} else {
err = listValidator(ctx, r.Name)
err = listValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -693,8 +693,8 @@ func (v *validationContext) validDomainResource() ResourceValidatorFunc {
v.isAvailableDomain(),
)
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
......@@ -709,9 +709,9 @@ func (v *validationContext) validDomainResource() ResourceValidatorFunc {
var err error
if isNew {
err = newDomainValidator(ctx, r.Name)
err = newDomainValidator(rctx, r.Name)
} else {
err = domainValidator(ctx, r.Name)
err = domainValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -736,8 +736,8 @@ func (v *validationContext) validWebsiteResource() ResourceValidatorFunc {
)
parentValidator := v.isAllowedDomain(ResourceTypeWebsite)
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
......@@ -752,14 +752,14 @@ func (v *validationContext) validWebsiteResource() ResourceValidatorFunc {
var err error
if isNew {
err = newNameValidator(ctx, r.Name)
err = newNameValidator(rctx, r.Name)
} else {
err = nameValidator(ctx, r.Name)
err = nameValidator(rctx, r.Name)
}
if err != nil {
return err
}
if err := parentValidator(ctx, r.Website.ParentDomain); err != nil {
if err := parentValidator(rctx, r.Website.ParentDomain); err != nil {
return err
}
......@@ -776,8 +776,8 @@ func (v *validationContext) validDAVResource() ResourceValidatorFunc {
davValidator,
v.isAvailableDAV(),
)
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
......@@ -788,9 +788,9 @@ func (v *validationContext) validDAVResource() ResourceValidatorFunc {
var err error
if isNew {
err = newDAVValidator(ctx, r.Name)
err = newDAVValidator(rctx, r.Name)
} else {
err = davValidator(ctx, r.Name)
err = davValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -816,8 +816,8 @@ func (v *validationContext) validDatabaseResource() ResourceValidatorFunc {
dbValidator,
v.isAvailableDatabase(),
)
return func(ctx context.Context, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(ctx, r, user); err != nil {
return func(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
if err := v.validateShardedResource(rctx, r, user); err != nil {
return err
}
......@@ -835,9 +835,9 @@ func (v *validationContext) validDatabaseResource() ResourceValidatorFunc {
var err error
if isNew {
err = newDBValidator(ctx, r.Name)
err = newDBValidator(rctx, r.Name)
} else {
err = dbValidator(ctx, r.Name)
err = dbValidator(rctx, r.Name)
}
if err != nil {
return err
......@@ -869,7 +869,7 @@ func newResourceValidator(v *validationContext) *resourceValidator {
}
}
func (v *resourceValidator) validateResource(ctx context.Context, r *Resource, user *User, isNew bool) error {
func (v *resourceValidator) validateResource(rctx *RequestContext, r *Resource, user *User, isNew bool) error {
// Obvious basic sanity checks on the Resource parameters.
if r.Name == "" {
return errors.New("resource name unset")
......@@ -883,7 +883,7 @@ func (v *resourceValidator) validateResource(ctx context.Context, r *Resource, u
return fmt.Errorf("unknown resource type '%s'", r.Type)
}
return rv(ctx, r, user, isNew)
return rv(rctx, r, user, isNew)
}
// Common validators for specific field types.
......@@ -900,7 +900,7 @@ func newFieldValidators(v *validationContext) *fieldValidators {
}
// UserValidatorFunc is a compound validator for User objects.
type UserValidatorFunc func(context.Context, *User, bool) error
type UserValidatorFunc func(*RequestContext, *User, bool) error
// Verify that user-level invariants are respected. This check can be applied
// to new or existing objects.
......@@ -952,12 +952,12 @@ func checkUserInvariants(user *User) error {
func (v *validationContext) validUser() UserValidatorFunc {
nameValidator := v.validHostedEmail()
newNameValidator := v.validHostedNewEmail()
return func(ctx context.Context, user *User, isNew bool) error {
return func(rctx *RequestContext, user *User, isNew bool) error {
var err error
if isNew {
err = newNameValidator(ctx, user.Name)
err = newNameValidator(rctx, user.Name)
} else {
err = nameValidator(ctx, user.Name)
err = nameValidator(rctx, user.Name)
}
if err != nil {
return err
......
......@@ -13,7 +13,7 @@ type validationTestData struct {
func runValidationTest(t testing.TB, v ValidatorFunc, testData []validationTestData) {
for _, td := range testData {
err := v(context.Background(), td.value)
err := v(&RequestContext{Context: context.Background()}, td.value)
if (err == nil && !td.ok) || (err != nil && td.ok) {
t.Errorf("test for '%s' failed: expected %v, got error %v", td.value, td.ok, err)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment