From 188a0870d99aac291fc001edf2f9d5ecf17a9bcc Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Thu, 19 Dec 2019 10:50:45 +0000
Subject: [PATCH] Return 404 for unauthenticated requests to URLs that do not
 exist

This avoids browsers messing up the session state (given that /login
calls session.Reset) with requests to various kinds of well-known URLs
that might not exist.

Also add an integration test for a server with non-nil URL prefix.
---
 server/http.go             | 21 ++++++++++---
 server/http_test.go        | 63 ++++++++++++++++++++++++++++++++------
 server/integration_test.go | 44 ++++++++++++++++++++++++--
 3 files changed, 113 insertions(+), 15 deletions(-)

diff --git a/server/http.go b/server/http.go
index 0502e77..b224026 100644
--- a/server/http.go
+++ b/server/http.go
@@ -140,9 +140,9 @@ func New(loginService *LoginService, authClient authclient.Client, config *Confi
 	// HTTP-based login workflow).
 	root.HandleFunc(h.urlFor("/exchange"), h.handleExchange)
 
-	// Build the main IDP application router, wrap it with a login
-	// handler, optional CSRF protection, custom HTTP headers,
-	// etc.
+	// Build the main application router (which only serves / and
+	// /logout), wrap it with a login handler, optional CSRF
+	// protection, custom HTTP headers, etc.
 	mainh := http.NewServeMux()
 	mainh.HandleFunc(h.urlFor("/logout"), h.handleLogout)
 	mainh.HandleFunc(h.urlFor("/"), h.handleGrantTicket)
@@ -167,7 +167,20 @@ func New(loginService *LoginService, authClient authclient.Client, config *Confi
 	})
 	apph = corsp.Handler(apph)
 
-	root.Handle(h.urlFor("/"), apph)
+	// Now we need to remap 'apph' onto 'root'. We do this by
+	// whitelisting certain methods only, which allows us to
+	// return 404s *before* authentication.
+	root.Handle(h.urlFor("/login"), apph)
+	root.Handle(h.urlFor("/login/"), apph)
+	root.Handle(h.urlFor("/logout"), apph)
+	root.HandleFunc(h.urlFor("/"), func(w http.ResponseWriter, r *http.Request) {
+		if r.URL.Path != h.urlFor("/") {
+			http.NotFound(w, r)
+			return
+		}
+		apph.ServeHTTP(w, r)
+	})
+
 	h.handler = root
 
 	return h, nil
diff --git a/server/http_test.go b/server/http_test.go
index 96f4a9a..8d48948 100644
--- a/server/http_test.go
+++ b/server/http_test.go
@@ -158,6 +158,12 @@ func checkStatusOk(t testing.TB, resp *http.Response) {
 	}
 }
 
