From dd2e9dd9c96ee62ab20f122453379799786ebdd1 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Mon, 12 Aug 2013 17:47:50 +0000
Subject: [PATCH] add a cache for validated sso tickets

---
 src/mod_sso/lru_cache.h                    | 167 +++++++++++++++++++++
 src/mod_sso/mod_sso.cc                     |  91 +++++++++--
 src/mod_sso/test/httpd_integration_test.py |  14 +-
 3 files changed, 262 insertions(+), 10 deletions(-)
 create mode 100644 src/mod_sso/lru_cache.h

diff --git a/src/mod_sso/lru_cache.h b/src/mod_sso/lru_cache.h
new file mode 100644
index 0000000..243dcf3
--- /dev/null
+++ b/src/mod_sso/lru_cache.h
@@ -0,0 +1,167 @@
+// Copyright (c) 2010-2011, Tim Day <timday@timday.com> 
+// 
+// Permission to use, copy, modify, and/or distribute this software for any 
+// purpose with or without fee is hereby granted, provided that the above 
+// copyright notice and this permission notice appear in all copies. 
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 
+// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 
+// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 
+// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#ifndef __mod_sso_lru_cache_h
+#define __mod_sso_lru_cache_h 1
+
+#include <cassert> 
+#include <list> 
+ 
+// Class providing fixed-size (by number of records) 
+// LRU-replacement cache of a function with signature 
+// V f(K). 
+// MAP should be one of std::map or std::unordered_map. 
+// Variadic template args used to deal with the 
+// different type argument signatures of those 
+// containers; the default comparator/hash/allocator 
+// will be used. 
+template <
+  typename K,
+  typename V,
+  template<typename...> class MAP
+  > class lru_cache_using_std
+{
+public:
+
+  typedef K key_type;
+  typedef V value_type;
+
+  // Key access history, most recent at back
+  typedef std::list<key_type> key_tracker_type;
+
+  // Key to value and key history iterator
+  typedef MAP<
+    key_type,
+    std::pair<
+      value_type,
+      typename key_tracker_type::iterator
+      >
+  > key_to_value_type;
+
+  // Constructor specifies the cached function and
+  // the maximum number of records to be stored
+  lru_cache_using_std(
+    value_type (*f)(const key_type&),
+    size_t c
+  )
+    :_fn(f)
+    ,_capacity(c)
+  {
+    assert(_capacity!=0);
+  }
+
+  // Obtain value of the cached function for k 
+  value_type operator()(const key_type& k) { 
+ 
+    // Attempt to find existing record 
+    const typename key_to_value_type::iterator it 
+      =_key_to_value.find(k); 
+ 
+    if (it==_key_to_value.end()) { 
+ 
+      // We don't have it: 
+ 
+      // Evaluate function and create new record 
+      const value_type v=_fn(k); 
+      insert(k,v); 
+ 
+      // Return the freshly computed value 
+      return v; 
+ 
+    } else { 
+ 
+      // We do have it: 
+ 
+      // Update access record by moving 
+      // accessed key to back of list 
+      _key_tracker.splice( 
+        _key_tracker.end(), 
+        _key_tracker, 
+        (*it).second.second 
+      ); 
+ 
+      // Return the retrieved value 
+      return (*it).second.first; 
+    }
+  }
+
+  // Obtain the cached keys, most recently used element 
+  // at head, least recently used at tail. 
+  // This method is provided purely to support testing. 
+  template <typename IT> void get_keys(IT dst) const {
+    typename key_tracker_type::const_reverse_iterator src
+        =_key_tracker.rbegin();
+    while (src!=_key_tracker.rend()) {
+      *dst++ = *src++;
+    }
+  }
+
+private:
+
+  // Record a fresh key-value pair in the cache
+  void insert(const key_type& k,const value_type& v) {
+
+    // Method is only called on cache misses
+    assert(_key_to_value.find(k)==_key_to_value.end());
+
+    // Make space if necessary
+    if (_key_to_value.size()==_capacity)
+      evict();
+
+    // Record k as most-recently-used key
+    typename key_tracker_type::iterator it
+      =_key_tracker.insert(_key_tracker.end(),k);
+
+    // Create the key-value entry,
+    // linked to the usage record.
+    _key_to_value.insert(
+      std::make_pair(
+        k,
+        std::make_pair(v,it)
+      )
+    );
+    // No need to check return,
+    // given previous assert.
+  }
+
+  // Purge the least-recently-used element in the cache
+  void evict() {
+
+    // Assert method is never called when cache is empty
+    assert(!_key_tracker.empty());
+
+    // Identify least recently used key
+    const typename key_to_value_type::iterator it
+      =_key_to_value.find(_key_tracker.front());
+    assert(it!=_key_to_value.end());
+
+    // Erase both elements to completely purge record 
+    _key_to_value.erase(it);
+    _key_tracker.pop_front();
+  }
+
+  // The function to be cached 
+  value_type (*_fn)(const key_type&);
+
+  // Maximum number of key-value pairs to be retained
+  const size_t _capacity;
+
+  // Key access history
+  key_tracker_type _key_tracker;
+
+  // Key-to-value lookup
+  key_to_value_type _key_to_value;
+};
+
+#endif
diff --git a/src/mod_sso/mod_sso.cc b/src/mod_sso/mod_sso.cc
index 0b53440..c78c329 100644
--- a/src/mod_sso/mod_sso.cc
+++ b/src/mod_sso/mod_sso.cc
@@ -24,6 +24,7 @@
 
 #include <iostream>
 #include <sstream>
+#include <memory>
 
 #include "httpd.h"
 #include "http_config.h"
@@ -35,6 +36,7 @@
 #include "ap_config.h"
 #include "apr_strings.h"
 
+#include "lru_cache.h"
 #include "mod_sso.h"
 
 extern "C" module AP_MODULE_DECLARE_DATA sso_module;
@@ -42,6 +44,7 @@ extern "C" module AP_MODULE_DECLARE_DATA sso_module;
 using std::map;
 using std::string;
 using std::ostringstream;
+using std::shared_ptr;
 
 typedef struct {
   const char *login_server;
@@ -480,6 +483,62 @@ static string encode_groups(const sso::groups_t& groups)
   return obuf.str();
 }
 
+struct verify_context {
+  string public_key;
+  string service;
+  string domain;
+  sso::groups_t req_groups;
+  string sso_cookie;
+
+  bool operator<(const verify_context& b) const {
+    if (domain < b.domain)
+      return true;
+    if (service < b.service)
+      return true;
+    if (req_groups != b.req_groups)
+      return true;
+    return sso_cookie < b.sso_cookie;
+  }
+};
+
+class verify_response {
+public:
+  ~verify_response() {
+    delete t;
+  }
+
+  sso::Ticket *t;
+  string error;
+};
+
+static shared_ptr<verify_response> verify_ticket(const verify_context& ctx) {
+  verify_response *resp = new verify_response;
+
+  try {
+    sso::Verifier verifier(ctx.public_key, ctx.service, ctx.domain, ctx.req_groups);
+    resp->t = verifier.verify(ctx.sso_cookie);
+  } catch (sso::sso_error& e) {
+    resp->t = NULL;
+    resp->error = e.what();
+  }
+  return shared_ptr<verify_response>(resp);
+}
+
+typedef lru_cache_using_std<verify_context, shared_ptr<verify_response>, std::map
+                            > sso_cache_t;
+
+static sso_cache_t sso_cache(verify_ticket, 128);
+
+#if APR_HAS_THREADS
+static apr_thread_mutex_t *sso_cache_lock;
+#endif
+
+static void mod_sso_child_init(apr_pool_t *p, server_rec *s) {
+#if APR_HAS_THREADS
+  apr_thread_mutex_create(&(sso_cache_lock), APR_THREAD_MUTEX_DEFAULT, p);
+#endif
+}
+
 /**
  * Apache authentication handler for mod_sso.
  *
@@ -529,12 +588,27 @@ static int mod_sso_authenticate_user(request_rec *r)
   string sso_cookie = get_cookie(r, sso_cookie_name);
   if (!sso_cookie.empty()) {
     string pkey(s_cfg->public_key, s_cfg->public_key_len);
-    sso::Verifier verifier(pkey, s_cfg->service,
-                           s_cfg->domain, req_groups);
 
-    try {
-      sso::Ticket *t = verifier.verify(sso_cookie);
-    
+    verify_context vctx;
+    vctx.public_key = pkey;
+    vctx.service = s_cfg->service;
+    vctx.domain = s_cfg->domain;
+    vctx.req_groups = req_groups;
+    vctx.sso_cookie = sso_cookie;
+
+#if APR_HAS_THREADS
+    apr_thread_mutex_lock(sso_cache_lock);
+#endif
+
+    shared_ptr<verify_response> vr = sso_cache(vctx);
+
+#if APR_HAS_THREADS
+    apr_thread_mutex_unlock(sso_cache_lock);
+#endif
+
+    if (vr->error.empty()) {
+      sso::Ticket *t = vr->t;
+
       // Check user authorization lists
       if (allow_any_user
 	  || (!req_users.empty()
@@ -545,17 +619,15 @@ static int mod_sso_authenticate_user(request_rec *r)
 	r->user = apr_pstrdup(r->pool, t->user().c_str());
 	ap_log_error(APLOG_MARK, APLOG_DEBUG, 0, r->server,
 		     "sso: authorized user '%s'", r->user);
-	delete t;
 	return OK;
       } else {
 	ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server,
 		     "sso: unauthorized user '%s'", t->user().c_str());
-        delete t;
         return HTTP_UNAUTHORIZED;
       }
-    } catch (sso::sso_error& e) {
+    } else {
       ap_log_error(APLOG_MARK, APLOG_WARNING, 0, r->server,
-                   "sso: validation error: %s", e.what());
+                   "sso: validation error: %s", vr->error.c_str());
     }
   }
 
@@ -610,6 +682,7 @@ static int mod_sso_auth_checker(request_rec *r)
 static void mod_sso_register_hooks (apr_pool_t *p) 
 {
   static const char * const mssoPost[] = {"mod_sso.c", NULL};
+  ap_hook_child_init(mod_sso_child_init, NULL, NULL, APR_HOOK_FIRST);
   ap_hook_handler(mod_sso_method_handler, NULL, NULL, APR_HOOK_FIRST);
   ap_hook_auth_checker(mod_sso_auth_checker, NULL, mssoPost, APR_HOOK_MIDDLE);
   ap_hook_check_user_id(mod_sso_authenticate_user, NULL, NULL, APR_HOOK_MIDDLE);
diff --git a/src/mod_sso/test/httpd_integration_test.py b/src/mod_sso/test/httpd_integration_test.py
index 4923c6f..1165430 100755
--- a/src/mod_sso/test/httpd_integration_test.py
+++ b/src/mod_sso/test/httpd_integration_test.py
@@ -98,7 +98,19 @@ class HttpdIntegrationTest(unittest.TestCase):
         print 'ticket:', signedt
         return signedt
 
-    def testManyRequests(self):
+    def testTicketCacheEviction(self):
+        # Requests for more users than the cache size, forcing eviction.
+        n = 200
+        errors = 0
+        for i in xrange(n):
+            cookie = 'SSO_test=%s' % self._ticket('user%d' % i)
+            status, body, location = _query("/index.html", cookie=cookie)
+            if status != 200:
+                errors += 1
+        self.assertEquals(0, errors)
+
+    def testTicketCache(self):
+        # Lots of requests for the same user.
         n = 100
         errors = 0
         for i in xrange(n):
-- 
GitLab