diff options
author | OpenShift Merge Robot <openshift-merge-robot@users.noreply.github.com> | 2019-06-25 21:40:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-25 21:40:38 +0200 |
commit | 5b7086abda91f4301af3bfb642d416a22349c276 (patch) | |
tree | bf139f29b261e55c161394637f1c7073da5103f0 /vendor/k8s.io/client-go/transport | |
parent | a488e197a6e3947dd420b40ed834b50db9c829c3 (diff) | |
parent | 2388222e98462fdbbe44f3e091b2b79d80956a9a (diff) | |
download | podman-5b7086abda91f4301af3bfb642d416a22349c276.tar.gz podman-5b7086abda91f4301af3bfb642d416a22349c276.tar.bz2 podman-5b7086abda91f4301af3bfb642d416a22349c276.zip |
Merge pull request #3418 from vrothberg/go-modules
update dependencies
Diffstat (limited to 'vendor/k8s.io/client-go/transport')
-rw-r--r-- | vendor/k8s.io/client-go/transport/OWNERS | 2 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/cache.go | 8 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/config.go | 33 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/round_trippers.go | 82 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/spdy/spdy.go | 2 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/token_source.go | 149 | ||||
-rw-r--r-- | vendor/k8s.io/client-go/transport/transport.go | 92 |
7 files changed, 332 insertions, 36 deletions
diff --git a/vendor/k8s.io/client-go/transport/OWNERS b/vendor/k8s.io/client-go/transport/OWNERS index bf0ba5b9f..a52176903 100644 --- a/vendor/k8s.io/client-go/transport/OWNERS +++ b/vendor/k8s.io/client-go/transport/OWNERS @@ -1,3 +1,5 @@ +# See the OWNERS docs at https://go.k8s.io/owners + reviewers: - smarterclayton - wojtek-t diff --git a/vendor/k8s.io/client-go/transport/cache.go b/vendor/k8s.io/client-go/transport/cache.go index 83291c575..7cffe2a5f 100644 --- a/vendor/k8s.io/client-go/transport/cache.go +++ b/vendor/k8s.io/client-go/transport/cache.go @@ -43,6 +43,7 @@ type tlsCacheKey struct { caData string certData string keyData string + getCert string serverName string dial string } @@ -52,7 +53,7 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "<redacted>" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, dial:%s", t.insecure, t.caData, t.certData, keyText, t.serverName, t.dial) + return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { @@ -85,7 +86,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { dial = (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).Dial + }).DialContext } // Cache a single transport for these options c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ @@ -93,7 +94,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, MaxIdleConnsPerHost: idleConnsPerHost, - Dial: dial, + DialContext: dial, }) return c.transports[key], nil } @@ -109,6 +110,7 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) { caData: string(c.TLS.CAData), certData: string(c.TLS.CertData), keyData: string(c.TLS.KeyData), + getCert: fmt.Sprintf("%p", c.TLS.GetCert), serverName: c.TLS.ServerName, dial: fmt.Sprintf("%p", c.Dial), }, nil diff --git a/vendor/k8s.io/client-go/transport/config.go b/vendor/k8s.io/client-go/transport/config.go index af347dafe..5de0a2cb1 100644 --- a/vendor/k8s.io/client-go/transport/config.go +++ b/vendor/k8s.io/client-go/transport/config.go @@ -17,6 +17,8 @@ limitations under the License. package transport import ( + "context" + "crypto/tls" "net" "net/http" ) @@ -37,6 +39,11 @@ type Config struct { // Bearer token for authentication BearerToken string + // Path to a file containing a BearerToken. + // If set, the contents are periodically read. + // The last successfully read value takes precedence over BearerToken. + BearerTokenFile string + // Impersonate is the config that this Config will impersonate using Impersonate ImpersonationConfig @@ -50,10 +57,13 @@ type Config struct { // from TLSClientConfig, Transport, or http.DefaultTransport). The // config may layer other RoundTrippers on top of the returned // RoundTripper. - WrapTransport func(rt http.RoundTripper) http.RoundTripper + // + // A future release will change this field to an array. Use config.Wrap() + // instead of setting this value directly. + WrapTransport WrapperFunc // Dial specifies the dial function for creating unencrypted TCP connections. - Dial func(network, addr string) (net.Conn, error) + Dial func(ctx context.Context, network, address string) (net.Conn, error) } // ImpersonationConfig has all the available impersonation options @@ -78,12 +88,25 @@ func (c *Config) HasBasicAuth() bool { // HasTokenAuth returns whether the configuration has token authentication or not. func (c *Config) HasTokenAuth() bool { - return len(c.BearerToken) != 0 + return len(c.BearerToken) != 0 || len(c.BearerTokenFile) != 0 } // HasCertAuth returns whether the configuration has certificate authentication or not. func (c *Config) HasCertAuth() bool { - return len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0 + return (len(c.TLS.CertData) != 0 || len(c.TLS.CertFile) != 0) && (len(c.TLS.KeyData) != 0 || len(c.TLS.KeyFile) != 0) +} + +// HasCertCallbacks returns whether the configuration has certificate callback or not. +func (c *Config) HasCertCallback() bool { + return c.TLS.GetCert != nil +} + +// Wrap adds a transport middleware function that will give the caller +// an opportunity to wrap the underlying http.RoundTripper prior to the +// first API call being made. The provided function is invoked after any +// existing transport wrappers are invoked. +func (c *Config) Wrap(fn WrapperFunc) { + c.WrapTransport = Wrappers(c.WrapTransport, fn) } // TLSConfig holds the information needed to set up a TLS transport. @@ -98,4 +121,6 @@ type TLSConfig struct { CAData []byte // Bytes of the PEM-encoded server trusted root certificates. Supercedes CAFile. CertData []byte // Bytes of the PEM-encoded client certificate. Supercedes CertFile. KeyData []byte // Bytes of the PEM-encoded client key. Supercedes KeyFile. + + GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. } diff --git a/vendor/k8s.io/client-go/transport/round_trippers.go b/vendor/k8s.io/client-go/transport/round_trippers.go index de64e0078..117a9c8c4 100644 --- a/vendor/k8s.io/client-go/transport/round_trippers.go +++ b/vendor/k8s.io/client-go/transport/round_trippers.go @@ -17,13 +17,13 @@ limitations under the License. package transport import ( - "bytes" "fmt" "net/http" "strings" "time" - "github.com/golang/glog" + "golang.org/x/oauth2" + "k8s.io/klog" utilnet "k8s.io/apimachinery/pkg/util/net" ) @@ -45,7 +45,11 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip case config.HasBasicAuth() && config.HasTokenAuth(): return nil, fmt.Errorf("username/password or bearer token may be set, but not both") case config.HasTokenAuth(): - rt = NewBearerAuthRoundTripper(config.BearerToken, rt) + var err error + rt, err = NewBearerAuthWithRefreshRoundTripper(config.BearerToken, config.BearerTokenFile, rt) + if err != nil { + return nil, err + } case config.HasBasicAuth(): rt = NewBasicAuthRoundTripper(config.Username, config.Password, rt) } @@ -63,13 +67,13 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip // DebugWrappers wraps a round tripper and logs based on the current log level. func DebugWrappers(rt http.RoundTripper) http.RoundTripper { switch { - case bool(glog.V(9)): + case bool(klog.V(9)): rt = newDebuggingRoundTripper(rt, debugCurlCommand, debugURLTiming, debugResponseHeaders) - case bool(glog.V(8)): + case bool(klog.V(8)): rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus, debugResponseHeaders) - case bool(glog.V(7)): + case bool(klog.V(7)): rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus) - case bool(glog.V(6)): + case bool(klog.V(6)): rt = newDebuggingRoundTripper(rt, debugURLTiming) } @@ -139,7 +143,7 @@ func (rt *authProxyRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.rt.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.rt) } } @@ -167,7 +171,7 @@ func (rt *userAgentRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.rt.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.rt) } } @@ -198,7 +202,7 @@ func (rt *basicAuthRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.rt.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.rt) } } @@ -258,7 +262,7 @@ func (rt *impersonatingRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.delegate.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.delegate) } } @@ -266,13 +270,35 @@ func (rt *impersonatingRoundTripper) WrappedRoundTripper() http.RoundTripper { r type bearerAuthRoundTripper struct { bearer string + source oauth2.TokenSource rt http.RoundTripper } // NewBearerAuthRoundTripper adds the provided bearer token to a request // unless the authorization header has already been set. func NewBearerAuthRoundTripper(bearer string, rt http.RoundTripper) http.RoundTripper { - return &bearerAuthRoundTripper{bearer, rt} + return &bearerAuthRoundTripper{bearer, nil, rt} +} + +// NewBearerAuthRoundTripper adds the provided bearer token to a request +// unless the authorization header has already been set. +// If tokenFile is non-empty, it is periodically read, +// and the last successfully read content is used as the bearer token. +// If tokenFile is non-empty and bearer is empty, the tokenFile is read +// immediately to populate the initial bearer token. +func NewBearerAuthWithRefreshRoundTripper(bearer string, tokenFile string, rt http.RoundTripper) (http.RoundTripper, error) { + if len(tokenFile) == 0 { + return &bearerAuthRoundTripper{bearer, nil, rt}, nil + } + source := NewCachedFileTokenSource(tokenFile) + if len(bearer) == 0 { + token, err := source.Token() + if err != nil { + return nil, err + } + bearer = token.AccessToken + } + return &bearerAuthRoundTripper{bearer, source, rt}, nil } func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -281,7 +307,13 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, } req = utilnet.CloneRequest(req) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", rt.bearer)) + token := rt.bearer + if rt.source != nil { + if refreshedToken, err := rt.source.Token(); err == nil { + token = refreshedToken.AccessToken + } + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return rt.rt.RoundTrip(req) } @@ -289,7 +321,7 @@ func (rt *bearerAuthRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.rt.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.rt) } } @@ -336,7 +368,7 @@ func (r *requestInfo) toCurl() string { } } - return fmt.Sprintf("curl -k -v -X%s %s %s", r.RequestVerb, headers, r.RequestURL) + return fmt.Sprintf("curl -k -v -X%s %s '%s'", r.RequestVerb, headers, r.RequestURL) } // debuggingRoundTripper will display information about the requests passing @@ -373,7 +405,7 @@ func (rt *debuggingRoundTripper) CancelRequest(req *http.Request) { if canceler, ok := rt.delegatedRoundTripper.(requestCanceler); ok { canceler.CancelRequest(req) } else { - glog.Errorf("CancelRequest not implemented") + klog.Errorf("CancelRequest not implemented by %T", rt.delegatedRoundTripper) } } @@ -381,17 +413,17 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e reqInfo := newRequestInfo(req) if rt.levels[debugJustURL] { - glog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL) + klog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL) } if rt.levels[debugCurlCommand] { - glog.Infof("%s", reqInfo.toCurl()) + klog.Infof("%s", reqInfo.toCurl()) } if rt.levels[debugRequestHeaders] { - glog.Infof("Request Headers:") + klog.Infof("Request Headers:") for key, values := range reqInfo.RequestHeaders { for _, value := range values { - glog.Infof(" %s: %s", key, value) + klog.Infof(" %s: %s", key, value) } } } @@ -403,16 +435,16 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e reqInfo.complete(response, err) if rt.levels[debugURLTiming] { - glog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) + klog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) } if rt.levels[debugResponseStatus] { - glog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) + klog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) } if rt.levels[debugResponseHeaders] { - glog.Infof("Response Headers:") + klog.Infof("Response Headers:") for key, values := range reqInfo.ResponseHeaders { for _, value := range values { - glog.Infof(" %s: %s", key, value) + klog.Infof(" %s: %s", key, value) } } } @@ -435,7 +467,7 @@ func shouldEscape(b byte) bool { } func headerKeyEscape(key string) string { - var buf bytes.Buffer + buf := strings.Builder{} for i := 0; i < len(key); i++ { b := key[i] if shouldEscape(b) { diff --git a/vendor/k8s.io/client-go/transport/spdy/spdy.go b/vendor/k8s.io/client-go/transport/spdy/spdy.go index e0eb468ba..53cc7ee18 100644 --- a/vendor/k8s.io/client-go/transport/spdy/spdy.go +++ b/vendor/k8s.io/client-go/transport/spdy/spdy.go @@ -38,7 +38,7 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er if err != nil { return nil, nil, err } - upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true) + upgradeRoundTripper := spdy.NewRoundTripper(tlsConfig, true, false) wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) if err != nil { return nil, nil, err diff --git a/vendor/k8s.io/client-go/transport/token_source.go b/vendor/k8s.io/client-go/transport/token_source.go new file mode 100644 index 000000000..b8cadd382 --- /dev/null +++ b/vendor/k8s.io/client-go/transport/token_source.go @@ -0,0 +1,149 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "fmt" + "io/ioutil" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/oauth2" + "k8s.io/klog" +) + +// TokenSourceWrapTransport returns a WrapTransport that injects bearer tokens +// authentication from an oauth2.TokenSource. +func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) http.RoundTripper { + return func(rt http.RoundTripper) http.RoundTripper { + return &tokenSourceTransport{ + base: rt, + ort: &oauth2.Transport{ + Source: ts, + Base: rt, + }, + } + } +} + +// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a +// file at a specified path and periodically reloads it. +func NewCachedFileTokenSource(path string) oauth2.TokenSource { + return &cachingTokenSource{ + now: time.Now, + leeway: 10 * time.Second, + base: &fileTokenSource{ + path: path, + // This period was picked because it is half of the duration between when the kubelet + // refreshes a projected service account token and when the original token expires. + // Default token lifetime is 10 minutes, and the kubelet starts refreshing at 80% of lifetime. + // This should induce re-reading at a frequency that works with the token volume source. + period: time.Minute, + }, + } +} + +// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a +// designed TokenSource. The ts would provide the source of token. +func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource { + return &cachingTokenSource{ + now: time.Now, + base: ts, + } +} + +type tokenSourceTransport struct { + base http.RoundTripper + ort http.RoundTripper +} + +func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // This is to allow --token to override other bearer token providers. + if req.Header.Get("Authorization") != "" { + return tst.base.RoundTrip(req) + } + return tst.ort.RoundTrip(req) +} + +type fileTokenSource struct { + path string + period time.Duration +} + +var _ = oauth2.TokenSource(&fileTokenSource{}) + +func (ts *fileTokenSource) Token() (*oauth2.Token, error) { + tokb, err := ioutil.ReadFile(ts.path) + if err != nil { + return nil, fmt.Errorf("failed to read token file %q: %v", ts.path, err) + } + tok := strings.TrimSpace(string(tokb)) + if len(tok) == 0 { + return nil, fmt.Errorf("read empty token from file %q", ts.path) + } + + return &oauth2.Token{ + AccessToken: tok, + Expiry: time.Now().Add(ts.period), + }, nil +} + +type cachingTokenSource struct { + base oauth2.TokenSource + leeway time.Duration + + sync.RWMutex + tok *oauth2.Token + + // for testing + now func() time.Time +} + +var _ = oauth2.TokenSource(&cachingTokenSource{}) + +func (ts *cachingTokenSource) Token() (*oauth2.Token, error) { + now := ts.now() + // fast path + ts.RLock() + tok := ts.tok + ts.RUnlock() + + if tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) { + return tok, nil + } + + // slow path + ts.Lock() + defer ts.Unlock() + if tok := ts.tok; tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) { + return tok, nil + } + + tok, err := ts.base.Token() + if err != nil { + if ts.tok == nil { + return nil, err + } + klog.Errorf("Unable to rotate token: %v", err) + return ts.tok, nil + } + + ts.tok = tok + return tok, nil +} diff --git a/vendor/k8s.io/client-go/transport/transport.go b/vendor/k8s.io/client-go/transport/transport.go index c2bb7ae5e..2a145c971 100644 --- a/vendor/k8s.io/client-go/transport/transport.go +++ b/vendor/k8s.io/client-go/transport/transport.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -28,7 +29,7 @@ import ( // or transport level security defined by the provided Config. func New(config *Config) (http.RoundTripper, error) { // Set transport level security - if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.TLS.Insecure) { + if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) { return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed") } @@ -52,7 +53,7 @@ func New(config *Config) (http.RoundTripper, error) { // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(c *Config) (*tls.Config, error) { - if !(c.HasCA() || c.HasCertAuth() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) { + if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0) { return nil, nil } if c.HasCA() && c.TLS.Insecure { @@ -75,12 +76,40 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { tlsConfig.RootCAs = rootCertPool(c.TLS.CAData) } + var staticCert *tls.Certificate if c.HasCertAuth() { + // If key/cert were provided, verify them before setting up + // tlsConfig.GetClientCertificate. cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + staticCert = &cert + } + + if c.HasCertAuth() || c.HasCertCallback() { + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Note: static key/cert data always take precedence over cert + // callback. + if staticCert != nil { + return staticCert, nil + } + if c.HasCertCallback() { + cert, err := c.TLS.GetCert() + if err != nil { + return nil, err + } + // GetCert may return empty value, meaning no cert. + if cert != nil { + return cert, nil + } + } + + // Both c.TLS.CertData/KeyData were unset and GetCert didn't return + // anything. Return an empty tls.Certificate, no client cert will + // be sent to the server. + return &tls.Certificate{}, nil + } } return tlsConfig, nil @@ -139,3 +168,60 @@ func rootCertPool(caData []byte) *x509.CertPool { certPool.AppendCertsFromPEM(caData) return certPool } + +// WrapperFunc wraps an http.RoundTripper when a new transport +// is created for a client, allowing per connection behavior +// to be injected. +type WrapperFunc func(rt http.RoundTripper) http.RoundTripper + +// Wrappers accepts any number of wrappers and returns a wrapper +// function that is the equivalent of calling each of them in order. Nil +// values are ignored, which makes this function convenient for incrementally +// wrapping a function. +func Wrappers(fns ...WrapperFunc) WrapperFunc { + if len(fns) == 0 { + return nil + } + // optimize the common case of wrapping a possibly nil transport wrapper + // with an additional wrapper + if len(fns) == 2 && fns[0] == nil { + return fns[1] + } + return func(rt http.RoundTripper) http.RoundTripper { + base := rt + for _, fn := range fns { + if fn != nil { + base = fn(base) + } + } + return base + } +} + +// ContextCanceller prevents new requests after the provided context is finished. +// err is returned when the context is closed, allowing the caller to provide a context +// appropriate error. +func ContextCanceller(ctx context.Context, err error) WrapperFunc { + return func(rt http.RoundTripper) http.RoundTripper { + return &contextCanceller{ + ctx: ctx, + rt: rt, + err: err, + } + } +} + +type contextCanceller struct { + ctx context.Context + rt http.RoundTripper + err error +} + +func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) { + select { + case <-b.ctx.Done(): + return nil, b.err + default: + return b.rt.RoundTrip(req) + } +} |