package main

import (
	"flag"
	"fmt"
	"log"
	"math/rand"
	"time"

	"git.autistici.org/ai3/go-common/pwhash"
)

var (
	algo          = flag.String("algo", "argon2", "password hashing algorithm to use")
	argon2Time    = flag.Int("time", 3, "argon2 `time` parameter")
	argon2Mem     = flag.Int("mem", 32, "argon2 `memory` parameter (Mb)")
	argon2Threads = flag.Int("threads", 4, "argon2 `threads` parameter")
	scryptN       = flag.Int("n", 16384, "scrypt `n` parameter")
	scryptR       = flag.Int("r", 8, "scrypt `r` parameter")
	scryptP       = flag.Int("p", 1, "scrypt `p` parameter")
	doBench       = flag.Bool("bench", false, "run a benchmark")
	doCompare     = flag.Bool("compare", false, "compare password against hash")
)

var randSrc = rand.New(rand.NewSource(time.Now().Unix()))

func fillRandomBytes(b []byte, n int) []byte {
	for i := 0; i < n; i += 8 {
		r := randSrc.Uint64()
		b[i] = byte(r & 0xff)
		b[i+1] = byte((r >> 8) & 0xff)
		b[i+2] = byte((r >> 16) & 0xff)
		b[i+3] = byte((r >> 24) & 0xff)
		b[i+4] = byte((r >> 32) & 0xff)
		b[i+5] = byte((r >> 40) & 0xff)
		b[i+6] = byte((r >> 48) & 0xff)
		b[i+7] = byte((r >> 56) & 0xff)
	}
	return b[:n]
}

var pwbuf = make([]byte, 128)

func randomPass() string {
	pwlen := 10 + rand.Intn(20)
	return string(fillRandomBytes(pwbuf, pwlen))
}

const (
	// Run at least these many iterations, then keep going until
	// we reach the timeout.
	benchChunkSize = 100

	// How long to run benchmarks for (more or less).
	benchTimeout = 5 * time.Second
)

func runBenchChunk(enc string) int {
	pw := randomPass()
	for i := 0; i < benchChunkSize; i++ {
		pwhash.ComparePassword(enc, pw)
	}
	return benchChunkSize
}

func runBench(h pwhash.PasswordHash, hname string) {
	start := time.Now()
	deadline := start.Add(benchTimeout)

	enc := h.Encrypt(randomPass())

	var n int
	for time.Now().Before(deadline) {
		n += runBenchChunk(enc)
	}

	elapsed := time.Since(start)
	opsPerSec := float64(n) / elapsed.Seconds()
	msPerOp := (elapsed.Seconds() * 1000) / float64(n)

	log.Printf("%s: %.4g ops/sec, %.4g ms/op", hname, opsPerSec, msPerOp)
}

func mkhash() (pwhash.PasswordHash, string, error) {
	var h pwhash.PasswordHash
	name := *algo
	switch *algo {
	case "argon2":
		h = pwhash.NewArgon2StdWithParams(uint32(*argon2Time), uint32(*argon2Mem*1024), uint8(*argon2Threads))
		name = fmt.Sprintf("%s(%d/%d/%d)", *algo, *argon2Time, *argon2Mem, *argon2Threads)
	case "scrypt":
		h = pwhash.NewScryptWithParams(*scryptN, *scryptR, *scryptP)
		name = fmt.Sprintf("%s(%d/%d/%d)", *algo, *scryptN, *scryptR, *scryptP)
	case "system":
		h = pwhash.NewSystemCrypt()
	default:
		return nil, "", fmt.Errorf("unknown algo %q", *algo)
	}
	return h, name, nil
}

func main() {
	log.SetFlags(0)
	flag.Parse()

	h, hname, err := mkhash()
	if err != nil {
		log.Fatal(err)
	}

	switch {
	case *doBench:
		runBench(h, hname)
	case *doCompare:
		if flag.NArg() < 2 {
			log.Fatal("not enough arguments")
		}
		if ok := h.ComparePassword(flag.Arg(0), flag.Arg(1)); !ok {
			log.Fatal("password does not match")
		}
		log.Printf("password ok")
	default:
		if flag.NArg() < 1 {
			log.Fatal("not enough arguments")
		}
		fmt.Printf("%s\n", h.Encrypt(flag.Arg(0)))
	}
}