Commit c544c4bd authored by ale's avatar ale

Refactor backends into a cleaner API

This should make the server code more readable by disentangling it
from the backend implementation.
parent fca48375
Pipeline #2655 passed with stages
in 1 minute and 51 seconds
package backend
import (
"context"
"github.com/tstranex/u2f"
"gopkg.in/yaml.v2"
"git.autistici.org/id/auth"
)
// User contains the attributes of a user account as relevant to the
// authentication server. It is only used internally, to communicate
// between the authserver and its storage backends.
type User struct {
Name string
Email string
Shard string
EncryptedPassword []byte
TOTPSecret string
U2FRegistrations []u2f.Registration
AppSpecificPasswords []*AppSpecificPassword
Groups []string
}
// AppSpecificPassword is a password tied to a single service.
type AppSpecificPassword struct {
Service string
EncryptedPassword []byte
}
// Has2FA returns true if the user supports any 2FA method.
func (u *User) Has2FA() bool {
return u.HasU2F() || u.HasOTP()
}
// HasOTP returns true if the user supports (T)OTP.
func (u *User) HasOTP() bool {
return u.TOTPSecret != ""
}
// HasU2F returns true if the user supports U2F.
func (u *User) HasU2F() bool {
return len(u.U2FRegistrations) > 0
}
// UserInfo returns extra user information in the format required by
// the auth wire protocol.
func (u *User) UserInfo() *auth.UserInfo {
return &auth.UserInfo{
Email: u.Email,
Shard: u.Shard,
Groups: u.Groups,
}
}
// Spec specifies backend-specific configuration for a service.
type Spec struct {
BackendName string `yaml:"backend"`
Params yaml.MapSlice `yaml:"params"`
StaticGroups []string `yaml:"static_groups"`
}
// UserBackend provides us with per-service user information.
type UserBackend interface {
Close()
NewServiceBackend(*Spec) (ServiceBackend, error)
}
// ServiceBackend looks up user info for a specific service.
type ServiceBackend interface {
GetUser(context.Context, string) (*User, bool)
}
...@@ -4,12 +4,12 @@ import ( ...@@ -4,12 +4,12 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"path/filepath"
"strings"
ct "git.autistici.org/ai3/go-common/ldap/compositetypes" ct "git.autistici.org/ai3/go-common/ldap/compositetypes"
"github.com/tstranex/u2f" "github.com/tstranex/u2f"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"git.autistici.org/id/auth/backend"
) )
// BackendSpec parameters for the file backend. // BackendSpec parameters for the file backend.
...@@ -59,8 +59,8 @@ func (f *fileUser) getU2FRegistrations() []u2f.Registration { ...@@ -59,8 +59,8 @@ func (f *fileUser) getU2FRegistrations() []u2f.Registration {
return out return out
} }
func (f *fileUser) ToUser() *User { func (f *fileUser) ToUser() *backend.User {
return &User{ return &backend.User{
Name: f.Name, Name: f.Name,
Email: f.Email, Email: f.Email,
Shard: f.Shard, Shard: f.Shard,
...@@ -80,7 +80,7 @@ type fileBackend struct { ...@@ -80,7 +80,7 @@ type fileBackend struct {
func loadUsersFile(path string) (map[string]*fileUser, error) { func loadUsersFile(path string) (map[string]*fileUser, error) {
var userList []*fileUser var userList []*fileUser
if err := loadYAML(path, &userList); err != nil { if err := backend.LoadYAML(path, &userList); err != nil {
return nil, err return nil, err
} }
users := make(map[string]*fileUser) users := make(map[string]*fileUser)
...@@ -90,20 +90,14 @@ func loadUsersFile(path string) (map[string]*fileUser, error) { ...@@ -90,20 +90,14 @@ func loadUsersFile(path string) (map[string]*fileUser, error) {
return users, nil return users, nil
} }
func newFileBackend(config *Config, _ yaml.MapSlice) (*fileBackend, error) { // New creates a new file-based UserBackend.
func New(_ yaml.MapSlice, configDir string) (backend.UserBackend, error) {
return &fileBackend{ return &fileBackend{
files: make(map[string]map[string]*fileUser), files: make(map[string]map[string]*fileUser),
configDir: filepath.Dir(config.path), configDir: configDir,
}, nil }, nil
} }
func (b *fileBackend) relativePath(path string) string {
if strings.HasPrefix(path, "/") {
return path
}
return filepath.Join(b.configDir, path)
}
func (b *fileBackend) getUserMap(path string) (map[string]*fileUser, error) { func (b *fileBackend) getUserMap(path string) (map[string]*fileUser, error) {
m, ok := b.files[path] m, ok := b.files[path]
if !ok { if !ok {
...@@ -119,12 +113,12 @@ func (b *fileBackend) getUserMap(path string) (map[string]*fileUser, error) { ...@@ -119,12 +113,12 @@ func (b *fileBackend) getUserMap(path string) (map[string]*fileUser, error) {
func (b *fileBackend) Close() {} func (b *fileBackend) Close() {}
func (b *fileBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, error) { func (b *fileBackend) NewServiceBackend(spec *backend.Spec) (backend.ServiceBackend, error) {
var params fileServiceParams var params fileServiceParams
if err := unmarshalMapSlice(spec.Params, &params); err != nil { if err := backend.UnmarshalMapSlice(spec.Params, &params); err != nil {
return nil, err return nil, err
} }
m, err := b.getUserMap(b.relativePath(params.Src)) m, err := b.getUserMap(backend.ResolvePath(params.Src, b.configDir))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -133,7 +127,7 @@ func (b *fileBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, erro ...@@ -133,7 +127,7 @@ func (b *fileBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, erro
type fileServiceBackend map[string]*fileUser type fileServiceBackend map[string]*fileUser
func (b fileServiceBackend) GetUser(_ context.Context, name string) (*User, bool) { func (b fileServiceBackend) GetUser(_ context.Context, name string) (*backend.User, bool) {
u, ok := b[name] u, ok := b[name]
if !ok { if !ok {
return nil, false return nil, false
......
...@@ -12,6 +12,8 @@ import ( ...@@ -12,6 +12,8 @@ import (
"github.com/tstranex/u2f" "github.com/tstranex/u2f"
"gopkg.in/ldap.v3" "gopkg.in/ldap.v3"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"git.autistici.org/id/auth/backend"
) )
// ldapServiceParams defines a search to be performed when looking up // ldapServiceParams defines a search to be performed when looking up
...@@ -63,11 +65,11 @@ func getListFromLDAPEntry(entry *ldap.Entry, attr string) []string { ...@@ -63,11 +65,11 @@ func getListFromLDAPEntry(entry *ldap.Entry, attr string) []string {
return entry.GetAttributeValues(attr) return entry.GetAttributeValues(attr)
} }
func decodeAppSpecificPasswordList(encodedAsps []string) []*AppSpecificPassword { func decodeAppSpecificPasswordList(encodedAsps []string) []*backend.AppSpecificPassword {
var out []*AppSpecificPassword var out []*backend.AppSpecificPassword
for _, enc := range encodedAsps { for _, enc := range encodedAsps {
if p, err := ct.UnmarshalAppSpecificPassword(enc); err == nil { if p, err := ct.UnmarshalAppSpecificPassword(enc); err == nil {
out = append(out, &AppSpecificPassword{ out = append(out, &backend.AppSpecificPassword{
Service: p.Service, Service: p.Service,
EncryptedPassword: []byte(p.EncryptedPassword), EncryptedPassword: []byte(p.EncryptedPassword),
}) })
...@@ -115,10 +117,11 @@ type ldapBackend struct { ...@@ -115,10 +117,11 @@ type ldapBackend struct {
pool *ldaputil.ConnectionPool pool *ldaputil.ConnectionPool
} }
func newLDAPBackend(config *Config, params yaml.MapSlice) (*ldapBackend, error) { // New returns a new LDAP backend.
func New(params yaml.MapSlice, configDir string) (backend.UserBackend, error) {
// Unmarshal and validate configuration. // Unmarshal and validate configuration.
var lc ldapConfig var lc ldapConfig
if err := unmarshalMapSlice(params, &lc); err != nil { if err := backend.UnmarshalMapSlice(params, &lc); err != nil {
return nil, err return nil, err
} }
if err := lc.valid(); err != nil { if err := lc.valid(); err != nil {
...@@ -128,7 +131,7 @@ func newLDAPBackend(config *Config, params yaml.MapSlice) (*ldapBackend, error) ...@@ -128,7 +131,7 @@ func newLDAPBackend(config *Config, params yaml.MapSlice) (*ldapBackend, error)
// Read the bind password. // Read the bind password.
bindPw := lc.BindPw bindPw := lc.BindPw
if lc.BindPwFile != "" { if lc.BindPwFile != "" {
pwData, err := ioutil.ReadFile(lc.BindPwFile) pwData, err := ioutil.ReadFile(backend.ResolvePath(lc.BindPwFile, configDir))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -151,9 +154,9 @@ func (b *ldapBackend) Close() { ...@@ -151,9 +154,9 @@ func (b *ldapBackend) Close() {
b.pool.Close() b.pool.Close()
} }
func (b *ldapBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, error) { func (b *ldapBackend) NewServiceBackend(spec *backend.Spec) (backend.ServiceBackend, error) {
var params ldapServiceParams var params ldapServiceParams
if err := unmarshalMapSlice(spec.Params, &params); err != nil { if err := backend.UnmarshalMapSlice(spec.Params, &params); err != nil {
return nil, err return nil, err
} }
return newLDAPServiceBackend(b.pool, &params) return newLDAPServiceBackend(b.pool, &params)
...@@ -224,7 +227,7 @@ func (b *ldapServiceBackend) searchRequest(username string) *ldap.SearchRequest ...@@ -224,7 +227,7 @@ func (b *ldapServiceBackend) searchRequest(username string) *ldap.SearchRequest
} }
// Build a User object from a LDAP response. // Build a User object from a LDAP response.
func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.SearchResult) (*User, bool) { func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.SearchResult) (*backend.User, bool) {
if len(result.Entries) < 1 { if len(result.Entries) < 1 {
return nil, false return nil, false
} }
...@@ -235,7 +238,7 @@ func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.Sear ...@@ -235,7 +238,7 @@ func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.Sear
// Apply the attribute map. We don't care if an attribute is // Apply the attribute map. We don't care if an attribute is
// not defined in the map, as the get* functions will silently // not defined in the map, as the get* functions will silently
// ignore an empty attribute name. // ignore an empty attribute name.
u := User{ u := backend.User{
Name: username, Name: username,
Email: getStringFromLDAPEntry(entry, b.attrs["email"]), Email: getStringFromLDAPEntry(entry, b.attrs["email"]),
Shard: getStringFromLDAPEntry(entry, b.attrs["shard"]), Shard: getStringFromLDAPEntry(entry, b.attrs["shard"]),
...@@ -248,7 +251,7 @@ func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.Sear ...@@ -248,7 +251,7 @@ func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.Sear
return &u, true return &u, true
} }
func (b *ldapServiceBackend) GetUser(ctx context.Context, name string) (*User, bool) { func (b *ldapServiceBackend) GetUser(ctx context.Context, name string) (*backend.User, bool) {
result, err := b.pool.Search(ctx, b.searchRequest(name)) result, err := b.pool.Search(ctx, b.searchRequest(name))
if err != nil { if err != nil {
log.Printf("LDAP error: %v", err) log.Printf("LDAP error: %v", err)
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"errors" "errors"
"log" "log"
ct "git.autistici.org/ai3/go-common/ldap/compositetypes"
"github.com/tstranex/u2f" "github.com/tstranex/u2f"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
...@@ -13,7 +14,7 @@ import ( ...@@ -13,7 +14,7 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
ct "git.autistici.org/ai3/go-common/ldap/compositetypes" "git.autistici.org/id/auth/backend"
) )
// Names for the known SQL queries. // Names for the known SQL queries.
...@@ -61,9 +62,10 @@ func compileStatements(db *sql.DB, queries map[string]string) (map[string]*sql.S ...@@ -61,9 +62,10 @@ func compileStatements(db *sql.DB, queries map[string]string) (map[string]*sql.S
return m, nil return m, nil
} }
func newSQLBackend(config *Config, params yaml.MapSlice) (*sqlBackend, error) { // New returns a new SQL backend.
func New(params yaml.MapSlice, _ string) (backend.UserBackend, error) {
var sc sqlConfig var sc sqlConfig
if err := unmarshalMapSlice(params, &sc); err != nil { if err := backend.UnmarshalMapSlice(params, &sc); err != nil {
return nil, err return nil, err
} }
if sc.Driver == "" { if sc.Driver == "" {
...@@ -84,9 +86,9 @@ func (b *sqlBackend) Close() { ...@@ -84,9 +86,9 @@ func (b *sqlBackend) Close() {
b.db.Close() b.db.Close()
} }
func (b *sqlBackend) NewServiceBackend(spec *BackendSpec) (serviceBackend, error) { func (b *sqlBackend) NewServiceBackend(spec *backend.Spec) (backend.ServiceBackend, error) {
var sc sqlServiceConfig var sc sqlServiceConfig
if err := unmarshalMapSlice(spec.Params, &sc); err != nil { if err := backend.UnmarshalMapSlice(spec.Params, &sc); err != nil {
return nil, err return nil, err
} }
return newSQLServiceBackend(b.db, &sc) return newSQLServiceBackend(b.db, &sc)
...@@ -111,14 +113,14 @@ func newSQLServiceBackend(db *sql.DB, sc *sqlServiceConfig) (*sqlServiceBackend, ...@@ -111,14 +113,14 @@ func newSQLServiceBackend(db *sql.DB, sc *sqlServiceConfig) (*sqlServiceBackend,
}, nil }, nil
} }
func (b *sqlServiceBackend) GetUser(ctx context.Context, name string) (*User, bool) { func (b *sqlServiceBackend) GetUser(ctx context.Context, name string) (*backend.User, bool) {
tx, err := b.db.Begin() tx, err := b.db.Begin()
if err != nil { if err != nil {
return nil, false return nil, false
} }
defer tx.Rollback() // nolint defer tx.Rollback() // nolint
user := User{Name: name} user := backend.User{Name: name}
// Use NullStrings for optional fields. // Use NullStrings for optional fields.
var nullableTOTP, nullableShard sql.NullString var nullableTOTP, nullableShard sql.NullString
...@@ -176,7 +178,7 @@ func (b *sqlServiceBackend) getUserU2FRegistrations(tx *sql.Tx, name string) ([] ...@@ -176,7 +178,7 @@ func (b *sqlServiceBackend) getUserU2FRegistrations(tx *sql.Tx, name string) ([]
return out, nil return out, nil
} }
func (b *sqlServiceBackend) getUserASPs(tx *sql.Tx, name string) ([]*AppSpecificPassword, error) { func (b *sqlServiceBackend) getUserASPs(tx *sql.Tx, name string) ([]*backend.AppSpecificPassword, error) {
stmt, ok := b.stmts[sqlQueryGetASP] stmt, ok := b.stmts[sqlQueryGetASP]
if !ok { if !ok {
return nil, nil return nil, nil
...@@ -187,9 +189,9 @@ func (b *sqlServiceBackend) getUserASPs(tx *sql.Tx, name string) ([]*AppSpecific ...@@ -187,9 +189,9 @@ func (b *sqlServiceBackend) getUserASPs(tx *sql.Tx, name string) ([]*AppSpecific
} }
defer rows.Close() defer rows.Close()
var out []*AppSpecificPassword var out []*backend.AppSpecificPassword
for rows.Next() { for rows.Next() {
var asp AppSpecificPassword var asp backend.AppSpecificPassword
if err := rows.Scan(&asp.Service, &asp.EncryptedPassword); err != nil { if err := rows.Scan(&asp.Service, &asp.EncryptedPassword); err != nil {
continue continue
} }
......
package backend
import (
"io/ioutil"
"path/filepath"
"gopkg.in/yaml.v2"
)
// Unmarshal a partially-parsed yaml.MapSlice.
func UnmarshalMapSlice(raw yaml.MapSlice, obj interface{}) error {
b, err := yaml.Marshal(raw)
if err != nil {
return err
}
return yaml.Unmarshal(b, obj)
}
// Load and unmarshal a YAML file.
func LoadYAML(path string, obj interface{}) error {
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return err
}
return yaml.Unmarshal(data, obj)
}
// ResolvePath returns the path evaluated as relative to base.
func ResolvePath(path, base string) string {
if !filepath.IsAbs(path) {
path = filepath.Join(base, path)
}
return path
}
This diff is collapsed.
package server package server
import ( import (
"io/ioutil"
"log" "log"
"path/filepath" "path/filepath"
"sort" "sort"
"git.autistici.org/ai3/go-common/clientutil" "git.autistici.org/ai3/go-common/clientutil"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
)
// BackendSpec specifies backend-specific configuration for a service. "git.autistici.org/id/auth/backend"
type BackendSpec struct { )
BackendName string `yaml:"backend"`
Params yaml.MapSlice `yaml:"params"`
StaticGroups []string `yaml:"static_groups"`
}
// ServiceConfig defines the authentication backends for a service. // ServiceConfig defines the authentication backends for a service.
type ServiceConfig struct { type ServiceConfig struct {
BackendSpecs []*BackendSpec `yaml:"backends"` BackendSpecs []*backend.Spec `yaml:"backends"`
ChallengeResponse bool `yaml:"challenge_response"` ChallengeResponse bool `yaml:"challenge_response"`
Enforce2FA bool `yaml:"enforce_2fa"` Enforce2FA bool `yaml:"enforce_2fa"`
EnableLastLoginReporting bool `yaml:"enable_last_login_reporting"` EnableLastLoginReporting bool `yaml:"enable_last_login_reporting"`
EnableDeviceTracking bool `yaml:"enable_device_tracking"` EnableDeviceTracking bool `yaml:"enable_device_tracking"`
Ratelimits []string `yaml:"rate_limits"` Ratelimits []string `yaml:"rate_limits"`
} }
// Config for the authentication server. // Config for the authentication server.
...@@ -54,20 +48,11 @@ type Config struct { ...@@ -54,20 +48,11 @@ type Config struct {
path string path string
} }
// Load and unmarshal a YAML file.
func loadYAML(path string, obj interface{}) error {
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return err
}
return yaml.Unmarshal(data, obj)
}
// Load a standalone service configuration: a YAML-encoded file that // Load a standalone service configuration: a YAML-encoded file that
// may contain one or more ServiceConfig definitions. // may contain one or more ServiceConfig definitions.
func loadStandaloneServiceConfig(path string) (map[string]*ServiceConfig, error) { func loadStandaloneServiceConfig(path string) (map[string]*ServiceConfig, error) {
var out map[string]*ServiceConfig var out map[string]*ServiceConfig
if err := loadYAML(path, &out); err != nil { if err := backend.LoadYAML(path, &out); err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
...@@ -77,7 +62,7 @@ func loadStandaloneServiceConfig(path string) (map[string]*ServiceConfig, error) ...@@ -77,7 +62,7 @@ func loadStandaloneServiceConfig(path string) (map[string]*ServiceConfig, error)
// may contain one or more ServiceConfig definitions. // may contain one or more ServiceConfig definitions.
func loadStandaloneBackendConfig(path string) (map[string]yaml.MapSlice, error) { func loadStandaloneBackendConfig(path string) (map[string]yaml.MapSlice, error) {
var out map[string]yaml.MapSlice var out map[string]yaml.MapSlice
if err := loadYAML(path, &out); err != nil { if err := backend.LoadYAML(path, &out); err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
...@@ -113,7 +98,7 @@ func LoadConfig(path string) (*Config, error) { ...@@ -113,7 +98,7 @@ func LoadConfig(path string) (*Config, error) {
"file": nil, "file": nil,
}, },
} }
if err := loadYAML(path, &config); err != nil { if err := backend.LoadYAML(path, &config); err != nil {
return nil, err return nil, err
} }
...@@ -148,12 +133,3 @@ func LoadConfig(path string) (*Config, error) { ...@@ -148,12 +133,3 @@ func LoadConfig(path string) (*Config, error) {
return &config, nil return &config, nil
} }
// Unmarshal a partially-parsed yaml.MapSlice.
func unmarshalMapSlice(raw yaml.MapSlice, obj interface{}) error {
b, err := yaml.Marshal(raw)
if err != nil {
return err
}
return yaml.Unmarshal(b, obj)
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"time" "time"
"git.autistici.org/ai3/go-common/clientutil" "git.autistici.org/ai3/go-common/clientutil"
"git.autistici.org/id/auth/backend"
"git.autistici.org/id/usermetadb/client" "git.autistici.org/id/usermetadb/client"
"git.autistici.org/id/auth" "git.autistici.org/id/auth"
...@@ -32,7 +33,7 @@ func newDeviceFilter(config *clientutil.BackendConfig) (*deviceFilter, error) { ...@@ -32,7 +33,7 @@ func newDeviceFilter(config *clientutil.BackendConfig) (*deviceFilter, error) {
return &deviceFilter{c}, nil return &deviceFilter{c}, nil
} }
func (f *deviceFilter) Filter(user *User, req *auth.Request, resp *auth.Response) *auth.Response { func (f *deviceFilter) Filter(user *backend.User, req *auth.Request, resp *auth.Response) *auth.Response {
// If there is no DeviceInfo, skip. // If there is no DeviceInfo, skip.
if req.DeviceInfo == nil { if req.DeviceInfo == nil {
return resp return resp
...@@ -64,7 +65,7 @@ func (f *deviceFilter) Filter(user *User, req *auth.Request, resp *auth.Response ...@@ -64,7 +65,7 @@ func (f *deviceFilter) Filter(user *User, req *auth.Request, resp *auth.Response
return resp return resp
} }
func (f *deviceFilter) sendNewDeviceEmail(user *User, dev *auth.DeviceInfo) error { func (f *deviceFilter) sendNewDeviceEmail(user *backend.User, dev *auth.DeviceInfo) error {
// TODO: Not implemented. // TODO: Not implemented.
log.Printf("new device for user %s: %+v", user.Name, dev) log.Printf("new device for user %s: %+v", user.Name, dev)
return nil return nil
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"git.autistici.org/ai3/go-common/clientutil" "git.autistici.org/ai3/go-common/clientutil"
"git.autistici.org/id/auth" "git.autistici.org/id/auth"
"git.autistici.org/id/auth/backend"
"git.autistici.org/id/usermetadb" "git.autistici.org/id/usermetadb"
"git.autistici.org/id/usermetadb/client" "git.autistici.org/id/usermetadb/client"
) )
...@@ -29,7 +30,7 @@ func newLastLoginFilter(config *clientutil.BackendConfig) (*lastloginFilter, err ...@@ -29,7 +30,7 @@ func newLastLoginFilter(config *clientutil.BackendConfig) (*lastloginFilter, err
var lastloginTimeout = 30 * time.Second var lastloginTimeout = 30 * time.Second
func (f *lastloginFilter) Filter(user *User, req *auth.Request, resp *auth.Response) *auth.Response { func (f *lastloginFilter) Filter(user *backend.User, req *auth.Request, resp *auth.Response) *auth.Response {
if resp.Status != auth.StatusOK { if resp.Status != auth.StatusOK {
return resp return resp
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"time" "time"
"git.autistici.org/id/auth" "git.autistici.org/id/auth"
"git.autistici.org/id/auth/backend"
) )
// Try to use as little memory as possible for each entry: use a UNIX // Try to use as little memory as possible for each entry: use a UNIX
...@@ -149,15 +150,15 @@ func (b *Blacklist) Incr(key string) { ...@@ -149,15 +150,15 @@ func (b *Blacklist) Incr(key string) {
} }
// Function that extracts a key from a request. // Function that extracts a key from a request.
type ratelimitKeyFunc func(*User, *auth.Request) string type ratelimitKeyFunc func(*backend.User, *auth.Request) string
// Extract the username from the request. // Extract the username from the request.
func usernameKey(user *User, _ *auth.Request) string { func usernameKey(user *backend.User, _ *auth.Request) string {
return user.Name return user.Name
} }
// Extract the client IP address (if present) from the request. // Extract the client IP address (if present) from the request.
func ipAddrKey(_ *User, req *auth.Request) string { func ipAddrKey(_ *backend.User, req *auth.Request) string {
if req.DeviceInfo != nil { if req.DeviceInfo != nil {
return req.DeviceInfo.RemoteAddr return req.DeviceInfo.RemoteAddr
} }
...@@ -197,7 +198,7 @@ type authRatelimiterBase struct { ...@@ -197,7 +198,7 @@ type authRatelimiterBase struct {
keyFuncs []ratelimitKeyFunc keyFuncs []ratelimitKeyFunc
} }
func (r *authRatelimiterBase) key(user *User, req *auth.Request) string { func (r *authRatelimiterBase) key(user *backend.User, req *auth.Request) string {
if len(r.keyFuncs) == 1 { if len(r.keyFuncs) == 1 {
return r.keyFuncs[0](user, req) return r.keyFuncs[0](user, req)
} }
...@@ -226,7 +227,7 @@ func newAuthRatelimiter(config *authRatelimiterConfig) (*authRatelimiter, error) ...@@ -226,7 +227,7 @@ func newAuthRatelimiter(config *authRatelimiterConfig) (*authRatelimiter, error)
}, nil }, nil
} }
func (r *authRatelimiter) AllowIncr(user *User, req *auth.Request) bool { func (r *authRatelimiter) AllowIncr(user *backend.User, req *auth.Request) bool {
return r.rl.AllowIncr(r.key(user, req)) return r.rl.AllowIncr(r.key(user, req))
} }
...@@ -249,11 +250,11 @@ func newAuthBlacklist(config *authRatelimiterConfig) (*authBlacklist, error) { ...@@ -249,11 +250,11 @@ func newAuthBlacklist(config *authRatelimiterConfig) (*authBlacklist, error) {
}, nil }, nil
} }
func (b *authBlacklist) Allow(user *User, req *auth.Request) bool { func (b *authBlacklist) Allow(user *backend.User, req *auth.Request) bool {
return b.bl.Allow(b.key(user, req)) return b.bl.Allow(b.key(user, req))
} }
func (b *authBlacklist) Incr(user *User, req *auth.Request, resp *auth.Response) { func (b *authBlacklist) Incr(user *backend.User, req *auth.Request, resp *auth.Response) {
if b.onFailure && resp.Status == auth.StatusOK { if b.onFailure && resp.Status == auth.StatusOK {
return return
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (