Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/acmeserver
  • godog/acmeserver
  • svp-bot/acmeserver
3 results
Show changes
Commits on Source (90)
Showing
with 1744 additions and 146 deletions
include: "https://git.autistici.org/ai3/build-deb/raw/master/ci-common.yml"
include:
- "https://git.autistici.org/pipelines/debian/raw/master/common.yml"
- "https://git.autistici.org/pipelines/images/test/golang/raw/master/ci.yml"
......@@ -133,10 +133,16 @@ func (a *ACME) acmeClient(ctx context.Context) (*acme.Client, error) {
// account is already registered we get a StatusConflict,
// which we can ignore.
_, err = client.Register(ctx, ac, func(_ string) bool { return true })
if ae, ok := err.(*acme.Error); err == nil || ok && ae.StatusCode == http.StatusConflict {
if ae, ok := err.(*acme.Error); err == nil || err == acme.ErrAccountAlreadyExists || (ok && ae.StatusCode == http.StatusConflict) {
a.client = client
err = nil
}
// Fetch account info and display it.
if acct, err := client.GetReg(ctx, ""); err == nil {
log.Printf("ACME account %s", acct.URI)
}
return a.client, err
}
......@@ -149,7 +155,8 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
return nil, nil, err
}
if err = a.verifyAll(ctx, client, c); err != nil {
o, err := a.verifyAll(ctx, client, c)
if err != nil {
return nil, nil, err
}
......@@ -157,7 +164,7 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
if err != nil {
return nil, nil, err
}
der, _, err = client.CreateCert(ctx, csr, 0, true)
der, _, err = client.CreateOrderCert(ctx, o.FinalizeURL, csr, true)
if err != nil {
return nil, nil, err
}
......@@ -168,56 +175,67 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
return der, leaf, nil
}
func (a *ACME) verifyAll(ctx context.Context, client *acme.Client, c *certConfig) error {
for _, domain := range c.Names {
if err := a.verify(ctx, client, c, domain); err != nil {
return err
}
}
return nil
}
func (a *ACME) verify(ctx context.Context, client *acme.Client, c *certConfig, domain string) error {
func (a *ACME) verifyAll(ctx context.Context, client *acme.Client, c *certConfig) (*acme.Order, error) {
// Make an authorization request to the ACME server, and
// verify that it returns a valid response with challenges.
authz, err := client.Authorize(ctx, domain)
o, err := client.AuthorizeOrder(ctx, acme.DomainIDs(c.Names...))
if err != nil {
return err
}
switch authz.Status {
case acme.StatusValid:
return nil // already authorized
case acme.StatusInvalid:
return fmt.Errorf("invalid authorization %q", authz.URI)
return nil, fmt.Errorf("AuthorizeOrder failed: %v", err)
}
// Pick a challenge that matches our preferences and the
// available validators. The validator fulfills the challenge,
// and returns a cleanup function that we're going to call
// before we return. All steps are sequential and idempotent.
chal := a.pickChallenge(authz.Challenges, c)
if chal == nil {
return fmt.Errorf("unable to authorize %q", domain)
switch o.Status {
case acme.StatusReady:
return o, nil // already authorized
case acme.StatusPending:
default:
return nil, fmt.Errorf("invalid new order status %q", o.Status)
}
v, ok := a.validators[chal.Type]
if !ok {
return fmt.Errorf("challenge type '%s' is not available", chal.Type)
}
cleanup, err := v.Fulfill(ctx, client, domain, chal)
if err != nil {
return err
for _, zurl := range o.AuthzURLs {
z, err := client.GetAuthorization(ctx, zurl)
if err != nil {
return nil, fmt.Errorf("GetAuthorization(%s) failed: %v", zurl, err)
}
if z.Status != acme.StatusPending {
continue
}
// Pick a challenge that matches our preferences and the
// available validators. The validator fulfills the challenge,
// and returns a cleanup function that we're going to call
// before we return. All steps are sequential and idempotent.
chal := a.pickChallenge(z.Challenges, c)
if chal == nil {
return nil, fmt.Errorf("unable to authorize %q", c.Names)
}
v, ok := a.validators[chal.Type]
if !ok {
return nil, fmt.Errorf("challenge type '%s' is not available", chal.Type)
}
log.Printf("attempting fulfillment for %q (identifier: %+v)", c.Names, z.Identifier)
for _, domain := range c.Names {
cleanup, err := v.Fulfill(ctx, client, domain, chal)
if err != nil {
return nil, fmt.Errorf("fulfillment failed: %v", err)
}
defer cleanup()
}
if _, err := client.Accept(ctx, chal); err != nil {
return nil, fmt.Errorf("challenge accept failed: %v", err)
}
if _, err := client.WaitAuthorization(ctx, z.URI); err != nil {
return nil, fmt.Errorf("WaitAuthorization(%s) failed: %v", z.URI, err)
}
}
defer cleanup()
// Tell the ACME server that we've accepted the challenge, and
// then wait, possibly for some time, until there is an
// authorization response (either successful or not) from the
// server.
if _, err = client.Accept(ctx, chal); err != nil {
return err
// Authorizations are satisfied, wait for the CA
// to update the order status.
if _, err = client.WaitOrder(ctx, o.URI); err != nil {
return nil, err
}
_, err = client.WaitAuthorization(ctx, authz.URI)
return err
return o, nil
}
// Pick a challenge with the right type from the Challenge response
......
......@@ -3,16 +3,15 @@ package main
import (
"context"
"flag"
"io/ioutil"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"git.autistici.org/ai3/tools/acmeserver"
"git.autistici.org/ai3/go-common/serverutil"
"gopkg.in/yaml.v2"
"git.autistici.org/ai3/tools/acmeserver"
"gopkg.in/yaml.v3"
)
var (
......@@ -29,12 +28,13 @@ type Config struct {
func loadConfig(path string) (*Config, error) {
// Read YAML config.
data, err := ioutil.ReadFile(path) // nolint: gosec
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
if err := yaml.NewDecoder(f).Decode(&config); err != nil {
return nil, err
}
return &config, nil
......
......@@ -3,11 +3,11 @@ package acmeserver
import (
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"gopkg.in/yaml.v2"
"gopkg.in/yaml.v3"
"git.autistici.org/ai3/go-common/clientutil"
)
......@@ -77,12 +77,13 @@ func (c *certConfig) check() error {
}
func readCertConfigs(path string) ([]*certConfig, error) {
data, err := ioutil.ReadFile(path) // nolint: gosec
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var cc []*certConfig
if err := yaml.Unmarshal(data, &cc); err != nil {
if err := yaml.NewDecoder(f).Decode(&cc); err != nil {
return nil, err
}
return cc, nil
......
......@@ -2,4 +2,4 @@ acmeserver (2.0) unstable; urgency=medium
* Initial Release.
-- Autistici/Inventati <debian@autistici.org> Sat, 15 Jun 2018 09:23:40 +0000
-- Autistici/Inventati <debian@autistici.org> Fri, 15 Jun 2018 09:23:40 +0000
10
......@@ -2,12 +2,12 @@ Source: acmeserver
Section: admin
Priority: optional
Maintainer: Autistici/Inventati <debian@autistici.org>
Build-Depends: debhelper (>=9), golang-any (>=1.11), dh-systemd, dh-golang
Build-Depends: debhelper-compat (= 13), golang-any (>= 1.11), dh-golang
Standards-Version: 3.9.6
Package: acmeserver
Architecture: any
Depends: ${shlibs:Depends}, ${misc:Depends}
Built-Using: ${misc:Built-Using}
Description: ACME server
Automatically manages and renews public SSL certificates.
......@@ -2,18 +2,17 @@
export DH_GOPKG = git.autistici.org/ai3/tools/acmeserver
export DH_GOLANG_EXCLUDES = vendor
export DH_GOLANG_INSTALL_ALL := 1
%:
dh $@ --with systemd --with golang --buildsystem golang
dh $@ --with golang --buildsystem golang
override_dh_auto_install:
dh_auto_install
$(RM) -rv debian/acmeserver/usr/share/gocode
dh_auto_install -- --no-source
override_dh_systemd_enable:
dh_systemd_enable --no-enable
override_dh_systemd_start:
dh_systemd_start --no-start
override_dh_installsystemd:
dh_installsystemd --no-enable
override_dh_installsystemd:
dh_installsystemd --no-start
......@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns"
"golang.org/x/crypto/acme"
"golang.org/x/net/publicsuffix"
)
const (
......@@ -87,8 +88,16 @@ func (d *dnsValidator) client() *dns.Client {
}
func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) {
zone := domain[strings.Index(domain, ".")+1:]
fqdn := dns.Fqdn(domain)
domain = strings.TrimPrefix(domain, "*.")
zone, err := publicsuffix.EffectiveTLDPlusOne(domain)
if err != nil {
return nil, fmt.Errorf("could not determine effective tld: %w", err)
}
zone = dns.Fqdn(zone)
fqdn := dns.Fqdn("_acme-challenge." + domain)
value, err := client.DNS01ChallengeRecord(chal.Token)
if err != nil {
return nil, err
......@@ -109,6 +118,7 @@ func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain
}
func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error {
log.Printf("updateNS(%s, %s, %s, %s, %v)", ns, zone, fqdn, value, remove)
rrs := d.makeRR(fqdn, value, rfc2136Timeout)
m := d.makeMsg(zone, rrs, remove)
c := d.client()
......
module git.autistici.org/ai3/tools/acmeserver
go 1.19
require (
git.autistici.org/ai3/go-common v0.0.0-20230816213645-b3aa3fb514d6
git.autistici.org/ai3/tools/replds v0.0.0-20230923170339-b6e6e3cc032b
github.com/miekg/dns v1.1.50
github.com/prometheus/client_golang v1.12.2
golang.org/x/crypto v0.24.0
golang.org/x/net v0.26.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/NYTimes/gziphandler v1.1.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/openzipkin/zipkin-go v0.4.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.34.0 // indirect
go.opentelemetry.io/contrib/propagators/b3 v1.9.0 // indirect
go.opentelemetry.io/otel v1.10.0 // indirect
go.opentelemetry.io/otel/exporters/zipkin v1.9.0 // indirect
go.opentelemetry.io/otel/metric v0.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.10.0 // indirect
go.opentelemetry.io/otel/trace v1.10.0 // indirect
golang.org/x/mod v0.4.2 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.27.1 // indirect
)
This diff is collapsed.
......@@ -16,10 +16,15 @@ import (
"sync"
"time"
"git.autistici.org/ai3/replds"
"git.autistici.org/ai3/tools/replds"
"github.com/prometheus/client_golang/prometheus"
)
var (
defaultRenewalDays = 21
updateInterval = 1 * time.Minute
)
// certInfo represents what we know about the state of the certificate
// at runtime.
type certInfo struct {
......@@ -71,18 +76,19 @@ type CertGenerator interface {
GetCertificate(context.Context, crypto.Signer, *certConfig) ([][]byte, *x509.Certificate, error)
}
// Manager periodically renews certificates before they expire, and
// responds to http-01 validation requests.
// Manager periodically renews certificates before they expire.
type Manager struct {
configDirs []string
useRSA bool
storage certStorage
certs []*certInfo
certGen CertGenerator
renewalDays int
configCh chan []*certConfig
doneCh chan bool
mx sync.Mutex
certs []*certInfo
}
// NewManager creates a new Manager with the given configuration.
......@@ -94,6 +100,9 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
if config.Output.Path == "" {
return nil, errors.New("'output.path' is unset")
}
if config.RenewalDays <= 0 {
config.RenewalDays = defaultRenewalDays
}
m := &Manager{
useRSA: config.UseRSA,
......@@ -103,9 +112,6 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
certGen: certGen,
renewalDays: config.RenewalDays,
}
if m.renewalDays <= 0 {
m.renewalDays = 15
}
ds := &dirStorage{root: config.Output.Path}
if config.Output.ReplDS == nil {
......@@ -158,7 +164,7 @@ func (m *Manager) Reload() {
var (
renewalTimeout = 10 * time.Minute
errorRetryTimeout = 10 * time.Minute
errorRetryTimeout = 6 * time.Hour
)
func (m *Manager) updateAllCerts(ctx context.Context, certs []*certInfo) {
......@@ -241,55 +247,41 @@ func (m *Manager) loadConfig(certs []*certConfig) []*certInfo {
return out
}
func (m *Manager) getCerts() []*certInfo {
m.mx.Lock()
defer m.mx.Unlock()
return m.certs
}
func (m *Manager) setCerts(certs []*certInfo) {
m.mx.Lock()
m.certs = certs
m.mx.Unlock()
}
// This channel is used by the testing code to trigger an update,
// without having to wait for the timer to tick.
var testUpdateCh = make(chan bool)
func (m *Manager) loop(ctx context.Context) {
// Updates are long-term jobs, so they should be
// interruptible. We run updates in a separate goroutine, and
// cancel them when the configuration is reloaded or on exit.
var upCancel context.CancelFunc
var wg sync.WaitGroup
startUpdate := func(certs []*certInfo) context.CancelFunc {
// Ensure the previous update has finished.
wg.Wait()
upCtx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
m.updateAllCerts(upCtx, certs)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
reloadCh := make(chan interface{}, 1)
go func() {
for config := range m.configCh {
certs := m.loadConfig(config)
m.setCerts(certs)
reloadCh <- certs
}
wg.Wait()
}
defer cancelUpdate()
}()
tick := time.NewTicker(5 * time.Minute)
defer tick.Stop()
for {
select {
case <-tick.C:
upCancel = startUpdate(m.certs)
case <-testUpdateCh:
upCancel = startUpdate(m.certs)
case certDomains := <-m.configCh:
cancelUpdate()
m.certs = m.loadConfig(certDomains)
case <-ctx.Done():
return
}
}
runWithUpdates(
ctx,
func(ctx context.Context, value interface{}) {
certs := value.([]*certInfo)
m.updateAllCerts(ctx, certs)
},
reloadCh,
updateInterval,
)
}
func concatDER(der [][]byte) []byte {
......@@ -308,10 +300,8 @@ func concatDER(der [][]byte) []byte {
func certRequest(key crypto.Signer, domains []string) ([]byte, error) {
req := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: domains[0]},
}
if len(domains) > 1 {
req.DNSNames = domains[1:]
Subject: pkix.Name{CommonName: domains[0]},
DNSNames: domains,
}
return x509.CreateCertificateRequest(rand.Reader, req, key)
}
......
......@@ -77,36 +77,44 @@ func TestManager_Reload(t *testing.T) {
defer cleanup()
// Data race: we read data owned by another goroutine!
if len(m.certs) < 1 {
certs := m.getCerts()
if len(certs) < 1 {
t.Fatal("configuration not loaded?")
}
if m.certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn())
if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
}
// Try a reload, catch obvious errors.
m.Reload()
time.Sleep(50 * time.Millisecond)
certs = m.getCerts()
if len(m.certs) != 1 {
if len(certs) != 1 {
t.Fatalf("certs count is %d, expected 1", len(m.certs))
}
if m.certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn())
if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
}
}
func TestManager_NewCert(t *testing.T) {
var oldUpdateInterval time.Duration
oldUpdateInterval, updateInterval = updateInterval, 50*time.Millisecond
defer func() {
updateInterval = oldUpdateInterval
}()
cleanup, _, m := newTestManager(t, NewSelfSignedCertGenerator())
defer cleanup()
now := time.Now()
ci := m.certs[0]
certs := m.getCerts()
ci := certs[0]
if ci.retryDeadline.After(now) {
t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline)
}
testUpdateCh <- true
time.Sleep(100 * time.Millisecond)
// Verify that the retry/renewal timestamp is in the future.
......@@ -135,7 +143,7 @@ func TestManager_NewCert(t *testing.T) {
m.Reload()
time.Sleep(50 * time.Millisecond)
ci = m.certs[0]
ci = m.getCerts()[0]
if !ci.valid {
t.Fatal("certificate is invalid after a reload")
}
......
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:base"
]
}
......@@ -16,7 +16,7 @@ import (
"strings"
"time"
"git.autistici.org/ai3/replds"
"git.autistici.org/ai3/tools/replds"
)
type dirStorage struct {
......@@ -79,6 +79,14 @@ func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]b
}
m[filepath.Join(cn, "cert.pem")] = data
if len(der) > 1 {
data, err = encodeCerts(der[1:])
if err != nil {
return nil, err
}
m[filepath.Join(cn, "chain.pem")] = data
}
data, err = encodePrivateKey(key)
if err != nil {
return nil, err
......@@ -104,7 +112,7 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error
now := time.Now()
var req replds.SetNodesRequest
for path, data := range filemap {
req.Nodes = append(req.Nodes, replds.Node{
req.Nodes = append(req.Nodes, &replds.Node{
Path: path,
Value: data,
Timestamp: now,
......
package acmeserver
import (
"context"
"sync"
"time"
)
// Updates are long-term jobs, so they should be interruptible. We run
// updates in a separate goroutine, and cancel them when the
// configuration is reloaded or on exit. A semaphore ensures that only
// one update goroutine will be running at any given time (without
// other ones piling up).
func runWithUpdates(ctx context.Context, fn func(context.Context, interface{}), reloadCh <-chan interface{}, updateInterval time.Duration) {
// Function to cancel the current update, and the associated
// WaitGroup to wait for its termination.
var upCancel context.CancelFunc
var wg sync.WaitGroup
sem := make(chan struct{}, 1)
startUpdate := func(value interface{}) context.CancelFunc {
// Acquire the semaphore, return if we fail to.
// Equivalent to a 'try-lock' construct.
select {
case sem <- struct{}{}:
default:
return nil
}
defer func() {
<-sem
}()
ctx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
fn(ctx, value)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
upCancel = nil
}
wg.Wait()
}
defer cancelUpdate()
var cur interface{}
tick := time.NewTicker(updateInterval)
defer tick.Stop()
for {
select {
case <-tick.C:
// Do not cancel running update when running the ticker.
if cancel := startUpdate(cur); cancel != nil {
upCancel = cancel
}
case value := <-reloadCh:
// Cancel the running update when configuration is reloaded.
cancelUpdate()
cur = value
case <-ctx.Done():
return
}
}
}
stages:
- test
run_tests:
stage: test
image: registry.git.autistici.org/ai3/docker/test/golang:master
script:
- run-go-test ./...
artifacts:
when: always
reports:
coverage_report:
coverage_format: cobertura
path: cover.xml
junit: report.xml
......@@ -2,6 +2,7 @@ package clientutil
import (
"context"
"net/http"
)
// BackendConfig specifies the configuration of a service backend.
......@@ -16,6 +17,13 @@ type BackendConfig struct {
TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
}
// Backend is a runtime class that provides http Clients for use with
......@@ -32,6 +40,13 @@ type Backend interface {
// *without* a shard ID on a sharded service is an error.
Call(context.Context, string, string, interface{}, interface{}) error
// Make a simple HTTP GET request to the remote backend,
// without parsing the response as JSON.
//
// Useful for streaming large responses, where the JSON
// encoding overhead is undesirable.
Get(context.Context, string, string) (*http.Response, error)
// Close all resources associated with the backend.
Close()
}
......
......@@ -9,6 +9,7 @@ import (
"fmt"
"log"
"math/rand"
"mime"
"net/http"
"net/url"
"os"
......@@ -16,7 +17,7 @@ import (
"strings"
"time"
"github.com/cenkalti/backoff"
"github.com/cenkalti/backoff/v4"
)
// Our own narrow logger interface.
......@@ -60,10 +61,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff {
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
baseURI *url.URL
sharded bool
resolver resolver
log logger
requestMaxTimeout time.Duration
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
......@@ -80,17 +82,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
}
}
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig),
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig, connectTimeout),
requestMaxTimeout: reqTimeout,
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
......@@ -115,6 +136,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
......@@ -135,7 +159,7 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
defer httpResp.Body.Close() // nolint
// Decode the response, unless the 'resp' output is nil.
if httpResp.Header.Get("Content-Type") != "application/json" {
if ct, _, _ := mime.ParseMediaType(httpResp.Header.Get("Content-Type")); ct != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
......@@ -145,6 +169,44 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
}, backoff.WithContext(newExponentialBackOff(), ctx))
}
// Makes a generic HTTP GET request to the backend uri.
func (b *balancedBackend) Get(ctx context.Context, shard, path string) (*http.Response, error) {
// Create the target sequence for this call. If there are multiple
// targets, reduce the timeout on each individual call accordingly to
// accomodate eventual failover.
seq, err := b.makeSequence(shard)
if err != nil {
return nil, err
}
innerTimeout := 1 * time.Hour
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
req, err := http.NewRequest("GET", b.getURIForRequest(shard, path), nil)
if err != nil {
return nil, err
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
var resp *http.Response
err = backoff.Retry(func() error {
innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
defer cancel()
// When do() returns successfully, we already know that the
// response had an HTTP status of 200.
var rerr error
resp, rerr = b.do(innerCtx, seq, req)
return rerr
}, backoff.WithContext(newExponentialBackOff(), ctx))
return resp, err
}
// Initialize a new target sequence.
func (b *balancedBackend) makeSequence(shard string) (*sequence, error) {
var tg targetGenerator = b.backendTracker
......@@ -198,7 +260,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
client := &http.Client{
Transport: b.transportCache.getTransport(target),
}
resp, err = client.Do(req.WithContext(ctx))
resp, err = client.Do(propagateDeadline(ctx, req))
if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) {
......@@ -212,6 +274,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
return
}
const deadlineHeader = "X-RPC-Deadline"
// Propagate context deadline to the server using a HTTP header.
func propagateDeadline(ctx context.Context, req *http.Request) *http.Request {
req = req.WithContext(ctx)
if deadline, ok := ctx.Deadline(); ok {
req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10))
} else {
req.Header.Del(deadlineHeader)
}
return req
}
var errNoTargets = errors.New("no available backends")
type targetGenerator interface {
......
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}