Commit 7b1b6837 authored by ale's avatar ale
Browse files

Use the ldap connection pool package from ai3/go-common

parent 451a4320
......@@ -4,10 +4,10 @@ import (
"context"
"errors"
"io/ioutil"
"net/url"
"strings"
"time"
ldaputil "git.autistici.org/ai3/go-common/ldap"
"gopkg.in/ldap.v2"
)
......@@ -16,6 +16,8 @@ import (
type LDAPServiceConfig struct {
// SearchBase, SearchFilter and Scope define parameters for
// the LDAP search. The search should return a single object.
// SearchFilter should contain the string "%s", which will be
// replaced with the username before performing a query.
SearchBase string `yaml:"search_base"`
SearchFilter string `yaml:"search_filter"`
Scope string `yaml:"scope"`
......@@ -162,89 +164,7 @@ func (c *LDAPConfig) Valid() error {
type ldapBackend struct {
config *LDAPConfig
pool *ldapConnectionPool
}
type ldapConnectionPool struct {
network string
target string
bindDN string
bindPw string
c chan *ldap.Conn
}
func (p *ldapConnectionPool) connect() (*ldap.Conn, error) {
conn, err := ldap.Dial(p.network, p.target)
if err != nil {
return nil, err
}
if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
conn.Close()
return nil, err
}
return conn, err
}
func (p *ldapConnectionPool) get(ctx context.Context) (*ldap.Conn, error) {
select {
case conn := <-p.c:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (p *ldapConnectionPool) release(conn *ldap.Conn, err error) {
// We assume that if we get an ErrorNetwork, then we need to reconnect.
for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
if conn != nil {
conn.Close()
}
conn, err = p.connect()
}
p.c <- conn
}
func (p *ldapConnectionPool) Close() {}
func newLDAPConnectionPool(uri, bindDN, bindPw string, numConns int) (*ldapConnectionPool, error) {
u, err := url.Parse(uri)
if err != nil {
return nil, err
}
network := "tcp"
target := "localhost:389"
switch u.Scheme {
case "ldap":
if u.Host != "" {
target = u.Host
}
case "ldapi":
target = u.Path
default:
return nil, errors.New("unsupported scheme")
}
p := &ldapConnectionPool{
c: make(chan *ldap.Conn, numConns),
network: network,
target: target,
bindDN: bindDN,
bindPw: bindPw,
}
for i := 0; i < numConns; i++ {
conn, err := p.connect()
if err != nil {
p.Close()
return nil, err
}
p.c <- conn
}
return p, nil
pool *ldaputil.ConnectionPool
}
func newLDAPBackend(config *Config) (*ldapBackend, error) {
......@@ -274,7 +194,7 @@ func newLDAPBackend(config *Config) (*ldapBackend, error) {
}
// Initialize the connection pool.
pool, err := newLDAPConnectionPool(config.LDAPConfig.URI, config.LDAPConfig.BindDN, strings.TrimSpace(string(bindPw)), 5)
pool, err := ldaputil.NewConnectionPool(config.LDAPConfig.URI, config.LDAPConfig.BindDN, strings.TrimSpace(string(bindPw)), 5)
if err != nil {
return nil, err
}
......@@ -295,7 +215,7 @@ func (b *ldapBackend) GetUser(ctx context.Context, spec *BackendSpec, name strin
return nil, false
}
conn, err := b.pool.get(ctx)
conn, err := b.pool.Get(ctx)
if err != nil {
return nil, false
}
......@@ -306,7 +226,7 @@ func (b *ldapBackend) GetUser(ctx context.Context, spec *BackendSpec, name strin
}
result, err := conn.Search(serviceConfig.searchRequest(name))
b.pool.release(conn, err)
b.pool.Release(conn, err)
if err != nil {
return nil, false
}
......
package clientutil
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)
// DoJSONHTTPRequest makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func DoJSONHTTPRequest(ctx context.Context, client *http.Client, uri string, req, resp interface{}) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
httpReq, err := http.NewRequest("POST", uri, bytes.NewReader(data))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq = httpReq.WithContext(ctx)
httpResp, err := RetryHTTPDo(client, httpReq, NewExponentialBackOff())
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != 200 {
return fmt.Errorf("HTTP status %d", httpResp.StatusCode)
}
if httpResp.Header.Get("Content-Type") != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
return nil
}
return json.NewDecoder(httpResp.Body).Decode(resp)
}
package clientutil
import (
"errors"
"net"
"net/http"
"time"
"github.com/cenkalti/backoff"
)
// NewExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func NewExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 100 * time.Millisecond
//b.Multiplier = 1.4142
return b
}
// Retry operation op until it succeeds according to the backoff policy b.
func Retry(op backoff.Operation, b backoff.BackOff) error {
innerOp := func() error {
err := op()
if err == nil {
return err
}
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
return err
}
return backoff.Permanent(err)
}
return backoff.Retry(innerOp, b)
}
var errHTTPBackOff = errors.New("temporary http error")
func isStatusTemporary(code int) bool {
switch code {
case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
// RetryHTTPDo retries an HTTP request until it succeeds, according to
// the backoff policy b. It will retry on temporary network errors and
// upon receiving specific temporary HTTP errors.
func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) {
var resp *http.Response
op := func() error {
// Clear up previous response if set.
if resp != nil {
resp.Body.Close()
}
var err error
resp, err = client.Do(req)
if err == nil && isStatusTemporary(resp.StatusCode) {
resp.Body.Close()
return errHTTPBackOff
}
return err
}
err := Retry(op, b)
return resp, err
}
package clientutil
import (
"crypto/tls"
common "git.autistici.org/ai3/go-common"
)
// TLSClientConfig defines the TLS parameters for a client connection
// that should use a client X509 certificate for authentication.
type TLSClientConfig struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
}
// TLSConfig returns a tls.Config object with the current configuration.
func (c *TLSClientConfig) TLSConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
if err != nil {
return nil, err
}
tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
}
if c.CA != "" {
cas, err := common.LoadCA(c.CA)
if err != nil {
return nil, err
}
tlsConf.RootCAs = cas
}
tlsConf.BuildNameToCertificate()
return tlsConf, nil
}
package clientutil
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"sync"
"time"
)
var errAllBackendsFailed = errors.New("all backends failed")
type dnsResolver struct{}
func (r *dnsResolver) ResolveIPs(hosts []string) []string {
var resolved []string
for _, hostport := range hosts {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
continue
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
continue
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
}
return resolved
}
var defaultResolver = &dnsResolver{}
type resolver interface {
ResolveIPs([]string) []string
}
// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or they all fail.
//
// This object should not be used for load balancing of individual
// HTTP requests: once a new connection is established, requests will
// be sent over it until it errors out. It's meant to provide a
// *reliable* connection to a set of equivalent backends for HA
// purposes.
type balancer struct {
hosts []string
resolver resolver
stop chan bool
// List of currently valid (or untested) backends, and ones
// that errored out at least once.
mx sync.Mutex
addrs []string
ok map[string]bool
}
var backendUpdateInterval = 60 * time.Second
// Periodically update the list of available backends.
func (b *balancer) updateProc() {
tick := time.NewTicker(backendUpdateInterval)
for {
select {
case <-b.stop:
return
case <-tick.C:
resolved := b.resolver.ResolveIPs(b.hosts)
if len(resolved) > 0 {
b.mx.Lock()
b.addrs = resolved
b.mx.Unlock()
}
}
}
}
// Returns a list of all available backends, split into "good ones"
// (no errors seen since last successful connection) and "bad ones".
func (b *balancer) getBackends() ([]string, []string) {
b.mx.Lock()
defer b.mx.Unlock()
var good, bad []string
for _, addr := range b.addrs {
if ok := b.ok[addr]; ok {
good = append(good, addr)
} else {
bad = append(bad, addr)
}
}
return good, bad
}
func (b *balancer) notify(addr string, ok bool) {
b.mx.Lock()
b.ok[addr] = ok
b.mx.Unlock()
}
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 30 * time.Second
// Go < 1.9 does not have net.DialContext, reimplement it in
// terms of net.DialTimeout.
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
}
func (b *balancer) dial(ctx context.Context, network, addr string) (net.Conn, error) {
// Start by attempting a connection on 'good' targets.
good, bad := b.getBackends()
for _, addr := range good {
// Go < 1.9 does not have DialContext, deal with it
conn, err := netDialContext(ctx, network, addr)
if err == nil {
return conn, nil
} else if err == context.Canceled {
return nil, err
}
b.notify(addr, false)
}
for _, addr := range bad {
conn, err := netDialContext(ctx, network, addr)
if err == nil {
b.notify(addr, true)
return conn, nil
} else if err == context.Canceled {
return nil, err
}
}
return nil, errAllBackendsFailed
}
// NewTransport returns a suitably configured http.RoundTripper that
// talks to a specific backend service. It performs discovery of
// available backends via DNS (using A or AAAA record lookups), tries
// to route traffic away from faulty backends.
//
// It will periodically attempt to rediscover new backends.
func NewTransport(backends []string, tlsConf *tls.Config, resolver resolver) http.RoundTripper {
if resolver == nil {
resolver = defaultResolver
}
addrs := resolver.ResolveIPs(backends)
b := &balancer{
hosts: backends,
resolver: resolver,
addrs: addrs,
ok: make(map[string]bool),
}
go b.updateProc()
return &http.Transport{
DialContext: b.dial,
TLSClientConfig: tlsConf,
}
}
package ldaputil
import (
"context"
"errors"
"net/url"
"gopkg.in/ldap.v2"
)
// ConnectionPool provides a goroutine-safe pool of long-lived LDAP
// connections that will reconnect on errors.
type ConnectionPool struct {
network string
addr string
bindDN string
bindPw string
c chan *ldap.Conn
}
func (p *ConnectionPool) connect() (*ldap.Conn, error) {
conn, err := ldap.Dial(p.network, p.addr)
if err != nil {
return nil, err
}
if err = conn.Bind(p.bindDN, p.bindPw); err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// Get a fresh connection from the pool.
func (p *ConnectionPool) Get(ctx context.Context) (*ldap.Conn, error) {
select {
case conn := <-p.c:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Release a used connection onto the pool.
func (p *ConnectionPool) Release(conn *ldap.Conn, err error) {
// We assume that if we get an ErrorNetwork, then we need to reconnect.
for err != nil && ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
if conn != nil {
conn.Close()
}
conn, err = p.connect()
}
p.c <- conn
}
// Close all connections. Not implemented yet.
func (p *ConnectionPool) Close() {}
// Parse a LDAP URI into network and address strings suitable for
// ldap.Dial.
func parseLDAPURI(uri string) (string, string, error) {
u, err := url.Parse(uri)
if err != nil {
return "", "", err
}
network := "tcp"
addr := "localhost:389"
switch u.Scheme {
case "ldap":
if u.Host != "" {
addr = u.Host
}
case "ldapi":
network = "unix"
addr = u.Path
default:
return "", "", errors.New("unsupported scheme")
}
return network, addr, nil
}
// NewConnectionPool creates a new pool of LDAP connections to the
// specified server, using the provided bind credentials. The pool
// will contain numConns connections.
func NewConnectionPool(uri, bindDN, bindPw string, numConns int) (*ConnectionPool, error) {
network, addr, err := parseLDAPURI(uri)
if err != nil {
return nil, err
}
p := &ConnectionPool{
c: make(chan *ldap.Conn, numConns),
network: network,
addr: addr,
bindDN: bindDN,
bindPw: bindPw,
}
for i := 0; i < numConns; i++ {
conn, err := p.connect()
if err != nil {
p.Close()
return nil, err
}
p.c <- conn
}
return p, nil
}
package common
import (
"crypto/x509"
"io/ioutil"
)
// LoadCA loads a file containing CA certificates into a x509.CertPool.
func LoadCA(path string) (*x509.CertPool, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
cas.AppendCertsFromPEM(data)
return cas, nil
}
This diff is collapsed.
usermetadb
==========
The *User Metadata Database* (`usermetadb`) stores long-term information
about user access patterns in order to detect anomalous behavior and