aboutsummaryrefslogtreecommitdiff
path: root/pkg/bindings/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/bindings/connection.go')
-rw-r--r--pkg/bindings/connection.go105
1 files changed, 76 insertions, 29 deletions
diff --git a/pkg/bindings/connection.go b/pkg/bindings/connection.go
index da3755fc8..e9032f083 100644
--- a/pkg/bindings/connection.go
+++ b/pkg/bindings/connection.go
@@ -15,7 +15,7 @@ import (
"strings"
"time"
- "github.com/containers/libpod/pkg/api/types"
+ "github.com/blang/semver"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -27,7 +27,7 @@ var (
basePath = &url.URL{
Scheme: "http",
Host: "d",
- Path: "/v" + types.MinimalAPIVersion + "/libpod",
+ Path: "/v" + APIVersion.String() + "/libpod",
}
)
@@ -39,6 +39,7 @@ type APIResponse struct {
type Connection struct {
_url *url.URL
client *http.Client
+ conn *net.Conn
}
type valueKey string
@@ -88,26 +89,26 @@ func NewConnection(ctx context.Context, uri string, identity ...string) (context
}
// Now we setup the http client to use the connection above
- var client *http.Client
+ var connection Connection
switch _url.Scheme {
case "ssh":
secure, err = strconv.ParseBool(_url.Query().Get("secure"))
if err != nil {
secure = false
}
- client, err = sshClient(_url, identity[0], secure)
+ connection, err = sshClient(_url, identity[0], secure)
case "unix":
if !strings.HasPrefix(uri, "unix:///") {
// autofix unix://path_element vs unix:///path_element
_url.Path = JoinURL(_url.Host, _url.Path)
_url.Host = ""
}
- client, err = unixClient(_url)
+ connection, err = unixClient(_url)
case "tcp":
if !strings.HasPrefix(uri, "tcp://") {
return nil, errors.New("tcp URIs should begin with tcp://")
}
- client, err = tcpClient(_url)
+ connection, err = tcpClient(_url)
default:
return nil, errors.Errorf("'%s' is not a supported schema", _url.Scheme)
}
@@ -115,46 +116,71 @@ func NewConnection(ctx context.Context, uri string, identity ...string) (context
return nil, errors.Wrapf(err, "Failed to create %sClient", _url.Scheme)
}
- ctx = context.WithValue(ctx, clientKey, &Connection{_url, client})
+ ctx = context.WithValue(ctx, clientKey, &connection)
if err := pingNewConnection(ctx); err != nil {
return nil, err
}
return ctx, nil
}
-func tcpClient(_url *url.URL) (*http.Client, error) {
- return &http.Client{
+func tcpClient(_url *url.URL) (Connection, error) {
+ connection := Connection{
+ _url: _url,
+ }
+ connection.client = &http.Client{
Transport: &http.Transport{
- DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
- return net.Dial("tcp", _url.Host)
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", _url.Host)
+ if c, ok := ctx.Value(clientKey).(*Connection); ok {
+ c.conn = &conn
+ }
+ return conn, err
},
DisableCompression: true,
},
- }, nil
+ }
+ return connection, nil
}
// pingNewConnection pings to make sure the RESTFUL service is up
-// and running. it should only be used where initializing a connection
+// and running. it should only be used when initializing a connection
func pingNewConnection(ctx context.Context) error {
client, err := GetClient(ctx)
if err != nil {
return err
}
// the ping endpoint sits at / in this case
- response, err := client.DoRequest(nil, http.MethodGet, "../../../_ping", nil)
+ response, err := client.DoRequest(nil, http.MethodGet, "../../../_ping", nil, nil)
if err != nil {
return err
}
+
if response.StatusCode == http.StatusOK {
- return nil
+ versionHdr := response.Header.Get("Libpod-API-Version")
+ if versionHdr == "" {
+ logrus.Info("Service did not provide Libpod-API-Version Header")
+ return nil
+ }
+ versionSrv, err := semver.ParseTolerant(versionHdr)
+ if err != nil {
+ return err
+ }
+
+ switch APIVersion.Compare(versionSrv) {
+ case 1, 0:
+ // Server's job when client version is equal or older
+ return nil
+ case -1:
+ return errors.Errorf("server API version is too old. client %q server %q", APIVersion.String(), versionSrv.String())
+ }
}
return errors.Errorf("ping response was %q", response.StatusCode)
}
-func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error) {
+func sshClient(_url *url.URL, identity string, secure bool) (Connection, error) {
auth, err := publicKey(identity)
if err != nil {
- return nil, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity)
+ return Connection{}, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity)
}
callback := ssh.InsecureIgnoreHostKey()
@@ -188,30 +214,43 @@ func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error
},
)
if err != nil {
- return nil, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String())
+ return Connection{}, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String())
}
- return &http.Client{
+
+ connection := Connection{_url: _url}
+ connection.client = &http.Client{
Transport: &http.Transport{
- DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
- return bastion.Dial("unix", _url.Path)
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ conn, err := bastion.Dial("unix", _url.Path)
+ if c, ok := ctx.Value(clientKey).(*Connection); ok {
+ c.conn = &conn
+ }
+ return conn, err
},
- }}, nil
+ }}
+ return connection, nil
}
-func unixClient(_url *url.URL) (*http.Client, error) {
- return &http.Client{
+func unixClient(_url *url.URL) (Connection, error) {
+ connection := Connection{_url: _url}
+ connection.client = &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
d := net.Dialer{}
- return d.DialContext(ctx, "unix", _url.Path)
+ conn, err := d.DialContext(ctx, "unix", _url.Path)
+ if c, ok := ctx.Value(clientKey).(*Connection); ok {
+ c.conn = &conn
+ }
+ return conn, err
},
DisableCompression: true,
},
- }, nil
+ }
+ return connection, nil
}
// DoRequest assembles the http request and returns the response
-func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, pathValues ...string) (*APIResponse, error) {
+func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, header map[string]string, pathValues ...string) (*APIResponse, error) {
var (
err error
response *http.Response
@@ -232,6 +271,10 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string,
if len(queryParams) > 0 {
req.URL.RawQuery = queryParams.Encode()
}
+ for key, val := range header {
+ req.Header.Set(key, val)
+ }
+ req = req.WithContext(context.WithValue(context.Background(), clientKey, c))
// Give the Do three chances in the case of a comm/service hiccup
for i := 0; i < 3; i++ {
response, err = c.client.Do(req) // nolint
@@ -243,6 +286,10 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string,
return &APIResponse{response, req}, err
}
+func (c *Connection) Write(b []byte) (int, error) {
+ return (*c.conn).Write(b)
+}
+
// FiltersToString converts our typical filter format of a
// map[string][]string to a query/html safe string.
func FiltersToString(filters map[string][]string) (string, error) {
@@ -295,8 +342,8 @@ func publicKey(path string) (ssh.AuthMethod, error) {
func hostKey(host string) ssh.PublicKey {
// parse OpenSSH known_hosts file
// ssh or use ssh-keyscan to get initial key
- known_hosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts")
- fd, err := os.Open(known_hosts)
+ knownHosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts")
+ fd, err := os.Open(knownHosts)
if err != nil {
logrus.Error(err)
return nil