Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/acmeserver
  • godog/acmeserver
  • svp-bot/acmeserver
3 results
Show changes
Showing
with 1101 additions and 398 deletions
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:base"
]
}
package acmeserver
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"log"
"math/big"
"time"
)
type selfSignedGenerator struct{}
// NewSelfSignedCertGenerator returns a CertGenerator that will create
// self-signed certificates for every request. Primarily useful for
// testing acmeserver as a functional component in integration tests.
func NewSelfSignedCertGenerator() CertGenerator {
return &selfSignedGenerator{}
}
func (g *selfSignedGenerator) GetCertificate(_ context.Context, priv crypto.Signer, c *certConfig) ([][]byte, *x509.Certificate, error) {
notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: c.Names[0],
},
DNSNames: c.Names,
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, nil, err
}
log.Printf("created certificate for %s", c.Names[0])
return [][]byte{der}, cert, nil
}
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
......@@ -16,8 +16,7 @@ import (
"strings"
"time"
"git.autistici.org/ai3/go-common/clientutil"
"git.autistici.org/ai3/replds"
"git.autistici.org/ai3/tools/replds"
)
type dirStorage struct {
......@@ -26,7 +25,7 @@ type dirStorage struct {
func (d *dirStorage) GetCert(cn string) ([][]byte, crypto.Signer, error) {
certPath := filepath.Join(d.root, cn, "fullchain.pem")
keyPath := filepath.Join(d.root, cn, "private_key.pem")
keyPath := filepath.Join(d.root, cn, "privkey.pem")
der, err := parseCertsFromFile(certPath)
if err != nil {
......@@ -47,14 +46,14 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
dir := filepath.Join(d.root, cn)
if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0755); err != nil {
if err = os.MkdirAll(dir, 0750); err != nil {
return err
}
}
for path, data := range filemap {
var mode os.FileMode = 0644
if strings.HasSuffix(path, "private_key.pem") {
if strings.HasSuffix(path, "privkey.pem") {
mode = 0400
}
log.Printf("writing %s (%03o)", path, mode)
......@@ -80,11 +79,19 @@ func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]b
}
m[filepath.Join(cn, "cert.pem")] = data
if len(der) > 1 {
data, err = encodeCerts(der[1:])
if err != nil {
return nil, err
}
m[filepath.Join(cn, "chain.pem")] = data
}
data, err = encodePrivateKey(key)
if err != nil {
return nil, err
}
m[filepath.Join(cn, "private_key.pem")] = data
m[filepath.Join(cn, "privkey.pem")] = data
return m, nil
}
......@@ -93,7 +100,7 @@ func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]b
// certificates to replds instead.
type replStorage struct {
*dirStorage
replClient clientutil.Backend
replClient replds.PublicClient
}
func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
......@@ -105,7 +112,7 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error
now := time.Now()
var req replds.SetNodesRequest
for path, data := range filemap {
req.Nodes = append(req.Nodes, replds.Node{
req.Nodes = append(req.Nodes, &replds.Node{
Path: path,
Value: data,
Timestamp: now,
......@@ -115,8 +122,8 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
var resp replds.SetNodesResponse
if err := clientutil.DoJSONHTTPRequest(ctx, d.replClient.Client(""), d.replClient.URL("")+"/api/set_nodes", req, &resp); err != nil {
resp, err := d.replClient.SetNodes(ctx, &req)
if err != nil {
return err
}
if resp.HostsOk < 1 {
......@@ -126,7 +133,7 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error
}
func parseCertsFromFile(path string) ([][]byte, error) {
data, err := ioutil.ReadFile(path)
data, err := ioutil.ReadFile(path) // nolint: gosec
if err != nil {
return nil, err
}
......@@ -144,7 +151,7 @@ func parseCertsFromFile(path string) ([][]byte, error) {
}
func parsePrivateKeyFromFile(path string) (crypto.Signer, error) {
data, err := ioutil.ReadFile(path)
data, err := ioutil.ReadFile(path) // nolint: gosec
if err != nil {
return nil, err
}
......@@ -172,7 +179,7 @@ func encodePrivateKey(key crypto.Signer) ([]byte, error) {
switch priv := key.(type) {
case *rsa.PrivateKey:
pb = &pem.Block{
Type: "PRIVATE KEY",
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
}
case *ecdsa.PrivateKey:
......
package acmeserver
import (
"context"
"sync"
"time"
)
// Updates are long-term jobs, so they should be interruptible. We run
// updates in a separate goroutine, and cancel them when the
// configuration is reloaded or on exit. A semaphore ensures that only
// one update goroutine will be running at any given time (without
// other ones piling up).
func runWithUpdates(ctx context.Context, fn func(context.Context, interface{}), reloadCh <-chan interface{}, updateInterval time.Duration) {
// Function to cancel the current update, and the associated
// WaitGroup to wait for its termination.
var upCancel context.CancelFunc
var wg sync.WaitGroup
sem := make(chan struct{}, 1)
startUpdate := func(value interface{}) context.CancelFunc {
// Acquire the semaphore, return if we fail to.
// Equivalent to a 'try-lock' construct.
select {
case sem <- struct{}{}:
default:
return nil
}
defer func() {
<-sem
}()
ctx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
fn(ctx, value)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
upCancel = nil
}
wg.Wait()
}
defer cancelUpdate()
var cur interface{}
tick := time.NewTicker(updateInterval)
defer tick.Stop()
for {
select {
case <-tick.C:
// Do not cancel running update when running the ticker.
if cancel := startUpdate(cur); cancel != nil {
upCancel = cancel
}
case value := <-reloadCh:
// Cancel the running update when configuration is reloaded.
cancelUpdate()
cur = value
case <-ctx.Done():
return
}
}
}
package acmeserver
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
"time"
)
// Verify that a certificate is valid according both to the current
// time and to the specified parameters.
func validCert(domains []string, der [][]byte, key crypto.Signer) (*x509.Certificate, error) {
leaf, err := leafCertFromDER(der)
if err != nil {
return nil, err
}
// Verify the leaf is not expired.
if err := isCertExpired(leaf); err != nil {
return nil, err
}
// Verify that it matches the given domains.
if err := certMatchesDomains(leaf, domains); err != nil {
return nil, err
}
// Verify that it matches the private key.
if err := certMatchesPrivateKey(leaf, key); err != nil {
return nil, err
}
return leaf, nil
}
func leafCertFromDER(der [][]byte) (*x509.Certificate, error) {
x509Cert, err := x509.ParseCertificates(concatDER(der))
if err != nil {
return nil, err
}
if len(x509Cert) == 0 {
return nil, errors.New("no public key found")
}
return x509Cert[0], nil
}
func isCertExpired(cert *x509.Certificate) error {
now := time.Now()
if now.Before(cert.NotBefore) {
return errors.New("certificate isn't valid yet")
}
if now.After(cert.NotAfter) {
return errors.New("certificate expired")
}
return nil
}
func certMatchesDomains(cert *x509.Certificate, domains []string) error {
for _, domain := range domains {
if err := cert.VerifyHostname(domain); err != nil {
return fmt.Errorf("certificate does not match domain %q", domain)
}
}
return nil
}
func certMatchesPrivateKey(cert *x509.Certificate, key crypto.Signer) error {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
prv, ok := key.(*rsa.PrivateKey)
if !ok {
return errors.New("private key type does not match public key type")
}
if pub.N.Cmp(prv.N) != 0 {
return errors.New("private key does not match public key")
}
case *ecdsa.PublicKey:
prv, ok := key.(*ecdsa.PrivateKey)
if !ok {
return errors.New("private key type does not match public key type")
}
if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 {
return errors.New("private key does not match public key")
}
default:
return errors.New("unsupported public key algorithm")
}
return nil
}
stages:
- test
run_tests:
stage: test
image: registry.git.autistici.org/ai3/docker/test/golang:master
script:
- run-go-test ./...
artifacts:
when: always
reports:
coverage_report:
coverage_format: cobertura
path: cover.xml
junit: report.xml
ai3/go-common
===
Common code for ai3 services and tools.
A quick overview of the contents:
* [client](clientutil/) and [server](serverutil/) HTTP-based
"RPC" implementation, just JSON POST requests but with retries,
backoff, timeouts, tracing, etc.
* [server implementation of a generic line-based protocol over a UNIX
socket](unix/).
* a [LDAP connection pool](ldap/).
* utilities to [serialize composite data types](ldap/compositetypes/)
used in our LDAP database.
* a [password hashing library](pwhash/) that uses fancy advanced
crypto by default but is also backwards compatible with old
libc crypto.
* utilities to [manage encryption keys](userenckey/), themselves
encrypted with a password and a KDF.
package clientutil
import (
"crypto/tls"
"fmt"
"context"
"net/http"
"net/url"
"sync"
"time"
)
// BackendConfig specifies the configuration to access a service.
// BackendConfig specifies the configuration of a service backend.
//
// Services with multiple backends can be replicated or partitioned,
// depending on a configuration switch, making it a deployment-time
......@@ -18,102 +14,44 @@ import (
// 'shard' parameter on their APIs.
type BackendConfig struct {
URL string `yaml:"url"`
TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"`
TLSConfig *TLSClientConfig `yaml:"tls_config"`
Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
}
// Backend is a runtime class that provides http Clients for use with
// a specific service backend. If the service can't be partitioned,
// pass an empty string to the Client method.
// pass an empty string to the Call method.
type Backend interface {
// URL for the service for a specific shard.
URL(string) string
// Client that can be used to make a request to the service.
Client(string) *http.Client
// Call a remote method. The sharding behavior is the following:
//
// Services that support sharding (partitioning) should always
// include the shard ID in their Call() requests. Users can
// then configure backends to be sharded or not in their
// Config. When invoking Call with a shard ID on a non-sharded
// service, the shard ID is simply ignored. Invoking Call
// *without* a shard ID on a sharded service is an error.
Call(context.Context, string, string, interface{}, interface{}) error
// Make a simple HTTP GET request to the remote backend,
// without parsing the response as JSON.
//
// Useful for streaming large responses, where the JSON
// encoding overhead is undesirable.
Get(context.Context, string, string) (*http.Response, error)
// Close all resources associated with the backend.
Close()
}
// NewBackend returns a new Backend with the given config.
func NewBackend(config *BackendConfig) (Backend, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
if config.TLSConfig != nil {
tlsConfig, err = config.TLSConfig.TLSConfig()
if err != nil {
return nil, err
}
}
if config.Sharded {
return &replicatedClient{
u: u,
c: newHTTPClient(u, tlsConfig),
}, nil
}
return &shardedClient{
baseURL: u,
tlsConfig: tlsConfig,
urls: make(map[string]*url.URL),
shards: make(map[string]*http.Client),
}, nil
}
type replicatedClient struct {
c *http.Client
u *url.URL
}
func (r *replicatedClient) Client(_ string) *http.Client { return r.c }
func (r *replicatedClient) URL(_ string) string { return r.u.String() }
type shardedClient struct {
baseURL *url.URL
tlsConfig *tls.Config
mx sync.Mutex
urls map[string]*url.URL
shards map[string]*http.Client
}
func (s *shardedClient) getShardURL(shard string) *url.URL {
if shard == "" {
return s.baseURL
}
u, ok := s.urls[shard]
if !ok {
var tmp = *s.baseURL
tmp.Host = fmt.Sprintf("%s.%s", shard, tmp.Host)
u = &tmp
s.urls[shard] = u
}
return u
}
func (s *shardedClient) URL(shard string) string {
s.mx.Lock()
defer s.mx.Unlock()
return s.getShardURL(shard).String()
}
func (s *shardedClient) Client(shard string) *http.Client {
s.mx.Lock()
defer s.mx.Unlock()
client, ok := s.shards[shard]
if !ok {
u := s.getShardURL(shard)
client = newHTTPClient(u, s.tlsConfig)
s.shards[shard] = client
}
return client
}
func newHTTPClient(u *url.URL, tlsConfig *tls.Config) *http.Client {
return &http.Client{
Transport: NewTransport([]string{u.Host}, tlsConfig, nil),
Timeout: 30 * time.Second,
}
return newBalancedBackend(config, defaultResolver)
}
package clientutil
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log"
"math/rand"
"mime"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
)
// Our own narrow logger interface.
type logger interface {
Printf(string, ...interface{})
}
// A nilLogger is used when Config.Debug is false.
type nilLogger struct{}
func (l nilLogger) Printf(_ string, _ ...interface{}) {}
// Parameters that define the exponential backoff algorithm used.
var (
ExponentialBackOffInitialInterval = 100 * time.Millisecond
ExponentialBackOffMultiplier = 1.4142
)
// newExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func newExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = ExponentialBackOffInitialInterval
b.Multiplier = ExponentialBackOffMultiplier
// Set MaxElapsedTime to 0 because we expect the overall
// timeout to be dictated by the request Context.
b.MaxElapsedTime = 0
return b
}
// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or returns a permanent error.
//
// This object should not be used for load balancing of individual
// HTTP requests: it doesn't do anything smart beyond trying to avoid
// broken targets. It's meant to provide a *reliable* connection to a
// set of equivalent services for HA purposes.
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
requestMaxTimeout time.Duration
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
if config.TLSConfig != nil {
tlsConfig, err = config.TLSConfig.TLSConfig()
if err != nil {
return nil, err
}
}
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig, connectTimeout),
requestMaxTimeout: reqTimeout,
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
// Call the backend. Makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error {
// Serialize the request body.
data, err := json.Marshal(req)
if err != nil {
return err
}
// Create the target sequence for this call. If there are multiple
// targets, reduce the timeout on each individual call accordingly to
// accomodate eventual failover.
seq, err := b.makeSequence(shard)
if err != nil {
return err
}
innerTimeout := 1 * time.Hour
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
return backoff.Retry(func() error {
req, rerr := b.newJSONRequest(path, shard, data)
if rerr != nil {
return rerr
}
innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
defer cancel()
// When do() returns successfully, we already know that the
// response had an HTTP status of 200.
httpResp, rerr := b.do(innerCtx, seq, req)
if rerr != nil {
return rerr
}
defer httpResp.Body.Close() // nolint
// Decode the response, unless the 'resp' output is nil.
if ct, _, _ := mime.ParseMediaType(httpResp.Header.Get("Content-Type")); ct != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
return nil
}
return json.NewDecoder(httpResp.Body).Decode(resp)
}, backoff.WithContext(newExponentialBackOff(), ctx))
}
// Makes a generic HTTP GET request to the backend uri.
func (b *balancedBackend) Get(ctx context.Context, shard, path string) (*http.Response, error) {
// Create the target sequence for this call. If there are multiple
// targets, reduce the timeout on each individual call accordingly to
// accomodate eventual failover.
seq, err := b.makeSequence(shard)
if err != nil {
return nil, err
}
innerTimeout := 1 * time.Hour
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
req, err := http.NewRequest("GET", b.getURIForRequest(shard, path), nil)
if err != nil {
return nil, err
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
var resp *http.Response
err = backoff.Retry(func() error {
innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
defer cancel()
// When do() returns successfully, we already know that the
// response had an HTTP status of 200.
var rerr error
resp, rerr = b.do(innerCtx, seq, req)
return rerr
}, backoff.WithContext(newExponentialBackOff(), ctx))
return resp, err
}
// Initialize a new target sequence.
func (b *balancedBackend) makeSequence(shard string) (*sequence, error) {
var tg targetGenerator = b.backendTracker
if b.sharded {
if shard == "" {
return nil, fmt.Errorf("call without shard to sharded service %s", b.baseURI.String())
}
tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver)
}
seq := newSequence(tg)
if seq.Len() == 0 {
return nil, errNoTargets
}
b.log.Printf("%016x: initialized", seq.ID())
return seq, nil
}
// Return the URI to be used for the request. This is used both in the
// Host HTTP header and as the TLS server name used to pick a server
// certificate (if using TLS).
func (b *balancedBackend) getURIForRequest(shard, path string) string {
u := *b.baseURI
if b.sharded && shard != "" {
u.Host = fmt.Sprintf("%s.%s", shard, u.Host)
}
u.Path = appendPath(u.Path, path)
return u.String()
}
// Build a http.Request object.
func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10))
return req, nil
}
// Select a new target from the given sequence and send the request to
// it. Wrap HTTP errors in a RemoteError object.
func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) {
target, terr := seq.Next()
if terr != nil {
return
}
b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target)
client := &http.Client{
Transport: b.transportCache.getTransport(target),
}
resp, err = client.Do(propagateDeadline(ctx, req))
if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) {
err = backoff.Permanent(err)
}
resp.Body.Close() // nolint
resp = nil
}
seq.Done(target, err)
return
}
const deadlineHeader = "X-RPC-Deadline"
// Propagate context deadline to the server using a HTTP header.
func propagateDeadline(ctx context.Context, req *http.Request) *http.Request {
req = req.WithContext(ctx)
if deadline, ok := ctx.Deadline(); ok {
req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10))
} else {
req.Header.Del(deadlineHeader)
}
return req
}
var errNoTargets = errors.New("no available backends")
type targetGenerator interface {
getTargets() []string
setStatus(string, bool)
}
// A replicatedSequence repeatedly iterates over available backends in order of
// preference. Once in a while it refreshes its list of available
// targets.
type sequence struct {
id uint64
tg targetGenerator
targets []string
pos int
}
func newSequence(tg targetGenerator) *sequence {
return &sequence{
id: rand.Uint64(),
tg: tg,
targets: tg.getTargets(),
}
}
func (s *sequence) ID() uint64 { return s.id }
func (s *sequence) Len() int { return len(s.targets) }
func (s *sequence) reloadTargets() {
targets := s.tg.getTargets()
if len(targets) > 0 {
s.targets = targets
s.pos = 0
}
}
// Next returns the next target.
func (s *sequence) Next() (t string, err error) {
if s.pos >= len(s.targets) {
s.reloadTargets()
if len(s.targets) == 0 {
err = errNoTargets
return
}
}
t = s.targets[s.pos]
s.pos++
return
}
func (s *sequence) Done(t string, err error) {
s.tg.setStatus(t, err == nil)
}
// A shardedGenerator returns a single sharded target to a sequence.
type shardedGenerator struct {
id uint64
addrs []string
}
func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator {
return &shardedGenerator{
id: rand.Uint64(),
addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)),
}
}
func (g *shardedGenerator) getTargets() []string { return g.addrs }
func (g *shardedGenerator) setStatus(_ string, _ bool) {}
// Concatenate two URI paths.
func appendPath(a, b string) string {
if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") {
return a + b[1:]
}
return a + b
}
// Some HTTP status codes are treated are temporary errors.
func isStatusTemporary(code int) bool {
switch code {
case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
File added
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}
// +build !go1.9
package clientutil
import (
"context"
"net"
"time"
)
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
if ctxTimeout < connectTimeout {
connectTimeout = ctxTimeout
}
}
return net.DialTimeout(network, addr, connectTimeout)
}
}
package clientutil
import (
"log"
"net"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type resolver interface {
ResolveIP(string) []string
}
type dnsResolver struct{}
func (r *dnsResolver) ResolveIP(hostport string) []string {
var resolved []string
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
return nil
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
return nil
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
return resolved
}
var defaultResolver = newDNSCache(&dnsResolver{})
type cacheDatum struct {
addrs []string
deadline time.Time
}
var dnsCacheTTL = 1 * time.Minute
type dnsCache struct {
resolver resolver
sf singleflight.Group
mx sync.RWMutex
cache map[string]cacheDatum
}
func newDNSCache(resolver resolver) *dnsCache {
return &dnsCache{
resolver: resolver,
cache: make(map[string]cacheDatum),
}
}
func (c *dnsCache) get(host string) ([]string, bool) {
d, ok := c.cache[host]
if !ok {
return nil, false
}
return d.addrs, d.deadline.After(time.Now())
}
func (c *dnsCache) update(host string) []string {
v, _, _ := c.sf.Do(host, func() (interface{}, error) {
addrs := c.resolver.ResolveIP(host)
// By uncommenting this, we stop caching negative results.
// if len(addrs) == 0 {
// return nil, nil
// }
c.mx.Lock()
c.cache[host] = cacheDatum{
addrs: addrs,
deadline: time.Now().Add(dnsCacheTTL),
}
c.mx.Unlock()
return addrs, nil
})
return v.([]string)
}
func (c *dnsCache) ResolveIP(host string) []string {
c.mx.RLock()
addrs, ok := c.get(host)
c.mx.RUnlock()
if ok {
return addrs
}
if len(addrs) > 0 {
go c.update(host)
return addrs
}
return c.update(host)
}
// Package clientutil implements a very simple style of JSON RPC.
//
// Requests and responses are both encoded in JSON, and they should
// have the "application/json" Content-Type.
//
// HTTP response statuses other than 200 indicate an error: in this
// case, the response body may contain (in plain text) further details
// about the error. Some HTTP status codes are considered temporary
// errors (incl. 429 for throttling). The client will retry requests,
// if targets are available, until the context expires - so it's quite
// important to remember to set a timeout on the context given to the
// Call() function!
//
// The client handles both replicated services and sharded
// (partitioned) services. Users of this package that want to support
// sharded deployments are supposed to pass a shard ID to every
// Call(). At the deployment stage, sharding can be enabled via the
// configuration.
//
// For replicated services, the client will expect the provided
// hostname to resolve to one or more IP addresses, in which case it
// will pick a random IP address on every new request, while
// remembering which addresses have had errors and trying to avoid
// them. It will however send an occasional request to the failed
// targets, to see if they've come back.
//
// For sharded services, the client makes simple HTTP requests to the
// specific target identified by the shard. It does this by prepending
// the shard ID to the backend hostname (so a request to "example.com"
// with shard ID "1" becomes a request to "1.example.com").
//
// The difference with other JSON-RPC implementations is that we use a
// different URI for every method, and we force the usage of
// request/response types. This makes it easy for projects to
// eventually migrate to GRPC.
//
package clientutil
package clientutil
import (
"fmt"
"io/ioutil"
"net/http"
)
// RemoteError represents a HTTP error from the server. The status
// code and response body can be retrieved with the StatusCode() and
// Body() methods.
type RemoteError struct {
statusCode int
body string
}
func remoteErrorFromResponse(resp *http.Response) *RemoteError {
// Optimistically read the response body, ignoring errors.
var body string
if data, err := ioutil.ReadAll(resp.Body); err == nil {
body = string(data)
}
return &RemoteError{statusCode: resp.StatusCode, body: body}
}
// Error implements the error interface.
func (e *RemoteError) Error() string {
return fmt.Sprintf("%d - %s", e.statusCode, e.body)
}
// StatusCode returns the HTTP status code.
func (e *RemoteError) StatusCode() int { return e.statusCode }
// Body returns the response body.
func (e *RemoteError) Body() string { return e.body }
package clientutil
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)
// DoJSONHTTPRequest makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func DoJSONHTTPRequest(ctx context.Context, client *http.Client, uri string, req, resp interface{}) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
httpReq, err := http.NewRequest("POST", uri, bytes.NewReader(data))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq = httpReq.WithContext(ctx)
httpResp, err := RetryHTTPDo(client, httpReq, NewExponentialBackOff())
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != 200 {
return fmt.Errorf("HTTP status %d", httpResp.StatusCode)
}
if httpResp.Header.Get("Content-Type") != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
return nil
}
return json.NewDecoder(httpResp.Body).Decode(resp)
}
package clientutil
import (
"errors"
"net/http"
"time"
"github.com/cenkalti/backoff"
)
// NewExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func NewExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 100 * time.Millisecond
//b.Multiplier = 1.4142
return b
}
// A temporary (retriable) error is something that has a Temporary method.
type tempError interface {
Temporary() bool
}
type tempErrorWrapper struct {
error
}
func (t tempErrorWrapper) Temporary() bool { return true }
// TempError makes a temporary (retriable) error out of a normal error.
func TempError(err error) error {
return tempErrorWrapper{err}
}
// Retry operation op until it succeeds according to the backoff
// policy b.
//
// Note that this function reverses the error semantics of
// backoff.Operation: all errors are permanent unless explicitly
// marked as temporary (i.e. they have a Temporary() method that
// returns true). This is to better align with the errors returned by
// the net package.
func Retry(op backoff.Operation, b backoff.BackOff) error {
innerOp := func() error {
err := op()
if err == nil {
return err
}
if tmpErr, ok := err.(tempError); ok && tmpErr.Temporary() {
return err
}
return backoff.Permanent(err)
}
return backoff.Retry(innerOp, b)
}
var errHTTPBackOff = TempError(errors.New("temporary http error"))
func isStatusTemporary(code int) bool {
switch code {
case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
// RetryHTTPDo retries an HTTP request until it succeeds, according to
// the backoff policy b. It will retry on temporary network errors and
// upon receiving specific temporary HTTP errors. It will use the
// context associated with the HTTP request object.
func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) {
var resp *http.Response
op := func() error {
// Clear up previous response if set.
if resp != nil {
resp.Body.Close()
}
var err error
resp, err = client.Do(req)
if err == nil && isStatusTemporary(resp.StatusCode) {
resp.Body.Close()
return errHTTPBackOff
}
return err
}
err := Retry(op, backoff.WithContext(b, req.Context()))
return resp, err
}
......@@ -2,6 +2,7 @@ package clientutil
import (
"crypto/tls"
"errors"
common "git.autistici.org/ai3/go-common"
)
......@@ -16,6 +17,10 @@ type TLSClientConfig struct {
// TLSConfig returns a tls.Config object with the current configuration.
func (c *TLSClientConfig) TLSConfig() (*tls.Config, error) {
if c.Cert == "" || c.Key == "" || c.CA == "" {
return nil, errors.New("incomplete client tls specification")
}
cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
if err != nil {
return nil, err
......@@ -24,13 +29,11 @@ func (c *TLSClientConfig) TLSConfig() (*tls.Config, error) {
Certificates: []tls.Certificate{cert},
}
if c.CA != "" {
cas, err := common.LoadCA(c.CA)
if err != nil {
return nil, err
}
tlsConf.RootCAs = cas
cas, err := common.LoadCA(c.CA)
if err != nil {
return nil, err
}
tlsConf.RootCAs = cas
tlsConf.BuildNameToCertificate()
return tlsConf, nil
......
package clientutil
import (
"math/rand"
"sync"
"time"
)
// The backendTracker tracks the state of the targets associated with
// a backend, and periodically checks DNS for updates.
type backendTracker struct {
log logger
addr string
resolver resolver
stopCh chan struct{}
mx sync.Mutex
resolved []string
failed map[string]time.Time
}
func newBackendTracker(addr string, resolver resolver, logger logger) *backendTracker {
// Resolve the targets once before returning.
b := &backendTracker{
addr: addr,
resolver: resolver,
resolved: resolver.ResolveIP(addr),
failed: make(map[string]time.Time),
stopCh: make(chan struct{}),
log: logger,
}
go b.updateProc()
return b
}
func (b *backendTracker) Close() {
close(b.stopCh)
}
// Return the full list of targets in reverse preference order.
func (b *backendTracker) getTargets() []string {
b.mx.Lock()
defer b.mx.Unlock()
var good, bad []string
for _, t := range b.resolved {
if _, ok := b.failed[t]; ok {
bad = append(bad, t)
} else {
good = append(good, t)
}
}
good = shuffle(good)
bad = shuffle(bad)
return append(good, bad...)
}
func (b *backendTracker) setStatus(addr string, ok bool) {
b.mx.Lock()
_, isFailed := b.failed[addr]
if isFailed && ok {
b.log.Printf("target %s now ok", addr)
delete(b.failed, addr)
} else if !isFailed && !ok {
b.log.Printf("target %s failed", addr)
b.failed[addr] = time.Now()
}
b.mx.Unlock()
}
var (
backendUpdateInterval = 60 * time.Second
backendFailureRetryInterval = 60 * time.Second
)
func (b *backendTracker) expireFailedTargets() {
b.mx.Lock()
now := time.Now()
for k, v := range b.failed {
if now.Sub(v) > backendFailureRetryInterval {
delete(b.failed, k)
}
}
b.mx.Unlock()
}
func (b *backendTracker) updateProc() {
tick := time.NewTicker(backendUpdateInterval)
defer tick.Stop()
for {
select {
case <-b.stopCh:
return
case <-tick.C:
b.expireFailedTargets()
resolved := b.resolver.ResolveIP(b.addr)
if len(resolved) > 0 {
b.mx.Lock()
b.resolved = resolved
b.mx.Unlock()
}
}
}
}
var shuffleSrc = rand.NewSource(time.Now().UnixNano())
// Re-order elements of a slice randomly.
func shuffle(values []string) []string {
if len(values) < 2 {
return values
}
rnd := rand.New(shuffleSrc)
for i := len(values) - 1; i > 0; i-- {
j := rnd.Intn(i + 1)
values[i], values[j] = values[j], values[i]
}
return values
}
package clientutil
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"sync"
"time"
)
var errAllBackendsFailed = errors.New("all backends failed")
type dnsResolver struct{}
func (r *dnsResolver) ResolveIPs(hosts []string) []string {
var resolved []string
for _, hostport := range hosts {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
continue
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
continue
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
}
return resolved
}
var defaultResolver = &dnsResolver{}
"git.autistici.org/ai3/go-common/tracing"
)
type resolver interface {
ResolveIPs([]string) []string
}
var defaultConnectTimeout = 30 * time.Second
// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or they all fail.
// The transportCache is just a cache of http transports, each
// connecting to a specific address.
//
// This object should not be used for load balancing of individual
// HTTP requests: once a new connection is established, requests will
// be sent over it until it errors out. It's meant to provide a
// *reliable* connection to a set of equivalent backends for HA
// purposes.
type balancer struct {
hosts []string
resolver resolver
stop chan bool
// List of currently valid (or untested) backends, and ones
// that errored out at least once.
mx sync.Mutex
addrs []string
ok map[string]bool
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
tlsConfig *tls.Config
connectTimeout time.Duration
mx sync.RWMutex
transports map[string]http.RoundTripper
}
var backendUpdateInterval = 60 * time.Second
// Periodically update the list of available backends.
func (b *balancer) updateProc() {
tick := time.NewTicker(backendUpdateInterval)
for {
select {
case <-b.stop:
return
case <-tick.C:
resolved := b.resolver.ResolveIPs(b.hosts)
if len(resolved) > 0 {
b.mx.Lock()
b.addrs = resolved
b.mx.Unlock()
}
}
func newTransportCache(tlsConfig *tls.Config, connectTimeout time.Duration) *transportCache {
if connectTimeout == 0 {
connectTimeout = defaultConnectTimeout
}
}
// Returns a list of all available backends, split into "good ones"
// (no errors seen since last successful connection) and "bad ones".
func (b *balancer) getBackends() ([]string, []string) {
b.mx.Lock()
defer b.mx.Unlock()
var good, bad []string
for _, addr := range b.addrs {
if ok := b.ok[addr]; ok {
good = append(good, addr)
} else {
bad = append(bad, addr)
}
return &transportCache{
tlsConfig: tlsConfig,
connectTimeout: connectTimeout,
transports: make(map[string]http.RoundTripper),
}
return good, bad
}
func (b *balancer) notify(addr string, ok bool) {
b.mx.Lock()
b.ok[addr] = ok
b.mx.Unlock()
}
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 30 * time.Second
// Go < 1.9 does not have net.DialContext, reimplement it in
// terms of net.DialTimeout.
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
func (m *transportCache) newTransport(addr string) http.RoundTripper {
return tracing.WrapTransport(&http.Transport{
TLSClientConfig: m.tlsConfig,
DialContext: netDialContext(addr, m.connectTimeout),
// Parameters match those of net/http.DefaultTransport.
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
})
}
func (b *balancer) dial(ctx context.Context, network, addr string) (net.Conn, error) {
// Start by attempting a connection on 'good' targets.
good, bad := b.getBackends()
for _, addr := range good {
// Go < 1.9 does not have DialContext, deal with it
conn, err := netDialContext(ctx, network, addr)
if err == nil {
return conn, nil
} else if err == context.Canceled {
// A timeout might be bad, set the error bit
// on the connection.
b.notify(addr, false)
return nil, err
}
b.notify(addr, false)
}
func (m *transportCache) getTransport(addr string) http.RoundTripper {
m.mx.RLock()
t, ok := m.transports[addr]
m.mx.RUnlock()
for _, addr := range bad {
conn, err := netDialContext(ctx, network, addr)
if err == nil {
b.notify(addr, true)
return conn, nil
} else if err == context.Canceled {
return nil, err
if !ok {
m.mx.Lock()
if t, ok = m.transports[addr]; !ok {
t = m.newTransport(addr)
m.transports[addr] = t
}
m.mx.Unlock()
}
return nil, errAllBackendsFailed
}
// NewTransport returns a suitably configured http.RoundTripper that
// talks to a specific backend service. It performs discovery of
// available backends via DNS (using A or AAAA record lookups), tries
// to route traffic away from faulty backends.
//
// It will periodically attempt to rediscover new backends.
func NewTransport(backends []string, tlsConf *tls.Config, resolver resolver) http.RoundTripper {
if resolver == nil {
resolver = defaultResolver
}
addrs := resolver.ResolveIPs(backends)
b := &balancer{
hosts: backends,
resolver: resolver,
addrs: addrs,
ok: make(map[string]bool),
}
go b.updateProc()
return &http.Transport{
DialContext: b.dial,
TLSClientConfig: tlsConf,
}
return t
}