summaryrefslogtreecommitdiff
path: root/pkg/bindings/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/bindings/connection.go')
-rw-r--r--pkg/bindings/connection.go85
1 files changed, 48 insertions, 37 deletions
diff --git a/pkg/bindings/connection.go b/pkg/bindings/connection.go
index b2e949f67..332aa97c8 100644
--- a/pkg/bindings/connection.go
+++ b/pkg/bindings/connection.go
@@ -15,7 +15,6 @@ import (
"github.com/blang/semver"
"github.com/containers/podman/v3/pkg/terminal"
"github.com/containers/podman/v3/version"
- jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
@@ -35,16 +34,24 @@ type Connection struct {
type valueKey string
const (
- clientKey = valueKey("Client")
+ clientKey = valueKey("Client")
+ versionKey = valueKey("ServiceVersion")
)
// GetClient from context build by NewConnection()
func GetClient(ctx context.Context) (*Connection, error) {
- c, ok := ctx.Value(clientKey).(*Connection)
- if !ok {
- return nil, errors.Errorf("ClientKey not set in context")
+ if c, ok := ctx.Value(clientKey).(*Connection); ok {
+ return c, nil
}
- return c, nil
+ return nil, errors.Errorf("%s not set in context", clientKey)
+}
+
+// ServiceVersion from context build by NewConnection()
+func ServiceVersion(ctx context.Context) *semver.Version {
+ if v, ok := ctx.Value(versionKey).(*semver.Version); ok {
+ return v
+ }
+ return new(semver.Version)
}
// JoinURL elements with '/'
@@ -52,6 +59,7 @@ func JoinURL(elements ...string) string {
return "/" + strings.Join(elements, "/")
}
+// NewConnection creates a new service connection without an identity
func NewConnection(ctx context.Context, uri string) (context.Context, error) {
return NewConnectionWithIdentity(ctx, uri, "")
}
@@ -116,9 +124,11 @@ func NewConnectionWithIdentity(ctx context.Context, uri string, identity string)
}
ctx = context.WithValue(ctx, clientKey, &connection)
- if err := pingNewConnection(ctx); err != nil {
+ serviceVersion, err := pingNewConnection(ctx)
+ if err != nil {
return nil, errors.Wrap(err, "unable to connect to Podman socket")
}
+ ctx = context.WithValue(ctx, versionKey, serviceVersion)
return ctx, nil
}
@@ -139,15 +149,15 @@ func tcpClient(_url *url.URL) Connection {
// pingNewConnection pings to make sure the RESTFUL service is up
// and running. it should only be used when initializing a connection
-func pingNewConnection(ctx context.Context) error {
+func pingNewConnection(ctx context.Context) (*semver.Version, error) {
client, err := GetClient(ctx)
if err != nil {
- return err
+ return nil, err
}
// the ping endpoint sits at / in this case
response, err := client.DoRequest(ctx, nil, http.MethodGet, "/_ping", nil, nil)
if err != nil {
- return err
+ return nil, err
}
defer response.Body.Close()
@@ -155,23 +165,23 @@ func pingNewConnection(ctx context.Context) error {
versionHdr := response.Header.Get("Libpod-API-Version")
if versionHdr == "" {
logrus.Info("Service did not provide Libpod-API-Version Header")
- return nil
+ return new(semver.Version), nil
}
versionSrv, err := semver.ParseTolerant(versionHdr)
if err != nil {
- return err
+ return nil, err
}
switch version.APIVersion[version.Libpod][version.MinimalAPI].Compare(versionSrv) {
case -1, 0:
// Server's job when Client version is equal or older
- return nil
+ return &versionSrv, nil
case 1:
- return errors.Errorf("server API version is too old. Client %q server %q",
+ return nil, errors.Errorf("server API version is too old. Client %q server %q",
version.APIVersion[version.Libpod][version.MinimalAPI].String(), versionSrv.String())
}
}
- return errors.Errorf("ping response was %d", response.StatusCode)
+ return nil, errors.Errorf("ping response was %d", response.StatusCode)
}
func sshClient(_url *url.URL, secure bool, passPhrase string, identity string) (Connection, error) {
@@ -306,26 +316,29 @@ func unixClient(_url *url.URL) Connection {
}
// DoRequest assembles the http request and returns the response
-func (c *Connection) DoRequest(ctx context.Context, httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, header map[string]string, pathValues ...string) (*APIResponse, error) {
+func (c *Connection) DoRequest(ctx context.Context, httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, headers http.Header, pathValues ...string) (*APIResponse, error) {
var (
err error
response *http.Response
)
- params := make([]interface{}, len(pathValues)+3)
+ params := make([]interface{}, len(pathValues)+1)
+
+ if v := headers.Values("API-Version"); len(v) > 0 {
+ params[0] = v[0]
+ } else {
+ // Including the semver suffices breaks older services... so do not include them
+ v := version.APIVersion[version.Libpod][version.CurrentAPI]
+ params[0] = fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch)
+ }
- // Including the semver suffices breaks older services... so do not include them
- v := version.APIVersion[version.Libpod][version.CurrentAPI]
- params[0] = v.Major
- params[1] = v.Minor
- params[2] = v.Patch
for i, pv := range pathValues {
// url.URL lacks the semantics for escaping embedded path parameters... so we manually
// escape each one and assume the caller included the correct formatting in "endpoint"
- params[i+3] = url.PathEscape(pv)
+ params[i+1] = url.PathEscape(pv)
}
- uri := fmt.Sprintf("http://d/v%d.%d.%d/libpod"+endpoint, params...)
+ uri := fmt.Sprintf("http://d/v%s/libpod"+endpoint, params...)
logrus.Debugf("DoRequest Method: %s URI: %v", httpMethod, uri)
req, err := http.NewRequestWithContext(ctx, httpMethod, uri, httpBody)
@@ -335,9 +348,17 @@ func (c *Connection) DoRequest(ctx context.Context, httpBody io.Reader, httpMeth
if len(queryParams) > 0 {
req.URL.RawQuery = queryParams.Encode()
}
- for key, val := range header {
- req.Header.Set(key, val)
+
+ for key, val := range headers {
+ if key == "API-Version" {
+ continue
+ }
+
+ for _, v := range val {
+ req.Header.Add(key, v)
+ }
}
+
// Give the Do three chances in the case of a comm/service hiccup
for i := 1; i <= 3; i++ {
response, err = c.Client.Do(req) // nolint
@@ -349,7 +370,7 @@ func (c *Connection) DoRequest(ctx context.Context, httpBody io.Reader, httpMeth
return &APIResponse{response, req}, err
}
-// Get raw Transport.DialContext from client
+// GetDialer returns raw Transport.DialContext from client
func (c *Connection) GetDialer(ctx context.Context) (net.Conn, error) {
client := c.Client
transport := client.Transport.(*http.Transport)
@@ -360,16 +381,6 @@ func (c *Connection) GetDialer(ctx context.Context) (net.Conn, error) {
return nil, errors.New("Unable to get dial context")
}
-// FiltersToString converts our typical filter format of a
-// map[string][]string to a query/html safe string.
-func FiltersToString(filters map[string][]string) (string, error) {
- lowerCaseKeys := make(map[string][]string)
- for k, v := range filters {
- lowerCaseKeys[strings.ToLower(k)] = v
- }
- return jsoniter.MarshalToString(lowerCaseKeys)
-}
-
// IsInformational returns true if the response code is 1xx
func (h *APIResponse) IsInformational() bool {
return h.Response.StatusCode/100 == 1