Commit aa79431d authored by ale's avatar ale

Add a SQL backend

This first version (with tests) only supports the SQLite driver.
parent f095fa5e
......@@ -149,8 +149,10 @@ func createBackends(config *Config) (map[string]userBackend, error) {
b, err = newFileBackend(config, params)
case "ldap":
b, err = newLDAPBackend(config, params)
case "sql":
b, err = newSQLBackend(config, params)
default:
err = fmt.Errorf("unknown backend %s", name)
err = fmt.Errorf("unknown backend type %s", name)
}
if err != nil {
return nil, err
......
......@@ -40,7 +40,6 @@ func createTestServer(t testing.TB, configFiles map[string]string) *testServer {
if err != nil {
t.Fatal("NewServer():", err)
}
srv.Close()
return &testServer{
tmpdir: tmpdir,
......@@ -49,6 +48,7 @@ func createTestServer(t testing.TB, configFiles map[string]string) *testServer {
}
func (s *testServer) Close() {
s.srv.Close()
_ = os.RemoveAll(s.tmpdir)
}
......
package server
import (
"context"
"database/sql"
"errors"
"log"
_ "github.com/mattn/go-sqlite3"
"github.com/tstranex/u2f"
"gopkg.in/yaml.v2"
ct "git.autistici.org/ai3/go-common/ldap/compositetypes"
)
// Names for the known SQL queries.
const (
sqlQueryGetUser = "get_user"
sqlQueryGetGroups = "get_user_groups"
sqlQueryGetU2F = "get_user_u2f"
sqlQueryGetASP = "get_user_asp"
)
// Default SQL queries.
var defaultSQLQueries = map[string]string{
sqlQueryGetUser: `
SELECT email, password, totp_secret, '' AS shard FROM users WHERE name = ?
`,
}
type sqlConfig struct {
Driver string `yaml:"driver"`
URI string `yaml:"db_uri"`
}
type sqlServiceConfig struct {
Queries map[string]string `yaml:"queries"`
}
type sqlBackend struct {
db *sql.DB
}
type sqlServiceBackend struct {
db *sql.DB
stmts map[string]*sql.Stmt
}
func compileStatements(db *sql.DB, queries map[string]string) (map[string]*sql.Stmt, error) {
m := make(map[string]*sql.Stmt)
for name, query := range queries {
stmt, err := db.Prepare(query)
if err != nil {
return nil, err
}
m[name] = stmt
}
return m, nil
}
func newSQLBackend(config *Config, params yaml.MapSlice) (*sqlBackend, error) {
var sc sqlConfig
if err := unmarshalMapSlice(params, &sc); err != nil {
return nil, err
}
if sc.Driver == "" {
return nil, errors.New("driver is empty")
}
db, err := sql.Open(sc.Driver, sc.URI)
if err != nil {
return nil, err
}
return &sqlBackend{
db: db,
}, nil
}
func (b *sqlBackend) Close() {
b.db.Close()
}
func (b *sqlBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, error) {
var sc sqlServiceConfig
if err := unmarshalMapSlice(spec.Params, &sc); err != nil {
return nil, err
}
return newSQLServiceBackend(b.db, &sc)
}
func newSQLServiceBackend(db *sql.DB, sc *sqlServiceConfig) (*sqlServiceBackend, error) {
// Apply default queries.
for name, q := range defaultSQLQueries {
if _, ok := sc.Queries[name]; !ok {
sc.Queries[name] = q
}
}
// Compile the SQL statements.
stmts, err := compileStatements(db, sc.Queries)
if err != nil {
return nil, err
}
return &sqlServiceBackend{
db: db,
stmts: stmts,
}, nil
}
func (b *sqlServiceBackend) GetUser(ctx context.Context, name string) (*User, bool) {
tx, err := b.db.Begin()
if err != nil {
return nil, false
}
defer tx.Rollback() // nolint
user := User{Name: name}
// Use NullStrings for optional fields.
var nullableTOTP, nullableShard sql.NullString
row := tx.Stmt(b.stmts[sqlQueryGetUser]).QueryRow(name)
if err := row.Scan(&user.Email, &user.EncryptedPassword, &nullableTOTP, &nullableShard); err != nil {
return nil, false
}
if nullableTOTP.Valid {
user.TOTPSecret = nullableTOTP.String
}
if nullableShard.Valid {
user.Shard = nullableShard.String
}
// Now read the one-to-many relations.
if groups, err := b.getUserGroups(tx, name); err == nil {
user.Groups = groups
}
if regs, err := b.getUserU2FRegistrations(tx, name); err == nil {
user.U2FRegistrations = regs
}
if asps, err := b.getUserASPs(tx, name); err == nil {
user.AppSpecificPasswords = asps
}
return &user, true
}
func (b *sqlServiceBackend) getUserU2FRegistrations(tx *sql.Tx, name string) ([]u2f.Registration, error) {
stmt, ok := b.stmts[sqlQueryGetU2F]
if !ok {
return nil, nil
}
rows, err := tx.Stmt(stmt).Query(name)
if err != nil {
return nil, err
}
defer rows.Close()
// Use the compositetypes.U2FRegistration type to decode the
// U2F registration public key data into a usable format.
var out []u2f.Registration
for rows.Next() {
var ctr ct.U2FRegistration
if err := rows.Scan(&ctr.PublicKey, &ctr.KeyHandle); err != nil {
continue
}
reg, err := ctr.Decode()
if err != nil {
log.Printf("invalid u2f registration: %v", err)
continue
}
out = append(out, *reg)
}
return out, nil
}
func (b *sqlServiceBackend) getUserASPs(tx *sql.Tx, name string) ([]*AppSpecificPassword, error) {
stmt, ok := b.stmts[sqlQueryGetASP]
if !ok {
return nil, nil
}
rows, err := tx.Stmt(stmt).Query(name)
if err != nil {
return nil, err
}
defer rows.Close()
var out []*AppSpecificPassword
for rows.Next() {
var asp AppSpecificPassword
if err := rows.Scan(&asp.Service, &asp.EncryptedPassword); err != nil {
continue
}
out = append(out, &asp)
}
return out, nil
}
func (b *sqlServiceBackend) getUserGroups(tx *sql.Tx, name string) ([]string, error) {
stmt, ok := b.stmts[sqlQueryGetGroups]
if !ok {
return nil, nil
}
rows, err := tx.Stmt(stmt).Query(name)
if err != nil {
return nil, err
}
defer rows.Close()
var out []string
for rows.Next() {
var group string
if err := rows.Scan(&group); err != nil {
continue
}
out = append(out, group)
}
return out, nil
}
package server
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
"git.autistici.org/id/auth"
)
var (
testConfigTemplateWithSimpleDB = `---
enabled_backends:
sql:
driver: sqlite3
db_uri: "%s"
services:
test:
backends:
- backend: sql
params:
queries:
get_user: "SELECT email, password, '' AS totp_secret, '' AS shard FROM users WHERE email = ?"
`
testConfigTemplateWithFullDB = `---
enabled_backends:
sql:
driver: sqlite3
db_uri: "%s"
services:
test:
backends:
- backend: sql
params:
queries:
get_user: "SELECT email, password, totp_secret, '' AS shard FROM users WHERE name = ?"
interactive:
challenge_response: true
backends:
- backend: sql
params:
queries:
get_user: "SELECT email, password, totp_secret, '' AS shard FROM users WHERE name = ?"
get_user_groups: "SELECT group_name FROM group_membership WHERE name = ?"
get_user_u2f: "SELECT public_key, key_handle FROM u2f_registrations WHERE name = ?"
get_user_asp: "SELECT service, password FROM asps WHERE name = ?"
`
)
func withTestDB(t testing.TB, schema string) (func(), string) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
dbPath := filepath.Join(dir, "test.db")
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
_, err = db.Exec(schema)
if err != nil {
t.Fatalf("sql error: %v", err)
}
db.Close()
return func() {
os.RemoveAll(dir)
}, dbPath
}
func TestBackend_SQL_SimpleSchema(t *testing.T) {
// Test a minimal database schema.
cleanup, dbPath := withTestDB(t, `
CREATE TABLE users (
email text,
password text
);
CREATE UNIQUE INDEX users_idx ON users(email);
INSERT INTO users (email, password) VALUES (
'test@example.com', '$s$16384$8$1$c479e8eb722f1b071efea7826ccf9c20$96d63ebed0c64afb746026f56f71b2a1f8796c73141d2d6b1958d4ea26c60a0b'
);
`)
defer cleanup()
conf := fmt.Sprintf(testConfigTemplateWithSimpleDB, dbPath)
s := createTestServer(t, map[string]string{
"config.yml": conf,
})
defer s.Close()
client := &clientAdapter{s.srv}
resp, err := client.Authenticate(context.Background(), &auth.Request{
Service: "test",
Username: "test@example.com",
Password: []byte("password"),
})
if err != nil {
t.Fatalf("Authenticate: %v", err)
}
if resp.Status != auth.StatusOK {
t.Fatalf("authentication failed: %v", resp.Status)
}
}
func TestBackend_SQL(t *testing.T) {
// Full schema that can run standard authentication tests.
cleanup, dbPath := withTestDB(t, `
CREATE TABLE users (
name text,
email text,
totp_secret text,
password text
);
CREATE UNIQUE INDEX users_name_idx ON users(name);
CREATE TABLE group_membership (
name text,
group_name text
);
CREATE INDEX group_membership_idx ON group_membership(name);
CREATE TABLE u2f_registrations (
name text,
key_handle blob,
public_key blob
);
CREATE INDEX u2f_registrations_idx ON u2f_registrations(name);
CREATE TABLE asps (
name text,
service text,
password text
);
CREATE INDEX asp_idx ON asps(name);
INSERT INTO users (name, email, totp_secret, password) VALUES (
'testuser', 'testuser@example.com', NULL, '$s$16384$8$1$c479e8eb722f1b071efea7826ccf9c20$96d63ebed0c64afb746026f56f71b2a1f8796c73141d2d6b1958d4ea26c60a0b'), (
'2fauser', '2fauser@example.com', 'O32OBVS5BL5EAPB5', '$s$16384$8$1$c479e8eb722f1b071efea7826ccf9c20$96d63ebed0c64afb746026f56f71b2a1f8796c73141d2d6b1958d4ea26c60a0b');
INSERT INTO group_membership (name, group_name) VALUES (
'testuser', 'group1'), (
'2fauser', 'group2');
INSERT INTO u2f_registrations (name, key_handle, public_key) VALUES (
'2fauser', X'25ca255c0e8a6a88a13bc56ec52ba0b424f98f287eea516e5972e41403def2cf6ab33c5332f0c0b499fc826620f6e18efa49a381aa7587496572196aaa30a92b', X'0498ee4565cd348031cf36ee3549b63b5ea23b5e7ea6f297e7cccaeba99983d185110fb94fa6455c82d3e5c8d0be10be71308d76062fb5fa50d3ea8228048f0037');
`)
defer cleanup()
conf := fmt.Sprintf(testConfigTemplateWithFullDB, dbPath)
s := createTestServer(t, map[string]string{
"config.yml": conf,
})
defer s.Close()
runAuthenticationTest(t, &clientAdapter{s.srv})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment