Skip to content
Snippets Groups Projects
Commit 59e04fff authored by ale's avatar ale
Browse files

enforce shard check on authentication

parent a0a93185
No related branches found
No related tags found
No related merge requests found
......@@ -36,12 +36,15 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
retval = protocol.ERR_AUTHENTICATION_FAILURE
errmsg = 'user does not exist'
shard = None
out_shard = None
user = current_app.userdb.get_user(username, service, shard)
if user:
retval, errmsg = auth.authenticate(
user, service, password, otp_token, source_ip)
shard = user.get_shard()
out_shard = user.get_shard()
if shard and out_shard != shard:
retval = protocol.ERR_AUTHENTICATION_FAILURE
errmsg = 'wrong shard'
if retval != protocol.OK and current_app.config.get('ENABLE_BLACKLIST'):
if user:
......@@ -50,4 +53,4 @@ def do_auth(username, service, shard, password, otp_token, source_ip):
and not whitelisted(source_ip, current_app.config.get('SOURCE_IP_WHITELIST'))):
bl.auth_failure('ip', source_ip)
return (retval, errmsg, shard)
return (retval, errmsg, out_shard)
......@@ -41,11 +41,12 @@ class FakeMemcache(object):
class FakeUser(model.User):
def __init__(self, username, password=None, asps=None, otp_key=None):
def __init__(self, username, password=None, asps=None, otp_key=None, shard=None):
self.username = username
self.password = crypt.crypt(password, '$6$abcdef1234567890')
self.asps = asps
self.otp_key = otp_key
self.shard = shard
def otp_enabled(self):
return self.otp_key is not None
......@@ -65,6 +66,9 @@ class FakeUser(model.User):
def get_name(self):
return self.username
def get_shard(self):
return self.shard
class FakeUserDb(model.UserDb):
......@@ -111,5 +115,3 @@ def free_port():
port = s.getsockname()[1]
s.close()
return port
......@@ -14,7 +14,7 @@ class ServerTest(unittest.TestCase):
def _time():
return self.tick
self.users = {
'user': FakeUser('user', 'pass'),
'user': FakeUser('user', 'pass', shard='a'),
}
app = server.create_app(app_main.app,
userdb=FakeUserDb(self.users),
......@@ -44,6 +44,25 @@ class ServerTest(unittest.TestCase):
self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
response.data)
def test_auth_sharded_ok(self):
response = self.app.post(
URL, data={
'username': 'user',
'password': 'pass',
'service': 'svc',
'shard': 'a'})
self.assertEquals(protocol.OK, response.data)
def test_auth_sharded_fail(self):
response = self.app.post(
URL, data={
'username': 'user',
'password': 'pass',
'service': 'svc',
'shard': 'b'})
self.assertEquals(protocol.ERR_AUTHENTICATION_FAILURE,
response.data)
def test_malformed_requests(self):
bad_data = [
{'username': 'user'},
......@@ -186,5 +205,3 @@ class ServerTest(unittest.TestCase):
'service': 'svc', 'source_ip': '127.0.0.1'})
self.assertEquals(200, response.status_code)
self.assertEquals(protocol.OK, response.data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment