diff --git a/cmd/saml-server/main.go b/cmd/saml-server/main.go index bddc3355baa9c2af40b6f0bcdb47189e0ddf0e5f..f9d5e1d0c8873fbb4bfcb84ef979e5a90f8fbbf7 100644 --- a/cmd/saml-server/main.go +++ b/cmd/saml-server/main.go @@ -1,20 +1,14 @@ package main import ( - "context" "flag" "io/ioutil" "log" - "net/http" - "os" - "os/signal" - "strings" - "syscall" - "time" - - "git.autistici.org/id/go-sso/saml" + "git.autistici.org/ai3/go-common/serverutil" "gopkg.in/yaml.v2" + + "git.autistici.org/id/go-sso/saml" ) var ( @@ -22,32 +16,28 @@ var ( configFile = flag.String("config", "/etc/sso/saml.yml", "`path` of config file") ) -func loadConfig() (*saml.Config, error) { +// Config wraps together the standard HTTP server config and the SAML +// service configuration. +type Config struct { + SAMLConfig *saml.Config `yaml:"saml"` + ServerConfig *serverutil.ServerConfig `yaml:"http_server"` +} + +func loadConfig() (*Config, error) { // Read YAML config. data, err := ioutil.ReadFile(*configFile) if err != nil { return nil, err } - var config saml.Config + var config Config if err := yaml.Unmarshal(data, &config); err != nil { return nil, err } return &config, nil } -// Set defaults for command-line flags using variables from the environment. -func setFlagDefaultsFromEnv() { - flag.VisitAll(func(f *flag.Flag) { - envVar := "SAML_" + strings.ToUpper(strings.Replace(f.Name, "-", "_", -1)) - if value := os.Getenv(envVar); value != "" { - f.DefValue = value - f.Value.Set(value) - } - }) -} - func main() { - setFlagDefaultsFromEnv() + log.SetFlags(0) flag.Parse() config, err := loadConfig() @@ -55,27 +45,12 @@ func main() { log.Fatal(err) } - s, err := saml.NewSAMLIDP(config) + s, err := saml.NewSAMLIDP(config.SAMLConfig) if err != nil { log.Fatal(err) } - srv := &http.Server{ - Addr: *addr, - Handler: s, - } - - sigCh := make(chan os.Signal, 1) - go func() { - <-sigCh - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = srv.Shutdown(ctx) - _ = srv.Close() - }() - signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) - - if err := srv.ListenAndServe(); err != nil { + if err := serverutil.Serve(s, config.ServerConfig, *addr); err != nil { log.Fatal(err) } } diff --git a/cmd/sso-proxy/main.go b/cmd/sso-proxy/main.go index 57b27a72b4c4769c54ac9d160d49dbe8814a57c1..a6b2d97933de2c62c7c5ceedd31c28e2ce522ab1 100644 --- a/cmd/sso-proxy/main.go +++ b/cmd/sso-proxy/main.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -35,19 +34,8 @@ func loadConfig() (*proxy.Config, error) { return &config, nil } -// Set defaults for command-line flags using variables from the environment. -func setFlagDefaultsFromEnv() { - flag.VisitAll(func(f *flag.Flag) { - envVar := "SSOPROXY_" + strings.ToUpper(strings.Replace(f.Name, "-", "_", -1)) - if value := os.Getenv(envVar); value != "" { - f.DefValue = value - f.Value.Set(value) - } - }) -} - func main() { - setFlagDefaultsFromEnv() + log.SetFlags(0) flag.Parse() config, err := loadConfig() diff --git a/cmd/sso-server/main.go b/cmd/sso-server/main.go index a9d245986f203a22bb78ce7bcc1d13cfa140d3a4..e53a5bab3b22fbe1fbf9f5bbf48aa0854a10cacb 100644 --- a/cmd/sso-server/main.go +++ b/cmd/sso-server/main.go @@ -1,74 +1,67 @@ package main import ( - "context" "flag" + "io/ioutil" "log" - "net/http" - "os" - "os/signal" - "syscall" - "time" + "git.autistici.org/ai3/go-common/serverutil" "git.autistici.org/id/auth/client" + "gopkg.in/yaml.v2" + "git.autistici.org/id/go-sso/server" ) var ( addr = flag.String("addr", ":4141", "tcp `address` to listen on") - configPath = flag.String("config", "/etc/sso/server.yml", "configuration `file`") + configFile = flag.String("config", "/etc/sso/server.yml", "configuration `file`") authSocket = flag.String("auth-socket", client.DefaultSocketPath, "authentication socket `path`") ) +// Config wraps together the sso-server configuration and the standard +// HTTP server config. +type Config struct { + *server.Config + ServerConfig *serverutil.ServerConfig `yaml:"http_server"` +} + +func loadConfig() (*Config, error) { + // Read YAML config. + data, err := ioutil.ReadFile(*configFile) + if err != nil { + return nil, err + } + var config Config + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, err + } + return &config, nil +} + func main() { log.SetFlags(0) flag.Parse() - config, err := server.LoadConfig(*configPath) + config, err := loadConfig() if err != nil { log.Fatal(err) } + if err = config.Config.Compile(); err != nil { + log.Fatal(err) + } - loginService, err := server.NewLoginService(config) + loginService, err := server.NewLoginService(config.Config) if err != nil { log.Fatal(err) } authClient := client.New(*authSocket) - httpSrv, err := server.New(loginService, authClient, config) + httpSrv, err := server.New(loginService, authClient, config.Config) if err != nil { log.Fatal(err) } - srv := &http.Server{ - Addr: *addr, - Handler: httpSrv.Handler(), - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - } - - done := make(chan struct{}) - sigCh := make(chan os.Signal, 1) - go func() { - <-sigCh - log.Printf("exiting") - - // Gracefully terminate for 3 seconds max, then shut - // down remaining clients. - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - if err := srv.Shutdown(ctx); err == context.Canceled { - srv.Close() - } - - close(done) - }() - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - if err = srv.ListenAndServe(); err != http.ErrServerClosed { - log.Fatal("error: %v", err) + if err := serverutil.Serve(httpSrv.Handler(), config.ServerConfig, *addr); err != nil { + log.Fatal(err) } - - <-done } diff --git a/server/bindata.go b/server/bindata.go index ec1216ad3998d5b79363328103dd0eccc56fc679..0dbb7893d137860c1ae563e8a7626ecc1c91b23d 100644 --- a/server/bindata.go +++ b/server/bindata.go @@ -73,7 +73,7 @@ func staticCssBootstrapMinCss() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/css/bootstrap.min.css", size: 124962, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/css/bootstrap.min.css", size: 124962, mode: os.FileMode(420), modTime: time.Unix(1509120975, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -132,7 +132,7 @@ func staticCssSigninCss() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/css/signin.css", size: 802, mode: os.FileMode(436), modTime: time.Unix(1511081405, 0)} + info := bindataFileInfo{name: "static/css/signin.css", size: 802, mode: os.FileMode(420), modTime: time.Unix(1511166680, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -154,7 +154,7 @@ func staticJsBootstrap400BetaMinJs() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/js/bootstrap-4.0.0-beta.min.js", size: 51143, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/js/bootstrap-4.0.0-beta.min.js", size: 51143, mode: os.FileMode(420), modTime: time.Unix(1509120962, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -175,7 +175,7 @@ func staticJsJquery321MinJs() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/js/jquery-3.2.1.min.js", size: 86659, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/js/jquery-3.2.1.min.js", size: 86659, mode: os.FileMode(420), modTime: time.Unix(1509120962, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -197,7 +197,7 @@ func staticJsPopper1110MinJs() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/js/popper-1.11.0.min.js", size: 19033, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/js/popper-1.11.0.min.js", size: 19033, mode: os.FileMode(420), modTime: time.Unix(1509120962, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -962,7 +962,7 @@ func staticJsU2fApiJs() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/js/u2f-api.js", size: 20880, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/js/u2f-api.js", size: 20880, mode: os.FileMode(420), modTime: time.Unix(1509120962, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1031,7 +1031,7 @@ func staticJsU2fJs() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "static/js/u2f.js", size: 1281, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "static/js/u2f.js", size: 1281, mode: os.FileMode(420), modTime: time.Unix(1509260310, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1069,7 +1069,7 @@ func templatesLogin_otpHtml() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "templates/login_otp.html", size: 529, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "templates/login_otp.html", size: 529, mode: os.FileMode(420), modTime: time.Unix(1509218738, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1129,7 +1129,7 @@ func templatesLogin_passwordHtml() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "templates/login_password.html", size: 1074, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "templates/login_password.html", size: 1074, mode: os.FileMode(420), modTime: time.Unix(1509218731, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1169,7 +1169,7 @@ func templatesLogin_u2fHtml() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "templates/login_u2f.html", size: 498, mode: os.FileMode(436), modTime: time.Unix(1510996183, 0)} + info := bindataFileInfo{name: "templates/login_u2f.html", size: 498, mode: os.FileMode(420), modTime: time.Unix(1509260387, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1227,7 +1227,7 @@ func templatesLogoutHtml() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "templates/logout.html", size: 820, mode: os.FileMode(436), modTime: time.Unix(1511083629, 0)} + info := bindataFileInfo{name: "templates/logout.html", size: 820, mode: os.FileMode(420), modTime: time.Unix(1511166680, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -1272,7 +1272,7 @@ func templatesPageHtml() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "templates/page.html", size: 1493, mode: os.FileMode(436), modTime: time.Unix(1511203590, 0)} + info := bindataFileInfo{name: "templates/page.html", size: 1493, mode: os.FileMode(420), modTime: time.Unix(1511337830, 0)} a := &asset{bytes: bytes, info: info} return a, nil } diff --git a/server/config.go b/server/config.go index d61ec3ade832611009113cb2c0dc52f4c684799a..258ad00b283bc3ca0a24d0b74744ac5373361ae3 100644 --- a/server/config.go +++ b/server/config.go @@ -2,14 +2,12 @@ package server import ( "errors" - "io/ioutil" "log" "regexp" "time" "git.autistici.org/id/go-sso/server/device" "github.com/gorilla/securecookie" - "gopkg.in/yaml.v2" ) // Config data for the SSO service. @@ -38,25 +36,6 @@ type Config struct { allowedServicesRx []*regexp.Regexp } -// LoadConfig reads configuration from a file. -func LoadConfig(path string) (*Config, error) { - data, err := ioutil.ReadFile(path) - if err != nil { - return nil, err - } - var config Config - if err := yaml.Unmarshal(data, &config); err != nil { - return nil, err - } - if err := config.valid(); err != nil { - return nil, err - } - if err := config.compile(); err != nil { - return nil, err - } - return &config, nil -} - // Check syntax (missing required values). func (c *Config) valid() error { if c.SecretKeyFile == "" { @@ -88,9 +67,13 @@ func (c *Config) valid() error { return nil } -// Compile the configuration (regular expressions etc). -func (c *Config) compile() error { - var err error +// Compile the configuration (parse regular expressions, etc). +func (c *Config) Compile() error { + err := c.valid() + if err != nil { + return err + } + for _, svcttl := range c.ServiceTTLs { svcttl.rx, err = regexp.Compile(svcttl.Regexp) if err != nil { diff --git a/server/service_test.go b/server/service_test.go index a5087b4183067edd8bb542b2b9946c5cea14ad31..e7e1ff7fe2ddb18f4aef4e22e3703f90fbec5a70 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -8,8 +8,22 @@ import ( "testing" "golang.org/x/crypto/ed25519" + "gopkg.in/yaml.v2" ) +func loadConfig(path string) (*Config, error) { + // Read YAML config. + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + var config Config + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, err + } + return &config, nil +} + func testConfig(t testing.TB, tmpdir string) *Config { pub, priv, err := ed25519.GenerateKey(nil) if err != nil { @@ -29,10 +43,13 @@ service_ttls: auth_service: login `, filepath.Join(tmpdir, "secret"), filepath.Join(tmpdir, "public"))), 0600) - config, err := LoadConfig(filepath.Join(tmpdir, "config")) + config, err := loadConfig(filepath.Join(tmpdir, "config")) if err != nil { t.Fatal("LoadConfig():", err) } + if err := config.Compile(); err != nil { + t.Fatal("Compile():", err) + } return config }