package sqlutil

import (
	"context"
	"database/sql"
	"fmt"
	"log"
	"strings"

	_ "github.com/mattn/go-sqlite3"
)

// DebugMigrations can be set to true to dump statements to stderr.
var DebugMigrations bool

const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF"

type sqlOptions struct {
	migrations []func(*sql.Tx) error
	sqlopts    string
}

type Option func(*sqlOptions)

func WithMigrations(migrations []func(*sql.Tx) error) Option {
	return func(opts *sqlOptions) {
		opts.migrations = migrations
	}
}

func WithSqliteOptions(sqlopts string) Option {
	return func(opts *sqlOptions) {
		opts.sqlopts = sqlopts
	}
}

// OpenDB opens a SQLite database and runs the database migrations.
func OpenDB(dburi string, options ...Option) (*sql.DB, error) {
	var opts sqlOptions
	opts.sqlopts = defaultOptions
	for _, o := range options {
		o(&opts)
	}

	// Add sqlite3-specific parameters if none are already
	// specified in the connection URI.
	if !strings.Contains(dburi, "?") {
		dburi += opts.sqlopts
	}

	db, err := sql.Open("sqlite3", dburi)
	if err != nil {
		return nil, err
	}

	// Limit the pool to a single connection.
	// https://github.com/mattn/go-sqlite3/issues/209
	db.SetMaxOpenConns(1)

	if err = migrate(db, opts.migrations); err != nil {
		db.Close() // nolint
		return nil, err
	}

	return db, nil
}

// Fetch legacy (golang-migrate/migrate/v4) schema version.
func getLegacyMigrationVersion(tx *sql.Tx) (int, error) {
	var version int
	if err := tx.QueryRow(`SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1`).Scan(&version); err != nil {
		return 0, err
	}
	return version, nil
}

func migrate(db *sql.DB, migrations []func(*sql.Tx) error) error {
	tx, err := db.Begin()
	if err != nil {
		return fmt.Errorf("DB migration begin transaction: %w", err)
	}
	defer tx.Rollback() // nolint: errcheck

	var idx int
	if err = tx.QueryRow("PRAGMA user_version").Scan(&idx); err != nil {
		return fmt.Errorf("getting latest applied migration: %w", err)
	}
	if idx == 0 {
		if legacyIdx, err := getLegacyMigrationVersion(tx); err == nil {
			idx = legacyIdx
		}
	}

	if idx == len(migrations) {
		// Already fully migrated, nothing needed.
		return nil
	} else if idx > len(migrations) {
		return fmt.Errorf("database is at version %d, which is more recent than this binary understands", idx)
	}

	for i, f := range migrations[idx:] {
		if err := f(tx); err != nil {
			return fmt.Errorf("migration to version %d failed: %w", i+1, err)
		}
	}

	if n := len(migrations); n > 0 {
		// For some reason, ? substitution doesn't work in PRAGMA
		// statements, sqlite reports a parse error.
		if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version=%d", n)); err != nil {
			return fmt.Errorf("recording new DB version: %w", err)
		}
		log.Printf("db migration: upgraded schema version %d -> %d", idx, n)
	}

	return tx.Commit()
}

// Statement for migrations, executes one or more SQL statements.
func Statement(idl ...string) func(*sql.Tx) error {
	return func(tx *sql.Tx) error {
		for _, stmt := range idl {
			if DebugMigrations {
				log.Printf("db migration: executing: %s", stmt)
			}
			if _, err := tx.Exec(stmt); err != nil {
				return err
			}
		}
		return nil
	}
}

func WithTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
	tx, err := db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	if err := f(tx); err != nil {
		tx.Rollback() // nolint
		return err
	}

	return tx.Commit()
}

func WithReadonlyTx(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
	tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
	if err != nil {
		return err
	}
	defer tx.Rollback() // nolint
	return f(tx)
}