Skip to content
Snippets Groups Projects
Commit eab7d283 authored by ale's avatar ale
Browse files

refactored ratelimiting code

parent 27d30c44
No related branches found
No related tags found
No related merge requests found
......@@ -4,9 +4,57 @@ from authserv import app
from authserv import protocol
key_sep = ':'
def key_from_request(header=None, param=None):
"""Build a key from HTTP request headers and query arguments."""
if not header and not param:
raise Exception('Must set either header or param')
required_parts = 2
if header:
required_parts += 1
if param:
required_parts += 1
def _key(prefix, args):
parts = [request.method, request.path]
if header:
value = request.environ.get(header)
if value:
parts.append('%s=%s' % (header, value))
if param:
if request.method == 'POST':
value = request.form.get(param)
else:
value = request.args.get(param)
if value:
parts.append('%s=%s' % (param, value))
if len(parts) >= required_parts:
return prefix + key_sep + key_sep.join(parts)
return None
return _key
def key_from_args(*args_idx):
"""Build a key from the wrapped function's arguments.
This function should be passed the positional index of the
arguments that will be used to build the key.
"""
def _key(prefix, args):
parts = [args[x] for x in args_idx if args[x]]
if not parts:
return None
return prefix + key_sep + key_sep.join(parts)
return _key
class RateLimit(object):
"""Keep track of request rates using Memcache."""
prefix = 'authserv/'
prefix = 'authserv/r:'
def __init__(self, count, period):
self.count = count
......@@ -14,52 +62,48 @@ class RateLimit(object):
def check(self, mc, key):
key = self.prefix + key
result = mc.incr(key)
try:
result = mc.incr(key)
except:
result = None
if result is None:
result = 1
if not mc.add(key, result, time=self.period):
result = mc.incr(key)
try:
result = mc.incr(key)
except:
result = None
if result is None:
# Memcache is failing.
return True
return result <= self.count
def ratelimit(header=None, param=None, count=0, period=0):
if not header and not param:
raise Exception('Must set either header or param')
def ratelimit_http_request(key_fn, count=0, period=0):
"""Rate limit an HTTP request handler.
If the rate limit is triggered, the request will generate a HTTP
503 error.
"""
rl = RateLimit(count, period)
required_parts = 2
if header:
required_parts += 1
if param:
required_parts += 1
def decoratorfn(fn):
@functools.wraps(fn)
def _ratelimit(*args, **kwargs):
parts = [request.method, request.path]
if header:
value = request.environ.get(header)
if value:
parts.append('%s=%s' % (header, value))
if param:
value = request.form.get(param)
if value:
parts.append('%s=%s' % (param, value))
if len(parts) >= required_parts:
key = ' '.join(parts)
if not rl.check(app.memcache, key):
app.logger.debug('ratelimited: %s', key)
abort(503)
key = key_fn(fn.__name__, args)
if key and not rl.check(app.memcache, key):
app.logger.debug('ratelimited: %s', key)
abort(503)
return fn(*args, **kwargs)
return _ratelimit
return decoratorfn
class BlackList(object):
"""Block requests once a rate limit is triggered."""
prefix = 'bl/'
prefix = 'authserv/b:'
def __init__(self, count, period, ttl):
self.rl = RateLimit(count, period)
......@@ -73,21 +117,30 @@ class BlackList(object):
def incr(self, mc, key):
key = self.prefix + key
if not self.rl.check(mc, key):
mc.add(key, 1, time=self.ttl)
mc.set(key, 'true', time=self.ttl)
def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0):
"""Blacklist authentication failures.
The wrapped function should return one of the error codes from
protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often
than the specified threshold, the request parameters are added to
the blacklist.
When a request is blacklisted, this function will return
ERR_AUTHENTICATION_FAILURE.
"""
def blacklist(args_idx=None, count=0, period=0, ttl=0):
if not args_idx:
raise Exception('Must set args_idx')
bl = BlackList(count, period, ttl)
def decoratorfn(fn):
@functools.wraps(fn)
def _blacklist(*args, **kwargs):
parts = [args[x] for x in args_idx if args[x]]
if not parts:
key = key_fn(fn.__name__, args)
if not key:
return fn(*args, **kwargs)
key = ' '.join(parts)
if not bl.check(app.memcache, key):
app.logger.debug('blacklisted %s', key)
return protocol.ERR_AUTHENTICATION_FAILURE
......
from authserv import app
from authserv import auth
from authserv import protocol
from authserv.ratelimit import blacklist, ratelimit
from authserv.ratelimit import *
from flask import Flask, request, abort, make_response
@blacklist([0], count=5, period=600, ttl=43200)
@blacklist([4], count=5, period=600, ttl=43200)
@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200)
@blacklist_on_auth_failure(key_from_args(4), count=5, period=600, ttl=43200)
def _auth(username, service, password, otp_token, source_ip):
user = app.userdb.get_user(username, service)
if not user:
......@@ -15,8 +15,10 @@ def _auth(username, service, password, otp_token, source_ip):
@app.route('/api/1/auth', methods=('POST',))
@ratelimit(count=10, period=60, header='HTTP_X_FORWARDED_FOR')
@ratelimit(count=10, period=60, param='username')
@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
count=10, period=60)
@ratelimit_http_request(key_from_request(param='username'),
count=10, period=60)
def do_auth():
service = request.form.get('service')
username = request.form.get('username')
......
......@@ -19,6 +19,9 @@ class FakeMemcache(object):
if result and result[1] > self.t():
return result[0]
def set(self, key, value, time=0):
self.data[key] = (value, self.t() + time)
def incr(self, key):
now = self.t()
result = self.data.get(key)
......
......@@ -28,3 +28,24 @@ class RateLimitTest(unittest.TestCase):
if self.rl.check(self.mc, 'key'):
ok += 1
self.assertEquals(100, ok)
class BlackListTest(unittest.TestCase):
def setUp(self):
self.tick = 0
def _time():
return self.tick
self.mc = FakeMemcache(_time)
self.bl = BlackList(100, 10, 3600)
def test_block(self):
for i in xrange(100):
self.bl.incr(self.mc, 'key')
self.assertTrue(self.bl.check(self.mc, 'key'))
# 101th request is blocked
self.bl.incr(self.mc, 'key')
self.assertFalse(self.bl.check(self.mc, 'key'))
# check expiration of the block
self.tick = 3601
self.assertTrue(self.bl.check(self.mc, 'key'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment