package mem import ( "encoding/binary" "fmt" "io" "log" "os" "path/filepath" "sort" "strconv" "strings" "sync" "time" pb "git.autistici.org/ai3/tools/replds2/proto" "github.com/DataDog/zstd" "google.golang.org/protobuf/proto" ) var ( maxLogSize int64 = 50 * 1024 * 1024 flushInterval = 30 * time.Second ) func logWrite(w io.Writer, node *pb.Node, dirty bool) (int, error) { data, err := proto.Marshal(node) if err != nil { return 0, err } // Encode 4-byte size. var sz [4]byte binary.LittleEndian.PutUint32(sz[:], uint32(len(data))) if _, err := w.Write(sz[:]); err != nil { return 0, err } // Encode dirty bit (byte). var b [1]byte if dirty { b[0] = 1 } if _, err := w.Write(b[:]); err != nil { return 0, err } // Write data. n, err := w.Write(data) return n + 5, err } // Read an entry from the log, reusing the given buffer for decoding. // Returns the node, dirty bit, new buffer, error. func logRead(r io.Reader, buf []byte) (*pb.Node, bool, []byte, error) { // Read carefully the size. var szbuf [4]byte n, err := r.Read(szbuf[:]) if err != nil { return nil, false, buf, err } if n != 4 { return nil, false, buf, fmt.Errorf("short read (%d/%d)", n, 4) } sz := int(binary.LittleEndian.Uint32(szbuf[:])) // Read the dirty bit. var db [1]byte _, err = r.Read(db[:]) if err != nil { return nil, false, buf, err } var dirty bool if db[0] == 1 { dirty = true } // Read the data, reallocate the decoding buffer if it // isn't large enough. if len(buf) < sz { buf = make([]byte, sz) } n, err = r.Read(buf[:sz]) if err != nil { return nil, false, buf, err } if n != sz { return nil, false, buf, fmt.Errorf("short read (%d/%d)", n, sz) } var node pb.Node if err := proto.Unmarshal(buf[:sz], &node); err != nil { return nil, false, buf, err } return &node, dirty, buf, nil } func listLogfiles(path string) []string { dir, err := os.Open(path) if err != nil { return nil } defer dir.Close() files, err := dir.Readdir(0) if err != nil { return nil } var lf []string for _, f := range files { if f.Mode().IsRegular() && strings.HasPrefix(f.Name(), "log.") { lf = append(lf, f.Name()) } } sort.Strings(lf) return lf } func logNameFromIndex(idx int) string { return fmt.Sprintf("log.%06d", idx) } func indexFromLogName(name string) int { idx, _ := strconv.Atoi(name[4:]) return idx } func processLog(path string, fn func(*pb.Node) error) (bool, error) { f, err := os.Open(path) if err != nil { return false, err } defer f.Close() ff := zstd.NewReader(f) defer ff.Close() var dirty bool var buf []byte for { node, nodeDirty, newBuf, err := logRead(ff, buf) buf = newBuf if err == io.EOF { return dirty, nil } if err != nil { return false, err } if nodeDirty { dirty = true } if err := fn(node); err != nil { return false, err } } } func checkpoint(w io.Writer, dumpFn func(func(*pb.Node) error) error) (int64, error) { var sz int64 err := dumpFn(func(node *pb.Node) error { n, err := logWrite(w, node, false) if err != nil { return err } sz += int64(n) return nil }) return sz, err } type appendLog struct { path string maxSize int64 curIdx int curF *os.File curW *zstd.Writer stopCh chan bool mx sync.Mutex curSize int64 initialSize int64 } func openLog(path string, setFn func(*pb.Node) error, dumpFn func(func(*pb.Node) error) error) (*appendLog, error) { if err := os.MkdirAll(path, 0700); err != nil { return nil, err } // Get the list of log files, and process each one of them. var curIdx int atCheckpoint := true for _, logf := range listLogfiles(path) { curIdx = indexFromLogName(logf) dirty, err := processLog(filepath.Join(path, logf), setFn) if err != nil { return nil, fmt.Errorf("in %s: %w", logf, err) } log.Printf("loaded %s (%v)", logf, dirty) if dirty { atCheckpoint = false } } l := &appendLog{ path: path, maxSize: maxLogSize, curIdx: curIdx, stopCh: make(chan bool), } // Open a new log and checkpoint. if err := l.Rotate(!atCheckpoint, dumpFn); err != nil { return nil, err } // Start periodic flusher. go l.flusher() return l, nil } // Expects caller to hold the mutex. func (l *appendLog) rotateLog() error { idx := l.curIdx + 1 logf := logNameFromIndex(idx) log.Printf("opening log %s", logf) f, err := os.Create(filepath.Join(l.path, logf)) if err != nil { return err } zf := zstd.NewWriterLevel(f, zstd.BestSpeed) if l.curW != nil { if err := l.curW.Close(); err != nil { stat, _ := l.curF.Stat() log.Printf("error flushing log %s: %v", stat.Name(), err) } } if l.curF != nil { l.curF.Close() } l.curW = zf l.curF = f l.curIdx = idx l.curSize = 0 l.initialSize = 0 return nil } func (l *appendLog) Write(node *pb.Node) error { l.mx.Lock() defer l.mx.Unlock() n, err := logWrite(l.curW, node, true) if err != nil { return err } l.curSize += int64(n) if l.curSize > l.maxSize { return l.rotateLog() } return nil } func (l *appendLog) Rotate(doCheckpoint bool, dumpFn func(func(*pb.Node) error) error) error { l.mx.Lock() defer l.mx.Unlock() // Make a list of old log files to remove. var oldLogs []string if doCheckpoint { oldLogs = listLogfiles(l.path) } if err := l.rotateLog(); err != nil { return err } if doCheckpoint { log.Printf("starting checkpoint") n, err := checkpoint(l.curW, dumpFn) if err != nil { return err } log.Printf("checkpointed %d bytes", n) if err := l.curW.Flush(); err != nil { return err } l.initialSize = n // Remove previous logs. for _, logf := range oldLogs { log.Printf("removing log %s", logf) os.Remove(filepath.Join(l.path, logf)) } } return nil } func (l *appendLog) Close() error { close(l.stopCh) if l.curW != nil { if err := l.curW.Close(); err != nil { return err } } return l.curF.Close() } func (l *appendLog) RealSize() int64 { fi, err := l.curF.Stat() if err != nil { return 0 } return fi.Size() } func (l *appendLog) UpdateSize() int64 { l.mx.Lock() defer l.mx.Unlock() return l.curSize - l.initialSize } func (l *appendLog) flusher() { timer := time.NewTicker(flushInterval) defer timer.Stop() for { select { case <-timer.C: l.mx.Lock() if l.curW != nil { if err := l.curW.Flush(); err != nil { log.Printf("log flush error: %v", err) } } l.mx.Unlock() case <-l.stopCh: return } } }