Skip to content
Snippets Groups Projects
Commit 9f2bc780 authored by ale's avatar ale
Browse files

refactor ratelimiting and blacklists

Whitelisting source IPs now works properly.
parent 10b9aa55
No related branches found
No related tags found
No related merge requests found
from flask import current_app
from flask import abort, current_app
from authserv import auth
from authserv import protocol
from authserv.ratelimit import *
......@@ -6,13 +6,48 @@ from authserv.ratelimit import *
_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'blacklisted', None)
@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200,
bl_return_value=_blacklisted)
@blacklist_on_auth_failure(key_from_args(5), count=5, period=600, ttl=43200,
check_wl=True, bl_return_value=_blacklisted)
def check_ratelimit(request, username, source_ip):
if current_app.config.get('ENABLE_RATELIMIT'):
if not ratelimit_http_request(
request, username, tag='u',
count=current_app.config.get('RATELIMIT_USER_COUNT', 10),
period=current_app.config.get('RATELIMIT_USER_PERIOD', 60)):
abort(503)
if (source_ip
and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))
and not ratelimit_http_request(
request, source_ip, tag='ip',
count=current_app.config.get('RATELIMIT_SOURCEIP_COUNT', 10),
period=current_app.config.get('RATELIMIT_SOURCEIP_PERIOD', 60))):
abort(503)
def do_auth(username, service, shard, password, otp_token, source_ip):
bl = AuthBlackList(current_app.config.get('BLACKLIST_COUNT', 5),
current_app.config.get('BLACKLIST_PERIOD', 600),
current_app.config.get('BLACKLIST_TIME', 6*3600))
if current_app.config.get('ENABLE_BLACKLIST'):
if bl.is_blacklisted('u', username):
return _blacklisted
if (source_ip
and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))
and bl.is_blacklisted('ip', source_ip)):
return _blacklisted
retval = protocol.ERR_AUTHENTICATION_FAILURE
errmsg = 'user does not exist'
shard = None
user = current_app.userdb.get_user(username, service, shard)
if not user:
return protocol.ERR_AUTHENTICATION_FAILURE, 'user does not exist', None
result, errmsg = auth.authenticate(user, service, password, otp_token, source_ip)
return (result, errmsg, user.get_shard())
if user:
retval, errmsg = auth.authenticate(
user, service, password, otp_token, source_ip)
shard = user.get_shard()
if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'):
if user:
bl.auth_failure('u', username)
if (source_ip
and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))):
bl.auth_failure('ip', source_ip)
return (retval, errmsg, shard)
from flask import Flask, request, abort, make_response
from authserv import auth
from authserv import protocol
from authserv.ratelimit import *
from authserv.app_common import do_auth
from authserv.app_common import do_auth, check_ratelimit
app = Flask(__name__)
......@@ -18,10 +17,6 @@ app = Flask(__name__)
# address. So we're practically only going to use X-Forwarded-For for
# requests that reach our frontends via HTTP.
@app.route('/api/1/auth', methods=('POST',))
@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
count=10, period=60)
@ratelimit_http_request(key_from_request(param='username'),
count=10, period=60)
def api_auth():
service = request.form.get('service')
username = request.form.get('username')
......@@ -33,6 +28,8 @@ def api_auth():
if not service or not username:
abort(400)
check_ratelimit(request, username, source_ip)
try:
result, errmsg, unused_shard = do_auth(
username, service, shard, password, otp_token, source_ip)
......
......@@ -2,7 +2,7 @@ import socket
import threading
import urllib
from flask import Flask, request, abort, make_response
from authserv.app_common import do_auth
from authserv.app_common import do_auth, check_ratelimit
app = Flask(__name__)
......@@ -31,6 +31,8 @@ def do_nginx_http_auth():
except ValueError:
n_attempt = 1
check_ratelimit(request, username, source_ip)
try:
auth_status, errmsg, shard = do_auth(
username, service, None, password, None, source_ip)
......
import functools
import re
import threading
from flask import abort, request, current_app
from authserv import protocol
key_sep = '/'
WHITELIST = [
DEFAULT_WHITELIST = [
'127.0.0.1',
'::1',
'172.16.1.*',
]
_whitelist_rx = [
re.compile('%s%s$' % (key_sep, x.replace('.', r'\.').replace('*', '.*')))
for x in WHITELIST]
def _tostr(s):
if isinstance(s, unicode):
......@@ -23,55 +20,24 @@ def _tostr(s):
return s
def whitelisted(ip):
for rx in _whitelist_rx:
if rx.search(ip):
return True
return False
_rxcache = {}
_rxcache_lock = threading.Lock()
def key_from_request(header=None, param=None):
"""Build a key from HTTP request headers and query arguments."""
if not header and not param:
raise Exception('Must set either header or param')
required_parts = 2
if header:
required_parts += 1
if param:
required_parts += 1
def _key(prefix, args):
parts = [request.method, request.path]
if header:
value = request.environ.get(header)
if value:
parts.append('%s=%s' % (header, value))
if param:
if request.method == 'POST':
value = request.form.get(param)
else:
value = request.args.get(param)
if value:
parts.append('%s=%s' % (param, value))
if len(parts) >= required_parts:
return prefix + key_sep + key_sep.join(parts)
return None
return _key
def key_from_args(*args_idx):
"""Build a key from the wrapped function's arguments.
This function should be passed the positional index of the
arguments that will be used to build the key.
"""
def _key(prefix, args):
parts = [args[x] for x in args_idx if args[x]]
if not parts:
return None
return prefix + key_sep + key_sep.join(parts)
return _key
def _rxcompile(s):
with _rxcache_lock:
if s not in _rxcache:
_rxcache[s] = re.compile('^%s$' % s.replace('.', r'\.').replace('*', '.*'))
return _rxcache[s]
def whitelisted(value, wl):
if not wl:
wl = DEFAULT_WHITELIST
for rx in wl:
if _rxcompile(rx).search(value):
return True
return False
class RateLimit(object):
......@@ -102,27 +68,15 @@ class RateLimit(object):
return result <= self.count
def ratelimit_http_request(key_fn, count=0, period=0):
def ratelimit_http_request(req, value, tag='', count=0, period=0):
"""Rate limit an HTTP request handler.
If the rate limit is triggered, the request will generate a HTTP
503 error.
Returns False if the specified request triggers the rate limit.
"""
rl = RateLimit(count, period)
def decoratorfn(fn):
@functools.wraps(fn)
def _ratelimit(*args, **kwargs):
if not current_app.config.get('ENABLE_RATELIMIT'):
return fn(*args, **kwargs)
key = _tostr(key_fn(fn.__name__, args))
if key and not rl.check(current_app.memcache, key):
current_app.logger.debug('ratelimited: %s', key)
abort(503)
return fn(*args, **kwargs)
return _ratelimit
return decoratorfn
key = key_sep.join([req.method, req.path, value])
return rl.check(current_app.memcache, key)
class BlackList(object):
......@@ -145,37 +99,17 @@ class BlackList(object):
mc.set(key, 'true', time=self.ttl)
def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0, check_wl=False,
bl_return_value=protocol.ERR_AUTHENTICATION_FAILURE):
"""Blacklist authentication failures.
class AuthBlackList(BlackList):
The wrapped function should return one of the error codes from
protocol.py. If it returns ERR_AUTHENTICATION_FAILURE more often
than the specified threshold, the request parameters are added to
the blacklist.
def is_blacklisted(self, tag, value):
if not value:
return False
key = key_sep.join([tag, value])
return not self.check(current_app.memcache, key)
When a request is blacklisted, this function will return
ERR_AUTHENTICATION_FAILURE.
"""
def auth_failure(self, tag, value):
if not value:
return
key = key_sep.join([tag, value])
self.incr(current_app.memcache, key)
bl = BlackList(count, period, ttl)
def decoratorfn(fn):
@functools.wraps(fn)
def _blacklist(*args, **kwargs):
if not current_app.config.get('ENABLE_BLACKLIST'):
return fn(*args, **kwargs)
key = _tostr(key_fn(fn.__name__, args))
if not key:
return fn(*args, **kwargs)
if ((not check_wl or not whitelisted(key))
and not bl.check(current_app.memcache, key)):
current_app.logger.info('blacklisted %s', key)
return bl_return_value
result = fn(*args, **kwargs)
if result != protocol.OK:
bl.incr(current_app.memcache, key)
return result
return _blacklist
return decoratorfn
......@@ -23,6 +23,7 @@ class ServerTest(unittest.TestCase):
'TESTING': True,
'DEBUG': True,
'ENABLE_BLACKLIST': True,
'ENABLE_RATELIMIT': True,
})
self.app = app.test_client()
......@@ -57,7 +58,7 @@ class ServerTest(unittest.TestCase):
for i in xrange(n):
self.users['user%d' % i] = FakeUser('user%d' % i, 'pass')
def disabledtest_ratelimit_by_client_ip(self):
def test_ratelimit_by_client_ip(self):
n = 20
ok = 0
self._create_many_users(n)
......@@ -65,21 +66,21 @@ class ServerTest(unittest.TestCase):
response = self.app.post(URL, data={
'username': 'user%d' % i,
'password': 'pass',
'service': 'svc'}, headers={
'X-Forwarded-For': '1.2.3.4'})
'service': 'svc',
'source_ip': '1.2.3.4'})
if response.status_code == 200:
ok += 1
self.assertEquals(10, ok)
def disabledtest_ratelimit_by_username(self):
def test_ratelimit_by_username(self):
n = 20
ok = 0
for i in xrange(n):
response = self.app.post(URL, data={
'username': 'user',
'password': 'pass',
'service': 'svc'}, headers={
'X-Forwarded-For': '1.2.3.%d' %i})
'service': 'svc',
'source_ip': '1.2.3.%d' %i})
if response.status_code == 200:
ok += 1
self.assertEquals(10, ok)
......@@ -87,7 +88,7 @@ class ServerTest(unittest.TestCase):
def test_ratelimit_ignores_unset_fields(self):
# This test will fail if the @ratelimit decorator does not
# skip its check if one of the fields is unset (in this case,
# the X-Forwarded-For header).
# 'source_ip').
n = 20
ok = 0
self._create_many_users(n)
......
......@@ -53,9 +53,9 @@ class BlackListTest(unittest.TestCase):
class WhitelistTest(unittest.TestCase):
def test_whitelist(self):
self.assertTrue(whitelisted('auth/127.0.0.1'))
self.assertTrue(whitelisted('auth/172.16.1.5'))
self.assertTrue(whitelisted('auth/::1'))
def test_default_whitelist(self):
self.assertTrue(whitelisted('127.0.0.1', None))
self.assertTrue(whitelisted('172.16.1.5', None))
self.assertTrue(whitelisted('::1', None))
self.assertFalse(whitelisted('auth/1.2.3.4'))
self.assertFalse(whitelisted('1.2.3.4', None))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment