package userenckey

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"encoding/pem"
	"errors"
)

// ErrBadPassword is returned on decryption failure.
var ErrBadPassword = errors.New("could not decrypt key with password")

func encodePublicKeyToPEM(pub *ecdsa.PublicKey) ([]byte, error) {
	der, err := x509.MarshalPKIXPublicKey(pub)
	if err != nil {
		return nil, err
	}
	return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der}), nil
}

// Key (unencrypted).
type Key struct {
	rawBytes []byte
}

// GenerateKey generates a new ECDSA key pair, and returns the
// PEM-encoded public and private key (in order).
func GenerateKey() ([]byte, *Key, error) {
	priv, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
	if err != nil {
		return nil, nil, err
	}

	//privBytes, err := encodePrivateKeyToPEM(priv)
	privBytes, err := x509.MarshalECPrivateKey(priv)
	if err != nil {
		return nil, nil, err
	}
	pubBytes, err := encodePublicKeyToPEM(&priv.PublicKey)
	if err != nil {
		return nil, nil, err
	}

	return pubBytes, &Key{privBytes}, nil
}

// PEM returns the key in PEM-encoded format.
func (k *Key) PEM() ([]byte, error) {
	// Parse the ASN.1 data and encode it with PKCS8 (in PEM format).
	priv, err := k.PrivateKey()
	if err != nil {
		return nil, err
	}

	return encodePrivateKeyToPEM(priv)
}

// PrivateKey parses the DER-encoded ASN.1 data in Key and returns the
// private key object.
func (k *Key) PrivateKey() (*ecdsa.PrivateKey, error) {
	return x509.ParseECPrivateKey(k.rawBytes)
}

// Encrypt a key with a password and a random salt.
func Encrypt(key *Key, pw []byte) ([]byte, error) {
	c, err := newContainer(key.rawBytes, pw)
	if err != nil {
		return nil, err
	}
	return c.Marshal()
}

// Decrypt one out of multiple keys with the specified password. The
// keys share the same cleartext, but have been encrypted with
// different passwords.
func Decrypt(encKeys [][]byte, pw []byte) (*Key, error) {
	for _, encKey := range encKeys {
		c, err := unmarshalContainer(encKey)
		if err != nil {
			//log.Printf("parse error: %v", err)
			continue
		}
		if dec, err := c.decrypt(pw); err == nil {
			return &Key{dec}, nil
		}
	}
	return nil, ErrBadPassword
}