Commit 989b5d38 authored by ale's avatar ale

Update clientutil.Backend interface

parent 5c61e9fc
......@@ -36,7 +36,7 @@ func (c *ksClient) Open(ctx context.Context, shard, username, password string, t
TTL: ttl,
}
var resp keystore.OpenResponse
return clientutil.DoJSONHTTPRequest(ctx, c.be.Client(shard), c.be.URL(shard)+"/api/open", &req, &resp)
return c.be.Call(ctx, shard, "/api/open", &req, &resp)
}
func (c *ksClient) Get(ctx context.Context, shard, username, ssoTicket string) ([]byte, error) {
......@@ -45,7 +45,7 @@ func (c *ksClient) Get(ctx context.Context, shard, username, ssoTicket string) (
SSOTicket: ssoTicket,
}
var resp keystore.GetResponse
err := clientutil.DoJSONHTTPRequest(ctx, c.be.Client(shard), c.be.URL(shard)+"/api/get_key", &req, &resp)
err := c.be.Call(ctx, shard, "/api/get_key", &req, &resp)
return resp.Key, err
}
......@@ -54,5 +54,5 @@ func (c *ksClient) Close(ctx context.Context, shard, username string) error {
Username: username,
}
var resp keystore.CloseResponse
return clientutil.DoJSONHTTPRequest(ctx, c.be.Client(shard), c.be.URL(shard)+"/api/close", &req, &resp)
return c.be.Call(ctx, shard, "/api/close", &req, &resp)
}
package clientutil
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"context"
)
// BackendConfig specifies the configuration to access a service.
// BackendConfig specifies the configuration of a service backend.
//
// Services with multiple backends can be replicated or partitioned,
// depending on a configuration switch, making it a deployment-time
......@@ -18,102 +13,30 @@ import (
// 'shard' parameter on their APIs.
type BackendConfig struct {
URL string `yaml:"url"`
Sharded bool `yaml:"sharded"`
TLSConfig *TLSClientConfig `yaml:"tls_config"`
Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"`
}
// Backend is a runtime class that provides http Clients for use with
// a specific service backend. If the service can't be partitioned,
// pass an empty string to the Client method.
// pass an empty string to the Call method.
type Backend interface {
// URL for the service for a specific shard.
URL(string) string
// Call a remote method. The sharding behavior is the following:
//
// Services that support sharding (partitioning) should always
// include the shard ID in their Call() requests. Users can
// then configure backends to be sharded or not in their
// Config. When invoking Call with a shard ID on a non-sharded
// service, the shard ID is simply ignored. Invoking Call
// *without* a shard ID on a sharded service is an error.
Call(context.Context, string, string, interface{}, interface{}) error
// Client that can be used to make a request to the service.
Client(string) *http.Client
// Close all resources associated with the backend.
Close()
}
// NewBackend returns a new Backend with the given config.
func NewBackend(config *BackendConfig) (Backend, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
if config.TLSConfig != nil {
tlsConfig, err = config.TLSConfig.TLSConfig()
if err != nil {
return nil, err
}
}
if config.Sharded {
return &replicatedClient{
u: u,
c: newHTTPClient(u, tlsConfig),
}, nil
}
return &shardedClient{
baseURL: u,
tlsConfig: tlsConfig,
urls: make(map[string]*url.URL),
shards: make(map[string]*http.Client),
}, nil
}
type replicatedClient struct {
c *http.Client
u *url.URL
}
func (r *replicatedClient) Client(_ string) *http.Client { return r.c }
func (r *replicatedClient) URL(_ string) string { return r.u.String() }
type shardedClient struct {
baseURL *url.URL
tlsConfig *tls.Config
mx sync.Mutex
urls map[string]*url.URL
shards map[string]*http.Client
}
func (s *shardedClient) getShardURL(shard string) *url.URL {
if shard == "" {
return s.baseURL
}
u, ok := s.urls[shard]
if !ok {
var tmp = *s.baseURL
tmp.Host = fmt.Sprintf("%s.%s", shard, tmp.Host)
u = &tmp
s.urls[shard] = u
}
return u
}
func (s *shardedClient) URL(shard string) string {
s.mx.Lock()
defer s.mx.Unlock()
return s.getShardURL(shard).String()
}
func (s *shardedClient) Client(shard string) *http.Client {
s.mx.Lock()
defer s.mx.Unlock()
client, ok := s.shards[shard]
if !ok {
u := s.getShardURL(shard)
client = newHTTPClient(u, s.tlsConfig)
s.shards[shard] = client
}
return client
}
func newHTTPClient(u *url.URL, tlsConfig *tls.Config) *http.Client {
return &http.Client{
Transport: NewTransport([]string{u.Host}, tlsConfig, nil),
Timeout: 30 * time.Second,
}
return newBalancedBackend(config, defaultResolver)
}
package clientutil
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log"
"math/rand"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/cenkalti/backoff"
)
// Our own narrow logger interface.
type logger interface {
Printf(string, ...interface{})
}
// A nilLogger is used when Config.Debug is false.
type nilLogger struct{}
func (l nilLogger) Printf(_ string, _ ...interface{}) {}
// Parameters that define the exponential backoff algorithm used.
var (
ExponentialBackOffInitialInterval = 100 * time.Millisecond
ExponentialBackOffMultiplier = 1.4142
)
// newExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func newExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = ExponentialBackOffInitialInterval
b.Multiplier = ExponentialBackOffMultiplier
// Set MaxElapsedTime to 0 because we expect the overall
// timeout to be dictated by the request Context.
b.MaxElapsedTime = 0
return b
}
// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or returns a permanent error.
//
// This object should not be used for load balancing of individual
// HTTP requests: it doesn't do anything smart beyond trying to avoid
// broken targets. It's meant to provide a *reliable* connection to a
// set of equivalent services for HA purposes.
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, err
}
var tlsConfig *tls.Config
if config.TLSConfig != nil {
tlsConfig, err = config.TLSConfig.TLSConfig()
if err != nil {
return nil, err
}
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig),
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
// Call the backend. Makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
var tg targetGenerator = b.backendTracker
if b.sharded && shard != "" {
tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver)
}
seq := newSequence(tg)
b.log.Printf("%016x: initialized", seq.ID())
var httpResp *http.Response
err = backoff.Retry(func() error {
req, rerr := b.newJSONRequest(path, shard, data)
if rerr != nil {
return rerr
}
httpResp, rerr = b.do(ctx, seq, req)
return rerr
}, backoff.WithContext(newExponentialBackOff(), ctx))
if err != nil {
return err
}
defer httpResp.Body.Close() // nolint
if httpResp.Header.Get("Content-Type") != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
return nil
}
return json.NewDecoder(httpResp.Body).Decode(resp)
}
// Return the URI to be used for the request. This is used both in the
// Host HTTP header and as the TLS server name used to pick a server
// certificate (if using TLS).
func (b *balancedBackend) getURIForRequest(shard, path string) string {
u := *b.baseURI
if b.sharded && shard != "" {
u.Host = fmt.Sprintf("%s.%s", shard, u.Host)
}
u.Path = appendPath(u.Path, path)
return u.String()
}
// Build a http.Request object.
func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10))
return req, nil
}
// Select a new target from the given sequence and send the request to
// it. Wrap HTTP errors in a RemoteError object.
func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) {
target, terr := seq.Next()
if terr != nil {
return
}
b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target)
client := &http.Client{
Transport: b.transportCache.getTransport(target),
}
resp, err = client.Do(req.WithContext(ctx))
if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) {
err = backoff.Permanent(err)
}
resp.Body.Close() // nolint
resp = nil
}
seq.Done(target, err)
return
}
var errNoTargets = errors.New("no available backends")
type targetGenerator interface {
getTargets() []string
setStatus(string, bool)
}
// A replicatedSequence repeatedly iterates over available backends in order of
// preference. Once in a while it refreshes its list of available
// targets.
type sequence struct {
id uint64
tg targetGenerator
targets []string
pos int
}
func newSequence(tg targetGenerator) *sequence {
return &sequence{
id: rand.Uint64(),
tg: tg,
targets: tg.getTargets(),
}
}
func (s *sequence) ID() uint64 { return s.id }
func (s *sequence) reloadTargets() {
targets := s.tg.getTargets()
if len(targets) > 0 {
s.targets = targets
s.pos = 0
}
}
// Next returns the next target.
func (s *sequence) Next() (t string, err error) {
if s.pos >= len(s.targets) {
s.reloadTargets()
if len(s.targets) == 0 {
err = errNoTargets
return
}
}
t = s.targets[s.pos]
s.pos++
return
}
func (s *sequence) Done(t string, err error) {
s.tg.setStatus(t, err == nil)
}
// A shardedGenerator returns a single sharded target to a sequence.
type shardedGenerator struct {
id uint64
addrs []string
}
func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator {
return &shardedGenerator{
id: rand.Uint64(),
addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)),
}
}
func (g *shardedGenerator) getTargets() []string { return g.addrs }
func (g *shardedGenerator) setStatus(_ string, _ bool) {}
// Concatenate two URI paths.
func appendPath(a, b string) string {
if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") {
return a + b[1:]
}
return a + b
}
// Some HTTP status codes are treated are temporary errors.
func isStatusTemporary(code int) bool {
switch code {
case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
package clientutil
import (
"log"
"net"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type resolver interface {
ResolveIP(string) []string
}
type dnsResolver struct{}
func (r *dnsResolver) ResolveIP(hostport string) []string {
var resolved []string
host, port, err := net.SplitHostPort(hostport)
if err != nil {
log.Printf("error parsing %s: %v", hostport, err)
return nil
}
hostIPs, err := net.LookupIP(host)
if err != nil {
log.Printf("error resolving %s: %v", host, err)
return nil
}
for _, ip := range hostIPs {
resolved = append(resolved, net.JoinHostPort(ip.String(), port))
}
return resolved
}
var defaultResolver = newDNSCache(&dnsResolver{})
type cacheDatum struct {
addrs []string
deadline time.Time
}
type dnsCache struct {
resolver resolver
sf singleflight.Group
mx sync.RWMutex
cache map[string]cacheDatum
}
func newDNSCache(resolver resolver) *dnsCache {
return &dnsCache{
resolver: resolver,
cache: make(map[string]cacheDatum),
}
}
func (c *dnsCache) get(host string) ([]string, bool) {
d, ok := c.cache[host]
if !ok {
return nil, false
}
return d.addrs, d.deadline.After(time.Now())
}
func (c *dnsCache) update(host string) []string {
v, _, _ := c.sf.Do(host, func() (interface{}, error) {
addrs := c.resolver.ResolveIP(host)
// By uncommenting this, we stop caching negative results.
// if len(addrs) == 0 {
// return nil, nil
// }
c.mx.Lock()
c.cache[host] = cacheDatum{
addrs: addrs,
deadline: time.Now().Add(60 * time.Second),
}
c.mx.Unlock()
return addrs, nil
})
return v.([]string)
}
func (c *dnsCache) ResolveIP(host string) []string {
c.mx.RLock()
addrs, ok := c.get(host)
c.mx.RUnlock()
if ok {
return addrs
}
if len(addrs) > 0 {
go c.update(host)
return addrs
}
return c.update(host)
}
// Package clientutil implements a very simple style of JSON RPC.
//
// Requests and responses are both encoded in JSON, and they should
// have the "application/json" Content-Type.
//
// HTTP response statuses other than 200 indicate an error: in this
// case, the response body may contain (in plain text) further details
// about the error. Some HTTP status codes are considered temporary
// errors (incl. 429 for throttling). The client will retry requests,
// if targets are available, until the context expires - so it's quite
// important to remember to set a timeout on the context given to the
// Call() function!
//
// The client handles both replicated services and sharded
// (partitioned) services. Users of this package that want to support
// sharded deployments are supposed to pass a shard ID to every
// Call(). At the deployment stage, sharding can be enabled via the
// configuration.
//
// For replicated services, the client will expect the provided
// hostname to resolve to one or more IP addresses, in which case it
// will pick a random IP address on every new request, while
// remembering which addresses have had errors and trying to avoid
// them. It will however send an occasional request to the failed
// targets, to see if they've come back.
//
// For sharded services, the client makes simple HTTP requests to the
// specific target identified by the shard. It does this by prepending
// the shard ID to the backend hostname (so a request to "example.com"
// with shard ID "1" becomes a request to "1.example.com").
//
// The difference with other JSON-RPC implementations is that we use a
// different URI for every method, and we force the usage of
// request/response types. This makes it easy for projects to
// eventually migrate to GRPC.
//
package clientutil
package clientutil
import (
"fmt"
"io/ioutil"
"net/http"
)
// RemoteError represents a HTTP error from the server. The status
// code and response body can be retrieved with the StatusCode() and
// Body() methods.
type RemoteError struct {
statusCode int
body string
}
func remoteErrorFromResponse(resp *http.Response) *RemoteError {
// Optimistically read the response body, ignoring errors.
var body string
if data, err := ioutil.ReadAll(resp.Body); err == nil {
body = string(data)
}
return &RemoteError{statusCode: resp.StatusCode, body: body}
}
// Error implements the error interface.
func (e *RemoteError) Error() string {
return fmt.Sprintf("%d - %s", e.statusCode, e.body)
}
// StatusCode returns the HTTP status code.
func (e *RemoteError) StatusCode() int { return e.statusCode }
// Body returns the response body.
func (e *RemoteError) Body() string { return e.body }
package clientutil
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)
// DoJSONHTTPRequest makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func DoJSONHTTPRequest(ctx context.Context, client *http.Client, uri string, req, resp interface{}) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
httpReq, err := http.NewRequest("POST", uri, bytes.NewReader(data))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq = httpReq.WithContext(ctx)
httpResp, err := RetryHTTPDo(client, httpReq, NewExponentialBackOff())
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != 200 {
return fmt.Errorf("HTTP status %d", httpResp.StatusCode)
}
if httpResp.Header.Get("Content-Type") != "application/json" {
return errors.New("not a JSON response")
}
if resp == nil {
return nil
}
return json.NewDecoder(httpResp.Body).Decode(resp)
}
package clientutil
import (
"errors"
"net/http"
"time"
"github.com/cenkalti/backoff"
)
// NewExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func NewExponentialBackOff() *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 100 * time.Millisecond
//b.Multiplier = 1.4142
return b
}