diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py index abcb0170f69c2c41a1f3356e39673b84cbada080..e438b54c5d054ed343650cf5a324b7f0dc9f3e17 100644 --- a/authserv/ratelimit.py +++ b/authserv/ratelimit.py @@ -4,9 +4,57 @@ from authserv import app from authserv import protocol +key_sep = ':' + + +def key_from_request(header=None, param=None): + """Build a key from HTTP request headers and query arguments.""" + if not header and not param: + raise Exception('Must set either header or param') + + required_parts = 2 + if header: + required_parts += 1 + if param: + required_parts += 1 + + def _key(prefix, args): + parts = [request.method, request.path] + if header: + value = request.environ.get(header) + if value: + parts.append('%s=%s' % (header, value)) + if param: + if request.method == 'POST': + value = request.form.get(param) + else: + value = request.args.get(param) + if value: + parts.append('%s=%s' % (param, value)) + if len(parts) >= required_parts: + return prefix + key_sep + key_sep.join(parts) + return None + return _key + + +def key_from_args(*args_idx): + """Build a key from the wrapped function's arguments. + + This function should be passed the positional index of the + arguments that will be used to build the key. + """ + def _key(prefix, args): + parts = [args[x] for x in args_idx if args[x]] + if not parts: + return None + return prefix + key_sep + key_sep.join(parts) + return _key + + class RateLimit(object): + """Keep track of request rates using Memcache.""" - prefix = 'authserv/' + prefix = 'authserv/r:' def __init__(self, count, period): self.count = count @@ -14,52 +62,48 @@ class RateLimit(object): def check(self, mc, key): key = self.prefix + key - result = mc.incr(key) + try: + result = mc.incr(key) + except: + result = None if result is None: result = 1 if not mc.add(key, result, time=self.period): - result = mc.incr(key) + try: + result = mc.incr(key) + except: + result = None if result is None: # Memcache is failing. return True return result <= self.count -def ratelimit(header=None, param=None, count=0, period=0): - if not header and not param: - raise Exception('Must set either header or param') +def ratelimit_http_request(key_fn, count=0, period=0): + """Rate limit an HTTP request handler. + + If the rate limit is triggered, the request will generate a HTTP + 503 error. + """ + rl = RateLimit(count, period) - required_parts = 2 - if header: - required_parts += 1 - if param: - required_parts += 1 def decoratorfn(fn): @functools.wraps(fn) def _ratelimit(*args, **kwargs): - parts = [request.method, request.path] - if header: - value = request.environ.get(header) - if value: - parts.append('%s=%s' % (header, value)) - if param: - value = request.form.get(param) - if value: - parts.append('%s=%s' % (param, value)) - if len(parts) >= required_parts: - key = ' '.join(parts) - if not rl.check(app.memcache, key): - app.logger.debug('ratelimited: %s', key) - abort(503) + key = key_fn(fn.__name__, args) + if key and not rl.check(app.memcache, key): + app.logger.debug('ratelimited: %s', key) + abort(503) return fn(*args, **kwargs) return _ratelimit return decoratorfn class BlackList(object): + """Block requests once a rate limit is triggered.""" - prefix = 'bl/' + prefix = 'authserv/b:' def __init__(self, count, period, ttl): self.rl = RateLimit(count, period) @@ -73,21 +117,30 @@ class BlackList(object): def incr(self, mc, key): key = self.prefix + key if not self.rl.check(mc, key): - mc.add(key, 1, time=self.ttl) + mc.set(key, 'true', time=self.ttl) + +def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0): + """Blacklist authentication failures. + + The wrapped function should return one of the error codes from + protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often + than the specified threshold, the request parameters are added to + the blacklist. + + When a request is blacklisted, this function will return + ERR_AUTHENTICATION_FAILURE. + """ -def blacklist(args_idx=None, count=0, period=0, ttl=0): - if not args_idx: - raise Exception('Must set args_idx') bl = BlackList(count, period, ttl) def decoratorfn(fn): @functools.wraps(fn) def _blacklist(*args, **kwargs): - parts = [args[x] for x in args_idx if args[x]] - if not parts: + key = key_fn(fn.__name__, args) + if not key: return fn(*args, **kwargs) - key = ' '.join(parts) + if not bl.check(app.memcache, key): app.logger.debug('blacklisted %s', key) return protocol.ERR_AUTHENTICATION_FAILURE diff --git a/authserv/server.py b/authserv/server.py index 2addcef867fde453071eaf10abc4731f684041f6..c5be265fbecf86695c4357921dfb14b54fa3068d 100644 --- a/authserv/server.py +++ b/authserv/server.py @@ -1,12 +1,12 @@ from authserv import app from authserv import auth from authserv import protocol -from authserv.ratelimit import blacklist, ratelimit +from authserv.ratelimit import * from flask import Flask, request, abort, make_response -@blacklist([0], count=5, period=600, ttl=43200) -@blacklist([4], count=5, period=600, ttl=43200) +@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200) +@blacklist_on_auth_failure(key_from_args(4), count=5, period=600, ttl=43200) def _auth(username, service, password, otp_token, source_ip): user = app.userdb.get_user(username, service) if not user: @@ -15,8 +15,10 @@ def _auth(username, service, password, otp_token, source_ip): @app.route('/api/1/auth', methods=('POST',)) -@ratelimit(count=10, period=60, header='HTTP_X_FORWARDED_FOR') -@ratelimit(count=10, period=60, param='username') +@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'), + count=10, period=60) +@ratelimit_http_request(key_from_request(param='username'), + count=10, period=60) def do_auth(): service = request.form.get('service') username = request.form.get('username') diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py index d05b74e85bf5206a4d6d298daa3e122a43797e26..8ef8f3674f631ac56a26fdb68d552fae1a3f0737 100644 --- a/authserv/test/__init__.py +++ b/authserv/test/__init__.py @@ -19,6 +19,9 @@ class FakeMemcache(object): if result and result[1] > self.t(): return result[0] + def set(self, key, value, time=0): + self.data[key] = (value, self.t() + time) + def incr(self, key): now = self.t() result = self.data.get(key) diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py index c6030bad73c4ee02f90318df496e294bbc4ebff2..120d26c42a4d8ec399230bb025b8f0cab05e0277 100644 --- a/authserv/test/test_ratelimit.py +++ b/authserv/test/test_ratelimit.py @@ -28,3 +28,24 @@ class RateLimitTest(unittest.TestCase): if self.rl.check(self.mc, 'key'): ok += 1 self.assertEquals(100, ok) + + +class BlackListTest(unittest.TestCase): + + def setUp(self): + self.tick = 0 + def _time(): + return self.tick + self.mc = FakeMemcache(_time) + self.bl = BlackList(100, 10, 3600) + + def test_block(self): + for i in xrange(100): + self.bl.incr(self.mc, 'key') + self.assertTrue(self.bl.check(self.mc, 'key')) + # 101th request is blocked + self.bl.incr(self.mc, 'key') + self.assertFalse(self.bl.check(self.mc, 'key')) + # check expiration of the block + self.tick = 3601 + self.assertTrue(self.bl.check(self.mc, 'key'))