diff --git a/server/http.go b/server/http.go index 6161cb29d5c9db4f6835b42502cf96bcfef5e64e..aef46426b637c57db788c06b4b556064a5b838db 100644 --- a/server/http.go +++ b/server/http.go @@ -442,13 +442,20 @@ func (h *Server) Handler() http.Handler { idph = csrf.Protect(h.csrfSecret)(idph) } + // Add CORS headers on the main SSO API endpoint. + c := cors.New(cors.Options{ + AllowedOrigins: h.allowedOrigins, + AllowCredentials: true, + MaxAge: 86400, + }) + // Add the SSO provider endpoints (root path and /exchange), // which do not need CSRF. We use a HandlerFunc to bypass the // '/' dispatch semantics of the standard http.ServeMux. - ssoh := h.withAuth(h.handleHomepage, h.redirectToLogin) + ssoh := c.Handler(h.withAuth(h.handleHomepage, h.redirectToLogin)) userh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == "GET" && r.URL.Path == h.urlFor("/"): + case r.URL.Path == h.urlFor("/"): ssoh.ServeHTTP(w, r) case r.URL.Path == h.urlFor("/exchange"): h.handleExchange(w, r) @@ -457,13 +464,6 @@ func (h *Server) Handler() http.Handler { } }) - // Add CORS headers around user-facing routes. - c := cors.New(cors.Options{ - AllowedOrigins: h.allowedOrigins, - AllowCredentials: true, - MaxAge: 86400, - }) - // User-facing routes require cache-busting and CSP headers. root.PathPrefix(h.urlFor("/")).Handler(withDynamicHeaders(c.Handler(userh))) diff --git a/server/http_test.go b/server/http_test.go index fa711be8dccb80b968a7b2fa1a7895541d097b56..8f2fd0b43df403add756f86771776c8181002219 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -331,3 +331,48 @@ func TestHTTP_LoginWithKeyStore(t *testing.T) { v.Set("password", "password") doPostForm(t, httpSrv, c, "/login", v, checkRedirectToTargetService) } + +func TestHTTP_CORS(t *testing.T) { + tmpdir, httpSrv := startTestHTTPServer(t) + defer os.RemoveAll(tmpdir) + defer httpSrv.Close() + + c := newTestHTTPClient() + + // To test a CORS preflight request we have to login first. + // Simulate an authorization request from a service, expect to + // see the login page. + v := make(url.Values) + v.Set("s", "service.example.com/") + v.Set("d", "https://service.example.com/admin/") + v.Set("n", "averysecretnonce") + doGet(t, httpSrv, c, "/?"+v.Encode(), checkStatusOk, checkLoginPasswordPage) + + // Attempt to login by submitting the form. We expect the + // result to be a 302 redirect to the target service. + v = make(url.Values) + v.Set("username", "testuser") + v.Set("password", "password") + doPostForm(t, httpSrv, c, "/login", v, checkRedirectToTargetService) + + // Simulate a CORS preflight request. + v = make(url.Values) + v.Set("s", "service.example.com/") + v.Set("d", "https://service.example.com/admin/") + v.Set("n", "averysecretnonce") + req, err := http.NewRequest("OPTIONS", httpSrv.URL+"/?"+v.Encode(), nil) + if err != nil { + t.Fatalf("NewRequest(): %v", err) + } + req.Header.Set("Origin", "https://origin.example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + resp, err := c.Do(req) + if err != nil { + t.Fatalf("http request error: %v", err) + } + defer resp.Body.Close() + checkStatusOk(t, resp) + if s := resp.Header.Get("Access-Control-Allow-Origin"); s != "https://origin.example.com" { + t.Fatalf("Bad Access-Control-Allow-Origin returned to OPTIONS request: %s", s) + } +} diff --git a/server/service_test.go b/server/service_test.go index b61f697036c3eda4827f8793579283bcda207d7a..4d9d6a12f8442137cfeebebc931f1a9d394cdc1c 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -38,6 +38,8 @@ public_key_file: %s domain: example.com allowed_services: - "^service\\.example\\.com/$" +allowed_cors_origins: + - "https://origin.example.com" service_ttls: - regexp: ".*" ttl: 60