+func checkStatusNotFound(t testing.TB, resp *http.Response) {
+	if resp.StatusCode != 404 {
+		t.Fatalf("expected status 404, got %s", resp.Status)
+	}
+
+}
 func checkRedirectToTargetService(t testing.TB, resp *http.Response) {
 	if resp.StatusCode != 302 {
 		t.Fatalf("expected status 302, got %s", resp.Status)
@@ -196,10 +202,13 @@ func checkTargetSSOTicket(config *Config) func(testing.TB, *http.Response) {
 
 var usernameFieldRx = regexp.MustCompile(`<input[^>]*name="username"`)
 
-func checkLoginPasswordPage(t testing.TB, resp *http.Response) {
+func checkLoginPageURL(t testing.TB, resp *http.Response) {
 	if resp.Request.URL.Path != "/login" {
 		t.Errorf("request path is not /login (%s)", resp.Request.URL.String())
 	}
+}
+
+func checkLoginPasswordPage(t testing.TB, resp *http.Response) {
 	data, err := ioutil.ReadAll(resp.Body)
 	if err != nil {
 		t.Fatalf("reading body: %v", err)
@@ -253,7 +262,7 @@ func TestHTTP_Login(t *testing.T) {
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
 	v.Set("g", "users")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We expect the
 	// result to be a 302 redirect to the target service.
@@ -276,13 +285,13 @@ func TestHTTP_LoginOnSecondAttempt(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login with wrong credentials.
 	v = make(url.Values)
 	v.Set("username", "testuser")
 	v.Set("password", "badpassword")
-	doPostForm(t, c, httpSrv.URL+"/login", v, checkStatusOk, checkLoginPasswordPage)
+	doPostForm(t, c, httpSrv.URL+"/login", v, checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We expect the
 	// result to be a 302 redirect to the target service.
@@ -305,7 +314,7 @@ func TestHTTP_LoginAndLogout(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We expect the
 	// result to be a 302 redirect to the target service.
@@ -322,7 +331,7 @@ func TestHTTP_LoginAndLogout(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 }
 
 func TestHTTP_LoginOTP(t *testing.T) {
@@ -338,7 +347,39 @@ func TestHTTP_LoginOTP(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
+
+	// Attempt to login by submitting the form. We should see the OTP page.
+	v = make(url.Values)
+	v.Set("username", "test2fa")
+	v.Set("password", "password")
+	doPostForm(t, c, httpSrv.URL+"/login", v, checkStatusOk, checkLoginOTPPage)
+
+	// Submit the correct OTP token. We expect the result to be a
+	// 302 redirect to the target service.
+	v = make(url.Values)
+	v.Set("otp", "123456")
+	doPostForm(t, c, httpSrv.URL+"/login/otp", v, checkRedirectToTargetService)
+}
+
+func TestHTTP_LoginOTP_Intermediate404(t *testing.T) {
+	// This test verifies that the session is not disrupted by a
+	// request for a URL that does not exist during a 2FA login
+	// workflow. The point is that the 404 should *not* Reset()
+	// the session.
+	tmpdir, httpSrv := startTestHTTPServer(t)
+	defer os.RemoveAll(tmpdir)
+	defer httpSrv.Close()
+
+	c := newTestHTTPClient()
+
+	// Simulate an authorization request from a service, expect to
+	// see the login page.
+	v := make(url.Values)
+	v.Set("s", "service.example.com/")
+	v.Set("d", "https://service.example.com/admin/")
+	v.Set("n", "averysecretnonce")
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We should see the OTP page.
 	v = make(url.Values)
@@ -346,6 +387,10 @@ func TestHTTP_LoginOTP(t *testing.T) {
 	v.Set("password", "password")
 	doPostForm(t, c, httpSrv.URL+"/login", v, checkStatusOk, checkLoginOTPPage)
 
+	// Make a request for a URL that does not exist, browsers might do this
+	// for a number of reasons.
+	doGet(t, c, httpSrv.URL+"/apple-iphone-special-icon.ico", checkStatusNotFound)
+
 	// Submit the correct OTP token. We expect the result to be a
 	// 302 redirect to the target service.
 	v = make(url.Values)
@@ -389,7 +434,7 @@ func TestHTTP_LoginWithKeyStore(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We expect the
 	// result to be a 302 redirect to the target service.
@@ -413,7 +458,7 @@ func TestHTTP_CORS(t *testing.T) {
 	v.Set("s", "service.example.com/")
 	v.Set("d", "https://service.example.com/admin/")
 	v.Set("n", "averysecretnonce")
-	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, httpSrv.URL+"/?"+v.Encode(), checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	// Attempt to login by submitting the form. We expect the
 	// result to be a 302 redirect to the target service.
diff --git a/server/integration_test.go b/server/integration_test.go
index 930b62a..d1cbace 100644
--- a/server/integration_test.go
+++ b/server/integration_test.go
@@ -48,6 +48,21 @@ func startTestHTTPServerAndApp(t testing.TB) (string, *httptest.Server, *httptes
 	return tmpdir, srv, app
 }
 
+func startTestHTTPServerWithPrefixAndApp(t testing.TB) (string, *httptest.Server, *httptest.Server) {
+	tmpdir, _ := ioutil.TempDir("", "")
+	config := testConfig(t, tmpdir, "")
+	config.URLPrefix = "/sso"
+	srv := createTestHTTPServer(t, config)
+	app := createTestProtectedService(t, "https://login.example.com/sso", tmpdir)
+	return tmpdir, srv, app
+}
+
+func checkLoginPageURLWithPrefix(t testing.TB, resp *http.Response) {
+	if resp.Request.URL.Path != "/sso/login" {
+		t.Errorf("request path is not /sso/login (%s)", resp.Request.URL.String())
+	}
+}
+
 func checkIsProtectedService(t testing.TB, resp *http.Response) {
 	data, err := ioutil.ReadAll(resp.Body)
 	if err != nil {
@@ -91,7 +106,7 @@ func TestIntegration(t *testing.T) {
 		"service.example.com:443": addrFromURL(app.URL),
 	}, true)
 
-	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
 
 	v := make(url.Values)
 	v.Set("username", "testuser")
@@ -101,5 +116,30 @@ func TestIntegration(t *testing.T) {
 	// Now attempt to logout, and verify that we can't access the service anymore.
 	doGet(t, c, "https://login.example.com/logout", checkStatusOk, checkLogoutPageHasLinks)
 	doGet(t, c, "https://service.example.com/sso_logout", checkStatusOk)
-	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPasswordPage)
+	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPageURL, checkLoginPasswordPage)
+}
+
+// Same test as above, but the server application has a URL prefix.
+func TestIntegration_WithURLPrefix(t *testing.T) {
+	tmpdir, srv, app := startTestHTTPServerWithPrefixAndApp(t)
+	defer os.RemoveAll(tmpdir)
+	defer srv.Close()
+	defer app.Close()
+
+	c := makeHTTPClient(map[string]string{
+		"login.example.com:443":   addrFromURL(srv.URL),
+		"service.example.com:443": addrFromURL(app.URL),
+	}, true)
+
+	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPageURLWithPrefix, checkLoginPasswordPage)
+
+	v := make(url.Values)
+	v.Set("username", "testuser")
+	v.Set("password", "password")
+	doPostForm(t, c, "https://login.example.com/sso/login", v, checkStatusOk, checkIsProtectedService)
+
+	// Now attempt to logout, and verify that we can't access the service anymore.
+	doGet(t, c, "https://login.example.com/sso/logout", checkStatusOk, checkLogoutPageHasLinks)
+	doGet(t, c, "https://service.example.com/sso_logout", checkStatusOk)
+	doGet(t, c, "https://service.example.com/", checkStatusOk, checkLoginPageURLWithPrefix, checkLoginPasswordPage)
 }
-- 
GitLab