Skip to content
Snippets Groups Projects
Commit 29862d1b authored by ale's avatar ale
Browse files

Fix CORS headers on the /sso_logout endpoint

Compute the origin properly.
parent 55d6fd60
Branches
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/gob" "encoding/gob"
"encoding/hex" "encoding/hex"
"fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
...@@ -36,6 +37,7 @@ type SSOWrapper struct { ...@@ -36,6 +37,7 @@ type SSOWrapper struct {
sessionAuthKey []byte sessionAuthKey []byte
sessionEncKey []byte sessionEncKey []byte
serverURL string serverURL string
serverOrigin string
} }
// NewSSOWrapper returns a new SSOWrapper that will authenticate users // NewSSOWrapper returns a new SSOWrapper that will authenticate users
...@@ -49,6 +51,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, ...@@ -49,6 +51,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
return &SSOWrapper{ return &SSOWrapper{
v: v, v: v,
serverURL: serverURL, serverURL: serverURL,
serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey, sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey, sessionEncKey: sessionEncKey,
}, nil }, nil
...@@ -130,8 +133,10 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess ...@@ -130,8 +133,10 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
} }
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Access-Control-Allow-Origin", strings.TrimRight(s.serverURL, "/")) if s.serverOrigin != "" {
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Origin", s.serverOrigin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
io.WriteString(w, "OK") io.WriteString(w, "OK")
} }
...@@ -181,3 +186,12 @@ func makeUniqueNonce() string { ...@@ -181,3 +186,12 @@ func makeUniqueNonce() string {
} }
return hex.EncodeToString(b[:]) return hex.EncodeToString(b[:])
} }
// Return the origin from a URL (stripping path and other components).
func originFromURL(s string) string {
parsed, err := url.Parse(s)
if err != nil {
return ""
}
return fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment