From eab7d283a3c3117a37d96a0d33142f6ce6b27354 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Tue, 18 Mar 2014 08:41:02 +0000
Subject: [PATCH] refactored ratelimiting code

---
 authserv/ratelimit.py           | 119 +++++++++++++++++++++++---------
 authserv/server.py              |  12 ++--
 authserv/test/__init__.py       |   3 +
 authserv/test/test_ratelimit.py |  21 ++++++
 4 files changed, 117 insertions(+), 38 deletions(-)

diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py
index abcb017..e438b54 100644
--- a/authserv/ratelimit.py
+++ b/authserv/ratelimit.py
@@ -4,9 +4,57 @@ from authserv import app
 from authserv import protocol
 
 
+key_sep = ':'
+
+
+def key_from_request(header=None, param=None):
+    """Build a key from HTTP request headers and query arguments."""
+    if not header and not param:
+        raise Exception('Must set either header or param')
+
+    required_parts = 2
+    if header:
+        required_parts += 1
+    if param:
+        required_parts += 1
+
+    def _key(prefix, args):
+        parts = [request.method, request.path]
+        if header:
+            value = request.environ.get(header)
+            if value:
+                parts.append('%s=%s' % (header, value))
+        if param:
+            if request.method == 'POST':
+                value = request.form.get(param)
+            else:
+                value = request.args.get(param)
+            if value:
+                parts.append('%s=%s' % (param, value))
+        if len(parts) >= required_parts:
+            return prefix + key_sep + key_sep.join(parts)
+        return None
+    return _key
+
+
+def key_from_args(*args_idx):
+    """Build a key from the wrapped function's arguments.
+
+    This function should be passed the positional index of the
+    arguments that will be used to build the key.
+    """
+    def _key(prefix, args):
+        parts = [args[x] for x in args_idx if args[x]]
+        if not parts:
+            return None
+        return prefix + key_sep + key_sep.join(parts)
+    return _key
+
+
 class RateLimit(object):
+    """Keep track of request rates using Memcache."""
 
-    prefix = 'authserv/'
+    prefix = 'authserv/r:'
 
     def __init__(self, count, period):
         self.count = count
@@ -14,52 +62,48 @@ class RateLimit(object):
 
     def check(self, mc, key):
         key = self.prefix + key
-        result = mc.incr(key)
+        try:
+            result = mc.incr(key)
+        except:
+            result = None
         if result is None:
             result = 1
             if not mc.add(key, result, time=self.period):
-                result = mc.incr(key)
+                try:
+                    result = mc.incr(key)
+                except:
+                    result = None
                 if result is None:
                     # Memcache is failing.
                     return True
         return result <= self.count
 
 
-def ratelimit(header=None, param=None, count=0, period=0):
-    if not header and not param:
-        raise Exception('Must set either header or param')
+def ratelimit_http_request(key_fn, count=0, period=0):
+    """Rate limit an HTTP request handler.
+
+    If the rate limit is triggered, the request will generate a HTTP
+    503 error.
+    """
+
     rl = RateLimit(count, period)
-    required_parts = 2
-    if header:
-        required_parts += 1
-    if param:
-        required_parts += 1
 
     def decoratorfn(fn):
         @functools.wraps(fn)
         def _ratelimit(*args, **kwargs):
-            parts = [request.method, request.path]
-            if header:
-                value = request.environ.get(header)
-                if value:
-                    parts.append('%s=%s' % (header, value))
-            if param:
-                value = request.form.get(param)
-                if value:
-                    parts.append('%s=%s' % (param, value))
-            if len(parts) >= required_parts:
-                key = ' '.join(parts)
-                if not rl.check(app.memcache, key):
-                    app.logger.debug('ratelimited: %s', key)
-                    abort(503)
+            key = key_fn(fn.__name__, args)
+            if key and not rl.check(app.memcache, key):
+                app.logger.debug('ratelimited: %s', key)
+                abort(503)
             return fn(*args, **kwargs)
         return _ratelimit
     return decoratorfn
 
 
 class BlackList(object):
+    """Block requests once a rate limit is triggered."""
 
-    prefix = 'bl/'
+    prefix = 'authserv/b:'
 
     def __init__(self, count, period, ttl):
         self.rl = RateLimit(count, period)
@@ -73,21 +117,30 @@ class BlackList(object):
     def incr(self, mc, key):
         key = self.prefix + key
         if not self.rl.check(mc, key):
-            mc.add(key, 1, time=self.ttl)
+            mc.set(key, 'true', time=self.ttl)
+
 
+def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0):
+    """Blacklist authentication failures.
+
+    The wrapped function should return one of the error codes from
+    protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often
+    than the specified threshold, the request parameters are added to
+    the blacklist.
+
+    When a request is blacklisted, this function will return
+    ERR_AUTHENTICATION_FAILURE.
+    """
 
-def blacklist(args_idx=None, count=0, period=0, ttl=0):
-    if not args_idx:
-        raise Exception('Must set args_idx')
     bl = BlackList(count, period, ttl)
 
     def decoratorfn(fn):
         @functools.wraps(fn)
         def _blacklist(*args, **kwargs):
-            parts = [args[x] for x in args_idx if args[x]]
-            if not parts:
+            key = key_fn(fn.__name__, args)
+            if not key:
                 return fn(*args, **kwargs)
-            key = ' '.join(parts)
+
             if not bl.check(app.memcache, key):
                 app.logger.debug('blacklisted %s', key)
                 return protocol.ERR_AUTHENTICATION_FAILURE
diff --git a/authserv/server.py b/authserv/server.py
index 2addcef..c5be265 100644
--- a/authserv/server.py
+++ b/authserv/server.py
@@ -1,12 +1,12 @@
 from authserv import app
 from authserv import auth
 from authserv import protocol
-from authserv.ratelimit import blacklist, ratelimit
+from authserv.ratelimit import *
 from flask import Flask, request, abort, make_response
 
 
-@blacklist([0], count=5, period=600, ttl=43200)
-@blacklist([4], count=5, period=600, ttl=43200)
+@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200)
+@blacklist_on_auth_failure(key_from_args(4), count=5, period=600, ttl=43200)
 def _auth(username, service, password, otp_token, source_ip):
     user = app.userdb.get_user(username, service)
     if not user:
@@ -15,8 +15,10 @@ def _auth(username, service, password, otp_token, source_ip):
 
 
 @app.route('/api/1/auth', methods=('POST',))
-@ratelimit(count=10, period=60, header='HTTP_X_FORWARDED_FOR')
-@ratelimit(count=10, period=60, param='username')
+@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
+                        count=10, period=60)
+@ratelimit_http_request(key_from_request(param='username'),
+                        count=10, period=60)
 def do_auth():
     service = request.form.get('service')
     username = request.form.get('username')
diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py
index d05b74e..8ef8f36 100644
--- a/authserv/test/__init__.py
+++ b/authserv/test/__init__.py
@@ -19,6 +19,9 @@ class FakeMemcache(object):
         if result and result[1] > self.t():
             return result[0]
 
+    def set(self, key, value, time=0):
+        self.data[key] = (value, self.t() + time)
+
     def incr(self, key):
         now = self.t()
         result = self.data.get(key)
diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py
index c6030ba..120d26c 100644
--- a/authserv/test/test_ratelimit.py
+++ b/authserv/test/test_ratelimit.py
@@ -28,3 +28,24 @@ class RateLimitTest(unittest.TestCase):
             if self.rl.check(self.mc, 'key'):
                 ok += 1
         self.assertEquals(100, ok)
+
+
+class BlackListTest(unittest.TestCase):
+
+    def setUp(self):
+        self.tick = 0
+        def _time():
+            return self.tick
+        self.mc = FakeMemcache(_time)
+        self.bl = BlackList(100, 10, 3600)
+
+    def test_block(self):
+        for i in xrange(100):
+            self.bl.incr(self.mc, 'key')
+            self.assertTrue(self.bl.check(self.mc, 'key'))
+        # 101th request is blocked
+        self.bl.incr(self.mc, 'key')
+        self.assertFalse(self.bl.check(self.mc, 'key'))
+        # check expiration of the block
+        self.tick = 3601
+        self.assertTrue(self.bl.check(self.mc, 'key'))
-- 
GitLab