Commit 21484927 authored by ale's avatar ale

Add a short-term storage for OTP token replay protection

Backed by memcache, like the U2F challenge storage. The code has been
refactored to abstract the common replicated memcache client away.
parent fc2ad88c
Pipeline #1345 passed with stage
in 18 seconds
......@@ -70,6 +70,13 @@ type UserBackend interface {
GetUser(context.Context, *BackendSpec, string) (*User, bool)
}
// OTPShortTermStorage stores short-term otp tokens for replay
// protection purposes.
type OTPShortTermStorage interface {
AddToken(string, string) error
HasToken(string, string) bool
}
// U2FShortTermStorage stores short-term u2f challenges.
type U2FShortTermStorage interface {
SetUserChallenge(string, *u2f.Challenge) error
......@@ -199,7 +206,7 @@ func (c *ServiceConfig) notifyBlacklists(user *User, req *auth.Request, resp *au
}
}
type u2fShortTermStorageConfig struct {
type shortTermStorageConfig struct {
Servers []string `yaml:"memcache_servers"`
}
......@@ -225,7 +232,11 @@ type Config struct {
// Configuration for the U2F short-term challenge storage
// (backed by memcached).
U2FShortTerm *u2fShortTermStorageConfig `yaml:"u2f_short_term_storage"`
U2FShortTerm *shortTermStorageConfig `yaml:"u2f_short_term_storage"`
// Configuration for the OTP short-term replay protection
// storage (backed by memcached).
OTPShortTerm *shortTermStorageConfig `yaml:"otp_short_term_storage"`
// Runtime versions of the above. These objects are shared by
// all services, as they contain the actual map data.
......@@ -376,6 +387,7 @@ type Server struct {
backends []UserBackend
config *Config
u2fShortTerm U2FShortTermStorage
otpShortTerm OTPShortTermStorage
}
func newError() *auth.Response {
......@@ -395,6 +407,11 @@ func NewServer(config *Config) (*Server, error) {
if config.U2FShortTerm != nil {
s.u2fShortTerm = newMemcacheU2FStorage(config.U2FShortTerm.Servers)
}
if config.OTPShortTerm != nil {
s.otpShortTerm = newMemcacheOTPStorage(config.OTPShortTerm.Servers)
} else {
s.otpShortTerm = &nullOTPStorage{}
}
for _, name := range config.EnabledBackends {
var b UserBackend
......@@ -563,7 +580,11 @@ func (s *Server) authenticateUserWith2FA(user *User, req *auth.Request) (*auth.R
}
return nil, errors.New("bad U2F response")
case req.OTP != "":
if user.HasOTP() && checkOTP(req.OTP, user.TOTPSecret) {
if user.HasOTP() && s.checkOTP(user, req.OTP, user.TOTPSecret) {
// Save the token for replay protection.
if err := s.otpShortTerm.AddToken(user.Name, req.OTP); err != nil {
log.Printf("error saving OTP token to short-term storage: %v", err)
}
return newOK(), nil
}
return nil, errors.New("bad OTP")
......@@ -616,6 +637,12 @@ func checkPassword(password, hash []byte) bool {
return pwhash.ComparePassword(string(hash), string(password))
}
func checkOTP(otp, secret string) bool {
func (s *Server) checkOTP(user *User, otp, secret string) bool {
// Check our short-ttl blacklist for the token (replay protection).
if s.otpShortTerm.HasToken(user.Name, otp) {
log.Printf("replay protection triggered for %s", user.Name)
return false
}
return totp.Validate(otp, secret)
}
package server
import (
"errors"
"sync"
"github.com/bradfitz/gomemcache/memcache"
)
type memcacheReplicatedClient struct {
caches []*memcache.Client
}
func newMemcacheReplicatedClient(servers []string) *memcacheReplicatedClient {
var m memcacheReplicatedClient
for _, s := range servers {
c := memcache.New(s)
c.Timeout = u2fClientTimeout
c.MaxIdleConns = u2fClientMaxIdleConns
m.caches = append(m.caches, c)
}
return &m
}
func (m *memcacheReplicatedClient) writeAll(key string, value []byte, ttl int) error {
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl),
}
// Write to the memcache servers. At least one write must succeed.
ch := make(chan error, len(m.caches))
defer close(ch)
for _, c := range m.caches {
go func(c *memcache.Client) {
ch <- c.Set(item)
}(c)
}
var ok bool
for i := 0; i < len(m.caches); i++ {
if err := <-ch; err == nil {
ok = true
}
}
if !ok {
return errors.New("all memcache servers failed")
}
return nil
}
func (m *memcacheReplicatedClient) readAny(key string) ([]byte, bool) {
// Run all reads in parallel, return the first non-error result.
//
// This would be better if the memcache API took a Context, so
// we could cancel all pending calls as soon as a result is
// received. This way, we keep them running in the background,
// ignore their results, and fire a goroutine to avoid leaking
// the result channel.
ch := make(chan []byte, 1)
var wg sync.WaitGroup
for _, c := range m.caches {
wg.Add(1)
go func(c *memcache.Client) {
defer wg.Done()
if item, err := c.Get(key); err == nil {
select {
case ch <- item.Value:
default:
}
}
}(c)
}
go func() {
wg.Wait()
close(ch)
}()
value := <-ch
if value == nil {
return nil, false
}
return value, true
}
package server
import "fmt"
const otpReplayProtectionTTL = 300
type memcacheOTPStorage struct {
*memcacheReplicatedClient
}
func newMemcacheOTPStorage(servers []string) *memcacheOTPStorage {
return &memcacheOTPStorage{newMemcacheReplicatedClient(servers)}
}
func (m *memcacheOTPStorage) AddToken(username, token string) error {
return m.writeAll(otpMemcacheKey(username, token), []byte{1}, otpReplayProtectionTTL)
}
func (m *memcacheOTPStorage) HasToken(username, token string) bool {
_, ok := m.readAny(otpMemcacheKey(username, token))
return ok
}
func otpMemcacheKey(username, token string) string {
return fmt.Sprintf("otp/%s/%s", username, token)
}
type nullOTPStorage struct{}
func (n *nullOTPStorage) AddToken(username, token string) error { return nil }
func (n *nullOTPStorage) HasToken(username, token string) bool { return false }
......@@ -3,18 +3,15 @@ package server
import (
"bytes"
"encoding/gob"
"errors"
"sync"
"time"
"github.com/bradfitz/gomemcache/memcache"
"github.com/tstranex/u2f"
)
var (
u2fClientMaxIdleConns = 5
u2fClientTimeout = 500 * time.Millisecond
u2fCacheExpirationSeconds int32 = 300
u2fClientMaxIdleConns = 5
u2fClientTimeout = 500 * time.Millisecond
u2fCacheExpirationSeconds = 300
)
func init() {
......@@ -31,18 +28,11 @@ func init() {
// servers in parallel.
//
type memcacheU2FStorage struct {
caches []*memcache.Client
*memcacheReplicatedClient
}
func newMemcacheU2FStorage(servers []string) *memcacheU2FStorage {
var m memcacheU2FStorage
for _, s := range servers {
c := memcache.New(s)
c.Timeout = u2fClientTimeout
c.MaxIdleConns = u2fClientMaxIdleConns
m.caches = append(m.caches, c)
}
return &m
return &memcacheU2FStorage{newMemcacheReplicatedClient(servers)}
}
func (m *memcacheU2FStorage) SetUserChallenge(user string, chal *u2f.Challenge) error {
......@@ -50,69 +40,15 @@ func (m *memcacheU2FStorage) SetUserChallenge(user string, chal *u2f.Challenge)
if err != nil {
return err
}
item := &memcache.Item{
Key: u2fChallengeKey(user),
Value: data,
Expiration: u2fCacheExpirationSeconds,
}
// Write to the memcache servers. At least one write must succeed.
ch := make(chan error, len(m.caches))
defer close(ch)
for _, c := range m.caches {
go func(c *memcache.Client) {
ch <- c.Set(item)
}(c)
}
var ok bool
for i := 0; i < len(m.caches); i++ {
if err := <-ch; err == nil {
ok = true
}
}
if !ok {
return errors.New("all memcache servers failed")
}
return nil
return m.writeAll(u2fChallengeKey(user), data, u2fCacheExpirationSeconds)
}
func (m *memcacheU2FStorage) GetUserChallenge(user string) (*u2f.Challenge, bool) {
// Run all reads in parallel, return the first non-error result.
//
// This would be better if the memcache API took a Context, so
// we could cancel all pending calls as soon as a result is
// received. This way, we keep them running in the background,
// ignore their results, and fire a goroutine to avoid leaking
// the result channel.
ch := make(chan *u2f.Challenge, 1)
var wg sync.WaitGroup
for _, c := range m.caches {
wg.Add(1)
go func(c *memcache.Client) {
defer wg.Done()
item, err := c.Get(u2fChallengeKey(user))
if err != nil {
return
}
chal, _ := deserializeU2FChallenge(item.Value) // nolint
select {
case ch <- chal:
default:
}
}(c)
}
go func() {
wg.Wait()
close(ch)
}()
chal := <-ch
if chal == nil {
value, ok := m.readAny(u2fChallengeKey(user))
if !ok {
return nil, false
}
chal, _ := deserializeU2FChallenge(value)
return chal, true
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment