package mem

import (
	"encoding/binary"
	"fmt"
	"io"
	"log"
	"os"
	"path/filepath"
	"sort"
	"strconv"
	"strings"
	"sync"
	"time"

	pb "git.autistici.org/ai3/tools/replds2/proto"
	"github.com/DataDog/zstd"
	"google.golang.org/protobuf/proto"
)

var (
	maxLogSize int64 = 50 * 1024 * 1024

	flushInterval = 30 * time.Second
)

func logWrite(w io.Writer, node *pb.Node, dirty bool) (int, error) {
	data, err := proto.Marshal(node)
	if err != nil {
		return 0, err
	}

	// Encode 4-byte size.
	var sz [4]byte
	binary.LittleEndian.PutUint32(sz[:], uint32(len(data)))
	if _, err := w.Write(sz[:]); err != nil {
		return 0, err
	}
	// Encode dirty bit (byte).
	var b [1]byte
	if dirty {
		b[0] = 1
	}
	if _, err := w.Write(b[:]); err != nil {
		return 0, err
	}
	// Write data.
	n, err := w.Write(data)
	return n + 5, err
}

// Read an entry from the log, reusing the given buffer for decoding.
// Returns the node, dirty bit, new buffer, error.
func logRead(r io.Reader, buf []byte) (*pb.Node, bool, []byte, error) {
	// Read carefully the size.
	var szbuf [4]byte
	n, err := r.Read(szbuf[:])
	if err != nil {
		return nil, false, buf, err
	}
	if n != 4 {
		return nil, false, buf, fmt.Errorf("short read (%d/%d)", n, 4)
	}
	sz := int(binary.LittleEndian.Uint32(szbuf[:]))

	// Read the dirty bit.
	var db [1]byte
	_, err = r.Read(db[:])
	if err != nil {
		return nil, false, buf, err
	}
	var dirty bool
	if db[0] == 1 {
		dirty = true
	}

	// Read the data, reallocate the decoding buffer if it
	// isn't large enough.
	if len(buf) < sz {
		buf = make([]byte, sz)
	}
	n, err = r.Read(buf[:sz])
	if err != nil {
		return nil, false, buf, err
	}
	if n != sz {
		return nil, false, buf, fmt.Errorf("short read (%d/%d)", n, sz)
	}

	var node pb.Node
	if err := proto.Unmarshal(buf[:sz], &node); err != nil {
		return nil, false, buf, err
	}

	return &node, dirty, buf, nil
}

func listLogfiles(path string) []string {
	dir, err := os.Open(path)
	if err != nil {
		return nil
	}
	defer dir.Close()
	files, err := dir.Readdir(0)
	if err != nil {
		return nil
	}
	var lf []string
	for _, f := range files {
		if f.Mode().IsRegular() && strings.HasPrefix(f.Name(), "log.") {
			lf = append(lf, f.Name())
		}
	}
	sort.Strings(lf)
	return lf
}

func logNameFromIndex(idx int) string {
	return fmt.Sprintf("log.%06d", idx)
}

func indexFromLogName(name string) int {
	idx, _ := strconv.Atoi(name[4:])
	return idx
}

func processLog(path string, fn func(*pb.Node) error) (bool, error) {
	f, err := os.Open(path)
	if err != nil {
		return false, err
	}
	defer f.Close()

	ff := zstd.NewReader(f)
	defer ff.Close()

	var dirty bool
	var buf []byte
	for {
		node, nodeDirty, newBuf, err := logRead(ff, buf)
		buf = newBuf

		if err == io.EOF {
			return dirty, nil
		}
		if err != nil {
			return false, err
		}

		if nodeDirty {
			dirty = true
		}

		if err := fn(node); err != nil {
			return false, err
		}
	}
}

func checkpoint(w io.Writer, dumpFn func(func(*pb.Node) error) error) (int64, error) {
	var sz int64
	err := dumpFn(func(node *pb.Node) error {
		n, err := logWrite(w, node, false)
		if err != nil {
			return err
		}
		sz += int64(n)
		return nil
	})
	return sz, err
}

type appendLog struct {
	path    string
	maxSize int64

	curIdx int
	curF   *os.File
	curW   *zstd.Writer

	stopCh      chan bool
	mx          sync.Mutex
	curSize     int64
	initialSize int64
}

func openLog(path string, setFn func(*pb.Node) error, dumpFn func(func(*pb.Node) error) error) (*appendLog, error) {
	if err := os.MkdirAll(path, 0700); err != nil {
		return nil, err
	}

	// Get the list of log files, and process each one of them.
	var curIdx int
	atCheckpoint := true
	for _, logf := range listLogfiles(path) {
		curIdx = indexFromLogName(logf)
		dirty, err := processLog(filepath.Join(path, logf), setFn)
		if err != nil {
			return nil, fmt.Errorf("in %s: %w", logf, err)
		}
		log.Printf("loaded %s (%v)", logf, dirty)
		if dirty {
			atCheckpoint = false
		}
	}

	l := &appendLog{
		path:    path,
		maxSize: maxLogSize,
		curIdx:  curIdx,
		stopCh:  make(chan bool),
	}

	// Open a new log and checkpoint.
	if err := l.Rotate(!atCheckpoint, dumpFn); err != nil {
		return nil, err
	}

	// Start periodic flusher.
	go l.flusher()

	return l, nil
}

// Expects caller to hold the mutex.
func (l *appendLog) rotateLog() error {
	idx := l.curIdx + 1
	logf := logNameFromIndex(idx)
	log.Printf("opening log %s", logf)
	f, err := os.Create(filepath.Join(l.path, logf))
	if err != nil {
		return err
	}
	zf := zstd.NewWriterLevel(f, zstd.BestSpeed)

	if l.curW != nil {
		if err := l.curW.Close(); err != nil {
			stat, _ := l.curF.Stat()
			log.Printf("error flushing log %s: %v", stat.Name(), err)
		}
	}
	if l.curF != nil {
		l.curF.Close()
	}

	l.curW = zf
	l.curF = f
	l.curIdx = idx
	l.curSize = 0
	l.initialSize = 0
	return nil
}

func (l *appendLog) Write(node *pb.Node) error {
	l.mx.Lock()
	defer l.mx.Unlock()

	n, err := logWrite(l.curW, node, true)
	if err != nil {
		return err
	}

	l.curSize += int64(n)
	if l.curSize > l.maxSize {
		return l.rotateLog()
	}
	return nil
}

func (l *appendLog) Rotate(doCheckpoint bool, dumpFn func(func(*pb.Node) error) error) error {
	l.mx.Lock()
	defer l.mx.Unlock()

	// Make a list of old log files to remove.
	var oldLogs []string
	if doCheckpoint {
		oldLogs = listLogfiles(l.path)
	}

	if err := l.rotateLog(); err != nil {
		return err
	}

	if doCheckpoint {
		log.Printf("starting checkpoint")
		n, err := checkpoint(l.curW, dumpFn)
		if err != nil {
			return err
		}
		log.Printf("checkpointed %d bytes", n)
		if err := l.curW.Flush(); err != nil {
			return err
		}
		l.initialSize = n

		// Remove previous logs.
		for _, logf := range oldLogs {
			log.Printf("removing log %s", logf)
			os.Remove(filepath.Join(l.path, logf))
		}
	}

	return nil
}

func (l *appendLog) Close() error {
	close(l.stopCh)

	if l.curW != nil {
		if err := l.curW.Close(); err != nil {
			return err
		}
	}
	return l.curF.Close()
}

func (l *appendLog) RealSize() int64 {
	fi, err := l.curF.Stat()
	if err != nil {
		return 0
	}
	return fi.Size()
}

func (l *appendLog) UpdateSize() int64 {
	l.mx.Lock()
	defer l.mx.Unlock()
	return l.curSize - l.initialSize
}

func (l *appendLog) flusher() {
	timer := time.NewTicker(flushInterval)
	defer timer.Stop()
	for {
		select {
		case <-timer.C:
			l.mx.Lock()
			if l.curW != nil {
				if err := l.curW.Flush(); err != nil {
					log.Printf("log flush error: %v", err)
				}
			}
			l.mx.Unlock()
		case <-l.stopCh:
			return
		}
	}
}