From a2dac15b1adaa4ea426dc1a1f9a750bb0e8ded4a Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Thu, 6 Jul 2017 21:44:08 +0100
Subject: [PATCH] Use a ThreadMappedPool for memcache connections

Now requires pylibmc.
---
 authserv/app_common.py            | 17 +++++++++++++----
 authserv/ratelimit.py             | 16 ++++++++++------
 authserv/server.py                | 12 ++++--------
 authserv/test/__init__.py         | 11 +++++++++++
 authserv/test/test_app_main.py    |  2 +-
 authserv/test/test_app_nginx.py   |  2 +-
 authserv/test/test_integration.py |  2 +-
 authserv/test/test_ratelimit.py   | 20 +++++++++++---------
 setup.py                          |  4 ++--
 9 files changed, 54 insertions(+), 32 deletions(-)

diff --git a/authserv/app_common.py b/authserv/app_common.py
index 32b6c0d..de8a4fa 100644
--- a/authserv/app_common.py
+++ b/authserv/app_common.py
@@ -4,7 +4,8 @@ from authserv import auth
 from authserv import protocol
 from authserv.ratelimit import *
 
-_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'blacklisted', None)
+_user_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'user_blacklisted', None)
+_ip_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, 'ip_blacklisted', None)
 
 
 def check_ratelimit(request, username, source_ip):
@@ -50,17 +51,25 @@ def _validate_username(username):
 
 def do_auth(username, service, shard, password, otp_token, source_ip,
             password_only=False):
+    with current_app.memcache.reserve() as mc:
+        return _do_auth(mc, username, service, shard, password, otp_token,
+                        source_ip, password_only)
+
+
+def _do_auth(mc, username, service, shard, password, otp_token, source_ip,
+            password_only):
     # Username must be an ASCII string.
     bl = AuthBlackList(current_app.config.get('BLACKLIST_COUNT', 5),
                        current_app.config.get('BLACKLIST_PERIOD', 600),
-                       current_app.config.get('BLACKLIST_TIME', 6*3600))
+                       current_app.config.get('BLACKLIST_TIME', 6*3600),
+                       mc)
     if current_app.config.get('ENABLE_BLACKLIST'):
         if bl.is_blacklisted('u', username):
-            return _blacklisted
+            return _user_blacklisted
         if (source_ip
             and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))
             and bl.is_blacklisted('ip', source_ip)):
-            return _blacklisted
+            return _ip_blacklisted
 
     retval = protocol.ERR_AUTHENTICATION_FAILURE
     errmsg = 'user does not exist'
diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py
index dcba833..43529b6 100644
--- a/authserv/ratelimit.py
+++ b/authserv/ratelimit.py
@@ -76,13 +76,14 @@ def ratelimit_http_request(req, value, tag='', count=0, period=0):
 
     rl = RateLimit(count, period)
     key = key_sep.join([req.method, req.path, value])
-    return rl.check(current_app.memcache, key)
+    with current_app.memcache.reserve() as mc:
+        return rl.check(mc, key)
 
 
 class BlackList(object):
     """Block requests once a rate limit is triggered."""
 
-    prefix = 'authserv/b:'
+    prefix = 'authserv/b/'
 
     def __init__(self, count, period, ttl):
         self.rl = RateLimit(count, period)
@@ -99,17 +100,20 @@ class BlackList(object):
             mc.set(key, 'true', time=self.ttl)
 
 
-class AuthBlackList(BlackList):
+class AuthBlackList(object):
+
+    def __init__(self, count, period, ttl, mc):
+        self.blacklist = BlackList(count, period, ttl)
+        self.mc = mc
 
     def is_blacklisted(self, tag, value):
         if not value:
             return False
         key = key_sep.join([tag, value])
-        return not self.check(current_app.memcache, key)
+        return not self.blacklist.check(self.mc, key)
 
     def auth_failure(self, tag, value):
         if not value:
             return
         key = key_sep.join([tag, value])
-        self.incr(current_app.memcache, key)
-
+        self.blacklist.incr(self.mc, key)
diff --git a/authserv/server.py b/authserv/server.py
index 28b27e6..7a18c86 100644
--- a/authserv/server.py
+++ b/authserv/server.py
@@ -7,6 +7,7 @@ import logging
 import logging.handlers
 import optparse
 import os
+import pylibmc
 import signal
 import sys
 from authserv import app_main
@@ -27,14 +28,9 @@ def create_app(app, userdb=None, mc=None):
     app.userdb = userdb
 
     if not mc:
-        try:
-            import pylibmc
-            mc = pylibmc.Client(
-                app.config['MEMCACHE_ADDR'], binary=True)
-        except ImportError:
-            import memcache
-            mc = memcache.Client(
-                app.config['MEMCACHE_ADDR'], debug=0)
+        client = pylibmc.Client(
+            app.config['MEMCACHE_ADDR'], binary=True)
+        mc = pylibmc.ThreadMappedPool(client)
     app.memcache = mc
 
     return app
diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py
index 2326407..ef91c7e 100644
--- a/authserv/test/__init__.py
+++ b/authserv/test/__init__.py
@@ -1,3 +1,4 @@
+import contextlib
 import crypt
 import logging
 import os
@@ -39,6 +40,16 @@ class FakeMemcache(object):
         return value
 
 
+class FakeMemcachePool(object):
+
+    def __init__(self, t=None):
+        self.mc = FakeMemcache(t)
+
+    @contextlib.contextmanager
+    def reserve(self):
+        yield self.mc
+
+
 class FakeUser(model.User):
 
     def __init__(self, username, password=None, asps=None, otp_key=None, shard=None):
diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py
index f0d2699..bb74625 100644
--- a/authserv/test/test_app_main.py
+++ b/authserv/test/test_app_main.py
@@ -21,7 +21,7 @@ class ServerTest(unittest.TestCase):
             }
         app = server.create_app(app_main.app,
                                 userdb=FakeUserDb(self.users),
-                                mc=FakeMemcache(_time))
+                                mc=FakeMemcachePool(_time))
         app.config.update({
                 'TESTING': True,
                 'DEBUG': True,
diff --git a/authserv/test/test_app_nginx.py b/authserv/test/test_app_nginx.py
index 089f9cf..1021f05 100644
--- a/authserv/test/test_app_nginx.py
+++ b/authserv/test/test_app_nginx.py
@@ -17,7 +17,7 @@ class ServerTest(unittest.TestCase):
             }
         app = server.create_app(app_nginx.app,
                                 userdb=FakeUserDb(self.users),
-                                mc=FakeMemcache(_time))
+                                mc=FakeMemcachePool(_time))
         app.config.update({
                 'TESTING': True,
                 'DEBUG': True,
diff --git a/authserv/test/test_integration.py b/authserv/test/test_integration.py
index 9104858..0320334 100644
--- a/authserv/test/test_integration.py
+++ b/authserv/test/test_integration.py
@@ -63,7 +63,7 @@ class SSLServerTest(unittest.TestCase):
         def _runserver():
             app = server.create_app(app_main.app,
                                     userdb=FakeUserDb(cls.users),
-                                    mc=FakeMemcache(time.time))
+                                    mc=FakeMemcachePool(time.time))
             app.config.update({
                 'TESTING': True,
                 'DEBUG': True,
diff --git a/authserv/test/test_ratelimit.py b/authserv/test/test_ratelimit.py
index 0d74c1f..7853636 100644
--- a/authserv/test/test_ratelimit.py
+++ b/authserv/test/test_ratelimit.py
@@ -8,25 +8,27 @@ class RateLimitTest(unittest.TestCase):
         self.tick = 0
         def _time():
             return self.tick
-        self.mc = FakeMemcache(_time)
+        self.mc = FakeMemcachePool(_time)
         self.rl = RateLimit(100, 10)
 
     def test_ratelimit_pass(self):
         n = 200
         ok = 0
-        for i in xrange(n):
-            self.tick = i
-            if self.rl.check(self.mc, 'key'):
-                ok += 1
+        with self.mc.reserve() as mc:
+            for i in xrange(n):
+                self.tick = i
+                if self.rl.check(mc, 'key'):
+                    ok += 1
         self.assertEquals(n, ok)
 
     def test_ratelimit_fail(self):
         n = 200
         ok = 0
-        for i in xrange(n):
-            self.tick = i / 20
-            if self.rl.check(self.mc, 'key'):
-                ok += 1
+        with self.mc.reserve() as mc:
+            for i in xrange(n):
+                self.tick = i / 20
+                if self.rl.check(mc, 'key'):
+                    ok += 1
         self.assertEquals(100, ok)
 
 
diff --git a/setup.py b/setup.py
index ce13fc5..5671e43 100644
--- a/setup.py
+++ b/setup.py
@@ -4,12 +4,12 @@ from setuptools import setup, find_packages
 
 setup(
     name="authserv",
-    version="0.1.1",
+    version="0.1.2",
     description="Authentication server",
     author="Autistici/Inventati",
     author_email="info@autistici.org",
     url="https://git.autistici.org/ai/authserv",
-    install_requires=["gevent", "python-ldap", "Flask", "python-memcached"],
+    install_requires=["gevent", "python-ldap", "Flask", "pylibmc"],
     setup_requires=[],
     zip_safe=False,
     packages=find_packages(),
-- 
GitLab