diff --git a/authserv/app_common.py b/authserv/app_common.py index 32b6c0d4f1aa425be99e95de92ce63211f827ae1..de8a4faae1d0c725eb46149670476188e97880a8 100644 --- a/authserv/app_common.py +++ b/authserv/app_common.py @@ -4,7 +4,8 @@ from authserv import auth from authserv import protocol from authserv.ratelimit import * -_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'blacklisted', None) +_user_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'user_blacklisted', None) +_ip_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'ip_blacklisted', None) def check_ratelimit(request, username, source_ip): @@ -50,17 +51,25 @@ def _validate_username(username): def do_auth(username, service, shard, password, otp_token, source_ip, password_only=False): + with current_app.memcache.reserve() as mc: + return _do_auth(mc, username, service, shard, password, otp_token, + source_ip, password_only) + + +def _do_auth(mc, username, service, shard, password, otp_token, source_ip, + password_only): # Username must be an ASCII string. bl = AuthBlackList(current_app.config.get('BLACKLIST_COUNT', 5), current_app.config.get('BLACKLIST_PERIOD', 600), - current_app.config.get('BLACKLIST_TIME', 6*3600)) + current_app.config.get('BLACKLIST_TIME', 6*3600), + mc) if current_app.config.get('ENABLE_BLACKLIST'): if bl.is_blacklisted('u', username): - return _blacklisted + return _user_blacklisted if (source_ip and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST')) and bl.is_blacklisted('ip', source_ip)): - return _blacklisted + return _ip_blacklisted retval = protocol.ERR_AUTHENTICATION_FAILURE errmsg = 'user does not exist' diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py index dcba833a11489d04722df9fffc355607450f7f4c..43529b61ea8c2c24533456df43768584b06ce4e8 100644 --- a/authserv/ratelimit.py +++ b/authserv/ratelimit.py @@ -76,13 +76,14 @@ def ratelimit_http_request(req, value, tag='', count=0, period=0): rl = RateLimit(count, period) key = key_sep.join([req.method, req.path, value]) - return rl.check(current_app.memcache, key) + with current_app.memcache.reserve() as mc: + return rl.check(mc, key) class BlackList(object): """Block requests once a rate limit is triggered.""" - prefix = 'authserv/b:' + prefix = 'authserv/b/' def __init__(self, count, period, ttl): self.rl = RateLimit(count, period) @@ -99,17 +100,20 @@ class BlackList(object): mc.set(key, 'true', time=self.ttl) -class AuthBlackList(BlackList): +class AuthBlackList(object): + + def __init__(self, count, period, ttl, mc): + self.blacklist = BlackList(count, period, ttl) + self.mc = mc def is_blacklisted(self, tag, value): if not value: return False key = key_sep.join([tag, value]) - return not self.check(current_app.memcache, key) + return not self.blacklist.check(self.mc, key) def auth_failure(self, tag, value): if not value: return key = key_sep.join([tag, value]) - self.incr(current_app.memcache, key) - + self.blacklist.incr(self.mc, key) diff --git a/authserv/server.py b/authserv/server.py index 28b27e6275ec55623866cdee905ea40b8efa1e8f..7a18c8618ad556d174dc0a7ff2a6bd90a2274a2b 100644 --- a/authserv/server.py +++ b/authserv/server.py @@ -7,6 +7,7 @@ import logging import logging.handlers import optparse import os +import pylibmc import signal import sys from authserv import app_main @@ -27,14 +28,9 @@ def create_app(app, userdb=None, mc=None): app.userdb = userdb if not mc: - try: - import pylibmc - mc = pylibmc.Client( - app.config['MEMCACHE_ADDR'], binary=True) - except ImportError: - import memcache - mc = memcache.Client( - app.config['MEMCACHE_ADDR'], debug=0) + client = pylibmc.Client( + app.config['MEMCACHE_ADDR'], binary=True) + mc = pylibmc.ThreadMappedPool(client) app.memcache = mc return app diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py index 23264070a19df12f2239d92ce0db85bc30aca830..ef91c7e7b28d88df07344fde3a4af88cfa44ae12 100644 --- a/authserv/test/__init__.py +++ b/authserv/test/__init__.py @@ -1,3 +1,4 @@ +import contextlib import crypt import logging import os @@ -39,6 +40,16 @@ class FakeMemcache(object): return value +class FakeMemcachePool(object): + + def __init__(self, t=None): + self.mc = FakeMemcache(t) + + @contextlib.contextmanager + def reserve(self): + yield self.mc + + class FakeUser(model.User): def __init__(self, username, password=None, asps=None, otp_key=None, shard=None): diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py index f0d26993fc25a629f45362416193d3735905bb3e..bb746259b145e99c5f243fb406a73f2676abfe61 100644 --- a/authserv/test/test_app_main.py +++ b/authserv/test/test_app_main.py @@ -21,7 +21,7 @@ class ServerTest(unittest.TestCase): } app = server.create_app(app_main.app, userdb=FakeUserDb(self.users), - mc=FakeMemcache(_time)) + mc=FakeMemcachePool(_time)) app.config.update({ 'TESTING': True, 'DEBUG': True, diff --git a/authserv/test/test_app_nginx.py b/authserv/test/test_app_nginx.py index 089f9cfe819a8e2d23f6d8f992b225feb4a9c3e0..1021f051a6f9182a0f80a6f852f5e36b2ef8c1da 100644 --- a/authserv/test/test_app_nginx.py +++ b/authserv/test/test_app_nginx.py @@ -17,7 +17,7 @@ class ServerTest(unittest.TestCase): } app = server.create_app(app_nginx.app, userdb=FakeUserDb(self.users), - mc=FakeMemcache(_time)) + mc=FakeMemcachePool(_time)) app.config.update({ 'TESTING': True, 'DEBUG': True, diff --git a/authserv/test/test_integration.py b/authserv/test/test_integration.py index 9104858c3d605f80947b87b488e79a4d7a20fe90..03203341d5c3e325afdb67579d90184f16af0c1e 100644 --- a/authserv/test/test_integration.py +++ b/authserv/test/test_integration.py @@ -63,7 +63,7 @@ class SSLServerTest(unittest.TestCase): def _runserver(): app = server.create_app(app_main.app, userdb=FakeUserDb(cls.users), - mc=FakeMemcache(time.time)) + mc=FakeMemcachePool(time.time)) app.config.update({ 'TESTING': True, 'DEBUG': True, diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py index 0d74c1f2db357c474f0cd43fc2795b05bf45a31b..78536369dac73c08a8092389119ee849f4e23d68 100644 --- a/authserv/test/test_ratelimit.py +++ b/authserv/test/test_ratelimit.py @@ -8,25 +8,27 @@ class RateLimitTest(unittest.TestCase): self.tick = 0 def _time(): return self.tick - self.mc = FakeMemcache(_time) + self.mc = FakeMemcachePool(_time) self.rl = RateLimit(100, 10) def test_ratelimit_pass(self): n = 200 ok = 0 - for i in xrange(n): - self.tick = i - if self.rl.check(self.mc, 'key'): - ok += 1 + with self.mc.reserve() as mc: + for i in xrange(n): + self.tick = i + if self.rl.check(mc, 'key'): + ok += 1 self.assertEquals(n, ok) def test_ratelimit_fail(self): n = 200 ok = 0 - for i in xrange(n): - self.tick = i / 20 - if self.rl.check(self.mc, 'key'): - ok += 1 + with self.mc.reserve() as mc: + for i in xrange(n): + self.tick = i / 20 + if self.rl.check(mc, 'key'): + ok += 1 self.assertEquals(100, ok) diff --git a/setup.py b/setup.py index ce13fc5c57bb8bdfbc5669f69062c2360c327d11..5671e434a9713819cb2265b8ffbf58a8825601eb 100644 --- a/setup.py +++ b/setup.py @@ -4,12 +4,12 @@ from setuptools import setup, find_packages setup( name="authserv", - version="0.1.1", + version="0.1.2", description="Authentication server", author="Autistici/Inventati", author_email="info@autistici.org", url="https://git.autistici.org/ai/authserv", - install_requires=["gevent", "python-ldap", "Flask", "python-memcached"], + install_requires=["gevent", "python-ldap", "Flask", "pylibmc"], setup_requires=[], zip_safe=False, packages=find_packages(),