diff --git a/authserv/app_common.py b/authserv/app_common.py index 2188aec51cea0e1e8b30c882ed188620e07b5456..9fcbfc6f20114a3d5188faa515c77651d06c18eb 100644 --- a/authserv/app_common.py +++ b/authserv/app_common.py @@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip): retval = protocol.ERR_AUTHENTICATION_FAILURE errmsg = 'user does not exist' - shard = None + out_shard = None user = current_app.userdb.get_user(username, service, shard) if user: retval, errmsg = auth.authenticate( user, service, password, otp_token, source_ip) - shard = user.get_shard() + out_shard = user.get_shard() + if shard and out_shard != shard: + retval = protocol.ERR_AUTHENTICATION_FAILURE + errmsg = 'wrong shard' if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'): if user: @@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip): and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))): bl.auth_failure('ip', source_ip) - return (retval, errmsg, shard) + return (retval, errmsg, out_shard) diff --git a/authserv/test/__init__.py b/authserv/test/__init__.py index 6df5a86db3533d6f1ea9c76932ebb589a32f8668..23264070a19df12f2239d92ce0db85bc30aca830 100644 --- a/authserv/test/__init__.py +++ b/authserv/test/__init__.py @@ -41,11 +41,12 @@ class FakeMemcache(object): class FakeUser(model.User): - def __init__(self, username, password=None, asps=None, otp_key=None): + def __init__(self, username, password=None, asps=None, otp_key=None, shard=None): self.username = username self.password = crypt.crypt(password, '$6$abcdef1234567890') self.asps = asps self.otp_key = otp_key + self.shard = shard def otp_enabled(self): return self.otp_key is not None @@ -65,6 +66,9 @@ class FakeUser(model.User): def get_name(self): return self.username + def get_shard(self): + return self.shard + class FakeUserDb(model.UserDb): @@ -111,5 +115,3 @@ def free_port(): port = s.getsockname()[1] s.close() return port - - diff --git a/authserv/test/test_app_main.py b/authserv/test/test_app_main.py index c637bf5558daca1be0f2f0966f8d8c3e2b1cf2f9..89c32057092d941782253af71844779e99cfb473 100644 --- a/authserv/test/test_app_main.py +++ b/authserv/test/test_app_main.py @@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase): def _time(): return self.tick self.users = { - 'user': FakeUser('user', 'pass'), + 'user': FakeUser('user', 'pass', shard='a'), } app = server.create_app(app_main.app, userdb=FakeUserDb(self.users), @@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase): self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE, response.data) + def test_auth_sharded_ok(self): + response = self.app.post( + URL, data={ + 'username': 'user', + 'password': 'pass', + 'service': 'svc', + 'shard': 'a'}) + self.assertEquals(protocol.OK, response.data) + + def test_auth_sharded_fail(self): + response = self.app.post( + URL, data={ + 'username': 'user', + 'password': 'pass', + 'service': 'svc', + 'shard': 'b'}) + self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE, + response.data) + def test_malformed_requests(self): bad_data = [ {'username': 'user'}, @@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase): 'service': 'svc', 'source_ip': '127.0.0.1'}) self.assertEquals(200, response.status_code) self.assertEquals(protocol.OK, response.data) - -