Skip to content
Snippets Groups Projects
Commit 59e04fff authored by ale's avatar ale
Browse files

enforce shard check on authentication

parent a0a93185
No related branches found
No related tags found
No related merge requests found
...@@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip): ...@@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
retval = protocol.ERR_AUTHENTICATION_FAILURE retval = protocol.ERR_AUTHENTICATION_FAILURE
errmsg = 'user does not exist' errmsg = 'user does not exist'
shard = None out_shard = None
user = current_app.userdb.get_user(username, service, shard) user = current_app.userdb.get_user(username, service, shard)
if user: if user:
retval, errmsg = auth.authenticate( retval, errmsg = auth.authenticate(
user, service, password, otp_token, source_ip) user, service, password, otp_token, source_ip)
shard = user.get_shard() out_shard = user.get_shard()
if shard and out_shard != shard:
retval = protocol.ERR_AUTHENTICATION_FAILURE
errmsg = 'wrong shard'
if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'): if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'):
if user: if user:
...@@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip): ...@@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))): and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))):
bl.auth_failure('ip', source_ip) bl.auth_failure('ip', source_ip)
return (retval, errmsg, shard) return (retval, errmsg, out_shard)
...@@ -41,11 +41,12 @@ class FakeMemcache(object): ...@@ -41,11 +41,12 @@ class FakeMemcache(object):
class FakeUser(model.User): class FakeUser(model.User):
def __init__(self, username, password=None, asps=None, otp_key=None): def __init__(self, username, password=None, asps=None, otp_key=None, shard=None):
self.username = username self.username = username
self.password = crypt.crypt(password, '$6$abcdef1234567890') self.password = crypt.crypt(password, '$6$abcdef1234567890')
self.asps = asps self.asps = asps
self.otp_key = otp_key self.otp_key = otp_key
self.shard = shard
def otp_enabled(self): def otp_enabled(self):
return self.otp_key is not None return self.otp_key is not None
...@@ -65,6 +66,9 @@ class FakeUser(model.User): ...@@ -65,6 +66,9 @@ class FakeUser(model.User):
def get_name(self): def get_name(self):
return self.username return self.username
def get_shard(self):
return self.shard
class FakeUserDb(model.UserDb): class FakeUserDb(model.UserDb):
...@@ -111,5 +115,3 @@ def free_port(): ...@@ -111,5 +115,3 @@ def free_port():
port = s.getsockname()[1] port = s.getsockname()[1]
s.close() s.close()
return port return port
...@@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase): ...@@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase):
def _time(): def _time():
return self.tick return self.tick
self.users = { self.users = {
'user': FakeUser('user', 'pass'), 'user': FakeUser('user', 'pass', shard='a'),
} }
app = server.create_app(app_main.app, app = server.create_app(app_main.app,
userdb=FakeUserDb(self.users), userdb=FakeUserDb(self.users),
...@@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase): ...@@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase):
self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE, self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
response.data) response.data)
def test_auth_sharded_ok(self):
response = self.app.post(
URL, data={
'username': 'user',
'password': 'pass',
'service': 'svc',
'shard': 'a'})
self.assertEquals(protocol.OK, response.data)
def test_auth_sharded_fail(self):
response = self.app.post(
URL, data={
'username': 'user',
'password': 'pass',
'service': 'svc',
'shard': 'b'})
self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
response.data)
def test_malformed_requests(self): def test_malformed_requests(self):
bad_data = [ bad_data = [
{'username': 'user'}, {'username': 'user'},
...@@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase): ...@@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase):
'service': 'svc', 'source_ip': '127.0.0.1'}) 'service': 'svc', 'source_ip': '127.0.0.1'})
self.assertEquals(200, response.status_code) self.assertEquals(200, response.status_code)
self.assertEquals(protocol.OK, response.data) self.assertEquals(protocol.OK, response.data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment