Commit fdcdd645 authored by ale's avatar ale

various fixes to SAML handlers, and more tests

parent 7a5dbdb6
......@@ -207,7 +207,11 @@ def create_app(config_file=None, config={}):
if 'SAML' in app.config:
from sso_server.saml.flask_views import init_app as saml_init_app
saml_init_app(app)
# Automatically add it to the allowed services.
if app.config.get('ALLOWED_SERVICES'):
app.config['ALLOWED_SERVICES'].append(
'%s/saml/' % app.config['SAML']['SSO_LOGIN_SERVER'].replace('.', r'\.'))
app.login_service = login_service.LoginService(app.config)
return app
......
......@@ -57,7 +57,7 @@ class LoginService(object):
"""
SERVICE_PATTERN = re.compile(
r'^(?:[a-z0-9][-a-z0-9]*\.)+[a-z]{2,4}(?::[0-9]{2,5})?(?:/.*)?/$',
r'^(?:(?:[a-z0-9][-a-z0-9]*\.)+[a-z]{2,4}|localhost)(?::[0-9]{2,5})?(?:/.*)?/$',
re.IGNORECASE)
LOGIN_SERVICE = '_login'
......
......@@ -51,9 +51,9 @@ class Processor(object):
processor_path = self._config.get('processor', 'invalid')
self._logger.info('initializing processor',
configured_processor=processor_path,
processor=self.dotted_path)
self._logger.info('initializing processor %s (%s)',
processor_path,
self.dotted_path)
if processor_path != self.dotted_path:
raise exceptions.ImproperlyConfigured(
......@@ -64,7 +64,7 @@ class Processor(object):
"no ACS URL specified in SP configuration: {}".format(
self._config))
self._logger.info('processor configured', config=self._config)
self._logger.info('processor configured: config=%s', self._config)
def _build_assertion(self):
"""
......@@ -112,8 +112,8 @@ class Processor(object):
"""
self._request_xml = base64.b64decode(self._saml_request)
self._logger.debug('SAML request decoded',
decoded_request=self._request_xml)
self._logger.debug('SAML request decoded: %s',
self._request_xml)
def _determine_assertion_id(self):
"""
......@@ -130,7 +130,7 @@ class Processor(object):
if not self._audience:
self._audience = self._request_params.get('PROVIDER_NAME', None)
self._logger.info('determined audience', audience=self._audience)
self._logger.info('determined audience: %s', self._audience)
def _determine_response_id(self):
"""
......
......@@ -4,9 +4,10 @@ import os
import sso
import urllib
from flask import request, session, abort, redirect, make_response, render_template, url_for, g
from flask import request, session, abort, redirect, make_response, render_template, url_for, g, current_app
from . import exceptions
from . import registry
from . import xml_signing
from .app import saml_app
......@@ -26,7 +27,7 @@ def init_app(app):
saml_app.config = saml_config
saml_app.login_server = saml_config['SSO_LOGIN_SERVER']
saml_app.sso_service = saml_app.login_server + 'saml/'
saml_app.sso_service = saml_app.login_server + '/saml/'
url_base = 'https://' + saml_app.sso_service
saml_app.sso_url = url_base + 'login'
saml_app.slo_url = url_base + 'logout'
......@@ -40,15 +41,24 @@ def init_app(app):
[])
class NoCookieError(Exception):
pass
def login_required(fn):
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
# Try to fetch the cookie.
try:
g.sso_ticket = saml_app.sso_verifier.verify(request.cookies.get(sso_cookie_name))
cookie = request.cookies.get(sso_cookie_name)
if not cookie:
raise NoCookieError('no cookie')
current_app.logger.info('retrieved cookie: %s', cookie)
g.sso_ticket = saml_app.sso_verifier.verify(str(cookie))
return fn(*args, **kwargs)
except (TypeError, sso.Error) as e:
redir_url = 'https://%s?%s' % (
except (NoCookieError, TypeError, sso.Error) as e:
current_app.logger.error('auth failed: %s', str(e))
redir_url = 'https://%s/?%s' % (
saml_app.login_server, urllib.urlencode({
's': saml_app.sso_service,
'd': request.url}))
......@@ -62,6 +72,7 @@ def sso_login():
next_url = request.args['d']
resp = redirect(next_url)
resp.set_cookie(sso_cookie_name, tkt_str)
current_app.logger.info('set sso cookie %s to %s', sso_cookie_name, tkt_str)
return resp
......@@ -105,7 +116,7 @@ def login_init(resource, target):
@saml_app.route('/login/process')
@login_required
def login_process():
proc = registry.find_processor(request)
proc = registry.find_processor()
return _generate_response(proc)
......@@ -132,4 +143,4 @@ def _generate_response(processor):
tv = processor.generate_response()
except exceptions.UserNotAuthorized:
return render_template('saml/invalid_user.html')
return render_template('saml/login.html', tv)
return render_template('saml/login.html', **tv)
......@@ -4,6 +4,7 @@ from __future__ import absolute_import
Registers and loads Processor classes from settings.
"""
import base64
import logging
import warnings
import zlib
......@@ -20,9 +21,9 @@ class SSOProcessor(base.Processor):
def _validate_request(self):
super(SSOProcessor, self)._validate_request()
url = self._request_params['ACS_URL']
if '.autistici.org' not in url:
raise exceptions.CannotHandleAssertion('ACS is not a supported URL')
#url = self._request_params['ACS_URL']
#if 'blah' not in url:
# raise exceptions.CannotHandleAssertion('ACS is not a supported URL')
def _decode_request(self):
self._request_xml = zlib.decompress(base64.b64decode(self._saml_request), -15)
......@@ -36,7 +37,7 @@ class SSOProcessor(base.Processor):
def get_processor(name, config):
return SSOProcessor(name, config)
return SSOProcessor(config, name)
def old_get_processor(name, config):
......@@ -74,18 +75,18 @@ def old_get_processor(name, config):
return instance
def find_processor(request):
def find_processor():
"""
Returns the Processor instance that is willing to handle this request.
Returns the Processor instance that is willing to handle the current request.
"""
for name, sp_config in saml_app.config['SAML2IDP_REMOTES'].items():
proc = get_processor(name, sp_config)
try:
if proc.can_handle(request):
if proc.can_handle():
return proc
except exceptions.CannotHandleAssertion as exc:
# Log these, but keep looking.
logger.debug('%s %s' % (proc, exc))
logging.debug('%s %s' % (proc, exc))
raise exceptions.CannotHandleAssertion(
'None of the processors in SAML2IDP_REMOTES could handle this request.')
import os
import re
import shutil
import sso
import tempfile
import unittest
from sso_server import application
def parse_form(response):
v = {}
for match in re.findall(
r'<input.*name="([^"]+)".*value="([^"]+)"', response.data):
v[match[0]] = match[1]
return v
class SSOServerTestBase(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
public_key, secret_key = sso.generate_keys()
self.public_key_file = os.path.join(self.tmpdir, 'public.key')
with open(self.public_key_file, 'w') as fd:
fd.write(public_key)
self.secret_key_file = os.path.join(self.tmpdir, 'secret.key')
with open(self.secret_key_file, 'w') as fd:
fd.write(secret_key)
self.domain = 'testdomain'
self.app = self._make_app()
def _config(self):
return {}
def _make_app(self, **config):
config_ = self._config()
config_.update(config)
return application.create_app(config=config_)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def _extract_csrf(self, data):
m = re.search(r'<input type="hidden" name="_csrf" value="([^"]+)"', data)
self.assertTrue(m is not None, "Could not extract CSRF\n\nPage contents:\n%s" % data)
return m.group(1)
def _login(self, c, location, query_args):
query_args['d'] = query_args['d'][0].replace('http://', 'https://')
response = c.get(location, query_string=query_args)
self.assertEquals(200, response.status_code, response.data)
values = parse_form(response)
self.assertTrue('_csrf' in values, values)
values['username'] = 'admin'
values['password'] = 'admin'
response = c.post('/', data=values)
self.assertEquals(302, response.status_code, response.data)
return response
import base64
import logging
import os
import shutil
import sso
import tempfile
import unittest
import urllib
import urlparse
import zlib
from datetime import datetime
from sso_server import application
from sso_server.test import SSOServerTestBase
logging.basicConfig()
class SAMLTest(unittest.TestCase):
def parse_args(url):
up = urlparse.urlparse(url)
return up.path, urlparse.parse_qs(up.query)
class SAMLTest(SSOServerTestBase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
public_key, secret_key = sso.generate_keys()
self.public_key_file = os.path.join(self.tmpdir, 'public.key')
with open(self.public_key_file, 'w') as fd:
fd.write(public_key)
self.secret_key_file = os.path.join(self.tmpdir, 'secret.key')
with open(self.secret_key_file, 'w') as fd:
fd.write(secret_key)
self.domain = 'testdomain'
self.saml_cert = os.path.join(os.path.dirname(__file__), 'saml.pem')
self.saml_key = os.path.join(os.path.dirname(__file__), 'saml.key')
super(SAMLTest, self).setUp()
self.app = self._make_app()
def _make_app(self, **config):
config_ = {
def _config(self):
saml_cert = os.path.join(os.path.dirname(__file__), 'saml.pem')
saml_key = os.path.join(os.path.dirname(__file__), 'saml.key')
return {
'SSO_SECRET_KEY': self.secret_key_file,
'SSO_PUBLIC_KEY': self.public_key_file,
'SSO_DOMAIN': self.domain,
'SECRET_KEY': 'barbablu',
'ALLOWED_SERVICES': [],
'ALLOWED_SERVICES': ['localhost/saml/'],
'SAML': {
'SSO_LOGIN_SERVER': 'https://localhost:1234/',
'CERTIFICATE_FILE': self.saml_cert,
'PRIVATE_KEY_FILE': self.saml_key,
'SSO_LOGIN_SERVER': 'localhost',
'CERTIFICATE_FILE': saml_cert,
'PRIVATE_KEY_FILE': saml_key,
'SAML2IDP_REMOTES': {
# TODO
'test': {
'acs_url': 'https://saml.example.com/users/auth/saml/callback',
'processor': 'sso_server.saml.registry.SSOProcessor',
},
},
},
}
config_.update(config)
return application.create_app(config=config_)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def test_idp_metadata(self):
# Fetch the IDP XML metadata, to verify that the SAML
......@@ -54,3 +51,45 @@ class SAMLTest(unittest.TestCase):
with self.app.test_client() as c:
response = c.get('/saml/metadata/xml/')
self.assertEquals(200, response.status_code)
def _make_saml_request(self):
saml_request = "<samlp:AuthnRequest AssertionConsumerServiceURL='https://saml.example.com/users/auth/saml/callback' Destination='https://localhost/saml/login' ID='_14443901-284e-4780-ae25-686f6fd781aa' IssueInstant='%(stamp)s' Version='2.0' xmlns:saml='urn:oasis:names:tc:SAML:2.0:assertion' xmlns:samlp='urn:oasis:names:tc:SAML:2.0:protocol'><saml:Issuer>https://saml.example.com</saml:Issuer><samlp:NameIDPolicy AllowCreate='true' Format='urn:oasis:names:tc:SAML:2.0:nameid-format:transient'/></samlp:AuthnRequest>" % {
'stamp': datetime.now().isoformat(),
}
comp = zlib.compressobj(9, zlib.DEFLATED, -15)
comp.compress(saml_request)
return base64.b64encode(comp.flush())
def test_saml_login_empty_request(self):
with self.app.test_client() as c:
response = c.get('/saml/login')
self.assertEquals(400, response.status_code)
def test_saml_login(self):
with self.app.test_client() as c:
def _follow(path, data):
response = c.get(path, query_string=data)
self.assertEquals(302, response.status_code,
'got %d for %s, expecting 302' % (
response.status_code, path))
path, data = parse_args(response.location)
return path, data
saml_request = self._make_saml_request()
path, data = _follow('/saml/login', {'SAMLRequest': saml_request})
# Follow the redirect to /login/process.
path, data = _follow(path, data)
# Submit the login form.
self.assertEquals('/', path)
response = self._login(c, path, data)
# Now request the /sso_login endpoint.
path, data = parse_args(response.location)
self.assertEquals('/saml/sso_login', path)
# Follow to the sso_login endpoint
path, data = _follow(path, data)
self.assertEquals('/saml/login/process', path)
# Back to the SAML app at last
# Finally we're getting some xml thingy or what
response = c.get(path, query_string=data)
self.assertEquals(200, response.status_code)
import base64
import logging
import os
import re
import shutil
import tempfile
import unittest
import urlparse
import urllib
import cookielib
......@@ -13,6 +9,7 @@ import flask
import sso
from sso_server import application
from sso_server.test import SSOServerTestBase
logging.basicConfig()
......@@ -27,33 +24,16 @@ def urldecode(query_string):
return out
class SSOServerTest(unittest.TestCase):
class SSOServerTest(SSOServerTestBase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
public_key, secret_key = sso.generate_keys()
self.public_key_file = os.path.join(self.tmpdir, 'public.key')
with open(self.public_key_file, 'w') as fd:
fd.write(public_key)
self.secret_key_file = os.path.join(self.tmpdir, 'secret.key')
with open(self.secret_key_file, 'w') as fd:
fd.write(secret_key)
self.domain = 'testdomain'
self.app = self._make_app()
def _make_app(self, **config):
config_ = {
def _config(self):
return {
'SSO_SECRET_KEY': self.secret_key_file,
'SSO_PUBLIC_KEY': self.public_key_file,
'SSO_DOMAIN': self.domain,
'SECRET_KEY': 'barbablu',
'ALLOWED_SERVICES': [],
}
config_.update(config)
return application.create_app(config=config_)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def get_local_ticket(self, user):
return self.app.login_service.local_generate(user)
......@@ -68,11 +48,6 @@ class SSOServerTest(unittest.TestCase):
True, None, False, None, None, None))
return c
def _extract_csrf(self, data):
m = re.search(r'<input type="hidden" name="_csrf" value="([^"]+)"', data)
self.assertTrue(m is not None, "Could not extract CSRF\n\nPage contents:\n%s" % data)
return m.group(1)
def _get_cookies(self, response):
cookies = {}
for x in response.headers.get_all('Set-Cookie'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment