diff --git a/server/authserver.go b/server/authserver.go index c67d33fff0fc758b2ee8c1aed955cf3fc6a72303..c5fb3532f7ffb7342a5fc07c0de4da4980e1799b 100644 --- a/server/authserver.go +++ b/server/authserver.go @@ -71,16 +71,122 @@ type U2FShortTermStorage interface { GetUserChallenge(string) (*u2f.Challenge, bool) } +type ratelimitKeyFunc func(*User, *auth.Request) string + +func usernameKey(user *User, _ *auth.Request) string { + return user.Name +} + +func ipAddrKey(_ *User, req *auth.Request) string { + if req.DeviceInfo != nil { + return req.DeviceInfo.RemoteAddr + } + return "" +} + +type authRatelimiterConfig struct { + Limit int `yaml:"limit"` + Period int `yaml:"period"` + BlacklistTime int `yaml:"blacklist_for"` + OnFailure bool `yaml:"on_failure"` + Keys []string `yaml:"keys"` + + keyFuncs []ratelimitKeyFunc +} + +func (r *authRatelimiterConfig) compile() error { + for _, k := range r.Keys { + var f ratelimitKeyFunc + switch k { + case "ip": + f = ipAddrKey + case "user": + f = usernameKey + default: + return fmt.Errorf("unknown key %s", k) + } + r.keyFuncs = append(r.keyFuncs, f) + } + return nil +} + +const rlKeySep = ";" + +func (r *authRatelimiterConfig) key(user *User, req *auth.Request) string { + if len(r.keyFuncs) == 1 { + return r.keyFuncs[0](user, req) + } + + var parts []string + for _, f := range r.keyFuncs { + parts = append(parts, f(user, req)) + } + return strings.Join(parts, rlKeySep) +} + +type authRatelimiter struct { + *authRatelimiterConfig + rl *Ratelimiter +} + +func (r *authRatelimiter) AllowIncr(user *User, req *auth.Request) bool { + return r.rl.AllowIncr(r.key(user, req)) +} + +type authBlacklist struct { + *authRatelimiterConfig + bl *Blacklist +} + +func (b *authBlacklist) Allow(user *User, req *auth.Request) bool { + return b.bl.Allow(b.key(user, req)) +} + +func (b *authBlacklist) Incr(user *User, req *auth.Request, resp *auth.Response) { + if b.OnFailure && resp.Status == auth.StatusOK { + return + } + b.bl.Incr(b.key(user, req)) +} + +type requestFilterFunc func(*User, *auth.Request, *auth.Response) *auth.Response + // BackendSpec specifies backend-specific configuration for a service. type BackendSpec struct { LDAPSpec *LDAPServiceConfig `yaml:"ldap"` FileSpec string `yaml:"file"` } -// ServiceConfig configures authentication backends for a service. +// ServiceConfig defines the authentication backends for a service. type ServiceConfig struct { BackendSpecs []*BackendSpec `yaml:"backends"` ChallengeResponse bool `yaml:"challenge_response"` + + Ratelimits []string `yaml:"rate_limits"` + + rl []*authRatelimiter + bl []*authBlacklist + filters []requestFilterFunc +} + +func (c *ServiceConfig) checkRateLimits(user *User, req *auth.Request) bool { + for _, rl := range c.rl { + if !rl.AllowIncr(user, req) { + return false + } + } + for _, bl := range c.bl { + if !bl.Allow(user, req) { + return false + } + } + return true +} + +func (c *ServiceConfig) notifyBlacklists(user *User, req *auth.Request, resp *auth.Response) { + for _, bl := range c.bl { + bl.Incr(user, req, resp) + } } // Config for the authentication server. @@ -94,6 +200,14 @@ type Config struct { // Service-specific configuration. Services map[string]*ServiceConfig `yaml:"services"` + // Named rate limiter configurations. + RateLimiters map[string]*authRatelimiterConfig `yaml:"rate_limits"` + + // Runtime versions of the above. These objects are shared by + // all services, as they contain the actual map data. + rl map[string]*Ratelimiter + bl map[string]*Blacklist + path string } @@ -115,6 +229,39 @@ func (c *Config) relativePath(path string) string { return filepath.Join(filepath.Dir(c.path), path) } +func (c *Config) compile() error { + c.rl = make(map[string]*Ratelimiter) + c.bl = make(map[string]*Blacklist) + for name, params := range c.RateLimiters { + if err := params.compile(); err != nil { + return err + } + if params.BlacklistTime > 0 { + c.bl[name] = newBlacklist(params.Limit, params.Period, params.BlacklistTime) + } else { + c.rl[name] = newRatelimiter(params.Limit, params.Period) + } + } + + for _, sc := range c.Services { + for _, name := range sc.Ratelimits { + config, ok := c.RateLimiters[name] + if !ok { + return fmt.Errorf("unknown rate limiter %s", name) + } + if rl, ok := c.rl[name]; ok { + sc.rl = append(sc.rl, &authRatelimiter{config, rl}) + } else if bl, ok := c.bl[name]; ok { + sc.bl = append(sc.bl, &authBlacklist{config, bl}) + } else { + panic("can't find rl/bl") + } + } + } + + return nil +} + // LoadConfig loads the configuration from a YAML-encoded file. func LoadConfig(path string) (*Config, error) { data, err := ioutil.ReadFile(path) @@ -125,6 +272,10 @@ func LoadConfig(path string) (*Config, error) { if err := yaml.Unmarshal(data, &config); err != nil { return nil, err } + log.Printf("configuration: %+v", config) + if err := config.compile(); err != nil { + return nil, err + } return &config, nil } @@ -183,6 +334,17 @@ func (s *Server) getServiceConfig(service string) (*ServiceConfig, bool) { return c, ok } +func (s *Server) getUser(ctx context.Context, serviceConfig *ServiceConfig, username string) (*User, bool) { + for _, spec := range serviceConfig.BackendSpecs { + for _, b := range s.backends { + if user, ok := b.GetUser(ctx, spec, username); ok { + return user, true + } + } + } + return nil, false +} + // Authenticate a user with the parameters specified in the incoming AuthRequest. func (s *Server) Authenticate(ctx context.Context, req *auth.Request) *auth.Response { serviceConfig, ok := s.getServiceConfig(req.Service) @@ -191,22 +353,28 @@ func (s *Server) Authenticate(ctx context.Context, req *auth.Request) *auth.Resp return newError() } - var user *User -outer: - for _, spec := range serviceConfig.BackendSpecs { - for _, b := range s.backends { - if user, ok = b.GetUser(ctx, spec, req.Username); ok { - break outer - } - } - } - + user, ok := s.getUser(ctx, serviceConfig, req.Username) if !ok { // User is unknown to all backends. log.Printf("unknown user %s", req.Username) return newError() } + // Apply rate limiting and blacklisting _before_ invoking the + // authentication handlers, as they may be CPU intensive. + if allowed := serviceConfig.checkRateLimits(user, req); !allowed { + return newError() + } + + resp := s.authenticateUser(req, serviceConfig, user) + + // Notify blacklists of the result. + serviceConfig.notifyBlacklists(user, req, resp) + + return resp +} + +func (s *Server) authenticateUser(req *auth.Request, serviceConfig *ServiceConfig, user *User) *auth.Response { // Verify different credentials depending on whether the user // has 2FA enabled or not, and on whether the service itself // supports challenge-response authentication. @@ -218,7 +386,16 @@ outer: resp = s.authenticateUserWithASP(user, req) } } else { - resp = s.authenticateUser(user, req) + resp = s.authenticateUserWithPassword(user, req) + } + + // Process the response through filters (device info checks, + // etc) that may or may not change the response itself. + for _, f := range serviceConfig.filters { + if resp.Status == auth.StatusError { + break + } + resp = f(user, req, resp) } // If the response is successful, augment it with user information. @@ -229,7 +406,7 @@ outer: return resp } -func (s *Server) authenticateUser(user *User, req *auth.Request) *auth.Response { +func (s *Server) authenticateUserWithPassword(user *User, req *auth.Request) *auth.Response { // Ok we only need to check the password here. if checkPassword(req.Password, user.EncryptedPassword) { return newOK() diff --git a/server/authserver_test.go b/server/authserver_test.go index e96ce19515822293896565ff8675d4aba494082e..710ceb80a2a8405da79601233df4710614902570 100644 --- a/server/authserver_test.go +++ b/server/authserver_test.go @@ -89,6 +89,24 @@ services: backends: - { file: users.yml } ` + + testConfigStrWithRatelimit = `--- +enabled_backends: + - file +services: + test: + backends: + - { file: users.yml } + rate_limits: + - failed_login_bl +rate_limits: + failed_login_bl: + limit: 10 + period: 300 + blacklist_for: 3600 + on_failure: true + keys: [user] +` ) func runAuthenticationTest(t *testing.T, client client.Client) { @@ -160,3 +178,56 @@ func TestAuthServer(t *testing.T) { defer s.Close() runAuthenticationTest(t, &clientAdapter{s.srv}) } + +func TestAuthServer_Blacklist(t *testing.T) { + s := createTestServer(t, map[string]string{ + "users.yml": testUsersFileStr, + "config.yml": testConfigStrWithRatelimit, + }) + defer s.Close() + c := &clientAdapter{s.srv} + + // Trigger the failed login blacklist, then verify that the + // user is blacklisted even when trying with the right password. + for i := 0; i < 100; i++ { + c.Authenticate(context.Background(), &auth.Request{ + Service: "test", + Username: "testuser", + Password: []byte("bad_password"), + }) + } + resp, _ := c.Authenticate(context.Background(), &auth.Request{ + Service: "test", + Username: "testuser", + Password: []byte("password"), + }) + if resp.Status != auth.StatusError { + t.Fatalf("user was not blacklisted: %v", resp) + } +} + +func TestAuthServer_Blacklist_BelowLimit(t *testing.T) { + s := createTestServer(t, map[string]string{ + "users.yml": testUsersFileStr, + "config.yml": testConfigStrWithRatelimit, + }) + defer s.Close() + c := &clientAdapter{s.srv} + + // A small number of failures should not trigger the blacklist. + for i := 0; i < 8; i++ { + c.Authenticate(context.Background(), &auth.Request{ + Service: "test", + Username: "testuser", + Password: []byte("bad_password"), + }) + } + resp, _ := c.Authenticate(context.Background(), &auth.Request{ + Service: "test", + Username: "testuser", + Password: []byte("password"), + }) + if resp.Status != auth.StatusOK { + t.Fatalf("user was incorrectly blacklisted: %+v", s.srv.config.Services["test"]) + } +} diff --git a/server/ratelimit.go b/server/ratelimit.go new file mode 100644 index 0000000000000000000000000000000000000000..34e60e8285a8f62fa7d4f569f563eeedd9f2992e --- /dev/null +++ b/server/ratelimit.go @@ -0,0 +1,145 @@ +package server + +import ( + "sync" + "time" +) + +// Try to use as little memory as possible for each entry: use a UNIX +// timestamp instead of a time.Time, and use an int32 as a saturating +// counter. +type ratelimitDatum struct { + stamp int64 + counter int32 +} + +func (d ratelimitDatum) age(now int64) int64 { + return now - d.stamp +} + +// Ratelimiter is a simple counter-based rate limiter, allowing the +// first N requests over each period of time T. +type Ratelimiter struct { + limit int32 + period int64 + + mx sync.Mutex + c map[string]ratelimitDatum +} + +func newRatelimiter(limit, period int) *Ratelimiter { + r := &Ratelimiter{ + limit: int32(limit), + period: int64(period), + c: make(map[string]ratelimitDatum), + } + go r.expungeThread() + return r +} + +// AllowIncr performs a check and an increment at the same time, while +// holding a mutex, so it is robust in face of concurrent requests. +func (r *Ratelimiter) AllowIncr(key string) bool { + if key == "" { + return true + } + r.mx.Lock() + d := r.get(key) + var allowed bool + if d.counter <= r.limit { + allowed = true + d.counter++ + r.set(key, d) + } + r.mx.Unlock() + + return allowed +} + +func (r *Ratelimiter) get(key string) ratelimitDatum { + now := time.Now().Unix() + d, ok := r.c[key] + if !ok || d.age(now) > r.period { + d = ratelimitDatum{stamp: now} + r.c[key] = d + } + return d +} + +func (r *Ratelimiter) set(key string, d ratelimitDatum) { + r.c[key] = d +} + +func (r *Ratelimiter) expunge() { + cutoff := time.Now().Unix() - 2*r.period + r.mx.Lock() + for k, d := range r.c { + if d.stamp < cutoff { + delete(r.c, k) + } + } + r.mx.Unlock() +} + +var ratelimitExpungePeriod = 300 * time.Second + +func (r *Ratelimiter) expungeThread() { + for range time.NewTicker(ratelimitExpungePeriod).C { + r.expunge() + } +} + +// Blacklist can blacklist keys whose request rate is above a +// specified threshold. +type Blacklist struct { + r *Ratelimiter + bl map[string]int64 + blTime int64 +} + +func newBlacklist(limit, period, blacklistTime int) *Blacklist { + return &Blacklist{ + r: newRatelimiter(limit, period), + bl: make(map[string]int64), + blTime: int64(blacklistTime), + } +} + +// Allow returns true if this request (identified by the given key) +// should be allowed. +func (b *Blacklist) Allow(key string) bool { + if key == "" { + return true + } + + b.r.mx.Lock() + deadline, ok := b.bl[key] + if ok && deadline < time.Now().Unix() { + delete(b.bl, key) + ok = false + } + b.r.mx.Unlock() + return !ok +} + +// Incr increments the counter for the given key for the current time +// period. +func (b *Blacklist) Incr(key string) { + if key == "" { + return + } + + // Count one higher than limit, and trigger the blacklist when + // we reach that. + limitp1 := b.r.limit + 1 + + b.r.mx.Lock() + d := b.r.get(key) + if d.counter < limitp1 { + d.counter++ + b.r.set(key, d) + } else if d.counter == limitp1 { + b.bl[key] = time.Now().Unix() + b.blTime + } + b.r.mx.Unlock() +}