From 9f2bc780024cf669eb6394f7aea888de9e38a2e0 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 14 Sep 2014 08:14:18 +0100
Subject: [PATCH] refactor ratelimiting and blacklists

Whitelisting source IPs now works properly.
---
 authserv/app_common.py          |  53 ++++++++++---
 authserv/app_main.py            |   9 +--
 authserv/app_nginx.py           |   4 +-
 authserv/ratelimit.py           | 132 ++++++++------------------------
 authserv/test/test_app_main.py  |  23 +++---
 authserv/test/test_ratelimit.py |  10 +--
 6 files changed, 100 insertions(+), 131 deletions(-)

diff --git a/authserv/app_common.py b/authserv/app_common.py
index a0db00f..2188aec 100644
--- a/authserv/app_common.py
+++ b/authserv/app_common.py
@@ -1,4 +1,4 @@
-from flask import current_app
+from flask import abort, current_app
 from authserv import auth
 from authserv import protocol
 from authserv.ratelimit import *
@@ -6,13 +6,48 @@ from authserv.ratelimit import *
 _blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'blacklisted', None)
 
 
-@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200,
-                           bl_return_value=_blacklisted)
-@blacklist_on_auth_failure(key_from_args(5), count=5, period=600, ttl=43200,
-                           check_wl=True, bl_return_value=_blacklisted)
+def check_ratelimit(request, username, source_ip):
+    if current_app.config.get('ENABLE_RATELIMIT'):
+        if not ratelimit_http_request(
+                request, username, tag='u',
+                count=current_app.config.get('RATELIMIT_USER_COUNT', 10),
+                period=current_app.config.get('RATELIMIT_USER_PERIOD', 60)):
+            abort(503)
+        if (source_ip
+            and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))
+            and not ratelimit_http_request(
+                request, source_ip, tag='ip',
+                count=current_app.config.get('RATELIMIT_SOURCEIP_COUNT', 10),
+                period=current_app.config.get('RATELIMIT_SOURCEIP_PERIOD', 60))):
+            abort(503)
+
+
 def do_auth(username, service, shard, password, otp_token, source_ip):
+    bl = AuthBlackList(current_app.config.get('BLACKLIST_COUNT', 5),
+                       current_app.config.get('BLACKLIST_PERIOD', 600),
+                       current_app.config.get('BLACKLIST_TIME', 6*3600))
+    if current_app.config.get('ENABLE_BLACKLIST'):
+        if bl.is_blacklisted('u', username):
+            return _blacklisted
+        if (source_ip
+            and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))
+            and bl.is_blacklisted('ip', source_ip)):
+            return _blacklisted
+
+    retval = protocol.ERR_AUTHENTICATION_FAILURE
+    errmsg = 'user does not exist'
+    shard = None
     user = current_app.userdb.get_user(username, service, shard)
-    if not user:
-        return protocol.ERR_AUTHENTICATION_FAILURE, 'user does not exist', None
-    result, errmsg = auth.authenticate(user, service, password, otp_token, source_ip)
-    return (result, errmsg, user.get_shard())
+    if user:
+        retval, errmsg = auth.authenticate(
+            user, service, password, otp_token, source_ip)
+        shard = user.get_shard()
+
+    if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'):
+        if user:
+            bl.auth_failure('u', username)
+        if (source_ip
+            and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))):
+            bl.auth_failure('ip', source_ip)
+
+    return (retval, errmsg, shard)
diff --git a/authserv/app_main.py b/authserv/app_main.py
index 3ee6df6..922dc4a 100644
--- a/authserv/app_main.py
+++ b/authserv/app_main.py
@@ -1,8 +1,7 @@
 from flask import Flask, request, abort, make_response
 from authserv import auth
 from authserv import protocol
-from authserv.ratelimit import *
-from authserv.app_common import do_auth
+from authserv.app_common import do_auth, check_ratelimit
 
 app = Flask(__name__)
 
@@ -18,10 +17,6 @@ app = Flask(__name__)
 # address. So we're practically only going to use X-Forwarded-For for
 # requests that reach our frontends via HTTP.
 @app.route('/api/1/auth', methods=('POST',))
-@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
-                        count=10, period=60)
-@ratelimit_http_request(key_from_request(param='username'),
-                        count=10, period=60)
 def api_auth():
     service = request.form.get('service')
     username = request.form.get('username')
@@ -33,6 +28,8 @@ def api_auth():
     if not service or not username:
         abort(400)
 
+    check_ratelimit(request, username, source_ip)
+
     try:
         result, errmsg, unused_shard = do_auth(
             username, service, shard, password, otp_token, source_ip)
