diff options
Diffstat (limited to 'pkg/bindings/connection.go')
-rw-r--r-- | pkg/bindings/connection.go | 311 |
1 files changed, 285 insertions, 26 deletions
diff --git a/pkg/bindings/connection.go b/pkg/bindings/connection.go index 551a63c62..4fe4dd72d 100644 --- a/pkg/bindings/connection.go +++ b/pkg/bindings/connection.go @@ -1,14 +1,34 @@ package bindings import ( + "bufio" + "context" "fmt" "io" + "io/ioutil" + "net" "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/containers/libpod/pkg/api/handlers" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "k8s.io/client-go/util/homedir" ) -const ( - defaultConnection string = "http://localhost:8080/v1.24/libpod" - pingConnection string = "http://localhost:8080/_ping" +var ( + basePath = &url.URL{ + Scheme: "http", + Host: "d", + Path: "/v" + handlers.MinimalApiVersion + "/libpod", + } ) type APIResponse struct { @@ -17,46 +37,285 @@ type APIResponse struct { } type Connection struct { - url string + _url *url.URL client *http.Client } -func NewConnection(url string) (Connection, error) { - if len(url) < 1 { - url = defaultConnection +type valueKey string + +const ( + clientKey = valueKey("client") +) + +// GetClient from context build by NewConnection() +func GetClient(ctx context.Context) (*Connection, error) { + c, ok := ctx.Value(clientKey).(*Connection) + if !ok { + return nil, errors.Errorf("ClientKey not set in context") + } + return c, nil +} + +// JoinURL elements with '/' +func JoinURL(elements ...string) string { + return strings.Join(elements, "/") +} + +// NewConnection takes a URI as a string and returns a context with the +// Connection embedded as a value. This context needs to be passed to each +// endpoint to work correctly. +// +// A valid URI connection should be scheme:// +// For example tcp://localhost:<port> +// or unix:///run/podman/podman.sock +// or ssh://<user>@<host>[:port]/run/podman/podman.sock?secure=True +func NewConnection(ctx context.Context, uri string, identity ...string) (context.Context, error) { + var ( + err error + secure bool + ) + if v, found := os.LookupEnv("PODMAN_HOST"); found { + uri = v + } + + if v, found := os.LookupEnv("PODMAN_SSHKEY"); found { + identity = []string{v} + } + + _url, err := url.Parse(uri) + if err != nil { + return nil, errors.Wrapf(err, "Value of PODMAN_HOST is not a valid url: %s", uri) } - newConn := Connection{ - url: url, - client: &http.Client{}, + + // Now we setup the http client to use the connection above + var client *http.Client + switch _url.Scheme { + case "ssh": + secure, err = strconv.ParseBool(_url.Query().Get("secure")) + if err != nil { + secure = false + } + client, err = sshClient(_url, identity[0], secure) + case "unix": + if !strings.HasPrefix(uri, "unix:///") { + // autofix unix://path_element vs unix:///path_element + _url.Path = JoinURL(_url.Host, _url.Path) + _url.Host = "" + } + client, err = unixClient(_url) + case "tcp": + if !strings.HasPrefix(uri, "tcp://") { + return nil, errors.New("tcp URIs should begin with tcp://") + } + client, err = tcpClient(_url) + default: + return nil, errors.Errorf("'%s' is not a supported schema", _url.Scheme) } - response, err := http.Get(pingConnection) if err != nil { - return newConn, err + return nil, errors.Wrapf(err, "Failed to create %sClient", _url.Scheme) } - if err := response.Body.Close(); err != nil { - return newConn, err + + ctx = context.WithValue(ctx, clientKey, &Connection{_url, client}) + if err := pingNewConnection(ctx); err != nil { + return nil, err } - return newConn, err + return ctx, nil +} + +func tcpClient(_url *url.URL) (*http.Client, error) { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("tcp", _url.Path) + }, + DisableCompression: true, + }, + }, nil } -func (c Connection) makeEndpoint(u string) string { - return fmt.Sprintf("%s%s", defaultConnection, u) +// pingNewConnection pings to make sure the RESTFUL service is up +// and running. it should only be used where initializing a connection +func pingNewConnection(ctx context.Context) error { + client, err := GetClient(ctx) + if err != nil { + return err + } + // the ping endpoint sits at / in this case + response, err := client.DoRequest(nil, http.MethodGet, "../../../_ping", nil) + if err != nil { + return err + } + if response.StatusCode == http.StatusOK { + return nil + } + return errors.Errorf("ping response was %q", response.StatusCode) } -func (c Connection) newRequest(httpMethod, endpoint string, httpBody io.Reader, params map[string]string) (*APIResponse, error) { - e := c.makeEndpoint(endpoint) +func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error) { + auth, err := publicKey(identity) + if err != nil { + return nil, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity) + } + + callback := ssh.InsecureIgnoreHostKey() + if secure { + key := hostKey(_url.Hostname()) + if key != nil { + callback = ssh.FixedHostKey(key) + } + } + + port := _url.Port() + if port == "" { + port = "22" + } + + bastion, err := ssh.Dial("tcp", + net.JoinHostPort(_url.Hostname(), port), + &ssh.ClientConfig{ + User: _url.User.Username(), + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: callback, + HostKeyAlgorithms: []string{ + ssh.KeyAlgoRSA, + ssh.KeyAlgoDSA, + ssh.KeyAlgoECDSA256, + ssh.KeyAlgoECDSA384, + ssh.KeyAlgoECDSA521, + ssh.KeyAlgoED25519, + }, + Timeout: 5 * time.Second, + }, + ) + if err != nil { + return nil, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String()) + } + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return bastion.Dial("unix", _url.Path) + }, + }}, nil +} + +func unixClient(_url *url.URL) (*http.Client, error) { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "unix", _url.Path) + }, + DisableCompression: true, + }, + }, nil +} + +// DoRequest assembles the http request and returns the response +func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, pathValues ...string) (*APIResponse, error) { + var ( + err error + response *http.Response + ) + safePathValues := make([]interface{}, len(pathValues)) + // Make sure path values are http url safe + for i, pv := range pathValues { + safePathValues[i] = url.PathEscape(pv) + } + // Lets eventually use URL for this which might lead to safer + // usage + safeEndpoint := fmt.Sprintf(endpoint, safePathValues...) + e := basePath.String() + safeEndpoint req, err := http.NewRequest(httpMethod, e, httpBody) if err != nil { return nil, err } - if len(params) > 0 { - // if more desirable we could use url to form the encoded endpoint with params - r := req.URL.Query() - for k, v := range params { - r.Add(k, v) + if len(queryParams) > 0 { + req.URL.RawQuery = queryParams.Encode() + } + // Give the Do three chances in the case of a comm/service hiccup + for i := 0; i < 3; i++ { + response, err = c.client.Do(req) // nolint + if err == nil { + break } - req.URL.RawQuery = r.Encode() + time.Sleep(time.Duration(i*100) * time.Millisecond) } - response, err := c.client.Do(req) // nolint return &APIResponse{response, req}, err } + +// FiltersToString converts our typical filter format of a +// map[string][]string to a query/html safe string. +func FiltersToString(filters map[string][]string) (string, error) { + lowerCaseKeys := make(map[string][]string) + for k, v := range filters { + lowerCaseKeys[strings.ToLower(k)] = v + } + return jsoniter.MarshalToString(lowerCaseKeys) +} + +// IsInformation returns true if the response code is 1xx +func (h *APIResponse) IsInformational() bool { + return h.Response.StatusCode/100 == 1 +} + +// IsSuccess returns true if the response code is 2xx +func (h *APIResponse) IsSuccess() bool { + return h.Response.StatusCode/100 == 2 +} + +// IsRedirection returns true if the response code is 3xx +func (h *APIResponse) IsRedirection() bool { + return h.Response.StatusCode/100 == 3 +} + +// IsClientError returns true if the response code is 4xx +func (h *APIResponse) IsClientError() bool { + return h.Response.StatusCode/100 == 4 +} + +// IsServerError returns true if the response code is 5xx +func (h *APIResponse) IsServerError() bool { + return h.Response.StatusCode/100 == 5 +} + +func publicKey(path string) (ssh.AuthMethod, error) { + key, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, err + } + + return ssh.PublicKeys(signer), nil +} + +func hostKey(host string) ssh.PublicKey { + // parse OpenSSH known_hosts file + // ssh or use ssh-keyscan to get initial key + known_hosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts") + fd, err := os.Open(known_hosts) + if err != nil { + logrus.Error(err) + return nil + } + + scanner := bufio.NewScanner(fd) + for scanner.Scan() { + _, hosts, key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes()) + if err != nil { + logrus.Errorf("Failed to parse known_hosts: %s", scanner.Text()) + continue + } + + for _, h := range hosts { + if h == host { + return key + } + } + } + + return nil +} |