diff options
Diffstat (limited to 'pkg/bindings/connection.go')
-rw-r--r-- | pkg/bindings/connection.go | 90 |
1 files changed, 65 insertions, 25 deletions
diff --git a/pkg/bindings/connection.go b/pkg/bindings/connection.go index da3755fc8..d21d55beb 100644 --- a/pkg/bindings/connection.go +++ b/pkg/bindings/connection.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/blang/semver" "github.com/containers/libpod/pkg/api/types" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" @@ -39,6 +40,7 @@ type APIResponse struct { type Connection struct { _url *url.URL client *http.Client + conn *net.Conn } type valueKey string @@ -88,26 +90,26 @@ func NewConnection(ctx context.Context, uri string, identity ...string) (context } // Now we setup the http client to use the connection above - var client *http.Client + var connection Connection switch _url.Scheme { case "ssh": secure, err = strconv.ParseBool(_url.Query().Get("secure")) if err != nil { secure = false } - client, err = sshClient(_url, identity[0], secure) + connection, err = sshClient(_url, identity[0], secure) case "unix": if !strings.HasPrefix(uri, "unix:///") { // autofix unix://path_element vs unix:///path_element _url.Path = JoinURL(_url.Host, _url.Path) _url.Host = "" } - client, err = unixClient(_url) + connection, err = unixClient(_url) case "tcp": if !strings.HasPrefix(uri, "tcp://") { return nil, errors.New("tcp URIs should begin with tcp://") } - client, err = tcpClient(_url) + connection, err = tcpClient(_url) default: return nil, errors.Errorf("'%s' is not a supported schema", _url.Scheme) } @@ -115,26 +117,34 @@ func NewConnection(ctx context.Context, uri string, identity ...string) (context return nil, errors.Wrapf(err, "Failed to create %sClient", _url.Scheme) } - ctx = context.WithValue(ctx, clientKey, &Connection{_url, client}) + ctx = context.WithValue(ctx, clientKey, &connection) if err := pingNewConnection(ctx); err != nil { return nil, err } return ctx, nil } -func tcpClient(_url *url.URL) (*http.Client, error) { - return &http.Client{ +func tcpClient(_url *url.URL) (Connection, error) { + connection := Connection{ + _url: _url, + } + connection.client = &http.Client{ Transport: &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("tcp", _url.Host) + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + conn, err := net.Dial("tcp", _url.Host) + if c, ok := ctx.Value(clientKey).(*Connection); ok { + c.conn = &conn + } + return conn, err }, DisableCompression: true, }, - }, nil + } + return connection, nil } // pingNewConnection pings to make sure the RESTFUL service is up -// and running. it should only be used where initializing a connection +// and running. it should only be used when initializing a connection func pingNewConnection(ctx context.Context) error { client, err := GetClient(ctx) if err != nil { @@ -145,16 +155,28 @@ func pingNewConnection(ctx context.Context) error { if err != nil { return err } + if response.StatusCode == http.StatusOK { - return nil + v, err := semver.ParseTolerant(response.Header.Get("Libpod-API-Version")) + if err != nil { + return err + } + + switch APIVersion.Compare(v) { + case 1, 0: + // Server's job when client version is equal or older + return nil + case -1: + return errors.Errorf("server API version is too old. client %q server %q", APIVersion.String(), v.String()) + } } return errors.Errorf("ping response was %q", response.StatusCode) } -func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error) { +func sshClient(_url *url.URL, identity string, secure bool) (Connection, error) { auth, err := publicKey(identity) if err != nil { - return nil, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity) + return Connection{}, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity) } callback := ssh.InsecureIgnoreHostKey() @@ -188,26 +210,39 @@ func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error }, ) if err != nil { - return nil, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String()) + return Connection{}, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String()) } - return &http.Client{ + + connection := Connection{_url: _url} + connection.client = &http.Client{ Transport: &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return bastion.Dial("unix", _url.Path) + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + conn, err := bastion.Dial("unix", _url.Path) + if c, ok := ctx.Value(clientKey).(*Connection); ok { + c.conn = &conn + } + return conn, err }, - }}, nil + }} + return connection, nil } -func unixClient(_url *url.URL) (*http.Client, error) { - return &http.Client{ +func unixClient(_url *url.URL) (Connection, error) { + connection := Connection{_url: _url} + connection.client = &http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { d := net.Dialer{} - return d.DialContext(ctx, "unix", _url.Path) + conn, err := d.DialContext(ctx, "unix", _url.Path) + if c, ok := ctx.Value(clientKey).(*Connection); ok { + c.conn = &conn + } + return conn, err }, DisableCompression: true, }, - }, nil + } + return connection, nil } // DoRequest assembles the http request and returns the response @@ -232,6 +267,7 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, if len(queryParams) > 0 { req.URL.RawQuery = queryParams.Encode() } + req = req.WithContext(context.WithValue(context.Background(), clientKey, c)) // Give the Do three chances in the case of a comm/service hiccup for i := 0; i < 3; i++ { response, err = c.client.Do(req) // nolint @@ -243,6 +279,10 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, return &APIResponse{response, req}, err } +func (c *Connection) Write(b []byte) (int, error) { + return (*c.conn).Write(b) +} + // FiltersToString converts our typical filter format of a // map[string][]string to a query/html safe string. func FiltersToString(filters map[string][]string) (string, error) { @@ -295,8 +335,8 @@ func publicKey(path string) (ssh.AuthMethod, error) { func hostKey(host string) ssh.PublicKey { // parse OpenSSH known_hosts file // ssh or use ssh-keyscan to get initial key - known_hosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts") - fd, err := os.Open(known_hosts) + knownHosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts") + fd, err := os.Open(knownHosts) if err != nil { logrus.Error(err) return nil |