diff --git a/authserv/app_nginx.py b/authserv/app_nginx.py
index 923bb5c..905c3c9 100644
--- a/authserv/app_nginx.py
+++ b/authserv/app_nginx.py
@@ -2,7 +2,7 @@ import socket
 import threading
 import urllib
 from flask import Flask, request, abort, make_response
-from authserv.app_common import do_auth
+from authserv.app_common import do_auth, check_ratelimit
 
 app = Flask(__name__)
 
@@ -31,6 +31,8 @@ def do_nginx_http_auth():
     except ValueError:
         n_attempt = 1
 
+    check_ratelimit(request, username, source_ip)
+
     try:
         auth_status, errmsg, shard = do_auth(
             username, service, None, password, None, source_ip)
diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py
index 9b0bb02..dcba833 100644
--- a/authserv/ratelimit.py
+++ b/authserv/ratelimit.py
@@ -1,21 +1,18 @@
 import functools
 import re
+import threading
 from flask import abort, request, current_app
 from authserv import protocol
 
 
 key_sep = '/'
 
-WHITELIST = [
+DEFAULT_WHITELIST = [
     '127.0.0.1',
     '::1',
     '172.16.1.*',
 ]
 
-_whitelist_rx = [
-    re.compile('%s%s$' % (key_sep, x.replace('.', r'\.').replace('*', '.*')))
-    for x in WHITELIST]
-
 
 def _tostr(s):
     if isinstance(s, unicode):
@@ -23,55 +20,24 @@ def _tostr(s):
     return s
 
 
-def whitelisted(ip):
-    for rx in _whitelist_rx:
-        if rx.search(ip):
-            return True
-    return False
+_rxcache = {}
+_rxcache_lock = threading.Lock()
 
 
-def key_from_request(header=None, param=None):
-    """Build a key from HTTP request headers and query arguments."""
-    if not header and not param:
-        raise Exception('Must set either header or param')
-
-    required_parts = 2
-    if header:
-        required_parts += 1
-    if param:
-        required_parts += 1
-
-    def _key(prefix, args):
-        parts = [request.method, request.path]
-        if header:
-            value = request.environ.get(header)
-            if value:
-                parts.append('%s=%s' % (header, value))
-        if param:
-            if request.method == 'POST':
-                value = request.form.get(param)
-            else:
-                value = request.args.get(param)
-            if value:
-                parts.append('%s=%s' % (param, value))
-        if len(parts) >= required_parts:
-            return prefix + key_sep + key_sep.join(parts)
-        return None
-    return _key
-
-
-def key_from_args(*args_idx):
-    """Build a key from the wrapped function's arguments.
-
-    This function should be passed the positional index of the
-    arguments that will be used to build the key.
-    """
-    def _key(prefix, args):
-        parts = [args[x] for x in args_idx if args[x]]
-        if not parts:
-            return None
-        return prefix + key_sep + key_sep.join(parts)
-    return _key
+def _rxcompile(s):
+    with _rxcache_lock:
+        if s not in _rxcache:
+            _rxcache[s] = re.compile('^%s$' % s.replace('.', r'\.').replace('*', '.*'))
+        return _rxcache[s]
+
+
+def whitelisted(value, wl):
+    if not wl:
+        wl = DEFAULT_WHITELIST
+    for rx in wl:
+        if _rxcompile(rx).search(value):
+            return True
+    return False
 
 
 class RateLimit(object):
@@ -102,27 +68,15 @@ class RateLimit(object):
         return result <= self.count
 
 
-def ratelimit_http_request(key_fn, count=0, period=0):
+def ratelimit_http_request(req, value, tag='', count=0, period=0):
     """Rate limit an HTTP request handler.
 
-    If the rate limit is triggered, the request will generate a HTTP
-    503 error.
+    Returns False if the specified request triggers the rate limit.
     """
 
     rl = RateLimit(count, period)
-
-    def decoratorfn(fn):
-        @functools.wraps(fn)
-        def _ratelimit(*args, **kwargs):
-            if not current_app.config.get('ENABLE_RATELIMIT'):
-                return fn(*args, **kwargs)
-            key = _tostr(key_fn(fn.__name__, args))
-            if key and not rl.check(current_app.memcache, key):
-                current_app.logger.debug('ratelimited: %s', key)
-                abort(503)
-            return fn(*args, **kwargs)
-        return _ratelimit
-    return decoratorfn
+    key = key_sep.join([req.method, req.path, value])
+    return rl.check(current_app.memcache, key)
 
 
 class BlackList(object):
@@ -145,37 +99,17 @@ class BlackList(object):
             mc.set(key, 'true', time=self.ttl)
 
 
-def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0, check_wl=False,
-                              bl_return_value=protocol.ERR_AUTHENTICATION_FAILURE):
-    """Blacklist authentication failures.
+class AuthBlackList(BlackList):
 
-    The wrapped function should return one of the error codes from
-    protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often
-    than the specified threshold, the request parameters are added to
-    the blacklist.
+    def is_blacklisted(self, tag, value):
+        if not value:
+            return False
+        key = key_sep.join([tag, value])
+        return not self.check(current_app.memcache, key)
 
-    When a request is blacklisted, this function will return
-    ERR_AUTHENTICATION_FAILURE.
-    """
+    def auth_failure(self, tag, value):
+        if not value:
+            return
+        key = key_sep.join([tag, value])
+        self.incr(current_app.memcache, key)
 
