Skip to content
Snippets Groups Projects
Commit 134389bf authored by ale's avatar ale
Browse files

implement sharded services (server side)

parent 107ab9ee
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,13 @@ class Error(Exception): ...@@ -11,6 +11,13 @@ class Error(Exception):
pass pass
def _expandvars(s, vars, quotefn):
if not s:
return s
qvars = dict((k, quotefn(str(v))) for k, v in vars.iteritems())
return s % qvars
class UserDb(model.UserDb): class UserDb(model.UserDb):
ldap_attrs = [ ldap_attrs = [
...@@ -33,7 +40,7 @@ class UserDb(model.UserDb): ...@@ -33,7 +40,7 @@ class UserDb(model.UserDb):
yield c yield c
c.unbind_s() c.unbind_s()
def _query_user(self, username, service): def _query_user(self, username, service, shard):
# Allow referencing a service to another, by specifying a # Allow referencing a service to another, by specifying a
# string rather than a dictionary as the value. If you build # string rather than a dictionary as the value. If you build
# infinite loops this way, it's your fault. # infinite loops this way, it's your fault.
...@@ -44,6 +51,11 @@ class UserDb(model.UserDb): ...@@ -44,6 +51,11 @@ class UserDb(model.UserDb):
logging.error('unknown service "%s"', service) logging.error('unknown service "%s"', service)
return None return None
# Arguments used for variable substitution in the LDAP filters.
ldap_vars = {'user': username,
'shard': shard,
'service': service}
with self._conn() as c: with self._conn() as c:
# LDAP queries can be built in two ways: # LDAP queries can be built in two ways:
# #
...@@ -56,12 +68,13 @@ class UserDb(model.UserDb): ...@@ -56,12 +68,13 @@ class UserDb(model.UserDb):
# 'filter' is required. # 'filter' is required.
# #
if 'dn' in ldap_params: if 'dn' in ldap_params:
base = ldap_params['dn'].replace('%s', escape_dn_chars(username)) base = _expandvars(ldap_params['dn'], ldap_vars, escape_dn_chars)
filt = ldap_params.get('filter', '(objectClass=*)').replace('%s', escape_filter_chars(username)) filt = _expandvars(ldap_params.get('filter', '(objectClass=*)'),
ldap_vars, escape_filter_chars)
scope = ldap.SCOPE_BASE scope = ldap.SCOPE_BASE
else: else:
base = ldap_params['base'].replace('%s', escape_dn_chars(username)) base = _expandvars(ldap_params['base'], ldap_vars, escape_dn_chars)
filt = ldap_params['filter'].replace('%s', escape_filter_chars(username)) filt = _expandvars(ldap_params['filter'], ldap_vars, escape_filter_chars)
scope = ldap.SCOPE_SUBTREE scope = ldap.SCOPE_SUBTREE
logging.debug('ldap search: base=%s, scope=%s, filt=%s', base, scope, filt) logging.debug('ldap search: base=%s, scope=%s, filt=%s', base, scope, filt)
result = c.search_s(base, scope, filt, self.ldap_attrs) result = c.search_s(base, scope, filt, self.ldap_attrs)
...@@ -73,9 +86,9 @@ class UserDb(model.UserDb): ...@@ -73,9 +86,9 @@ class UserDb(model.UserDb):
return User(username, result[0][0], result[0][1]) return User(username, result[0][0], result[0][1])
def get_user(self, username, service): def get_user(self, username, service, shard):
try: try:
return self._query_user(username, service) return self._query_user(username, service, shard)
except (Error, ldap.LDAPError), e: except (Error, ldap.LDAPError), e:
logging.error('userdb error: %s', e) logging.error('userdb error: %s', e)
return None return None
......
...@@ -5,7 +5,7 @@ class UserDb(object): ...@@ -5,7 +5,7 @@ class UserDb(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def get_user(self, username, service): def get_user(self, username, service, shard):
pass pass
......
...@@ -12,8 +12,8 @@ from flask import Flask, request, abort, make_response ...@@ -12,8 +12,8 @@ from flask import Flask, request, abort, make_response
@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200) @blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200)
@blacklist_on_auth_failure(key_from_args(4), count=5, period=600, ttl=43200, @blacklist_on_auth_failure(key_from_args(4), count=5, period=600, ttl=43200,
check_wl=True) check_wl=True)
def _auth(username, service, password, otp_token, source_ip): def _auth(username, service, shard, password, otp_token, source_ip):
user = app.userdb.get_user(username, service) user = app.userdb.get_user(username, service, shard)
if not user: if not user:
return protocol.ERR_AUTHENTICATION_FAILURE return protocol.ERR_AUTHENTICATION_FAILURE
return auth.authenticate(user, service, password, otp_token) return auth.authenticate(user, service, password, otp_token)
...@@ -30,12 +30,16 @@ def do_auth(): ...@@ -30,12 +30,16 @@ def do_auth():
password = request.form.get('password') password = request.form.get('password')
otp_token = request.form.get('otp') otp_token = request.form.get('otp')
source_ip = request.form.get('source_ip') source_ip = request.form.get('source_ip')
try:
shard = int(request.form.get('shard'))
except:
shard = -1
if not service or not username: if not service or not username:
abort(400) abort(400)
try: try:
result = _auth(username, service, password, otp_token, source_ip) result = _auth(username, service, shard, password, otp_token, source_ip)
except Exception, e: except Exception, e:
app.logger.exception('Unexpected exception in authenticate()') app.logger.exception('Unexpected exception in authenticate()')
abort(500) abort(500)
......
...@@ -71,7 +71,7 @@ class FakeUserDb(model.UserDb): ...@@ -71,7 +71,7 @@ class FakeUserDb(model.UserDb):
def __init__(self, users): def __init__(self, users):
self.users = users self.users = users
def get_user(self, username, service): def get_user(self, username, service, shard):
return self.users.get(username) return self.users.get(username)
......
...@@ -10,10 +10,14 @@ class LdapAuthTestBase(LdapTestBase): ...@@ -10,10 +10,14 @@ class LdapAuthTestBase(LdapTestBase):
SERVICE_MAP = { SERVICE_MAP = {
'mail': { 'mail': {
'base': 'ou=People,dc=investici,dc=org,o=Anarchy', 'base': 'ou=People,dc=investici,dc=org,o=Anarchy',
'filter': '(&(status=active)(mail=%s))', 'filter': '(&(status=active)(mail=%(user)s))',
},
'sharded': {
'base': 'ou=People,dc=investici,dc=org,o=Anarchy',
'filter': '(&(status=active)(mail=%(user)s)(host=%(shard)s))',
}, },
'account': { 'account': {
'dn': 'uid=%s,ou=People,dc=investici,dc=org,o=Anarchy', 'dn': 'uid=%(user)s,ou=People,dc=investici,dc=org,o=Anarchy',
}, },
'aliased-service': 'account', 'aliased-service': 'account',
} }
...@@ -34,25 +38,35 @@ class LdapAuthTest(LdapAuthTestBase): ...@@ -34,25 +38,35 @@ class LdapAuthTest(LdapAuthTestBase):
def test_userdb_get_user(self): def test_userdb_get_user(self):
self.assertTrue( self.assertTrue(
self.userdb.get_user('test@investici.org', 'account')) self.userdb.get_user('test@investici.org', 'account', -1))
self.assertTrue(
self.userdb.get_user('test@investici.org', 'account', 'whatever'))
def test_userdb_get_user_sharded(self):
self.assertTrue(
self.userdb.get_user('test@investici.org', 'sharded', 'latitanza'))
self.assertFalse(
self.userdb.get_user('test@investici.org', 'sharded', 'contumacia'))
self.assertFalse(
self.userdb.get_user('test@investici.org', 'sharded', -1))
def test_userdb_unknown_service(self): def test_userdb_unknown_service(self):
self.assertFalse( self.assertFalse(
self.userdb.get_user('test@investici.org', 'unknownservice')) self.userdb.get_user('test@investici.org', 'unknownservice', -1))
def test_userdb_service_alias(self): def test_userdb_service_alias(self):
self.assertTrue( self.assertTrue(
self.userdb.get_user('test@investici.org', 'aliased-service')) self.userdb.get_user('test@investici.org', 'aliased-service', -1))
def test_auth_password_ok(self): def test_auth_password_ok(self):
u = self.userdb.get_user('test@investici.org', 'mail') u = self.userdb.get_user('test@investici.org', 'mail', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.OK, protocol.OK,
authenticate(u, 'mail', 'password', None)) authenticate(u, 'mail', 'password', None))
def test_auth_password_fail(self): def test_auth_password_fail(self):
u = self.userdb.get_user('test@investici.org', 'mail') u = self.userdb.get_user('test@investici.org', 'mail', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_AUTHENTICATION_FAILURE, protocol.ERR_AUTHENTICATION_FAILURE,
...@@ -66,21 +80,21 @@ class LdapOtpTest(LdapAuthTestBase): ...@@ -66,21 +80,21 @@ class LdapOtpTest(LdapAuthTestBase):
] ]
def test_auth_password_requires_otp(self): def test_auth_password_requires_otp(self):
u = self.userdb.get_user('test@investici.org', 'account') u = self.userdb.get_user('test@investici.org', 'account', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_OTP_REQUIRED, protocol.ERR_OTP_REQUIRED,
authenticate(u, 'account', 'password', None)) authenticate(u, 'account', 'password', None))
def test_auth_bad_password_requires_otp(self): def test_auth_bad_password_requires_otp(self):
u = self.userdb.get_user('test@investici.org', 'account') u = self.userdb.get_user('test@investici.org', 'account', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_OTP_REQUIRED, protocol.ERR_OTP_REQUIRED,
authenticate(u, 'account', 'wrong password', None)) authenticate(u, 'account', 'wrong password', None))
def test_auth_otp_ok(self): def test_auth_otp_ok(self):
u = self.userdb.get_user('test@investici.org', 'account') u = self.userdb.get_user('test@investici.org', 'account', -1)
self.assertTrue(u) self.assertTrue(u)
secret= '089421' secret= '089421'
token = totp(secret, format='dec6', period=30) token = totp(secret, format='dec6', period=30)
...@@ -89,7 +103,7 @@ class LdapOtpTest(LdapAuthTestBase): ...@@ -89,7 +103,7 @@ class LdapOtpTest(LdapAuthTestBase):
authenticate(u, 'account', 'password', str(token))) authenticate(u, 'account', 'password', str(token)))
def test_auth_otp_ok_bad_password(self): def test_auth_otp_ok_bad_password(self):
u = self.userdb.get_user('test@investici.org', 'account') u = self.userdb.get_user('test@investici.org', 'account', -1)
self.assertTrue(u) self.assertTrue(u)
secret= '089421' secret= '089421'
token = totp(secret, format='dec6', period=30) token = totp(secret, format='dec6', period=30)
...@@ -98,7 +112,7 @@ class LdapOtpTest(LdapAuthTestBase): ...@@ -98,7 +112,7 @@ class LdapOtpTest(LdapAuthTestBase):
authenticate(u, 'account', 'wrong password', str(token))) authenticate(u, 'account', 'wrong password', str(token)))
def test_auth_bad_otp(self): def test_auth_bad_otp(self):
u = self.userdb.get_user('test@investici.org', 'account') u = self.userdb.get_user('test@investici.org', 'account', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_AUTHENTICATION_FAILURE, protocol.ERR_AUTHENTICATION_FAILURE,
...@@ -112,21 +126,21 @@ class LdapASPTest(LdapAuthTestBase): ...@@ -112,21 +126,21 @@ class LdapASPTest(LdapAuthTestBase):
] ]
def test_app_specific_password_ok(self): def test_app_specific_password_ok(self):
u = self.userdb.get_user('test@investici.org', 'mail') u = self.userdb.get_user('test@investici.org', 'mail', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.OK, protocol.OK,
authenticate(u, 'mail', 'veryspecificpassword', None)) authenticate(u, 'mail', 'veryspecificpassword', None))
def test_plain_password_fails(self): def test_plain_password_fails(self):
u = self.userdb.get_user('test@investici.org', 'mail') u = self.userdb.get_user('test@investici.org', 'mail', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_AUTHENTICATION_FAILURE, protocol.ERR_AUTHENTICATION_FAILURE,
authenticate(u, 'mail', 'password', None)) authenticate(u, 'mail', 'password', None))
def test_plain_password_and_otp_fails(self): def test_plain_password_and_otp_fails(self):
u = self.userdb.get_user('test@investici.org', 'mail') u = self.userdb.get_user('test@investici.org', 'mail', -1)
self.assertTrue(u) self.assertTrue(u)
self.assertEquals( self.assertEquals(
protocol.ERR_AUTHENTICATION_FAILURE, protocol.ERR_AUTHENTICATION_FAILURE,
......
...@@ -12,18 +12,18 @@ LDAP_SERVICE_MAP = { ...@@ -12,18 +12,18 @@ LDAP_SERVICE_MAP = {
# Mail accounts (dovecot, nginx-mail-mapper). # Mail accounts (dovecot, nginx-mail-mapper).
'mail': { 'mail': {
'base': 'ou=People, dc=investici, dc=org, o=Anarchy', 'base': 'ou=People, dc=investici, dc=org, o=Anarchy',
'filter': '(&(objectClass=virtualMailUser)(status=active)(mail=%s))', 'filter': '(&(objectClass=virtualMailUser)(status=active)(mail=%(user)s))',
}, },
# DAV access (webdav fcgi handler). # DAV access (webdav fcgi handler).
'dav': { 'dav': {
'base': 'ou=People, dc=investici, dc=org, o=Anarchy', 'base': 'ou=People, dc=investici, dc=org, o=Anarchy',
'filter': '(&(objectClass=ftpAccount)(status=active)(host=%s)(ftpname=%%s))' % host, 'filter': '(&(objectClass=ftpAccount)(status=active)(host=%(shard)s)(ftpname=%%(user)s))' % host,
}, },
# Main account (pannello). # Main account (pannello).
'account': { 'account': {
'dn': 'uid=%s, ou=People, dc=investici, dc=org, o=Anarchy', 'dn': 'uid=%(user)s, ou=People, dc=investici, dc=org, o=Anarchy',
}, },
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment