From 6474e1ab21699f1ccd0835ccd1fb403765bf64bb Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Wed, 25 Jun 2014 18:44:28 +0200
Subject: [PATCH] run the NGINX mail_http_auth handler on a separate, non-SSL
 port

---
 authserv/__init__.py                          |   2 -
 authserv/app_common.py                        |  18 +++
 authserv/app_main.py                          |  50 ++++++++
 authserv/app_nginx.py                         |  39 ++++++
 authserv/server.py                            | 121 ++++--------------
 .../test/{test_server.py => test_app_main.py} |  29 +----
 authserv/test/test_app_nginx.py               |  50 ++++++++
 authserv/test/test_integration.py             |   4 +-
 8 files changed, 186 insertions(+), 127 deletions(-)
 create mode 100644 authserv/app_common.py
 create mode 100644 authserv/app_main.py
 create mode 100644 authserv/app_nginx.py
 rename authserv/test/{test_server.py => test_app_main.py} (87%)
 create mode 100644 authserv/test/test_app_nginx.py

diff --git a/authserv/__init__.py b/authserv/__init__.py
index cf1f414..e69de29 100644
--- a/authserv/__init__.py
+++ b/authserv/__init__.py
@@ -1,2 +0,0 @@
-from flask import Flask
-app = Flask(__name__)
diff --git a/authserv/app_common.py b/authserv/app_common.py
new file mode 100644
index 0000000..f10a66b
--- /dev/null
+++ b/authserv/app_common.py
@@ -0,0 +1,18 @@
+from flask import current_app
+from authserv import auth
+from authserv import protocol
+from authserv.ratelimit import *
+
+_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, None)
+
+
+@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200,
+                           bl_return_value=_blacklisted)
+@blacklist_on_auth_failure(key_from_args(5), count=5, period=600, ttl=43200,
+                           check_wl=True, bl_return_value=_blacklisted)
+def do_auth(username, service, shard, password, otp_token, source_ip):
+    user = current_app.userdb.get_user(username, service, shard)
+    if not user:
+        return _blacklisted
+    return (auth.authenticate(user, service, password, otp_token),
+            user.get_shard())
diff --git a/authserv/app_main.py b/authserv/app_main.py
new file mode 100644
index 0000000..5731361
--- /dev/null
+++ b/authserv/app_main.py
@@ -0,0 +1,50 @@
+from flask import Flask, request, abort, make_response
+from authserv import auth
+from authserv import protocol
+from authserv.ratelimit import *
+from authserv.app_common import do_auth
+
+app = Flask(__name__)
+
+
+# Quick clarification on the rate limits: 'username' is the one that's
+# going to be used all the time, while the X-Forwarded-For header on
+# the request is only going to be present for those authentication
+# requests where we have knowledge of the original users' IP (remember
+# that 'source_ip' can sometimes be the server address or localhost).
+# For instance, authentication requests that come from PAM usually do
+# not have knowledge of the users' IP address, as the protocols for
+# which we use PAM handlers do not support forwarding of the IP
+# address. So we're practically only going to use X-Forwarded-For for
+# requests that reach our frontends via HTTP.
+@app.route('/api/1/auth', methods=('POST',))
+@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
+                        count=10, period=60)
+@ratelimit_http_request(key_from_request(param='username'),
+                        count=10, period=60)
+def api_auth():
+    service = request.form.get('service')
+    username = request.form.get('username')
+    password = request.form.get('password')
+    otp_token = request.form.get('otp')
+    source_ip = request.form.get('source_ip')
+    shard = request.form.get('shard')
+
+    if not service or not username:
+        abort(400)
+
+    try:
+        result, _ = do_auth(username, service, shard, password, otp_token, source_ip)
+    except Exception, e:
+        app.logger.exception('Unexpected exception in authenticate()')
+        abort(500)
+
+    app.logger.info(
+        'AUTH %s %s otp=%s %s',
+        username, service, otp_token and 'y' or 'n', result)
+
+    response = make_response(result)
+    response.headers['Cache-Control'] = 'no-cache'
+    response.headers['Content-Type'] = 'text/plain'
+    response.headers['Expires'] = '-1'
+    return response
diff --git a/authserv/app_nginx.py b/authserv/app_nginx.py
new file mode 100644
index 0000000..b754c8a
--- /dev/null
+++ b/authserv/app_nginx.py
@@ -0,0 +1,39 @@
+from flask import Flask, request, abort, make_response
+from authserv.app_common import do_auth
+
+app = Flask(__name__)
+
+_default_port_map = {'imap': 143, 'pop3': 110}
+
+
+@app.route('/auth', methods=('GET',))
+def do_nginx_http_auth():
+    service = app.config.get('NGINX_AUTH_SERVICE', 'mail')
+    username = request.environ.get('HTTP_AUTH_USER')
+    password = request.environ.get('HTTP_AUTH_PASS')
+    source_ip = request.environ.get('HTTP_CLIENT_IP')
+    protocol = request.environ.get('HTTP_AUTH_PROTOCOL')
+    try:
+        n_attempt = int(request.environ.get('HTTP_AUTH_LOGIN_ATTEMPT'))
+    except ValueError:
+        n_attempt = 1
+
+    try:
+        auth_status, shard = do_auth(
+            username, service, None, password, None, source_ip)
+    except Exception, e:
+        app.logger.exception('Unexpected exception in authenticate()')
+        abort(500)
+
+    response = make_response('')
+    if auth_status == 'OK':
+        response.headers['Auth-Status'] = 'OK'
+        response.headers['Auth-Server'] = shard
+        response.headers['Auth-Port'] = str(
+            app.config.get('NGINX_AUTH_PORT_MAP', _default_port_map)[protocol])
+    else:
+        response.headers['Auth-Status'] = 'Invalid login or password'
+        if n_attempt <= 3:
+            response.headers['Auth-Wait'] = '3'
+    return response
+
diff --git a/authserv/server.py b/authserv/server.py
index 5124d88..75aa406 100644
--- a/authserv/server.py
+++ b/authserv/server.py
@@ -3,104 +3,12 @@ import logging.handlers
 import optparse
 import os
 import signal
-from authserv import app
-from authserv import auth
-from authserv import protocol
-from authserv.ratelimit import *
-from flask import Flask, request, abort, make_response
-
-_blacklisted = (protocol.ERR_AUTHENTICATION_FAILURE, None)
-
-@blacklist_on_auth_failure(key_from_args(0), count=5, period=600, ttl=43200,
-                           bl_return_value=_blacklisted)
-@blacklist_on_auth_failure(key_from_args(5), count=5, period=600, ttl=43200,
-                           check_wl=True, bl_return_value=_blacklisted)
-def _auth(username, service, shard, password, otp_token, source_ip):
-    user = app.userdb.get_user(username, service, shard)
-    if not user:
-        return _blacklisted
-    return (auth.authenticate(user, service, password, otp_token),
-            user.get_shard())
-
-
-# Quick clarification on the rate limits: 'username' is the one that's
-# going to be used all the time, while the X-Forwarded-For header on
-# the request is only going to be present for those authentication
-# requests where we have knowledge of the original users' IP (remember
-# that 'source_ip' can sometimes be the server address or localhost).
-# For instance, authentication requests that come from PAM usually do
-# not have knowledge of the users' IP address, as the protocols for
-# which we use PAM handlers do not support forwarding of the IP
-# address. So we're practically only going to use X-Forwarded-For for
-# requests that reach our frontends via HTTP.
-@app.route('/api/1/auth', methods=('POST',))
-@ratelimit_http_request(key_from_request(header='HTTP_X_FORWARDED_FOR'),
-                        count=10, period=60)
-@ratelimit_http_request(key_from_request(param='username'),
-                        count=10, period=60)
-def do_auth():
-    service = request.form.get('service')
-    username = request.form.get('username')
-    password = request.form.get('password')
-    otp_token = request.form.get('otp')
-    source_ip = request.form.get('source_ip')
-    shard = request.form.get('shard')
-
-    if not service or not username:
-        abort(400)
-
-    try:
-        result, _ = _auth(username, service, shard, password, otp_token, source_ip)
-    except Exception, e:
-        app.logger.exception('Unexpected exception in authenticate()')
-        abort(500)
-
-    app.logger.info(
-        'AUTH %s %s otp=%s %s',
-        username, service, otp_token and 'y' or 'n', result)
-
-    response = make_response(result)
-    response.headers['Cache-Control'] = 'no-cache'
-    response.headers['Content-Type'] = 'text/plain'
-    response.headers['Expires'] = '-1'
-    return response
-
-
-_default_port_map = {'imap': 143, 'pop3': 110}
-
-@app.route('/auth', methods=('GET',))
-def do_nginx_http_auth():
-    service = app.config.get('NGINX_AUTH_SERVICE', 'mail')
-    username = request.environ.get('HTTP_AUTH_USER')
-    password = request.environ.get('HTTP_AUTH_PASS')
-    source_ip = request.environ.get('HTTP_CLIENT_IP')
-    protocol = request.environ.get('HTTP_AUTH_PROTOCOL')
-    try:
-        n_attempt = int(request.environ.get('HTTP_AUTH_LOGIN_ATTEMPT'))
-    except ValueError:
-        n_attempt = 1
-
-    try:
-        auth_status, shard = _auth(
-            username, service, None, password, None, source_ip)
-    except Exception, e:
-        app.logger.exception('Unexpected exception in authenticate()')
-        abort(500)
-
-    response = make_response('')
-    if auth_status == 'OK':
-        response.headers['Auth-Status'] = 'OK'
-        response.headers['Auth-Server'] = shard
-        response.headers['Auth-Port'] = str(
-            app.config.get('NGINX_AUTH_PORT_MAP', _default_port_map)[protocol])
-    else:
-        response.headers['Auth-Status'] = 'Invalid login or password'
-        if n_attempt <= 3:
-            response.headers['Auth-Wait'] = '3'
-    return response
+import threading
+from authserv import app_main
+from authserv import app_nginx
 
 
-def create_app(userdb=None, mc=None):
+def create_app(app, userdb=None, mc=None):
     app.config.from_envvar('APP_CONFIG', silent=True)
 
     if not userdb:
@@ -178,10 +86,13 @@ def main():
     parser = optparse.OptionParser()
     parser.add_option('--config',
                       help='Configuration file')
-    parser.add_option('--addr', dest='addr', default='0.0.0.0',
+    parser.add_option('--addr', default='0.0.0.0',
                       help='Address to listen on (default: %default)')
     parser.add_option('--port', type='int', default=1616,
                       help='TCP port to listen on (default: %default)')
+    parser.add_option('--nginx-port', dest='nginx_port', type='int', default=0,
+                      help='TCP port for the NGINX mail_http_auth handler '
+                      '(forced bind to localhost, default: disabled)')
     parser.add_option('--engine', dest='engine',
                       help='HTTP engine to use (default: try gevent, then werkzeug)')
     parser.add_option('--ca', dest='ssl_ca',
@@ -213,7 +124,9 @@ def main():
 
     if opts.config:
         os.environ['APP_CONFIG'] = opts.config
-    app.config.update({'DEBUG': opts.debug})
+    if opts.debug:
+        for a in (app_main.app, app_nginx.app):
+            a.config['DEBUG'] = True
 
     def _stopall(signo, frame):
         logging.info('terminating with signal %d', signo)
@@ -221,7 +134,17 @@ def main():
     signal.signal(signal.SIGINT, _stopall)
     signal.signal(signal.SIGTERM, _stopall)
 
-    run(create_app(),
+    # Start the applications that were requested: the NGINX
+    # mail_http_auth handler (on its own thread), and the main auth
+    # server application.
+    if opts.nginx_port > 0:
+        t = threading.Thread(target=run, args=(
+            create_app(app_nginx.app),
+            opts.engine, '127.0.0.1', opts.nginx_port, None,
+            None, None, None))
+        t.setDaemon(True)
+        t.start()
+    run(create_app(app_main.app),
         opts.engine, opts.addr, opts.port, opts.ssl_ca,
         opts.ssl_cert, opts.ssl_key, opts.dh_params)
 
diff --git a/authserv/test/test_server.py b/authserv/test/test_app_main.py
similarity index 87%
rename from authserv/test/test_server.py
rename to authserv/test/test_app_main.py
index 2008383..298af68 100644
--- a/authserv/test/test_server.py
+++ b/authserv/test/test_app_main.py
@@ -2,6 +2,7 @@ from authserv.test import *
 from authserv.ratelimit import *
 from authserv import protocol
 from authserv import server
+from authserv import app_main
 
 URL = '/api/1/auth'
 
@@ -15,7 +16,8 @@ class ServerTest(unittest.TestCase):
         self.users = {
             'user': FakeUser('user', 'pass'),
             }
-        app = server.create_app(userdb=FakeUserDb(self.users),
+        app = server.create_app(app_main.app,
+                                userdb=FakeUserDb(self.users),
                                 mc=FakeMemcache(_time))
         app.config.update({
                 'TESTING': True,
@@ -183,27 +185,4 @@ class ServerTest(unittest.TestCase):
         self.assertEquals(200, response.status_code)
         self.assertEquals(protocol.OK, response.data)
 
-    def test_nginx_http_auth_ok(self):
-        response = self.app.get(
-            '/auth', headers={
-                'Auth-User': 'user',
-                'Auth-Pass': 'pass',
-                'Client-IP': '127.0.0.1',
-                'Auth-Protocol': 'imap',
-                'Auth-Login-Attempt': '1',
-            })
-        self.assertEquals(200, response.status_code)
-        self.assertEquals('OK', response.headers['Auth-Status'])
-
-    def test_nginx_http_auth_fail(self):
-        response = self.app.get(
-            '/auth', headers={
-                'Auth-User': 'user',
-                'Auth-Pass': 'bad password',
-                'Client-IP': '127.0.0.1',
-                'Auth-Protocol': 'imap',
-                'Auth-Login-Attempt': '1',
-            })
-        self.assertEquals(200, response.status_code)
-        self.assertNotEquals('OK', response.headers['Auth-Status'])
-
+        
diff --git a/authserv/test/test_app_nginx.py b/authserv/test/test_app_nginx.py
new file mode 100644
index 0000000..1f1e6c6
--- /dev/null
+++ b/authserv/test/test_app_nginx.py
@@ -0,0 +1,50 @@
+from authserv.test import *
+from authserv.ratelimit import *
+from authserv import protocol
+from authserv import server
+from authserv import app_nginx
+
+URL = '/api/1/auth'
+
+
+class ServerTest(unittest.TestCase):
+
+    def setUp(self):
+        self.tick = 0
+        def _time():
+            return self.tick
+        self.users = {
+            'user': FakeUser('user', 'pass'),
+            }
+        app = server.create_app(app_nginx.app,
+                                userdb=FakeUserDb(self.users),
+                                mc=FakeMemcache(_time))
+        app.config.update({
+                'TESTING': True,
+                'DEBUG': True,
+                })
+        self.app = app.test_client()
+
+    def test_nginx_http_auth_ok(self):
+        response = self.app.get(
+            '/auth', headers={
+                'Auth-User': 'user',
+                'Auth-Pass': 'pass',
+                'Client-IP': '127.0.0.1',
+                'Auth-Protocol': 'imap',
+                'Auth-Login-Attempt': '1',
+            })
+        self.assertEquals(200, response.status_code)
+        self.assertEquals('OK', response.headers['Auth-Status'])
+
+    def test_nginx_http_auth_fail(self):
+        response = self.app.get(
+            '/auth', headers={
+                'Auth-User': 'user',
+                'Auth-Pass': 'bad password',
+                'Client-IP': '127.0.0.1',
+                'Auth-Protocol': 'imap',
+                'Auth-Login-Attempt': '1',
+            })
+        self.assertEquals(200, response.status_code)
+        self.assertNotEquals('OK', response.headers['Auth-Status'])
diff --git a/authserv/test/test_integration.py b/authserv/test/test_integration.py
index 5db318a..ea07810 100644
--- a/authserv/test/test_integration.py
+++ b/authserv/test/test_integration.py
@@ -10,6 +10,7 @@ from authserv.test import *
 from authserv.ratelimit import *
 from authserv import protocol
 from authserv import server
+from authserv import app_main
 
 URL = '/api/1/auth'
 
@@ -56,7 +57,8 @@ class SSLServerTest(unittest.TestCase):
         self.users = {
             'user': FakeUser('user', 'pass'),
         }
-        app = server.create_app(userdb=FakeUserDb(self.users),
+        app = server.create_app(app_main.app,
+                                userdb=FakeUserDb(self.users),
                                 mc=FakeMemcache(time.time))
         app.config.update({
                 'TESTING': True,
-- 
GitLab