From 59e04fff18ae2803dbd11a70d0435f15f34daa97 Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Sun, 9 Nov 2014 17:29:05 +0000 Subject: [PATCH] enforce shard check on authentication --- authserv/app_common.py | 9 ++++++--- authserv/test/__init__.py | 8 +++++--- authserv/test/test_app_main.py | 23 ++++++++++++++++++++--- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/authserv/app_common.py b/authserv/app_common.py index 2188aec..9fcbfc6 100644 --- a/authserv/app_common.py +++ b/authserv/app_common.py @@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip): retval = protocol.ERR_AUTHENTICATION_FAILURE errmsg = 'user does not exist' - shard = None + out_shard = None user = current_app.userdb.get_user(username, service, shard) if user: retval, errmsg = auth.authenticate( user, service, password, otp_token, source_ip) - shard = user.get_shard() + out_shard = user.get_shard() + if shard and out_shard != shard: + retval = protocol.ERR_AUTHENTICATION_FAILURE + errmsg = 'wrong shard' if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'): if user: @@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip): and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))): bl.auth_failure('ip', source_ip) - return (retval, errmsg, shard) + return (retval, errmsg, out_shard) diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py index 6df5a86..2326407 100644 --- a/authserv/test/__init__.py +++ b/authserv/test/__init__.py @@ -41,11 +41,12 @@ class FakeMemcache(object): class FakeUser(model.User): - def __init__(self, username, password=None, asps=None, otp_key=None): + def __init__(self, username, password=None, asps=None, otp_key=None, shard=None): self.username = username self.password = crypt.crypt(password, '$6$abcdef1234567890') self.asps = asps self.otp_key = otp_key + self.shard = shard def otp_enabled(self): return self.otp_key is not None @@ -65,6 +66,9 @@ class FakeUser(model.User): def get_name(self): return self.username + def get_shard(self): + return self.shard + class FakeUserDb(model.UserDb): @@ -111,5 +115,3 @@ def free_port(): port = s.getsockname()[1] s.close() return port - - diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py index c637bf5..89c3205 100644 --- a/authserv/test/test_app_main.py +++ b/authserv/test/test_app_main.py @@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase): def _time(): return self.tick self.users = { - 'user': FakeUser('user', 'pass'), + 'user': FakeUser('user', 'pass', shard='a'), } app = server.create_app(app_main.app, userdb=FakeUserDb(self.users), @@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase): self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE, response.data) + def test_auth_sharded_ok(self): + response = self.app.post( + URL, data={ + 'username': 'user', + 'password': 'pass', + 'service': 'svc', + 'shard': 'a'}) + self.assertEquals(protocol.OK, response.data) + + def test_auth_sharded_fail(self): + response = self.app.post( + URL, data={ + 'username': 'user', + 'password': 'pass', + 'service': 'svc', + 'shard': 'b'}) + self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE, + response.data) + def test_malformed_requests(self): bad_data = [ {'username': 'user'}, @@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase): 'service': 'svc', 'source_ip': '127.0.0.1'}) self.assertEquals(200, response.status_code) self.assertEquals(protocol.OK, response.data) - - -- GitLab