package server

import (
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"testing"

	"golang.org/x/crypto/ed25519"
	"gopkg.in/yaml.v2"
)

func loadConfig(path string) (*Config, error) {
	// Read YAML config.
	data, err := ioutil.ReadFile(path)
	if err != nil {
		return nil, err
	}
	var config Config
	if err := yaml.Unmarshal(data, &config); err != nil {
		return nil, err
	}
	return &config, nil
}

func testConfig(t testing.TB, tmpdir, keystoreURL string) *Config {
	pub, priv, err := ed25519.GenerateKey(nil)
	if err != nil {
		t.Fatal(err)
	}
	ioutil.WriteFile(filepath.Join(tmpdir, "secret"), priv, 0600) // nolint
	ioutil.WriteFile(filepath.Join(tmpdir, "public"), pub, 0600)  // nolint

	cfgstr := fmt.Sprintf(`---
secret_key_file: %s
public_key_file: %s
domain: example.com
allowed_services:
  - "^service\\.example\\.com/$"
allowed_cors_origins:
  - "https://origin.example.com"
service_ttls:
  - regexp: ".*"
    ttl: 60
auth_service: login
`, filepath.Join(tmpdir, "secret"), filepath.Join(tmpdir, "public"))
	if keystoreURL != "" {
		cfgstr += fmt.Sprintf(`
keystore:
  url: "%s"
`, keystoreURL)
	}
	ioutil.WriteFile(filepath.Join(tmpdir, "config"), []byte(cfgstr), 0600) // nolint

	config, err := loadConfig(filepath.Join(tmpdir, "config"))
	if err != nil {
		t.Fatal("LoadConfig():", err)
	}
	if err := config.Compile(); err != nil {
		t.Fatal("Compile():", err)
	}
	return config
}

func TestLoginService_Ok(t *testing.T) {
	tmpdir, _ := ioutil.TempDir("", "")
	defer os.RemoveAll(tmpdir)

	config := testConfig(t, tmpdir, "")
	svc, err := NewLoginService(config)
	if err != nil {
		t.Fatal("NewLoginService():", err)
	}

	token, err := svc.Authorize("user", "service.example.com/", "https://service.example.com/", "nonce", []string{"group1"})
	if err != nil {
		t.Fatal("Authorize():", err)
	}

	v, err := newValidatorFromConfig(config)
	if err != nil {
		t.Fatal("newValidatorFromConfig():", err)
	}
	ticket, err := v.Validate(token, "nonce", "service.example.com/", []string{"group1"})
	if err != nil {
		t.Fatal("Validate():", err)
	}
	if ticket.User != "user" {
		t.Fatalf("bad Username: got=%s, expected=user", ticket.User)
	}
}

func TestLoginService_SanityChecks(t *testing.T) {
	tmpdir, _ := ioutil.TempDir("", "")
	defer os.RemoveAll(tmpdir)

	config := testConfig(t, tmpdir, "")
	svc, err := NewLoginService(config)
	if err != nil {
		t.Fatal("NewLoginService():", err)
	}

	var testdata = []struct {
		service, destination string
		ok                   bool
	}{
		{"service.example.com/", "https://service.example.com/", true},
		{"service.example.com", "https://service.example.com/", false},
		{"service.example.com/", "https://service.example.com/foo", true},
		{"service.example.com/", "http://service.example.com/", false},
		{"foo.example.com/", "https://foo.example.com/", false},
		{"service.example.com/", "https://foo.example.com/", false},
	}

	for _, td := range testdata {
		_, err := svc.Authorize("user", td.service, td.destination, "nonce", nil)
		if (err == nil) != td.ok {
			t.Errorf("Authorize error: s=%s d=%s expected=%v got=%v", td.service, td.destination, td.ok, err)
		}
	}
}