Skip to content
Snippets Groups Projects
Commit dd2e9dd9 authored by ale's avatar ale
Browse files

add a cache for validated sso tickets

parent d10bfc4f
No related branches found
No related tags found
No related merge requests found
// Copyright (c) 2010-2011, Tim Day <timday@timday.com>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#ifndef __mod_sso_lru_cache_h
#define __mod_sso_lru_cache_h 1
#include <cassert>
#include <list>
// Class providing fixed-size (by number of records)
// LRU-replacement cache of a function with signature
// V f(K).
// MAP should be one of std::map or std::unordered_map.
// Variadic template args used to deal with the
// different type argument signatures of those
// containers; the default comparator/hash/allocator
// will be used.
template <
typename K,
typename V,
template<typename...> class MAP
> class lru_cache_using_std
{
public:
typedef K key_type;
typedef V value_type;
// Key access history, most recent at back
typedef std::list<key_type> key_tracker_type;
// Key to value and key history iterator
typedef MAP<
key_type,
std::pair<
value_type,
typename key_tracker_type::iterator
>
> key_to_value_type;
// Constructor specifies the cached function and
// the maximum number of records to be stored
lru_cache_using_std(
value_type (*f)(const key_type&),
size_t c
)
:_fn(f)
,_capacity(c)
{
assert(_capacity!=0);
}
// Obtain value of the cached function for k
value_type operator()(const key_type& k) {
// Attempt to find existing record
const typename key_to_value_type::iterator it
=_key_to_value.find(k);
if (it==_key_to_value.end()) {
// We don't have it:
// Evaluate function and create new record
const value_type v=_fn(k);
insert(k,v);
// Return the freshly computed value
return v;
} else {
// We do have it:
// Update access record by moving
// accessed key to back of list
_key_tracker.splice(
_key_tracker.end(),
_key_tracker,
(*it).second.second
);
// Return the retrieved value
return (*it).second.first;
}
}
// Obtain the cached keys, most recently used element
// at head, least recently used at tail.
// This method is provided purely to support testing.
template <typename IT> void get_keys(IT dst) const {
typename key_tracker_type::const_reverse_iterator src
=_key_tracker.rbegin();
while (src!=_key_tracker.rend()) {
*dst++ = *src++;
}
}
private:
// Record a fresh key-value pair in the cache
void insert(const key_type& k,const value_type& v) {
// Method is only called on cache misses
assert(_key_to_value.find(k)==_key_to_value.end());
// Make space if necessary
if (_key_to_value.size()==_capacity)
evict();
// Record k as most-recently-used key
typename key_tracker_type::iterator it
=_key_tracker.insert(_key_tracker.end(),k);
// Create the key-value entry,
// linked to the usage record.
_key_to_value.insert(
std::make_pair(
k,
std::make_pair(v,it)
)
);
// No need to check return,
// given previous assert.
}
// Purge the least-recently-used element in the cache
void evict() {
// Assert method is never called when cache is empty
assert(!_key_tracker.empty());
// Identify least recently used key
const typename key_to_value_type::iterator it
=_key_to_value.find(_key_tracker.front());
assert(it!=_key_to_value.end());
// Erase both elements to completely purge record
_key_to_value.erase(it);
_key_tracker.pop_front();
}
// The function to be cached
value_type (*_fn)(const key_type&);
// Maximum number of key-value pairs to be retained
const size_t _capacity;
// Key access history
key_tracker_type _key_tracker;
// Key-to-value lookup
key_to_value_type _key_to_value;
};
#endif
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <memory>
#include "httpd.h" #include "httpd.h"
#include "http_config.h" #include "http_config.h"
...@@ -35,6 +36,7 @@ ...@@ -35,6 +36,7 @@
#include "ap_config.h" #include "ap_config.h"
#include "apr_strings.h" #include "apr_strings.h"
#include "lru_cache.h"
#include "mod_sso.h" #include "mod_sso.h"
extern "C" module AP_MODULE_DECLARE_DATA sso_module; extern "C" module AP_MODULE_DECLARE_DATA sso_module;
...@@ -42,6 +44,7 @@ extern "C" module AP_MODULE_DECLARE_DATA sso_module; ...@@ -42,6 +44,7 @@ extern "C" module AP_MODULE_DECLARE_DATA sso_module;
using std::map; using std::map;
using std::string; using std::string;
using std::ostringstream; using std::ostringstream;
using std::shared_ptr;
typedef struct { typedef struct {
const char *login_server; const char *login_server;
...@@ -480,6 +483,62 @@ static string encode_groups(const sso::groups_t& groups) ...@@ -480,6 +483,62 @@ static string encode_groups(const sso::groups_t& groups)
return obuf.str(); return obuf.str();
} }
struct verify_context {
string public_key;
string service;
string domain;
sso::groups_t req_groups;
string sso_cookie;
bool operator<(const verify_context& b) const {
if (domain < b.domain)
return true;
if (service < b.service)
return true;
if (req_groups != b.req_groups)
return true;
return sso_cookie < b.sso_cookie;
}
};
class verify_response {
public:
~verify_response() {
delete t;
}
sso::Ticket *t;
string error;
};
static shared_ptr<verify_response> verify_ticket(const verify_context& ctx) {
verify_response *resp = new verify_response;
try {
sso::Verifier verifier(ctx.public_key, ctx.service, ctx.domain, ctx.req_groups);
resp->t = verifier.verify(ctx.sso_cookie);
} catch (sso::sso_error& e) {
resp->t = NULL;
resp->error = e.what();
}
return shared_ptr<verify_response>(resp);
}
typedef lru_cache_using_std<verify_context, shared_ptr<verify_response>, std::map
> sso_cache_t;
static sso_cache_t sso_cache(verify_ticket, 128);
#if APR_HAS_THREADS
static apr_thread_mutex_t *sso_cache_lock;
#endif
static void mod_sso_child_init(apr_pool_t *p, server_rec *s) {
#if APR_HAS_THREADS
apr_thread_mutex_create(&(sso_cache_lock), APR_THREAD_MUTEX_DEFAULT, p);
#endif
}
/** /**
* Apache authentication handler for mod_sso. * Apache authentication handler for mod_sso.
* *
...@@ -529,11 +588,26 @@ static int mod_sso_authenticate_user(request_rec *r) ...@@ -529,11 +588,26 @@ static int mod_sso_authenticate_user(request_rec *r)
string sso_cookie = get_cookie(r, sso_cookie_name); string sso_cookie = get_cookie(r, sso_cookie_name);
if (!sso_cookie.empty()) { if (!sso_cookie.empty()) {
string pkey(s_cfg->public_key, s_cfg->public_key_len); string pkey(s_cfg->public_key, s_cfg->public_key_len);
sso::Verifier verifier(pkey, s_cfg->service,
s_cfg->domain, req_groups);
try { verify_context vctx;
sso::Ticket *t = verifier.verify(sso_cookie); vctx.public_key = pkey;
vctx.service = s_cfg->service;
vctx.domain = s_cfg->domain;
vctx.req_groups = req_groups;
vctx.sso_cookie = sso_cookie;
#if APR_HAS_THREADS
apr_thread_mutex_lock(sso_cache_lock);
#endif
shared_ptr<verify_response> vr = sso_cache(vctx);
#if APR_HAS_THREADS
apr_thread_mutex_unlock(sso_cache_lock);
#endif
if (vr->error.empty()) {
sso::Ticket *t = vr->t;
// Check user authorization lists // Check user authorization lists
if (allow_any_user if (allow_any_user
...@@ -545,17 +619,15 @@ static int mod_sso_authenticate_user(request_rec *r) ...@@ -545,17 +619,15 @@ static int mod_sso_authenticate_user(request_rec *r)
r->user = apr_pstrdup(r->pool, t->user().c_str()); r->user = apr_pstrdup(r->pool, t->user().c_str());
ap_log_error(APLOG_MARK, APLOG_DEBUG, 0, r->server, ap_log_error(APLOG_MARK, APLOG_DEBUG, 0, r->server,
"sso: authorized user '%s'", r->user); "sso: authorized user '%s'", r->user);
delete t;
return OK; return OK;
} else { } else {
ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server, ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server,
"sso: unauthorized user '%s'", t->user().c_str()); "sso: unauthorized user '%s'", t->user().c_str());
delete t;
return HTTP_UNAUTHORIZED; return HTTP_UNAUTHORIZED;
} }
} catch (sso::sso_error& e) { } else {
ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server, ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server,
"sso: validation error: %s", e.what()); "sso: validation error: %s", vr->error.c_str());
} }
} }
...@@ -610,6 +682,7 @@ static int mod_sso_auth_checker(request_rec *r) ...@@ -610,6 +682,7 @@ static int mod_sso_auth_checker(request_rec *r)
static void mod_sso_register_hooks (apr_pool_t *p) static void mod_sso_register_hooks (apr_pool_t *p)
{ {
static const char * const mssoPost[] = {"mod_sso.c", NULL}; static const char * const mssoPost[] = {"mod_sso.c", NULL};
ap_hook_child_init(mod_sso_child_init, NULL, NULL, APR_HOOK_FIRST);
ap_hook_handler(mod_sso_method_handler, NULL, NULL, APR_HOOK_FIRST); ap_hook_handler(mod_sso_method_handler, NULL, NULL, APR_HOOK_FIRST);
ap_hook_auth_checker(mod_sso_auth_checker, NULL, mssoPost, APR_HOOK_MIDDLE); ap_hook_auth_checker(mod_sso_auth_checker, NULL, mssoPost, APR_HOOK_MIDDLE);
ap_hook_check_user_id(mod_sso_authenticate_user, NULL, NULL, APR_HOOK_MIDDLE); ap_hook_check_user_id(mod_sso_authenticate_user, NULL, NULL, APR_HOOK_MIDDLE);
......
...@@ -98,7 +98,19 @@ class HttpdIntegrationTest(unittest.TestCase): ...@@ -98,7 +98,19 @@ class HttpdIntegrationTest(unittest.TestCase):
print 'ticket:', signedt print 'ticket:', signedt
return signedt return signedt
def testManyRequests(self): def testTicketCacheEviction(self):
# Requests for more users than the cache size, forcing eviction.
n = 200
errors = 0
for i in xrange(n):
cookie = 'SSO_test=%s' % self._ticket('user%d' % i)
status, body, location = _query("/index.html", cookie=cookie)
if status != 200:
errors += 1
self.assertEquals(0, errors)
def testTicketCache(self):
# Lots of requests for the same user.
n = 100 n = 100
errors = 0 errors = 0
for i in xrange(n): for i in xrange(n):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment