From dd2e9dd9c96ee62ab20f122453379799786ebdd1 Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Mon, 12 Aug 2013 17:47:50 +0000 Subject: [PATCH] add a cache for validated sso tickets --- src/mod_sso/lru_cache.h | 167 +++++++++++++++++++++ src/mod_sso/mod_sso.cc | 91 +++++++++-- src/mod_sso/test/httpd_integration_test.py | 14 +- 3 files changed, 262 insertions(+), 10 deletions(-) create mode 100644 src/mod_sso/lru_cache.h diff --git a/src/mod_sso/lru_cache.h b/src/mod_sso/lru_cache.h new file mode 100644 index 0000000..243dcf3 --- /dev/null +++ b/src/mod_sso/lru_cache.h @@ -0,0 +1,167 @@ +// Copyright (c) 2010-2011, Tim Day <timday@timday.com> +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +#ifndef __mod_sso_lru_cache_h +#define __mod_sso_lru_cache_h 1 + +#include <cassert> +#include <list> + +// Class providing fixed-size (by number of records) +// LRU-replacement cache of a function with signature +// V f(K). +// MAP should be one of std::map or std::unordered_map. +// Variadic template args used to deal with the +// different type argument signatures of those +// containers; the default comparator/hash/allocator +// will be used. +template < + typename K, + typename V, + template<typename...> class MAP + > class lru_cache_using_std +{ +public: + + typedef K key_type; + typedef V value_type; + + // Key access history, most recent at back + typedef std::list<key_type> key_tracker_type; + + // Key to value and key history iterator + typedef MAP< + key_type, + std::pair< + value_type, + typename key_tracker_type::iterator + > + > key_to_value_type; + + // Constructor specifies the cached function and + // the maximum number of records to be stored + lru_cache_using_std( + value_type (*f)(const key_type&), + size_t c + ) + :_fn(f) + ,_capacity(c) + { + assert(_capacity!=0); + } + + // Obtain value of the cached function for k + value_type operator()(const key_type& k) { + + // Attempt to find existing record + const typename key_to_value_type::iterator it + =_key_to_value.find(k); + + if (it==_key_to_value.end()) { + + // We don't have it: + + // Evaluate function and create new record + const value_type v=_fn(k); + insert(k,v); + + // Return the freshly computed value + return v; + + } else { + + // We do have it: + + // Update access record by moving + // accessed key to back of list + _key_tracker.splice( + _key_tracker.end(), + _key_tracker, + (*it).second.second + ); + + // Return the retrieved value + return (*it).second.first; + } + } + + // Obtain the cached keys, most recently used element + // at head, least recently used at tail. + // This method is provided purely to support testing. + template <typename IT> void get_keys(IT dst) const { + typename key_tracker_type::const_reverse_iterator src + =_key_tracker.rbegin(); + while (src!=_key_tracker.rend()) { + *dst++ = *src++; + } + } + +private: + + // Record a fresh key-value pair in the cache + void insert(const key_type& k,const value_type& v) { + + // Method is only called on cache misses + assert(_key_to_value.find(k)==_key_to_value.end()); + + // Make space if necessary + if (_key_to_value.size()==_capacity) + evict(); + + // Record k as most-recently-used key + typename key_tracker_type::iterator it + =_key_tracker.insert(_key_tracker.end(),k); + + // Create the key-value entry, + // linked to the usage record. + _key_to_value.insert( + std::make_pair( + k, + std::make_pair(v,it) + ) + ); + // No need to check return, + // given previous assert. + } + + // Purge the least-recently-used element in the cache + void evict() { + + // Assert method is never called when cache is empty + assert(!_key_tracker.empty()); + + // Identify least recently used key + const typename key_to_value_type::iterator it + =_key_to_value.find(_key_tracker.front()); + assert(it!=_key_to_value.end()); + + // Erase both elements to completely purge record + _key_to_value.erase(it); + _key_tracker.pop_front(); + } + + // The function to be cached + value_type (*_fn)(const key_type&); + + // Maximum number of key-value pairs to be retained + const size_t _capacity; + + // Key access history + key_tracker_type _key_tracker; + + // Key-to-value lookup + key_to_value_type _key_to_value; +}; + +#endif diff --git a/src/mod_sso/mod_sso.cc b/src/mod_sso/mod_sso.cc index 0b53440..c78c329 100644 --- a/src/mod_sso/mod_sso.cc +++ b/src/mod_sso/mod_sso.cc @@ -24,6 +24,7 @@ #include <iostream> #include <sstream> +#include <memory> #include "httpd.h" #include "http_config.h" @@ -35,6 +36,7 @@ #include "ap_config.h" #include "apr_strings.h" +#include "lru_cache.h" #include "mod_sso.h" extern "C" module AP_MODULE_DECLARE_DATA sso_module; @@ -42,6 +44,7 @@ extern "C" module AP_MODULE_DECLARE_DATA sso_module; using std::map; using std::string; using std::ostringstream; +using std::shared_ptr; typedef struct { const char *login_server; @@ -480,6 +483,62 @@ static string encode_groups(const sso::groups_t& groups) return obuf.str(); } +struct verify_context { + string public_key; + string service; + string domain; + sso::groups_t req_groups; + string sso_cookie; + + bool operator<(const verify_context& b) const { + if (domain < b.domain) + return true; + if (service < b.service) + return true; + if (req_groups != b.req_groups) + return true; + return sso_cookie < b.sso_cookie; + } +}; + +class verify_response { +public: + ~verify_response() { + delete t; + } + + sso::Ticket *t; + string error; +}; + +static shared_ptr<verify_response> verify_ticket(const verify_context& ctx) { + verify_response *resp = new verify_response; + + try { + sso::Verifier verifier(ctx.public_key, ctx.service, ctx.domain, ctx.req_groups); + resp->t = verifier.verify(ctx.sso_cookie); + } catch (sso::sso_error& e) { + resp->t = NULL; + resp->error = e.what(); + } + return shared_ptr<verify_response>(resp); +} + +typedef lru_cache_using_std<verify_context, shared_ptr<verify_response>, std::map + > sso_cache_t; + +static sso_cache_t sso_cache(verify_ticket, 128); + +#if APR_HAS_THREADS +static apr_thread_mutex_t *sso_cache_lock; +#endif + +static void mod_sso_child_init(apr_pool_t *p, server_rec *s) { +#if APR_HAS_THREADS + apr_thread_mutex_create(&(sso_cache_lock), APR_THREAD_MUTEX_DEFAULT, p); +#endif +} + /** * Apache authentication handler for mod_sso. * @@ -529,12 +588,27 @@ static int mod_sso_authenticate_user(request_rec *r) string sso_cookie = get_cookie(r, sso_cookie_name); if (!sso_cookie.empty()) { string pkey(s_cfg->public_key, s_cfg->public_key_len); - sso::Verifier verifier(pkey, s_cfg->service, - s_cfg->domain, req_groups); - try { - sso::Ticket *t = verifier.verify(sso_cookie); - + verify_context vctx; + vctx.public_key = pkey; + vctx.service = s_cfg->service; + vctx.domain = s_cfg->domain; + vctx.req_groups = req_groups; + vctx.sso_cookie = sso_cookie; + +#if APR_HAS_THREADS + apr_thread_mutex_lock(sso_cache_lock); +#endif + + shared_ptr<verify_response> vr = sso_cache(vctx); + +#if APR_HAS_THREADS + apr_thread_mutex_unlock(sso_cache_lock); +#endif + + if (vr->error.empty()) { + sso::Ticket *t = vr->t; + // Check user authorization lists if (allow_any_user || (!req_users.empty() @@ -545,17 +619,15 @@ static int mod_sso_authenticate_user(request_rec *r) r->user = apr_pstrdup(r->pool, t->user().c_str()); ap_log_error(APLOG_MARK, APLOG_DEBUG, 0, r->server, "sso: authorized user '%s'", r->user); - delete t; return OK; } else { ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server, "sso: unauthorized user '%s'", t->user().c_str()); - delete t; return HTTP_UNAUTHORIZED; } - } catch (sso::sso_error& e) { + } else { ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server, - "sso: validation error: %s", e.what()); + "sso: validation error: %s", vr->error.c_str()); } } @@ -610,6 +682,7 @@ static int mod_sso_auth_checker(request_rec *r) static void mod_sso_register_hooks (apr_pool_t *p) { static const char * const mssoPost[] = {"mod_sso.c", NULL}; + ap_hook_child_init(mod_sso_child_init, NULL, NULL, APR_HOOK_FIRST); ap_hook_handler(mod_sso_method_handler, NULL, NULL, APR_HOOK_FIRST); ap_hook_auth_checker(mod_sso_auth_checker, NULL, mssoPost, APR_HOOK_MIDDLE); ap_hook_check_user_id(mod_sso_authenticate_user, NULL, NULL, APR_HOOK_MIDDLE); diff --git a/src/mod_sso/test/httpd_integration_test.py b/src/mod_sso/test/httpd_integration_test.py index 4923c6f..1165430 100755 --- a/src/mod_sso/test/httpd_integration_test.py +++ b/src/mod_sso/test/httpd_integration_test.py @@ -98,7 +98,19 @@ class HttpdIntegrationTest(unittest.TestCase): print 'ticket:', signedt return signedt - def testManyRequests(self): + def testTicketCacheEviction(self): + # Requests for more users than the cache size, forcing eviction. + n = 200 + errors = 0 + for i in xrange(n): + cookie = 'SSO_test=%s' % self._ticket('user%d' % i) + status, body, location = _query("/index.html", cookie=cookie) + if status != 200: + errors += 1 + self.assertEquals(0, errors) + + def testTicketCache(self): + # Lots of requests for the same user. n = 100 errors = 0 for i in xrange(n): -- GitLab