Commit e4a26106 authored by ale's avatar ale

Use ai3/go-common for TLS client code and HA transport

parent 6b6d2bb1
Pipeline #612 passed with stages
in 1 minute and 16 seconds
......@@ -2,7 +2,6 @@ package proxy
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
......@@ -10,6 +9,7 @@ import (
"net/http/httputil"
"net/url"
"git.autistici.org/ai3/go-common/clientutil"
"github.com/gorilla/mux"
"git.autistici.org/id/go-sso/httpsso"
......@@ -18,10 +18,9 @@ import (
// Backend defines a single-host HTTP proxy to a set of upstream
// backends.
type Backend struct {
Host string `yaml:"host"`
Upstream []string `yaml:"upstream"`
ClientTLSConfig *TLSConfig `yaml:"client_tls"`
//ServerTLSConfig *TLSConfig `yaml:"server_tls"`
Host string `yaml:"host"`
Upstream []string `yaml:"upstream"`
ClientTLSConfig *clientutil.TLSClientConfig `yaml:"client_tls"`
AllowedGroups []string `yaml:"allowed_groups"`
}
......@@ -42,51 +41,17 @@ func (b *Backend) newHandler(ssow *httpsso.SSOWrapper) (http.Handler, error) {
var tlsConfig *tls.Config
if b.ClientTLSConfig != nil {
var err error
tlsConfig, err = b.ClientTLSConfig.toClientConfig()
tlsConfig, err = b.ClientTLSConfig.TLSConfig()
if err != nil {
return nil, err
}
}
proxy.Transport = newTransport(b.Upstream, tlsConfig)
proxy.Transport = clientutil.NewTransport(b.Upstream, tlsConfig, nil)
h := ssow.Wrap(proxy, b.Host+"/", b.AllowedGroups)
return h, nil
}
// TLSConfig defines the TLS parameters for a client connection.
type TLSConfig struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
}
func (c *TLSConfig) toClientConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
if err != nil {
return nil, err
}
cas, err := loadCA(c.CA)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: cas,
}, nil
}
func loadCA(path string) (*x509.CertPool, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
cas.AppendCertsFromPEM(data)
return cas, nil
}
// func buildServerTLSConfig(config *Configuration) (*tls.Config, error) {
// var certs []tls.Certificate
// for _, b := range config.Backends {
......
package proxy
import (
"crypto/tls"
"errors"
"log"
"math/rand"
"net"
"net/http"
"sort"
"sync"
)
func resolveIPs(hosts []string) []string {
var resolved []string
for _, hostport := range hosts {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
continue
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
continue
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
}
return resolved
}
type balancer struct {
mx sync.Mutex
ips []string
errs []uint64
}
func (b *balancer) incrError(index int) {
b.mx.Lock()
b.errs[index]++
b.mx.Unlock()
}
func (b *balancer) dial(network, addr string) (net.Conn, error) {
ips, err := b.pickIPs()
if err != nil {
return nil, err
}
for _, s := range ips {
conn, err := net.Dial(network, s.ip)
if err == nil {
return conn, nil
}
log.Printf("error connecting to %s: %v", s.ip, err)
b.incrError(s.index)
}
return nil, errors.New("all upstream connections failed")
}
type ipScore struct {
ip string
score int
index int
}
type ipScoreList []ipScore
func (l ipScoreList) Len() int { return len(l) }
func (l ipScoreList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l ipScoreList) Less(i, j int) bool { return l[i].score < l[j].score }
func shuffleScores(scores []ipScore) {
for i, j := range rand.Perm(len(scores)) {
scores[i], scores[j] = scores[j], scores[i]
}
}
const minErrs = 3
func (b *balancer) pickIPs() ([]ipScore, error) {
b.mx.Lock()
scores := make([]ipScore, len(b.ips))
for i, ip := range b.ips {
score := 1
if b.errs[i] > minErrs {
score *= 10
}
scores[i] = ipScore{ip: ip, score: score, index: i}
}
b.mx.Unlock()
sort.Sort(ipScoreList(scores))
// Iterate through the sorted list, shuffling groups of
// elements that have identical scores.
curScore := scores[0].score
head := 0
for i := 1; i < len(scores); i++ {
if scores[i].score != curScore {
group := scores[head : i+1]
if len(group) > 1 {
shuffleScores(group)
}
head = i + 1
}
}
group := scores[head:]
if len(group) > 1 {
shuffleScores(group)
}
return scores, nil
}
func newTransport(backends []string, tlsConf *tls.Config) http.RoundTripper {
ips := resolveIPs(backends)
b := &balancer{
ips: ips,
errs: make([]uint64, len(ips)),
}
return &http.Transport{
Dial: b.dial,
TLSClientConfig: tlsConf,
}
}
package clientutil
import (
"errors"
"net"
"net/http"
"time"
"github.com/cenkalti/backoff"
)
// NewExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func NewExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 100 * time.Millisecond
//b.Multiplier = 1.4142
return b
}
// Retry operation op until it succeeds according to the backoff policy b.
func Retry(op backoff.Operation, b backoff.BackOff) error {
innerOp := func() error {
err := op()
if err == nil {
return err
}
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
return err
}
return backoff.Permanent(err)
}
return backoff.Retry(innerOp, b)
}
var errHTTPBackOff = errors.New("http status 503")
// RetryHTTPDo retries an HTTP request until it succeeds, according to
// the backoff policy b. It will retry on temporary network errors and
// upon receiving specific throttling HTTP errors (currently just
// status code 503).
func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) {
var resp *http.Response
op := func() error {
// Clear up previous response if set.
if resp != nil {
resp.Body.Close()
}
var err error
resp, err = client.Do(req)
if err == nil && resp.StatusCode == 503 {
resp.Body.Close()
return errHTTPBackOff
}
return err
}
err := Retry(op, b)
return resp, err
}
package clientutil
import (
"crypto/tls"
common "git.autistici.org/ai3/go-common"
)
// TLSClientConfig defines the TLS parameters for a client connection
// that should use a client X509 certificate for authentication.
type TLSClientConfig struct {
Cert string `yaml:"cert"`
Key string `yaml:"key"`
CA string `yaml:"ca"`
}
// TLSConfig returns a tls.Config object with the current configuration.
func (c *TLSClientConfig) TLSConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
if err != nil {
return nil, err
}
tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
}
if c.CA != "" {
cas, err := common.LoadCA(c.CA)
if err != nil {
return nil, err
}
tlsConf.RootCAs = cas
}
tlsConf.BuildNameToCertificate()
return tlsConf, nil
}
package clientutil
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"sync"
"time"
)
var errAllBackendsFailed = errors.New("all backends failed")
type dnsResolver struct{}
func (r *dnsResolver) ResolveIPs(hosts []string) []string {
var resolved []string
for _, hostport := range hosts {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
continue
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
continue
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
}
return resolved
}
var defaultResolver = &dnsResolver{}
type resolver interface {
ResolveIPs([]string) []string
}
// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or they all fail.
//
// This object should not be used for load balancing of individual
// HTTP requests: once a new connection is established, requests will
// be sent over it until it errors out. It's meant to provide a
// *reliable* connection to a set of equivalent backends for HA
// purposes.
type balancer struct {
hosts []string
resolver resolver
stop chan bool
// List of currently valid (or untested) backends, and ones
// that errored out at least once.
mx sync.Mutex
addrs []string
ok map[string]bool
}
var backendUpdateInterval = 60 * time.Second
// Periodically update the list of available backends.
func (b *balancer) updateProc() {
tick := time.NewTicker(backendUpdateInterval)
for {
select {
case <-b.stop:
return
case <-tick.C:
resolved := b.resolver.ResolveIPs(b.hosts)
if len(resolved) > 0 {
b.mx.Lock()
b.addrs = resolved
b.mx.Unlock()
}
}
}
}
// Returns a list of all available backends, split into "good ones"
// (no errors seen since last successful connection) and "bad ones".
func (b *balancer) getBackends() ([]string, []string) {
b.mx.Lock()
defer b.mx.Unlock()
var good, bad []string
for _, addr := range b.addrs {
if ok := b.ok[addr]; ok {
good = append(good, addr)
} else {
bad = append(bad, addr)
}
}
return good, bad
}
func (b *balancer) notify(addr string, ok bool) {
b.mx.Lock()
b.ok[addr] = ok
b.mx.Unlock()
}
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 30 * time.Second
// Go < 1.9 does not have net.DialContext, reimplement it in
// terms of net.DialTimeout.
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
}
func (b *balancer) dial(ctx context.Context, network, addr string) (net.Conn, error) {
// Start by attempting a connection on 'good' targets.
good, bad := b.getBackends()
for _, addr := range good {
// Go < 1.9 does not have DialContext, deal with it
conn, err := netDialContext(ctx, network, addr)
if err == nil {
return conn, nil
} else if err == context.Canceled {
return nil, err
}
b.notify(addr, false)
}
for _, addr := range bad {
conn, err := netDialContext(ctx, network, addr)
if err == nil {
b.notify(addr, true)
return conn, nil
} else if err == context.Canceled {
return nil, err
}
}
return nil, errAllBackendsFailed
}
// NewTransport returns a suitably configured http.RoundTripper that
// talks to a specific backend service. It performs discovery of
// available backends via DNS (using A or AAAA record lookups), tries
// to route traffic away from faulty backends.
//
// It will periodically attempt to rediscover new backends.
func NewTransport(backends []string, tlsConf *tls.Config, resolver resolver) http.RoundTripper {
if resolver == nil {
resolver = defaultResolver
}
addrs := resolver.ResolveIPs(backends)
b := &balancer{
hosts: backends,
resolver: resolver,
addrs: addrs,
ok: make(map[string]bool),
}
go b.updateProc()
return &http.Transport{
DialContext: b.dial,
TLSClientConfig: tlsConf,
}
}
package common
import (
"crypto/x509"
"io/ioutil"
)
// LoadCA loads a file containing CA certificates into a x509.CertPool.
func LoadCA(path string) (*x509.CertPool, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
cas.AppendCertsFromPEM(data)
return cas, nil
}
The MIT License (MIT)
Copyright (c) 2014 Cenk Altı
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls]
This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client].
[Exponential backoff][exponential backoff wiki]
is an algorithm that uses feedback to multiplicatively decrease the rate of some process,
in order to gradually find an acceptable rate.
The retries exponentially increase and stop increasing when a certain threshold is met.
## Usage
See https://godoc.org/github.com/cenkalti/backoff#pkg-examples
## Contributing
* I would like to keep this library as small as possible.
* Please don't send a PR without opening an issue and discussing it first.
* If proposed change is not a common use case, I will probably not accept it.
[godoc]: https://godoc.org/github.com/cenkalti/backoff
[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png
[travis]: https://travis-ci.org/cenkalti/backoff
[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master
[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master
[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master
[google-http-java-client]: https://github.com/google/google-http-java-client
[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff
[advanced example]: https://godoc.org/github.com/cenkalti/backoff#example_
// Package backoff implements backoff algorithms for retrying operations.
//
// Use Retry function for retrying operations that may fail.
// If Retry does not meet your needs,
// copy/paste the function into your project and modify as you wish.
//
// There is also Ticker type similar to time.Ticker.
// You can use it if you need to work with channels.
//
// See Examples section below for usage examples.
package backoff
import "time"
// BackOff is a backoff policy for retrying an operation.
type BackOff interface {
// NextBackOff returns the duration to wait before retrying the operation,
// or backoff. Stop to indicate that no more retries should be made.
//
// Example usage:
//
// duration := backoff.NextBackOff();
// if (duration == backoff.Stop) {
// // Do not retry operation.
// } else {
// // Sleep for duration and retry operation.
// }
//
NextBackOff() time.Duration
// Reset to initial state.
Reset()
}
// Stop indicates that no more retries should be made for use in NextBackOff().
const Stop time.Duration = -1
// ZeroBackOff is a fixed backoff policy whose backoff time is always zero,
// meaning that the operation is retried immediately without waiting, indefinitely.
type ZeroBackOff struct{}
func (b *ZeroBackOff) Reset() {}
func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 }
// StopBackOff is a fixed backoff policy that always returns backoff.Stop for
// NextBackOff(), meaning that the operation should never be retried.
type StopBackOff struct{}
func (b *StopBackOff) Reset() {}
func (b *StopBackOff) NextBackOff() time.Duration { return Stop }
// ConstantBackOff is a backoff policy that always returns the same backoff delay.
// This is in contrast to an exponential backoff policy,
// which returns a delay that grows longer as you call NextBackOff() over and over again.
type ConstantBackOff struct {
Interval time.Duration
}
func (b *ConstantBackOff) Reset() {}
func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval }
func NewConstantBackOff(d time.Duration) *ConstantBackOff {
return &ConstantBackOff{Interval: d}
}
package backoff
import (
"time"
"golang.org/x/net/context"
)