summaryrefslogtreecommitdiff
path: root/vendor/k8s.io/client-go/transport
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/k8s.io/client-go/transport')
-rw-r--r--vendor/k8s.io/client-go/transport/OWNERS2
-rw-r--r--vendor/k8s.io/client-go/transport/cache.go8
-rw-r--r--vendor/k8s.io/client-go/transport/config.go33
-rw-r--r--vendor/k8s.io/client-go/transport/round_trippers.go82
-rw-r--r--vendor/k8s.io/client-go/transport/spdy/spdy.go2
-rw-r--r--vendor/k8s.io/client-go/transport/token_source.go149
-rw-r--r--vendor/k8s.io/client-go/transport/transport.go92
7 files changed, 332 insertions, 36 deletions
diff --git a/vendor/k8s.io/client-go/transport/OWNERS b/vendor/k8s.io/client-go/transport/OWNERS
index bf0ba5b9f..a52176903 100644
--- a/vendor/k8s.io/client-go/transport/OWNERS
+++ b/vendor/k8s.io/client-go/transport/OWNERS
@@ -1,3 +1,5 @@
+# See the OWNERS docs at https://go.k8s.io/owners
+
reviewers:
- smarterclayton
- wojtek-t
diff --git a/vendor/k8s.io/client-go/transport/cache.go b/vendor/k8s.io/client-go/transport/cache.go
index 83291c575..7cffe2a5f 100644
--- a/vendor/k8s.io/client-go/transport/cache.go
+++ b/vendor/k8s.io/client-go/transport/cache.go
@@ -43,6 +43,7 @@ type tlsCacheKey struct {
caData string
certData string
keyData string
+ getCert string
serverName string
dial string
}
@@ -52,7 +53,7 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
- return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, dial:%s", t.insecure, t.caData, t.certData, keyText, t.serverName, t.dial)
+ return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@@ -85,7 +86,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
- }).Dial
+ }).DialContext
}
// Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@@ -93,7 +94,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost,
- Dial: dial,
+ DialContext: dial,
})
return c.transports[key], nil
}
@@ -109,6 +110,7 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) {
caData: string(c.TLS.CAData),
certData: string(c.TLS.CertData),
keyData: string(c.TLS.KeyData),
+ getCert: fmt.Sprintf("%p", c.TLS.GetCert),
serverName: c.TLS.ServerName,
dial: fmt.Sprintf("%p", c.Dial),
}, nil
diff --git a/vendor/k8s.io/client-go/transport/config.go b/vendor/k8s.io/client-go/transport/config.go
index af347dafe..5de0a2cb1 100644
--- a/vendor/k8s.io/client-go/transport/config.go
+++ b/vendor/k8s.io/client-go/transport/config.go
@@ -17,6 +17,8 @@ limitations under the License.
package transport
import (
+ "context"
+ "crypto/tls"
"net"
"net/http"
)
@@ -37,6 +39,11 @@ type Config struct {
// Bearer token for authentication
BearerToken string
+ // Path to a file containing a BearerToken.
+ // If set, the contents are periodically read.
+ // The last successfully read value takes precedence over BearerToken.
+ BearerTokenFile string
+
// Impersonate is the config that this Config will impersonate using
Impersonate ImpersonationConfig
@@ -50,10 +57,13 @@ type Config struct {
// from TLSClientConfig, Transport, or http.DefaultTransport). The
// config may layer other RoundTrippers on top of the returned
// RoundTripper.
- WrapTransport func(rt http.RoundTripper) http.RoundTripper
+ //
+ // A future release will change this field to an array. Use config.Wrap()
+ // instead of setting this value directly.
+ WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
- Dial func(network, addr string) (net.Conn, error)
+ Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options
@@ -78,12 +88,25 @@ func (c *Config) HasBasicAuth() bool {
// HasTokenAuth returns whether the configuration has token authentication or not.
func (c *Config) HasTokenAuth() bool {
- return len(c.BearerToken) != 0
+ return len(c.BearerToken) != 0 || len(c.BearerTokenFile) != 0
}
// HasCertAuth returns whether the configuration has certificate authentication or not.
func (c *Config) HasCertAuth() bool {
- return len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0
+ return (len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0) && (len(c.TLS.KeyData) != 0 || len(c.TLS.KeyFile) != 0)
+}
+
+// HasCertCallbacks returns whether the configuration has certificate callback or not.
+func (c *Config) HasCertCallback() bool {
+ return c.TLS.GetCert != nil
+}
+
+// Wrap adds a transport middleware function that will give the caller
+// an opportunity to wrap the underlying http.RoundTripper prior to the
+// first API call being made. The provided function is invoked after any
+// existing transport wrappers are invoked.
+func (c *Config) Wrap(fn WrapperFunc) {
+ c.WrapTransport = Wrappers(c.WrapTransport, fn)
}
// TLSConfig holds the information needed to set up a TLS transport.
@@ -98,4 +121,6 @@ type TLSConfig struct {
CAData []byte // Bytes of the PEM-encoded server trusted root certificates. Supercedes CAFile.
CertData []byte // Bytes of the PEM-encoded client certificate. Supercedes CertFile.
KeyData []byte // Bytes of the PEM-encoded client key. Supercedes KeyFile.
+
+ GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
}
diff --git a/vendor/k8s.io/client-go/transport/round_trippers.go b/vendor/k8s.io/client-go/transport/round_trippers.go
index de64e0078..117a9c8c4 100644
--- a/vendor/k8s.io/client-go/transport/round_trippers.go
+++ b/vendor/k8s.io/client-go/transport/round_trippers.go
@@ -17,13 +17,13 @@ limitations under the License.
package transport
import (
- "bytes"
"fmt"
"net/http"
"strings"
"time"
- "github.com/golang/glog"
+ "golang.org/x/oauth2"
+ "k8s.io/klog"
utilnet "k8s.io/apimachinery/pkg/util/net"
)
@@ -45,7 +45,11 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip
case config.HasBasicAuth() && config.HasTokenAuth():
return nil, fmt.Errorf("username/password or bearer token may be set, but not both")
case config.HasTokenAuth():
- rt = NewBearerAuthRoundTripper(config.BearerToken, rt)
+ var err error
+ rt, err = NewBearerAuthWithRefreshRoundTripper(config.BearerToken, config.BearerTokenFile, rt)
+ if err != nil {
+ return nil, err
+ }
case config.HasBasicAuth():
rt = NewBasicAuthRoundTripper(config.Username, config.Password, rt)
}
@@ -63,13 +67,13 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip
// DebugWrappers wraps a round tripper and logs based on the current log level.
func DebugWrappers(rt http.RoundTripper) http.RoundTripper {
switch {
- case bool(glog.V(9)):
+ case bool(klog.V(9)):
rt = newDebuggingRoundTripper(rt, debugCurlCommand, debugURLTiming, debugResponseHeaders)
- case bool(glog.V(8)):
+ case bool(klog.V(8)):
rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus, debugResponseHeaders)
- case bool(glog.V(7)):
+ case bool(klog.V(7)):
rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus)
- case bool(glog.V(6)):
+ case bool(klog.V(6)):
rt = newDebuggingRoundTripper(rt, debugURLTiming)
}
@@ -139,7 +143,7 @@ func (rt *authProxyRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@@ -167,7 +171,7 @@ func (rt *userAgentRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@@ -198,7 +202,7 @@ func (rt *basicAuthRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@@ -258,7 +262,7 @@ func (rt *impersonatingRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.delegate.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.delegate)
}
}
@@ -266,13 +270,35 @@ func (rt *impersonatingRoundTripper) WrappedRoundTripper() http.RoundTripper { r
type bearerAuthRoundTripper struct {
bearer string
+ source oauth2.TokenSource
rt http.RoundTripper
}
// NewBearerAuthRoundTripper adds the provided bearer token to a request
// unless the authorization header has already been set.
func NewBearerAuthRoundTripper(bearer string, rt http.RoundTripper) http.RoundTripper {
- return &bearerAuthRoundTripper{bearer, rt}
+ return &bearerAuthRoundTripper{bearer, nil, rt}
+}
+
+// NewBearerAuthRoundTripper adds the provided bearer token to a request
+// unless the authorization header has already been set.
+// If tokenFile is non-empty, it is periodically read,
+// and the last successfully read content is used as the bearer token.
+// If tokenFile is non-empty and bearer is empty, the tokenFile is read
+// immediately to populate the initial bearer token.
+func NewBearerAuthWithRefreshRoundTripper(bearer string, tokenFile string, rt http.RoundTripper) (http.RoundTripper, error) {
+ if len(tokenFile) == 0 {
+ return &bearerAuthRoundTripper{bearer, nil, rt}, nil
+ }
+ source := NewCachedFileTokenSource(tokenFile)
+ if len(bearer) == 0 {
+ token, err := source.Token()
+ if err != nil {
+ return nil, err
+ }
+ bearer = token.AccessToken
+ }
+ return &bearerAuthRoundTripper{bearer, source, rt}, nil
}
func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -281,7 +307,13 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response,
}
req = utilnet.CloneRequest(req)
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", rt.bearer))
+ token := rt.bearer
+ if rt.source != nil {
+ if refreshedToken, err := rt.source.Token(); err == nil {
+ token = refreshedToken.AccessToken
+ }
+ }
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return rt.rt.RoundTrip(req)
}
@@ -289,7 +321,7 @@ func (rt *bearerAuthRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.rt.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.rt)
}
}
@@ -336,7 +368,7 @@ func (r *requestInfo) toCurl() string {
}
}
- return fmt.Sprintf("curl -k -v -X%s %s %s", r.RequestVerb, headers, r.RequestURL)
+ return fmt.Sprintf("curl -k -v -X%s %s '%s'", r.RequestVerb, headers, r.RequestURL)
}
// debuggingRoundTripper will display information about the requests passing
@@ -373,7 +405,7 @@ func (rt *debuggingRoundTripper) CancelRequest(req *http.Request) {
if canceler, ok := rt.delegatedRoundTripper.(requestCanceler); ok {
canceler.CancelRequest(req)
} else {
- glog.Errorf("CancelRequest not implemented")
+ klog.Errorf("CancelRequest not implemented by %T", rt.delegatedRoundTripper)
}
}
@@ -381,17 +413,17 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
reqInfo := newRequestInfo(req)
if rt.levels[debugJustURL] {
- glog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL)
+ klog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL)
}
if rt.levels[debugCurlCommand] {
- glog.Infof("%s", reqInfo.toCurl())
+ klog.Infof("%s", reqInfo.toCurl())
}
if rt.levels[debugRequestHeaders] {
- glog.Infof("Request Headers:")
+ klog.Infof("Request Headers:")
for key, values := range reqInfo.RequestHeaders {
for _, value := range values {
- glog.Infof(" %s: %s", key, value)
+ klog.Infof(" %s: %s", key, value)
}
}
}
@@ -403,16 +435,16 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
reqInfo.complete(response, err)
if rt.levels[debugURLTiming] {
- glog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
+ klog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
}
if rt.levels[debugResponseStatus] {
- glog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
+ klog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond))
}
if rt.levels[debugResponseHeaders] {
- glog.Infof("Response Headers:")
+ klog.Infof("Response Headers:")
for key, values := range reqInfo.ResponseHeaders {
for _, value := range values {
- glog.Infof(" %s: %s", key, value)
+ klog.Infof(" %s: %s", key, value)
}
}
}
@@ -435,7 +467,7 @@ func shouldEscape(b byte) bool {
}
func headerKeyEscape(key string) string {
- var buf bytes.Buffer
+ buf := strings.Builder{}
for i := 0; i < len(key); i++ {
b := key[i]
if shouldEscape(b) {
diff --git a/vendor/k8s.io/client-go/transport/spdy/spdy.go b/vendor/k8s.io/client-go/transport/spdy/spdy.go
index e0eb468ba..53cc7ee18 100644
--- a/vendor/k8s.io/client-go/transport/spdy/spdy.go
+++ b/vendor/k8s.io/client-go/transport/spdy/spdy.go
@@ -38,7 +38,7 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er
if err != nil {
return nil, nil, err
}
- upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true)
+ upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true, false)
wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper)
if err != nil {
return nil, nil, err
diff --git a/vendor/k8s.io/client-go/transport/token_source.go b/vendor/k8s.io/client-go/transport/token_source.go
new file mode 100644
index 000000000..b8cadd382
--- /dev/null
+++ b/vendor/k8s.io/client-go/transport/token_source.go
@@ -0,0 +1,149 @@
+/*
+Copyright 2018 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package transport
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/oauth2"
+ "k8s.io/klog"
+)
+
+// TokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
+// authentication from an oauth2.TokenSource.
+func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) http.RoundTripper {
+ return func(rt http.RoundTripper) http.RoundTripper {
+ return &tokenSourceTransport{
+ base: rt,
+ ort: &oauth2.Transport{
+ Source: ts,
+ Base: rt,
+ },
+ }
+ }
+}
+
+// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a
+// file at a specified path and periodically reloads it.
+func NewCachedFileTokenSource(path string) oauth2.TokenSource {
+ return &cachingTokenSource{
+ now: time.Now,
+ leeway: 10 * time.Second,
+ base: &fileTokenSource{
+ path: path,
+ // This period was picked because it is half of the duration between when the kubelet
+ // refreshes a projected service account token and when the original token expires.
+ // Default token lifetime is 10 minutes, and the kubelet starts refreshing at 80% of lifetime.
+ // This should induce re-reading at a frequency that works with the token volume source.
+ period: time.Minute,
+ },
+ }
+}
+
+// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a
+// designed TokenSource. The ts would provide the source of token.
+func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
+ return &cachingTokenSource{
+ now: time.Now,
+ base: ts,
+ }
+}
+
+type tokenSourceTransport struct {
+ base http.RoundTripper
+ ort http.RoundTripper
+}
+
+func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // This is to allow --token to override other bearer token providers.
+ if req.Header.Get("Authorization") != "" {
+ return tst.base.RoundTrip(req)
+ }
+ return tst.ort.RoundTrip(req)
+}
+
+type fileTokenSource struct {
+ path string
+ period time.Duration
+}
+
+var _ = oauth2.TokenSource(&fileTokenSource{})
+
+func (ts *fileTokenSource) Token() (*oauth2.Token, error) {
+ tokb, err := ioutil.ReadFile(ts.path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read token file %q: %v", ts.path, err)
+ }
+ tok := strings.TrimSpace(string(tokb))
+ if len(tok) == 0 {
+ return nil, fmt.Errorf("read empty token from file %q", ts.path)
+ }
+
+ return &oauth2.Token{
+ AccessToken: tok,
+ Expiry: time.Now().Add(ts.period),
+ }, nil
+}
+
+type cachingTokenSource struct {
+ base oauth2.TokenSource
+ leeway time.Duration
+
+ sync.RWMutex
+ tok *oauth2.Token
+
+ // for testing
+ now func() time.Time
+}
+
+var _ = oauth2.TokenSource(&cachingTokenSource{})
+
+func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
+ now := ts.now()
+ // fast path
+ ts.RLock()
+ tok := ts.tok
+ ts.RUnlock()
+
+ if tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
+ return tok, nil
+ }
+
+ // slow path
+ ts.Lock()
+ defer ts.Unlock()
+ if tok := ts.tok; tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
+ return tok, nil
+ }
+
+ tok, err := ts.base.Token()
+ if err != nil {
+ if ts.tok == nil {
+ return nil, err
+ }
+ klog.Errorf("Unable to rotate token: %v", err)
+ return ts.tok, nil
+ }
+
+ ts.tok = tok
+ return tok, nil
+}
diff --git a/vendor/k8s.io/client-go/transport/transport.go b/vendor/k8s.io/client-go/transport/transport.go
index c2bb7ae5e..2a145c971 100644
--- a/vendor/k8s.io/client-go/transport/transport.go
+++ b/vendor/k8s.io/client-go/transport/transport.go
@@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
+ "context"
"crypto/tls"
"crypto/x509"
"fmt"
@@ -28,7 +29,7 @@ import (
// or transport level security defined by the provided Config.
func New(config *Config) (http.RoundTripper, error) {
// Set transport level security
- if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.TLS.Insecure) {
+ if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
}
@@ -52,7 +53,7 @@ func New(config *Config) (http.RoundTripper, error) {
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) {
- if !(c.HasCA() || c.HasCertAuth() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) {
+ if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) {
return nil, nil
}
if c.HasCA() && c.TLS.Insecure {
@@ -75,12 +76,40 @@ func TLSConfigFor(c *Config) (*tls.Config, error) {
tlsConfig.RootCAs = rootCertPool(c.TLS.CAData)
}
+ var staticCert *tls.Certificate
if c.HasCertAuth() {
+ // If key/cert were provided, verify them before setting up
+ // tlsConfig.GetClientCertificate.
cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData)
if err != nil {
return nil, err
}
- tlsConfig.Certificates = []tls.Certificate{cert}
+ staticCert = &cert
+ }
+
+ if c.HasCertAuth() || c.HasCertCallback() {
+ tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ // Note: static key/cert data always take precedence over cert
+ // callback.
+ if staticCert != nil {
+ return staticCert, nil
+ }
+ if c.HasCertCallback() {
+ cert, err := c.TLS.GetCert()
+ if err != nil {
+ return nil, err
+ }
+ // GetCert may return empty value, meaning no cert.
+ if cert != nil {
+ return cert, nil
+ }
+ }
+
+ // Both c.TLS.CertData/KeyData were unset and GetCert didn't return
+ // anything. Return an empty tls.Certificate, no client cert will
+ // be sent to the server.
+ return &tls.Certificate{}, nil
+ }
}
return tlsConfig, nil
@@ -139,3 +168,60 @@ func rootCertPool(caData []byte) *x509.CertPool {
certPool.AppendCertsFromPEM(caData)
return certPool
}
+
+// WrapperFunc wraps an http.RoundTripper when a new transport
+// is created for a client, allowing per connection behavior
+// to be injected.
+type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
+
+// Wrappers accepts any number of wrappers and returns a wrapper
+// function that is the equivalent of calling each of them in order. Nil
+// values are ignored, which makes this function convenient for incrementally
+// wrapping a function.
+func Wrappers(fns ...WrapperFunc) WrapperFunc {
+ if len(fns) == 0 {
+ return nil
+ }
+ // optimize the common case of wrapping a possibly nil transport wrapper
+ // with an additional wrapper
+ if len(fns) == 2 && fns[0] == nil {
+ return fns[1]
+ }
+ return func(rt http.RoundTripper) http.RoundTripper {
+ base := rt
+ for _, fn := range fns {
+ if fn != nil {
+ base = fn(base)
+ }
+ }
+ return base
+ }
+}
+
+// ContextCanceller prevents new requests after the provided context is finished.
+// err is returned when the context is closed, allowing the caller to provide a context
+// appropriate error.
+func ContextCanceller(ctx context.Context, err error) WrapperFunc {
+ return func(rt http.RoundTripper) http.RoundTripper {
+ return &contextCanceller{
+ ctx: ctx,
+ rt: rt,
+ err: err,
+ }
+ }
+}
+
+type contextCanceller struct {
+ ctx context.Context
+ rt http.RoundTripper
+ err error
+}
+
+func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) {
+ select {
+ case <-b.ctx.Done():
+ return nil, b.err
+ default:
+ return b.rt.RoundTrip(req)
+ }
+}