Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • master
  • renovate/github.com-duo-labs-webauthn-digest
  • renovate/github.com-go-ldap-ldap-v3-3.x
  • renovate/github.com-go-webauthn-webauthn-0.x
  • renovate/github.com-google-go-cmp-0.x
  • renovate/github.com-lunixbochs-struc-digest
  • renovate/github.com-mattn-go-sqlite3-1.x
  • renovate/github.com-prometheus-client_golang-1.x
  • renovate/go-1.x
  • renovate/golang.org-x-crypto-0.x
  • renovate/golang.org-x-sync-0.x
  • renovate/opentelemetry-go-monorepo
12 results

Target

Select target project
  • ai3/go-common
1 result
Select Git revision
  • master
  • renovate/github.com-duo-labs-webauthn-digest
  • renovate/github.com-go-ldap-ldap-v3-3.x
  • renovate/github.com-go-webauthn-webauthn-0.x
  • renovate/github.com-google-go-cmp-0.x
  • renovate/github.com-lunixbochs-struc-digest
  • renovate/github.com-mattn-go-sqlite3-1.x
  • renovate/github.com-prometheus-client_golang-1.x
  • renovate/go-1.x
  • renovate/golang.org-x-crypto-0.x
  • renovate/golang.org-x-sync-0.x
  • renovate/opentelemetry-go-monorepo
12 results
Show changes
Commits on Source (5)
...@@ -86,7 +86,7 @@ func mkhash() (pwhash.PasswordHash, string, error) { ...@@ -86,7 +86,7 @@ func mkhash() (pwhash.PasswordHash, string, error) {
name := *algo name := *algo
switch *algo { switch *algo {
case "argon2": case "argon2":
h = pwhash.NewArgon2WithParams(uint32(*argon2Time), uint32(*argon2Mem*1024), uint8(*argon2Threads)) h = pwhash.NewArgon2StdWithParams(uint32(*argon2Time), uint32(*argon2Mem*1024), uint8(*argon2Threads))
name = fmt.Sprintf("%s(%d/%d/%d)", *algo, *argon2Time, *argon2Mem, *argon2Threads) name = fmt.Sprintf("%s(%d/%d/%d)", *algo, *argon2Time, *argon2Mem, *argon2Threads)
case "scrypt": case "scrypt":
h = pwhash.NewScryptWithParams(*scryptN, *scryptR, *scryptP) h = pwhash.NewScryptWithParams(*scryptN, *scryptR, *scryptP)
......
...@@ -18,7 +18,7 @@ require ( ...@@ -18,7 +18,7 @@ require (
github.com/gofrs/flock v0.12.1 github.com/gofrs/flock v0.12.1
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
github.com/mattn/go-sqlite3 v1.14.23 github.com/mattn/go-sqlite3 v1.14.25
github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75
github.com/prometheus/client_golang v1.20.3 github.com/prometheus/client_golang v1.20.3
github.com/russross/blackfriday/v2 v2.1.0 github.com/russross/blackfriday/v2 v2.1.0
......
...@@ -241,6 +241,8 @@ github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0 ...@@ -241,6 +241,8 @@ github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg=
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.25 h1:rszkIulEvxqZ8JfFG4yWEZh5u9qAKeSOdea67p8kk6s=
github.com/mattn/go-sqlite3 v1.14.25/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 h1:cUVxyR+UfmdEAZGJ8IiKld1O0dbGotEnkMolG5hfMSY= github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 h1:cUVxyR+UfmdEAZGJ8IiKld1O0dbGotEnkMolG5hfMSY=
......
...@@ -6,15 +6,15 @@ import ( ...@@ -6,15 +6,15 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log"
"strconv" "strconv"
"strings" "strings"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
var ( const (
argonKeyLen uint32 = 32 argonLegacyKeySize = 32
argonDefaultKeySize = 16
argonSaltLen = 16 argonSaltLen = 16
) )
...@@ -29,9 +29,11 @@ type argon2PasswordHash struct { ...@@ -29,9 +29,11 @@ type argon2PasswordHash struct {
// newArgon2PasswordHash returns an Argon2i-based PasswordHash using the // newArgon2PasswordHash returns an Argon2i-based PasswordHash using the
// specified parameters for time, memory, and number of threads. // specified parameters for time, memory, and number of threads.
func newArgon2PasswordHash(time, mem uint32, threads uint8, codec argon2Codec) PasswordHash { func newArgon2PasswordHash(kind string, keySize int, time, mem uint32, threads uint8, codec argon2Codec) PasswordHash {
return &argon2PasswordHash{ return &argon2PasswordHash{
params: argon2Params{ params: argon2Params{
KeySize: keySize,
Kind: kind,
Time: time, Time: time,
Memory: mem, Memory: mem,
Threads: threads, Threads: threads,
...@@ -41,8 +43,8 @@ func newArgon2PasswordHash(time, mem uint32, threads uint8, codec argon2Codec) P ...@@ -41,8 +43,8 @@ func newArgon2PasswordHash(time, mem uint32, threads uint8, codec argon2Codec) P
} }
// NewArgon2 returns an Argon2i-based PasswordHash using the default parameters. // NewArgon2 returns an Argon2i-based PasswordHash using the default parameters.
func NewArgon2() PasswordHash { func NewArgon2Legacy() PasswordHash {
return NewArgon2WithParams( return NewArgon2LegacyWithParams(
defaultArgon2Params.Time, defaultArgon2Params.Time,
defaultArgon2Params.Memory, defaultArgon2Params.Memory,
defaultArgon2Params.Threads, defaultArgon2Params.Threads,
...@@ -51,8 +53,8 @@ func NewArgon2() PasswordHash { ...@@ -51,8 +53,8 @@ func NewArgon2() PasswordHash {
// NewArgon2WithParams returns an Argon2i-based PasswordHash using the // NewArgon2WithParams returns an Argon2i-based PasswordHash using the
// specified parameters for time, memory, and number of threads. // specified parameters for time, memory, and number of threads.
func NewArgon2WithParams(time, mem uint32, threads uint8) PasswordHash { func NewArgon2LegacyWithParams(time, mem uint32, threads uint8) PasswordHash {
return newArgon2PasswordHash(time, mem, threads, &a2Codec{}) return newArgon2PasswordHash(kindArgon2I, argonLegacyKeySize, time, mem, threads, &a2LegacyCodec{})
} }
// NewArgon2Std returns an Argon2i-based PasswordHash that conforms // NewArgon2Std returns an Argon2i-based PasswordHash that conforms
...@@ -65,12 +67,12 @@ func NewArgon2Std() PasswordHash { ...@@ -65,12 +67,12 @@ func NewArgon2Std() PasswordHash {
) )
} }
// NewArgon2StdWithParams returns an Argon2i-based PasswordHash using // NewArgon2StdWithParams returns an Argon2id-based PasswordHash using
// the specified parameters for time, memory, and number of // the specified parameters for time, memory, and number of
// threads. This will use the string encoding ("$argon2$") documented // threads. This will use the string encoding ("$argon2id$") documented
// in the argon2 reference implementation. // in the argon2 reference implementation.
func NewArgon2StdWithParams(time, mem uint32, threads uint8) PasswordHash { func NewArgon2StdWithParams(time, mem uint32, threads uint8) PasswordHash {
return newArgon2PasswordHash(time, mem, threads, &argon2StdCodec{}) return newArgon2PasswordHash(kindArgon2ID, argonDefaultKeySize, time, mem, threads, &argon2StdCodec{})
} }
// ComparePassword returns true if the given password matches the // ComparePassword returns true if the given password matches the
...@@ -80,28 +82,53 @@ func (s *argon2PasswordHash) ComparePassword(encrypted, password string) bool { ...@@ -80,28 +82,53 @@ func (s *argon2PasswordHash) ComparePassword(encrypted, password string) bool {
if err != nil { if err != nil {
return false return false
} }
dk2 := argon2.Key([]byte(password), salt, params.Time, params.Memory, params.Threads, argonKeyLen)
dk2 := params.hash(password, salt)
return subtle.ConstantTimeCompare(dk, dk2) == 1 return subtle.ConstantTimeCompare(dk, dk2) == 1
} }
// Encrypt the given password with the Argon2 algorithm. // Encrypt the given password with the Argon2 algorithm.
func (s *argon2PasswordHash) Encrypt(password string) string { func (s *argon2PasswordHash) Encrypt(password string) string {
salt := getRandomBytes(argonSaltLen) salt := getRandomBytes(argonSaltLen)
dk := argon2.Key([]byte(password), salt, s.params.Time, s.params.Memory, s.params.Threads, argonKeyLen) dk := s.params.hash(password, salt)
return s.codec.encodeArgon2Hash(s.params, salt, dk) return s.codec.encodeArgon2Hash(s.params, salt, dk)
} }
const (
kindArgon2I = "argon2i"
kindArgon2ID = "argon2id"
)
type argon2Params struct { type argon2Params struct {
Kind string
KeySize int
Time uint32 Time uint32
Memory uint32 Memory uint32
Threads uint8 Threads uint8
} }
func (p argon2Params) hash(password string, salt []byte) []byte {
if p.KeySize == 0 {
panic("key size is 0")
}
switch p.Kind {
case kindArgon2I:
return argon2.Key([]byte(password), salt, p.Time, p.Memory, p.Threads, uint32(p.KeySize))
case kindArgon2ID:
return argon2.IDKey([]byte(password), salt, p.Time, p.Memory, p.Threads, uint32(p.KeySize))
default:
panic("unknown argon2 hash kind")
}
}
// Default Argon2 parameters are tuned for a high-traffic // Default Argon2 parameters are tuned for a high-traffic
// authentication service (<1ms per operation). // authentication service (<1ms per operation).
var defaultArgon2Params = argon2Params{ var defaultArgon2Params = argon2Params{
Kind: kindArgon2ID,
KeySize: 16,
Time: 1, Time: 1,
Memory: 4 * 1024, Memory: 64 * 1024,
Threads: 4, Threads: 4,
} }
...@@ -110,13 +137,14 @@ type argon2Codec interface { ...@@ -110,13 +137,14 @@ type argon2Codec interface {
decodeArgon2Hash(string) (argon2Params, []byte, []byte, error) decodeArgon2Hash(string) (argon2Params, []byte, []byte, error)
} }
type a2Codec struct{} // Argon2i legacy encoding, do not use.
type a2LegacyCodec struct{}
func (*a2Codec) encodeArgon2Hash(params argon2Params, salt, dk []byte) string { func (*a2LegacyCodec) encodeArgon2Hash(params argon2Params, salt, dk []byte) string {
return fmt.Sprintf("$a2$%d$%d$%d$%x$%x", params.Time, params.Memory, params.Threads, salt, dk) return fmt.Sprintf("$a2$%d$%d$%d$%x$%x", params.Time, params.Memory, params.Threads, salt, dk)
} }
func (*a2Codec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk []byte, err error) { func (*a2LegacyCodec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk []byte, err error) {
if !strings.HasPrefix(s, "$a2$") { if !strings.HasPrefix(s, "$a2$") {
err = errors.New("not an Argon2 password hash") err = errors.New("not an Argon2 password hash")
return return
...@@ -128,6 +156,8 @@ func (*a2Codec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk ...@@ -128,6 +156,8 @@ func (*a2Codec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk
return return
} }
params.Kind = kindArgon2I
var i uint64 var i uint64
if i, err = strconv.ParseUint(parts[0], 10, 32); err != nil { if i, err = strconv.ParseUint(parts[0], 10, 32); err != nil {
...@@ -149,16 +179,36 @@ func (*a2Codec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk ...@@ -149,16 +179,36 @@ func (*a2Codec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk
if err != nil { if err != nil {
return return
} }
dk, err = hex.DecodeString(parts[4]) dk, err = hex.DecodeString(parts[4])
if err != nil {
return
}
params.KeySize = len(dk)
switch len(dk) {
case 16, 24, 32:
default:
err = errors.New("bad key size")
}
return return
} }
// Standard Argon2 encoding as per the reference implementation in
// https://github.com/P-H-C/phc-winner-argon2/blob/4ac8640c2adc1257677d27d3f833c8d1ee68c7d2/src/encoding.c#L242-L252
type argon2StdCodec struct{} type argon2StdCodec struct{}
const argon2HashVersionStr = "v=19"
func (*argon2StdCodec) encodeArgon2Hash(params argon2Params, salt, dk []byte) string { func (*argon2StdCodec) encodeArgon2Hash(params argon2Params, salt, dk []byte) string {
encSalt := base64.RawStdEncoding.EncodeToString(salt) encSalt := base64.RawStdEncoding.EncodeToString(salt)
encDK := base64.RawStdEncoding.EncodeToString(dk) encDK := base64.RawStdEncoding.EncodeToString(dk)
return fmt.Sprintf("$argon2i$v=19$m=%d,t=%d,p=%d$%s$%s", params.Memory, params.Time, params.Threads, encSalt, encDK) return fmt.Sprintf(
"$%s$%s$m=%d,t=%d,p=%d$%s$%s",
params.Kind, argon2HashVersionStr,
params.Memory, params.Time, params.Threads,
encSalt, encDK)
} }
func parseArgon2HashParams(s string) (params argon2Params, err error) { func parseArgon2HashParams(s string) (params argon2Params, err error) {
...@@ -182,7 +232,7 @@ func parseArgon2HashParams(s string) (params argon2Params, err error) { ...@@ -182,7 +232,7 @@ func parseArgon2HashParams(s string) (params argon2Params, err error) {
i, err = strconv.ParseUint(kv[1], 10, 8) i, err = strconv.ParseUint(kv[1], 10, 8)
params.Threads = uint8(i) params.Threads = uint8(i)
default: default:
err = errors.New("unknown parameter in hash") err = fmt.Errorf("unknown parameter '%s' in hash", kv[0])
} }
if err != nil { if err != nil {
return return
...@@ -192,30 +242,46 @@ func parseArgon2HashParams(s string) (params argon2Params, err error) { ...@@ -192,30 +242,46 @@ func parseArgon2HashParams(s string) (params argon2Params, err error) {
} }
func (*argon2StdCodec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk []byte, err error) { func (*argon2StdCodec) decodeArgon2Hash(s string) (params argon2Params, salt []byte, dk []byte, err error) {
if !strings.HasPrefix(s, "$argon2i$") { var kind string
switch {
case strings.HasPrefix(s, "$argon2i$"):
kind = kindArgon2I
case strings.HasPrefix(s, "$argon2id$"):
kind = kindArgon2ID
default:
err = errors.New("not an Argon2 password hash") err = errors.New("not an Argon2 password hash")
return return
} }
parts := strings.SplitN(s[9:], "$", 4) parts := strings.SplitN(s, "$", 6)
if len(parts) != 4 { if len(parts) != 6 {
err = errors.New("bad encoding") err = errors.New("bad encoding")
return return
} }
if parts[0] != "v=19" { if parts[2] != argon2HashVersionStr {
err = errors.New("bad argon2 hash version") err = errors.New("bad argon2 hash version")
return return
} }
params, err = parseArgon2HashParams(parts[1]) params, err = parseArgon2HashParams(parts[3])
if err != nil { if err != nil {
return return
} }
if salt, err = base64.RawStdEncoding.DecodeString(parts[2]); err != nil { params.Kind = kind
if salt, err = base64.RawStdEncoding.DecodeString(parts[4]); err != nil {
return
}
if dk, err = base64.RawStdEncoding.DecodeString(parts[5]); err != nil {
return return
} }
dk, err = base64.RawStdEncoding.DecodeString(parts[3])
log.Printf("params: %+v", params) params.KeySize = len(dk)
switch len(dk) {
case 16, 24, 32:
default:
err = errors.New("bad key size")
}
return return
} }
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
// //
// go test -bench=Argon2 -run=none . 2>&1 | \ // go test -bench=Argon2 -run=none . 2>&1 | \
// awk '/^Bench/ {ops=1000000000 / $3; print $1 " " ops " ops/sec"}' // awk '/^Bench/ {ops=1000000000 / $3; print $1 " " ops " ops/sec"}'
//
package pwhash package pwhash
import ( import (
...@@ -53,8 +52,9 @@ var prefixRegistry = map[string]PasswordHash{ ...@@ -53,8 +52,9 @@ var prefixRegistry = map[string]PasswordHash{
"$5$": NewSystemCrypt(), "$5$": NewSystemCrypt(),
"$6$": NewSystemCrypt(), "$6$": NewSystemCrypt(),
"$s$": NewScrypt(), "$s$": NewScrypt(),
"$a2$": NewArgon2(), "$a2$": NewArgon2Legacy(),
"$argon2i$": NewArgon2Std(), "$argon2i$": NewArgon2Std(),
"$argon2id$": NewArgon2Std(),
} }
// ComparePassword returns true if the given password matches the // ComparePassword returns true if the given password matches the
...@@ -65,6 +65,7 @@ func ComparePassword(encrypted, password string) bool { ...@@ -65,6 +65,7 @@ func ComparePassword(encrypted, password string) bool {
return h.ComparePassword(encrypted, password) return h.ComparePassword(encrypted, password)
} }
} }
return false return false
} }
...@@ -73,7 +74,7 @@ func ComparePassword(encrypted, password string) bool { ...@@ -73,7 +74,7 @@ func ComparePassword(encrypted, password string) bool {
var DefaultEncryptAlgorithm PasswordHash var DefaultEncryptAlgorithm PasswordHash
func init() { func init() {
DefaultEncryptAlgorithm = NewArgon2() DefaultEncryptAlgorithm = NewArgon2Std()
} }
// Encrypt will encrypt a password with the default algorithm. // Encrypt will encrypt a password with the default algorithm.
......
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"testing" "testing"
) )
func TestArgon2(t *testing.T) { func TestArgon2Legacy(t *testing.T) {
testImpl(t, NewArgon2()) testImpl(t, NewArgon2Legacy())
} }
func TestArgon2Std(t *testing.T) { func TestArgon2Std(t *testing.T) {
...@@ -65,7 +65,7 @@ func testImpl(t *testing.T, h PasswordHash) { ...@@ -65,7 +65,7 @@ func testImpl(t *testing.T, h PasswordHash) {
} }
} }
func TestStandardArgon2Password(t *testing.T) { func TestStandardArgon2IPassword(t *testing.T) {
enc := "$argon2i$v=19$m=32768,t=4,p=1$DG0B56zlrrx+VMVaM6wvsw$8iV+HwTKmofjrb+q9I2zZGQnGXzXtiIXv8VdHdvbbX8" enc := "$argon2i$v=19$m=32768,t=4,p=1$DG0B56zlrrx+VMVaM6wvsw$8iV+HwTKmofjrb+q9I2zZGQnGXzXtiIXv8VdHdvbbX8"
pw := "idontmindbirds" pw := "idontmindbirds"
if !ComparePassword(enc, pw) { if !ComparePassword(enc, pw) {
...@@ -73,6 +73,15 @@ func TestStandardArgon2Password(t *testing.T) { ...@@ -73,6 +73,15 @@ func TestStandardArgon2Password(t *testing.T) {
} }
} }
func TestStandardArgon2IDPassword(t *testing.T) {
// python3 -c 'from argon2 import PasswordHasher ; print(PasswordHasher().hash("idontmindbirds"))'
enc := "$argon2id$v=19$m=102400,t=2,p=8$7hQLBrHoxYxRO0R8km62pA$Dv5+BCctW4nCrxsy5C9JBg"
pw := "idontmindbirds"
if !ComparePassword(enc, pw) {
t.Fatal("comparison failed")
}
}
func BenchmarkArgon2(b *testing.B) { func BenchmarkArgon2(b *testing.B) {
var testParams []argon2Params var testParams []argon2Params
for iTime := 1; iTime <= 5; iTime++ { for iTime := 1; iTime <= 5; iTime++ {
...@@ -93,7 +102,7 @@ func BenchmarkArgon2(b *testing.B) { ...@@ -93,7 +102,7 @@ func BenchmarkArgon2(b *testing.B) {
for _, tp := range testParams { for _, tp := range testParams {
name := fmt.Sprintf("%d/%d/%d", tp.Time, tp.Memory, tp.Threads) name := fmt.Sprintf("%d/%d/%d", tp.Time, tp.Memory, tp.Threads)
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
h := NewArgon2WithParams(tp.Time, tp.Memory, tp.Threads) h := NewArgon2StdWithParams(tp.Time, tp.Memory, tp.Threads)
encPw := h.Encrypt(goodPw) encPw := h.Encrypt(goodPw)
b.ResetTimer() b.ResetTimer()
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"log" "log"
"net/url"
"strings" "strings"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
...@@ -13,11 +14,24 @@ import ( ...@@ -13,11 +14,24 @@ import (
// DebugMigrations can be set to true to dump statements to stderr. // DebugMigrations can be set to true to dump statements to stderr.
var DebugMigrations bool var DebugMigrations bool
const defaultOptions = "?cache=shared&_busy_timeout=10000&_journal=WAL&_sync=OFF" // See https://github.com/mattn/go-sqlite3/issues/209 for details on
// why these default parameters were chosen. WAL mode is mandatory for
// external litestream support.
func defaultOptions() url.Values {
v := make(url.Values)
v.Set("cache", "shared")
v.Set("_journal", "WAL")
v.Set("_sync", "OFF")
v.Set("_busy_timeout", "999999")
v.Set("_fk", "true")
v.Set("_cache_size", "268435456")
v.Set("_auto_vacuum", "incremental")
return v
}
type sqlOptions struct { type sqlOptions struct {
migrations []func(*sql.Tx) error migrations []func(*sql.Tx) error
sqlopts string sqlopts url.Values
} }
type Option func(*sqlOptions) type Option func(*sqlOptions)
...@@ -28,16 +42,16 @@ func WithMigrations(migrations []func(*sql.Tx) error) Option { ...@@ -28,16 +42,16 @@ func WithMigrations(migrations []func(*sql.Tx) error) Option {
} }
} }
func WithSqliteOptions(sqlopts string) Option { func WithSqliteOption(opt, value string) Option {
return func(opts *sqlOptions) { return func(opts *sqlOptions) {
opts.sqlopts = sqlopts opts.sqlopts.Set(opt, value)
} }
} }
// OpenDB opens a SQLite database and runs the database migrations. // OpenDB opens a SQLite database and runs the database migrations.
func OpenDB(dburi string, options ...Option) (*sql.DB, error) { func OpenDB(dburi string, options ...Option) (*sql.DB, error) {
var opts sqlOptions var opts sqlOptions
opts.sqlopts = defaultOptions opts.sqlopts = defaultOptions()
for _, o := range options { for _, o := range options {
o(&opts) o(&opts)
} }
...@@ -45,7 +59,7 @@ func OpenDB(dburi string, options ...Option) (*sql.DB, error) { ...@@ -45,7 +59,7 @@ func OpenDB(dburi string, options ...Option) (*sql.DB, error) {
// Add sqlite3-specific parameters if none are already // Add sqlite3-specific parameters if none are already
// specified in the connection URI. // specified in the connection URI.
if !strings.Contains(dburi, "?") { if !strings.Contains(dburi, "?") {
dburi += opts.sqlopts dburi = fmt.Sprintf("%s?%s", dburi, opts.sqlopts.Encode())
} }
db, err := sql.Open("sqlite3", dburi) db, err := sql.Open("sqlite3", dburi)
...@@ -56,6 +70,7 @@ func OpenDB(dburi string, options ...Option) (*sql.DB, error) { ...@@ -56,6 +70,7 @@ func OpenDB(dburi string, options ...Option) (*sql.DB, error) {
// Limit the pool to a single connection. // Limit the pool to a single connection.
// https://github.com/mattn/go-sqlite3/issues/209 // https://github.com/mattn/go-sqlite3/issues/209
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
if err = migrate(db, opts.migrations); err != nil { if err = migrate(db, opts.migrations); err != nil {
db.Close() // nolint db.Close() // nolint
......
...@@ -3,7 +3,6 @@ package sqlutil ...@@ -3,7 +3,6 @@ package sqlutil
import ( import (
"context" "context"
"database/sql" "database/sql"
"io/ioutil"
"os" "os"
"testing" "testing"
) )
...@@ -13,7 +12,7 @@ func init() { ...@@ -13,7 +12,7 @@ func init() {
} }
func TestOpenDB(t *testing.T) { func TestOpenDB(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -44,7 +43,7 @@ func checkTestValue(t *testing.T, db *sql.DB) { ...@@ -44,7 +43,7 @@ func checkTestValue(t *testing.T, db *sql.DB) {
} }
func TestOpenDB_Migrations_MultipleStatements(t *testing.T) { func TestOpenDB_Migrations_MultipleStatements(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -64,7 +63,7 @@ func TestOpenDB_Migrations_MultipleStatements(t *testing.T) { ...@@ -64,7 +63,7 @@ func TestOpenDB_Migrations_MultipleStatements(t *testing.T) {
} }
func TestOpenDB_Migrations_SingleStatement(t *testing.T) { func TestOpenDB_Migrations_SingleStatement(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -86,7 +85,7 @@ func TestOpenDB_Migrations_SingleStatement(t *testing.T) { ...@@ -86,7 +85,7 @@ func TestOpenDB_Migrations_SingleStatement(t *testing.T) {
} }
func TestOpenDB_Migrations_Versions(t *testing.T) { func TestOpenDB_Migrations_Versions(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -114,7 +113,7 @@ func TestOpenDB_Migrations_Versions(t *testing.T) { ...@@ -114,7 +113,7 @@ func TestOpenDB_Migrations_Versions(t *testing.T) {
} }
func TestOpenDB_Write(t *testing.T) { func TestOpenDB_Write(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -143,7 +142,7 @@ func TestOpenDB_Write(t *testing.T) { ...@@ -143,7 +142,7 @@ func TestOpenDB_Write(t *testing.T) {
} }
func TestOpenDB_Migrations_Legacy(t *testing.T) { func TestOpenDB_Migrations_Legacy(t *testing.T) {
dir, err := ioutil.TempDir("", "") dir, err := os.MkdirTemp("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......