From 59e04fff18ae2803dbd11a70d0435f15f34daa97 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 9 Nov 2014 17:29:05 +0000
Subject: [PATCH] enforce shard check on authentication

---
 authserv/app_common.py         |  9 ++++++---
 authserv/test/__init__.py      |  8 +++++---
 authserv/test/test_app_main.py | 23 ++++++++++++++++++++---
 3 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/authserv/app_common.py b/authserv/app_common.py
index 2188aec..9fcbfc6 100644
--- a/authserv/app_common.py
+++ b/authserv/app_common.py
@@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
 
     retval = protocol.ERR_AUTHENTICATION_FAILURE
     errmsg = 'user does not exist'
-    shard = None
+    out_shard = None
     user = current_app.userdb.get_user(username, service, shard)
     if user:
         retval, errmsg = auth.authenticate(
             user, service, password, otp_token, source_ip)
-        shard = user.get_shard()
+        out_shard = user.get_shard()
+        if shard and out_shard != shard:
+            retval = protocol.ERR_AUTHENTICATION_FAILURE
+            errmsg = 'wrong shard'
 
     if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'):
         if user:
@@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
             and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))):
             bl.auth_failure('ip', source_ip)
 
-    return (retval, errmsg, shard)
+    return (retval, errmsg, out_shard)
diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py
index 6df5a86..2326407 100644
--- a/authserv/test/__init__.py
+++ b/authserv/test/__init__.py
@@ -41,11 +41,12 @@ class FakeMemcache(object):
 
 class FakeUser(model.User):
 
-    def __init__(self, username, password=None, asps=None, otp_key=None):
+    def __init__(self, username, password=None, asps=None, otp_key=None, shard=None):
         self.username = username
         self.password = crypt.crypt(password, '$6$abcdef1234567890')
         self.asps = asps
         self.otp_key = otp_key
+        self.shard = shard
 
     def otp_enabled(self):
         return self.otp_key is not None
@@ -65,6 +66,9 @@ class FakeUser(model.User):
     def get_name(self):
         return self.username
 
+    def get_shard(self):
+        return self.shard
+
 
 class FakeUserDb(model.UserDb):
 
@@ -111,5 +115,3 @@ def free_port():
     port = s.getsockname()[1]
     s.close()
     return port
-
-
diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py
index c637bf5..89c3205 100644
--- a/authserv/test/test_app_main.py
+++ b/authserv/test/test_app_main.py
@@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase):
         def _time():
             return self.tick
         self.users = {
-            'user': FakeUser('user', 'pass'),
+            'user': FakeUser('user', 'pass', shard='a'),
             }
         app = server.create_app(app_main.app,
                                 userdb=FakeUserDb(self.users),
@@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase):
         self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
                           response.data)
 
+    def test_auth_sharded_ok(self):
+        response = self.app.post(
+            URL, data={
+                'username': 'user',
+                'password': 'pass',
+                'service': 'svc',
+                'shard': 'a'})
+        self.assertEquals(protocol.OK, response.data)
+
+    def test_auth_sharded_fail(self):
+        response = self.app.post(
+            URL, data={
+                'username': 'user',
+                'password': 'pass',
+                'service': 'svc',
+                'shard': 'b'})
+        self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
+                          response.data)
+
     def test_malformed_requests(self):
         bad_data = [
             {'username': 'user'},
@@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase):
             'service': 'svc', 'source_ip': '127.0.0.1'})
         self.assertEquals(200, response.status_code)
         self.assertEquals(protocol.OK, response.data)
-
-        
-- 
GitLab