-    bl = BlackList(count, period, ttl)
-
-    def decoratorfn(fn):
-        @functools.wraps(fn)
-        def _blacklist(*args, **kwargs):
-            if not current_app.config.get('ENABLE_BLACKLIST'):
-                return fn(*args, **kwargs)
-            key = _tostr(key_fn(fn.__name__, args))
-            if not key:
-                return fn(*args, **kwargs)
-
-            if ((not check_wl or not whitelisted(key))
-                and not bl.check(current_app.memcache, key)):
-                current_app.logger.info('blacklisted %s', key)
-                return bl_return_value
-            result = fn(*args, **kwargs)
-            if result != protocol.OK:
-                bl.incr(current_app.memcache, key)
-            return result
-        return _blacklist
-    return decoratorfn
diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py
index 2ff5591..c637bf5 100644
--- a/authserv/test/test_app_main.py
+++ b/authserv/test/test_app_main.py
@@ -23,6 +23,7 @@ class ServerTest(unittest.TestCase):
                 'TESTING': True,
                 'DEBUG': True,
                 'ENABLE_BLACKLIST': True,
+                'ENABLE_RATELIMIT': True,
                 })
         self.app = app.test_client()
 
@@ -57,29 +58,29 @@ class ServerTest(unittest.TestCase):
         for i in xrange(n):
             self.users['user%d' % i] = FakeUser('user%d' % i, 'pass')
 
-    def disabledtest_ratelimit_by_client_ip(self):
+    def test_ratelimit_by_client_ip(self):
         n = 20
         ok = 0
         self._create_many_users(n)
         for i in xrange(n):
             response = self.app.post(URL, data={
-                    'username': 'user%d' % i,
-                    'password': 'pass',
-                    'service': 'svc'}, headers={
-                    'X-Forwarded-For': '1.2.3.4'})
+                'username': 'user%d' % i,
+                'password': 'pass',
+                'service': 'svc',
+                'source_ip': '1.2.3.4'})
             if response.status_code == 200:
                 ok += 1
         self.assertEquals(10, ok)
 
-    def disabledtest_ratelimit_by_username(self):
+    def test_ratelimit_by_username(self):
         n = 20
         ok = 0
         for i in xrange(n):
             response = self.app.post(URL, data={
-                    'username': 'user',
-                    'password': 'pass',
-                    'service': 'svc'}, headers={
-                    'X-Forwarded-For': '1.2.3.%d' %i})
+                'username': 'user',
+                'password': 'pass',
+                'service': 'svc',
+                'source_ip': '1.2.3.%d' %i})
             if response.status_code == 200:
                 ok += 1
         self.assertEquals(10, ok)
@@ -87,7 +88,7 @@ class ServerTest(unittest.TestCase):
     def test_ratelimit_ignores_unset_fields(self):
         # This test will fail if the @ratelimit decorator does not
         # skip its check if one of the fields is unset (in this case,
-        # the X-Forwarded-For header).
+        # 'source_ip').
         n = 20
         ok = 0
         self._create_many_users(n)
diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py
index 546c354..0d74c1f 100644
--- a/authserv/test/test_ratelimit.py
+++ b/authserv/test/test_ratelimit.py
@@ -53,9 +53,9 @@ class BlackListTest(unittest.TestCase):
 
 class WhitelistTest(unittest.TestCase):
 
-    def test_whitelist(self):
-        self.assertTrue(whitelisted('auth/127.0.0.1'))
-        self.assertTrue(whitelisted('auth/172.16.1.5'))
-        self.assertTrue(whitelisted('auth/::1'))
+    def test_default_whitelist(self):
+        self.assertTrue(whitelisted('127.0.0.1', None))
+        self.assertTrue(whitelisted('172.16.1.5', None))
+        self.assertTrue(whitelisted('::1', None))
 
-        self.assertFalse(whitelisted('auth/1.2.3.4'))
+        self.assertFalse(whitelisted('1.2.3.4', None))
-- 
GitLab