From 9f2bc780024cf669eb6394f7aea888de9e38a2e0 Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Sun, 14 Sep 2014 08:14:18 +0100 Subject: [PATCH] refactor ratelimiting and blacklists Whitelisting source IPs now works properly. --- authserv/app_common.py | 53 ++++++++++--- authserv/app_main.py | 9 +-- authserv/app_nginx.py | 4 +- authserv/ratelimit.py | 132 ++++++++------------------------ authserv/test/test_app_main.py | 23 +++--- authserv/test/test_ratelimit.py | 10 +-- 6 files changed, 100 insertions(+), 131 deletions(-) diff --git a/authserv/app_common.py b/authserv/app_common.py index a0db00f..2188aec 100644 --- a/authserv/app_common.py +++ b/authserv/app_common.py @@ -1,4 +1,4 @@ -from flask import current_app +from flask import abort, current_app from authserv import auth from authserv import protocol from authserv.ratelimit import * @@ -6,13 +6,48 @@ from authserv.ratelimit import * _blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'blacklisted', None) -@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200, - bl_return_value=_blacklisted) -@blacklist_on_auth_failure(key_from_args(5), count=5, period=600, ttl=43200, - check_wl=True, bl_return_value=_blacklisted) +def check_ratelimit(request, username, source_ip): + if current_app.config.get('ENABLE_RATELIMIT'): + if not ratelimit_http_request( + request, username, tag='u', + count=current_app.config.get('RATELIMIT_USER_COUNT', 10), + period=current_app.config.get('RATELIMIT_USER_PERIOD', 60)): + abort(503) + if (source_ip + and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST')) + and not ratelimit_http_request( + request, source_ip, tag='ip', + count=current_app.config.get('RATELIMIT_SOURCEIP_COUNT', 10), + period=current_app.config.get('RATELIMIT_SOURCEIP_PERIOD', 60))): + abort(503) + + def do_auth(username, service, shard, password, otp_token, source_ip): + bl = AuthBlackList(current_app.config.get('BLACKLIST_COUNT', 5), + current_app.config.get('BLACKLIST_PERIOD', 600), + current_app.config.get('BLACKLIST_TIME', 6*3600)) + if current_app.config.get('ENABLE_BLACKLIST'): + if bl.is_blacklisted('u', username): + return _blacklisted + if (source_ip + and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST')) + and bl.is_blacklisted('ip', source_ip)): + return _blacklisted + + retval = protocol.ERR_AUTHENTICATION_FAILURE + errmsg = 'user does not exist' + shard = None user = current_app.userdb.get_user(username, service, shard) - if not user: - return protocol.ERR_AUTHENTICATION_FAILURE, 'user does not exist', None - result, errmsg = auth.authenticate(user, service, password, otp_token, source_ip) - return (result, errmsg, user.get_shard()) + if user: + retval, errmsg = auth.authenticate( + user, service, password, otp_token, source_ip) + shard = user.get_shard() + + if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'): + if user: + bl.auth_failure('u', username) + if (source_ip + and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))): + bl.auth_failure('ip', source_ip) + + return (retval, errmsg, shard) diff --git a/authserv/app_main.py b/authserv/app_main.py index 3ee6df6..922dc4a 100644 --- a/authserv/app_main.py +++ b/authserv/app_main.py @@ -1,8 +1,7 @@ from flask import Flask, request, abort, make_response from authserv import auth from authserv import protocol -from authserv.ratelimit import * -from authserv.app_common import do_auth +from authserv.app_common import do_auth, check_ratelimit app = Flask(__name__) @@ -18,10 +17,6 @@ app = Flask(__name__) # address. So we're practically only going to use X-Forwarded-For for # requests that reach our frontends via HTTP. @app.route('/api/1/auth', methods=('POST',)) -@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'), - count=10, period=60) -@ratelimit_http_request(key_from_request(param='username'), - count=10, period=60) def api_auth(): service = request.form.get('service') username = request.form.get('username') @@ -33,6 +28,8 @@ def api_auth(): if not service or not username: abort(400) + check_ratelimit(request, username, source_ip) + try: result, errmsg, unused_shard = do_auth( username, service, shard, password, otp_token, source_ip) diff --git a/authserv/app_nginx.py b/authserv/app_nginx.py index 923bb5c..905c3c9 100644 --- a/authserv/app_nginx.py +++ b/authserv/app_nginx.py @@ -2,7 +2,7 @@ import socket import threading import urllib from flask import Flask, request, abort, make_response -from authserv.app_common import do_auth +from authserv.app_common import do_auth, check_ratelimit app = Flask(__name__) @@ -31,6 +31,8 @@ def do_nginx_http_auth(): except ValueError: n_attempt = 1 + check_ratelimit(request, username, source_ip) + try: auth_status, errmsg, shard = do_auth( username, service, None, password, None, source_ip) diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py index 9b0bb02..dcba833 100644 --- a/authserv/ratelimit.py +++ b/authserv/ratelimit.py @@ -1,21 +1,18 @@ import functools import re +import threading from flask import abort, request, current_app from authserv import protocol key_sep = '/' -WHITELIST = [ +DEFAULT_WHITELIST = [ '127.0.0.1', '::1', '172.16.1.*', ] -_whitelist_rx = [ - re.compile('%s%s$' % (key_sep, x.replace('.', r'\.').replace('*', '.*'))) - for x in WHITELIST] - def _tostr(s): if isinstance(s, unicode): @@ -23,55 +20,24 @@ def _tostr(s): return s -def whitelisted(ip): - for rx in _whitelist_rx: - if rx.search(ip): - return True - return False +_rxcache = {} +_rxcache_lock = threading.Lock() -def key_from_request(header=None, param=None): - """Build a key from HTTP request headers and query arguments.""" - if not header and not param: - raise Exception('Must set either header or param') - - required_parts = 2 - if header: - required_parts += 1 - if param: - required_parts += 1 - - def _key(prefix, args): - parts = [request.method, request.path] - if header: - value = request.environ.get(header) - if value: - parts.append('%s=%s' % (header, value)) - if param: - if request.method == 'POST': - value = request.form.get(param) - else: - value = request.args.get(param) - if value: - parts.append('%s=%s' % (param, value)) - if len(parts) >= required_parts: - return prefix + key_sep + key_sep.join(parts) - return None - return _key - - -def key_from_args(*args_idx): - """Build a key from the wrapped function's arguments. - - This function should be passed the positional index of the - arguments that will be used to build the key. - """ - def _key(prefix, args): - parts = [args[x] for x in args_idx if args[x]] - if not parts: - return None - return prefix + key_sep + key_sep.join(parts) - return _key +def _rxcompile(s): + with _rxcache_lock: + if s not in _rxcache: + _rxcache[s] = re.compile('^%s$' % s.replace('.', r'\.').replace('*', '.*')) + return _rxcache[s] + + +def whitelisted(value, wl): + if not wl: + wl = DEFAULT_WHITELIST + for rx in wl: + if _rxcompile(rx).search(value): + return True + return False class RateLimit(object): @@ -102,27 +68,15 @@ class RateLimit(object): return result <= self.count -def ratelimit_http_request(key_fn, count=0, period=0): +def ratelimit_http_request(req, value, tag='', count=0, period=0): """Rate limit an HTTP request handler. - If the rate limit is triggered, the request will generate a HTTP - 503 error. + Returns False if the specified request triggers the rate limit. """ rl = RateLimit(count, period) - - def decoratorfn(fn): - @functools.wraps(fn) - def _ratelimit(*args, **kwargs): - if not current_app.config.get('ENABLE_RATELIMIT'): - return fn(*args, **kwargs) - key = _tostr(key_fn(fn.__name__, args)) - if key and not rl.check(current_app.memcache, key): - current_app.logger.debug('ratelimited: %s', key) - abort(503) - return fn(*args, **kwargs) - return _ratelimit - return decoratorfn + key = key_sep.join([req.method, req.path, value]) + return rl.check(current_app.memcache, key) class BlackList(object): @@ -145,37 +99,17 @@ class BlackList(object): mc.set(key, 'true', time=self.ttl) -def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0, check_wl=False, - bl_return_value=protocol.ERR_AUTHENTICATION_FAILURE): - """Blacklist authentication failures. +class AuthBlackList(BlackList): - The wrapped function should return one of the error codes from - protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often - than the specified threshold, the request parameters are added to - the blacklist. + def is_blacklisted(self, tag, value): + if not value: + return False + key = key_sep.join([tag, value]) + return not self.check(current_app.memcache, key) - When a request is blacklisted, this function will return - ERR_AUTHENTICATION_FAILURE. - """ + def auth_failure(self, tag, value): + if not value: + return + key = key_sep.join([tag, value]) + self.incr(current_app.memcache, key) - bl = BlackList(count, period, ttl) - - def decoratorfn(fn): - @functools.wraps(fn) - def _blacklist(*args, **kwargs): - if not current_app.config.get('ENABLE_BLACKLIST'): - return fn(*args, **kwargs) - key = _tostr(key_fn(fn.__name__, args)) - if not key: - return fn(*args, **kwargs) - - if ((not check_wl or not whitelisted(key)) - and not bl.check(current_app.memcache, key)): - current_app.logger.info('blacklisted %s', key) - return bl_return_value - result = fn(*args, **kwargs) - if result != protocol.OK: - bl.incr(current_app.memcache, key) - return result - return _blacklist - return decoratorfn diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py index 2ff5591..c637bf5 100644 --- a/authserv/test/test_app_main.py +++ b/authserv/test/test_app_main.py @@ -23,6 +23,7 @@ class ServerTest(unittest.TestCase): 'TESTING': True, 'DEBUG': True, 'ENABLE_BLACKLIST': True, + 'ENABLE_RATELIMIT': True, }) self.app = app.test_client() @@ -57,29 +58,29 @@ class ServerTest(unittest.TestCase): for i in xrange(n): self.users['user%d' % i] = FakeUser('user%d' % i, 'pass') - def disabledtest_ratelimit_by_client_ip(self): + def test_ratelimit_by_client_ip(self): n = 20 ok = 0 self._create_many_users(n) for i in xrange(n): response = self.app.post(URL, data={ - 'username': 'user%d' % i, - 'password': 'pass', - 'service': 'svc'}, headers={ - 'X-Forwarded-For': '1.2.3.4'}) + 'username': 'user%d' % i, + 'password': 'pass', + 'service': 'svc', + 'source_ip': '1.2.3.4'}) if response.status_code == 200: ok += 1 self.assertEquals(10, ok) - def disabledtest_ratelimit_by_username(self): + def test_ratelimit_by_username(self): n = 20 ok = 0 for i in xrange(n): response = self.app.post(URL, data={ - 'username': 'user', - 'password': 'pass', - 'service': 'svc'}, headers={ - 'X-Forwarded-For': '1.2.3.%d' %i}) + 'username': 'user', + 'password': 'pass', + 'service': 'svc', + 'source_ip': '1.2.3.%d' %i}) if response.status_code == 200: ok += 1 self.assertEquals(10, ok) @@ -87,7 +88,7 @@ class ServerTest(unittest.TestCase): def test_ratelimit_ignores_unset_fields(self): # This test will fail if the @ratelimit decorator does not # skip its check if one of the fields is unset (in this case, - # the X-Forwarded-For header). + # 'source_ip'). n = 20 ok = 0 self._create_many_users(n) diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py index 546c354..0d74c1f 100644 --- a/authserv/test/test_ratelimit.py +++ b/authserv/test/test_ratelimit.py @@ -53,9 +53,9 @@ class BlackListTest(unittest.TestCase): class WhitelistTest(unittest.TestCase): - def test_whitelist(self): - self.assertTrue(whitelisted('auth/127.0.0.1')) - self.assertTrue(whitelisted('auth/172.16.1.5')) - self.assertTrue(whitelisted('auth/::1')) + def test_default_whitelist(self): + self.assertTrue(whitelisted('127.0.0.1', None)) + self.assertTrue(whitelisted('172.16.1.5', None)) + self.assertTrue(whitelisted('::1', None)) - self.assertFalse(whitelisted('auth/1.2.3.4')) + self.assertFalse(whitelisted('1.2.3.4', None)) -- GitLab