package acmeserver

import (
	"bytes"
	"context"
	"crypto"
	"crypto/ecdsa"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"strings"
	"time"

	"git.autistici.org/ai3/replds"
)

type dirStorage struct {
	root string
}

func (d *dirStorage) GetCert(cn string) ([][]byte, crypto.Signer, error) {
	certPath := filepath.Join(d.root, cn, "fullchain.pem")
	keyPath := filepath.Join(d.root, cn, "privkey.pem")

	der, err := parseCertsFromFile(certPath)
	if err != nil {
		return nil, nil, err
	}
	priv, err := parsePrivateKeyFromFile(keyPath)
	if err != nil {
		return nil, nil, err
	}
	return der, priv, nil
}

func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
	filemap, err := dumpCertsAndKey(cn, der, key)
	if err != nil {
		return err
	}

	dir := filepath.Join(d.root, cn)
	if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) {
		if err = os.MkdirAll(dir, 0750); err != nil {
			return err
		}
	}

	for path, data := range filemap {
		var mode os.FileMode = 0644
		if strings.HasSuffix(path, "privkey.pem") {
			mode = 0400
		}
		log.Printf("writing %s (%03o)", path, mode)
		if err := ioutil.WriteFile(filepath.Join(d.root, path), data, mode); err != nil {
			return err
		}
	}
	return nil
}

func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]byte, error) {
	m := make(map[string][]byte)

	data, err := encodeCerts(der)
	if err != nil {
		return nil, err
	}
	m[filepath.Join(cn, "fullchain.pem")] = data

	data, err = encodeCerts(der[:1])
	if err != nil {
		return nil, err
	}
	m[filepath.Join(cn, "cert.pem")] = data

	data, err = encodePrivateKey(key)
	if err != nil {
		return nil, err
	}
	m[filepath.Join(cn, "privkey.pem")] = data

	return m, nil
}

// The replStorage overrides the PutCert method and writes the
// certificates to replds instead.
type replStorage struct {
	*dirStorage
	replClient replds.PublicClient
}

func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
	filemap, err := dumpCertsAndKey(cn, der, key)
	if err != nil {
		return err
	}

	now := time.Now()
	var req replds.SetNodesRequest
	for path, data := range filemap {
		req.Nodes = append(req.Nodes, replds.Node{
			Path:      path,
			Value:     data,
			Timestamp: now,
		})
	}

	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
	defer cancel()

	resp, err := d.replClient.SetNodes(ctx, &req)
	if err != nil {
		return err
	}
	if resp.HostsOk < 1 {
		return errors.New("not enough successful replds writes")
	}
	return nil
}

func parseCertsFromFile(path string) ([][]byte, error) {
	data, err := ioutil.ReadFile(path) // nolint: gosec
	if err != nil {
		return nil, err
	}

	var der [][]byte
	for {
		block, rest := pem.Decode(data)
		if block == nil {
			break
		}
		der = append(der, block.Bytes)
		data = rest
	}
	return der, nil
}

func parsePrivateKeyFromFile(path string) (crypto.Signer, error) {
	data, err := ioutil.ReadFile(path) // nolint: gosec
	if err != nil {
		return nil, err
	}

	priv, _ := pem.Decode(data)
	if priv == nil || !strings.Contains(priv.Type, "PRIVATE") {
		return nil, errors.New("invalid account key")
	}
	return parsePrivateKey(priv.Bytes)
}

func encodeCerts(der [][]byte) ([]byte, error) {
	var buf bytes.Buffer
	for _, b := range der {
		pb := &pem.Block{Type: "CERTIFICATE", Bytes: b}
		if err := pem.Encode(&buf, pb); err != nil {
			return nil, err
		}
	}
	return buf.Bytes(), nil
}

func encodePrivateKey(key crypto.Signer) ([]byte, error) {
	var pb *pem.Block
	switch priv := key.(type) {
	case *rsa.PrivateKey:
		pb = &pem.Block{
			Type:  "RSA PRIVATE KEY",
			Bytes: x509.MarshalPKCS1PrivateKey(priv),
		}
	case *ecdsa.PrivateKey:
		b, err := x509.MarshalECPrivateKey(priv)
		if err != nil {
			return nil, err
		}
		pb = &pem.Block{
			Type:  "EC PRIVATE KEY",
			Bytes: b,
		}
	default:
		return nil, errors.New("unknown private key type")
	}
	var buf bytes.Buffer
	if err := pem.Encode(&buf, pb); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}