From ba83b3f8bceda3556a2fad4ee6f54da27b4c51a0 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Wed, 18 Jun 2014 22:07:43 +0100
Subject: [PATCH] enforce type correctness for memcache

---
 authserv/ratelimit.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/authserv/ratelimit.py b/authserv/ratelimit.py
index 20f87de..e1561ba 100644
--- a/authserv/ratelimit.py
+++ b/authserv/ratelimit.py
@@ -18,6 +18,12 @@ _whitelist_rx = [
     for x in WHITELIST]
 
 
+def _tostr(s):
+    if isinstance(s, unicode):
+        return s.encode('utf-8')
+    return s
+
+
 def whitelisted(ip):
     for rx in _whitelist_rx:
         if rx.search(ip):
@@ -79,7 +85,7 @@ class RateLimit(object):
         self.period = period
 
     def check(self, mc, key):
-        key = self.prefix + key
+        key = _tostr(self.prefix + key)
         try:
             result = mc.incr(key)
         except:
@@ -109,7 +115,7 @@ def ratelimit_http_request(key_fn, count=0, period=0):
     def decoratorfn(fn):
         @functools.wraps(fn)
         def _ratelimit(*args, **kwargs):
-            key = key_fn(fn.__name__, args)
+            key = _tostr(key_fn(fn.__name__, args))
             if key and not rl.check(app.memcache, key):
                 app.logger.debug('ratelimited: %s', key)
                 abort(503)
@@ -128,12 +134,12 @@ class BlackList(object):
         self.ttl = ttl
 
     def check(self, mc, key):
-        key = self.prefix + key
+        key = _tostr(self.prefix + key)
         result = mc.get(key)
         return result is None
 
     def incr(self, mc, key):
-        key = self.prefix + key
+        key = _tostr(self.prefix + key)
         if not self.rl.check(mc, key):
             mc.set(key, 'true', time=self.ttl)
 
@@ -155,7 +161,7 @@ def blacklist_on_auth_failure(key_fn, count=0, period=0, ttl=0, check_wl=False):
     def decoratorfn(fn):
         @functools.wraps(fn)
         def _blacklist(*args, **kwargs):
-            key = key_fn(fn.__name__, args)
+            key = _tostr(key_fn(fn.__name__, args))
             if not key:
                 return fn(*args, **kwargs)
 
-- 
GitLab