diff --git a/src/sso_server/sso_server/auth/__init__.py b/src/sso_server/sso_server/auth/__init__.py index 4f16afb6afdb0c27ca00cce4713928ea44cbdc00..1234004855750ca23812e050d0ad67b8cf32c014 100644 --- a/src/sso_server/sso_server/auth/__init__.py +++ b/src/sso_server/sso_server/auth/__init__.py @@ -33,3 +33,6 @@ class AuthBase(object): def match_groups(self, username, groups): return set() + + def get_user_email(self, username): + return None diff --git a/src/sso_server/sso_server/auth/auth_machdb.py b/src/sso_server/sso_server/auth/auth_machdb.py index 894829cadb7b6b25d7b87d9f6274cb39556bb133..1350765e9c4133dd10dadf29d8a45eab32010990 100644 --- a/src/sso_server/sso_server/auth/auth_machdb.py +++ b/src/sso_server/sso_server/auth/auth_machdb.py @@ -17,11 +17,12 @@ class _CredentialsCache(dict): self._lock = threading.Lock() self._data = {'pwcache': {}, 'otpcache': {}, 'grpcache': {}} - def update(self, pwcache, otpcache, grpcache): + def update(self, pwcache, otpcache, grpcache, mailcache): with self._lock: self._data['pwcache'] = pwcache self._data['otpcache'] = otpcache self._data['grpcache'] = grpcache + self._data['mailcache'] = mailcache def get(self, tag): with self._lock: @@ -44,7 +45,7 @@ class Updater(threading.Thread): time.sleep(600) def update_auth_cache(self): - pwcache, otpcache, grpcache = {}, {}, {} + pwcache, otpcache, grpcache, mailcache = {}, {}, {}, {} for user in mdb.User.find(): if not user.enabled: continue @@ -52,7 +53,9 @@ class Updater(threading.Thread): if user.totp_key: otpcache[user.name] = user.totp_key grpcache[user.name] = set(x.name for x in user.groups) - self.auth_cache.update(pwcache, otpcache, grpcache) + if user.email: + mailcache[user.name] = user.email + self.auth_cache.update(pwcache, otpcache, grpcache, mailcache) class Auth(AuthBase): @@ -87,3 +90,7 @@ class Auth(AuthBase): user_groups.intersection_update(groups) return user_groups + def get_user_email(self, username): + mailcache = self.auth_cache.get('mailcache') + return mailcache.get(username) + diff --git a/src/sso_server/sso_server/auth/auth_test.py b/src/sso_server/sso_server/auth/auth_test.py index fc628c0f00df350026efe4df28c540c16cb9bc43..1e232281bf516045e4100e654ada6f2dfcbc7c90 100644 --- a/src/sso_server/sso_server/auth/auth_test.py +++ b/src/sso_server/sso_server/auth/auth_test.py @@ -43,3 +43,6 @@ class Auth(AuthBase): allowed_groups = set(["group1", "group2"]) allowed_groups.intersection_update(groups) return allowed_groups + + def get_user_email(self, u): + return u + '@example.com' diff --git a/src/sso_server/sso_server/saml/flask_views.py b/src/sso_server/sso_server/saml/flask_views.py index 74d97fb16d3b8ae25fbe07f4fff963371782b11d..9a385018aecec0b9e7f220626f8a440d3f750beb 100644 --- a/src/sso_server/sso_server/saml/flask_views.py +++ b/src/sso_server/sso_server/saml/flask_views.py @@ -55,6 +55,10 @@ def login_required(fn): raise NoCookieError('no cookie') current_app.logger.info('retrieved cookie: %s', cookie) g.sso_ticket = saml_app.sso_verifier.verify(str(cookie)) + # Cheat by looking up the email using the LoginService + # private to the main app. + g.user_email = current_app.login_service.auth.get_user_email( + g.sso_ticket.user()) return fn(*args, **kwargs) except (NoCookieError, TypeError, sso.Error) as e: current_app.logger.error('auth failed: %s', str(e)) diff --git a/src/sso_server/sso_server/saml/registry.py b/src/sso_server/sso_server/saml/registry.py index 2f5d827c18296bc6877344e142b1bdf44705d611..3f8451598c02df0ae947cb846b41b6c6d58f5e55 100644 --- a/src/sso_server/sso_server/saml/registry.py +++ b/src/sso_server/sso_server/saml/registry.py @@ -8,6 +8,7 @@ import logging import warnings import zlib +from flask import g from importlib import import_module from . import base @@ -37,6 +38,7 @@ class SSOProcessor(base.Processor): # Add attributes that gitlab needs (?). self._assertion_params['ATTRIBUTES'] = { 'name': self._subject, + 'email': g.user_email, } self._assertion_xml = xml_render.get_assertion_salesforce_xml(self._assertion_params, signed=True)