diff options
Diffstat (limited to 'pkg/bindings/connection.go')
-rw-r--r-- | pkg/bindings/connection.go | 234 |
1 files changed, 182 insertions, 52 deletions
diff --git a/pkg/bindings/connection.go b/pkg/bindings/connection.go index f270060a6..75f1fc6a5 100644 --- a/pkg/bindings/connection.go +++ b/pkg/bindings/connection.go @@ -1,22 +1,34 @@ package bindings import ( + "bufio" "context" "fmt" "io" + "io/ioutil" "net" "net/http" "net/url" + "os" "path/filepath" + "strconv" "strings" + "time" "github.com/containers/libpod/pkg/api/handlers" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "k8s.io/client-go/util/homedir" ) var ( - defaultConnectionPath string = filepath.Join(fmt.Sprintf("v%s", handlers.MinimalApiVersion), "libpod") + basePath = &url.URL{ + Scheme: "http", + Host: "d", + Path: "/v" + handlers.MinimalApiVersion + "/libpod", + } ) type APIResponse struct { @@ -25,9 +37,28 @@ type APIResponse struct { } type Connection struct { - scheme string - address string - client *http.Client + _url *url.URL + client *http.Client +} + +type valueKey string + +const ( + clientKey = valueKey("client") +) + +// GetClient from context build by NewConnection() +func GetClient(ctx context.Context) (*Connection, error) { + c, ok := ctx.Value(clientKey).(*Connection) + if !ok { + return nil, errors.Errorf("ClientKey not set in context") + } + return c, nil +} + +// JoinURL elements with '/' +func JoinURL(elements ...string) string { + return strings.Join(elements, "/") } // NewConnection takes a URI as a string and returns a context with the @@ -36,46 +67,81 @@ type Connection struct { // // A valid URI connection should be scheme:// // For example tcp://localhost:<port> -// or unix://run/podman/podman.sock -func NewConnection(uri string) (context.Context, error) { - u, err := url.Parse(uri) - if err != nil { - return nil, err - } - // TODO once ssh is implemented, remove this block and - // add it to the conditional beneath it - if u.Scheme == "ssh" { - return nil, ErrNotImplemented +// or unix:///run/podman/podman.sock +// or ssh://<user>@<host>[:port]/run/podman/podman.sock?secure=True +func NewConnection(ctx context.Context, uri string, identity ...string) (context.Context, error) { + var ( + err error + secure bool + ) + if v, found := os.LookupEnv("PODMAN_HOST"); found { + uri = v } - if u.Scheme != "tcp" && u.Scheme != "unix" { - return nil, errors.Errorf("%s is not a support schema", u.Scheme) + + if v, found := os.LookupEnv("PODMAN_SSHKEY"); found { + identity = []string{v} } - if u.Scheme == "tcp" && !strings.HasPrefix(uri, "tcp://") { - return nil, errors.New("tcp URIs should begin with tcp://") + _url, err := url.Parse(uri) + if err != nil { + return nil, errors.Wrapf(err, "Value of PODMAN_HOST is not a valid url: %s", uri) } - address := u.Path - if u.Scheme == "tcp" { - address = u.Host + // Now we setup the http client to use the connection above + var client *http.Client + switch _url.Scheme { + case "ssh": + secure, err = strconv.ParseBool(_url.Query().Get("secure")) + if err != nil { + secure = false + } + client, err = sshClient(_url, identity[0], secure) + case "unix": + if !strings.HasPrefix(uri, "unix:///") { + // autofix unix://path_element vs unix:///path_element + _url.Path = JoinURL(_url.Host, _url.Path) + _url.Host = "" + } + client, err = unixClient(_url) + case "tcp": + if !strings.HasPrefix(uri, "tcp://") { + return nil, errors.New("tcp URIs should begin with tcp://") + } + client, err = tcpClient(_url) + default: + return nil, errors.Errorf("%s is not a support schema", _url.Scheme) + } + if err != nil { + return nil, errors.Wrapf(err, "Failed to create %sClient", _url.Scheme) } - newConn := newConnection(u.Scheme, address) - ctx := context.WithValue(context.Background(), "conn", &newConn) + + ctx = context.WithValue(ctx, clientKey, &Connection{_url, client}) if err := pingNewConnection(ctx); err != nil { return nil, err } return ctx, nil } +func tcpClient(_url *url.URL) (*http.Client, error) { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("tcp", _url.Path) + }, + DisableCompression: true, + }, + }, nil +} + // pingNewConnection pings to make sure the RESTFUL service is up // and running. it should only be used where initializing a connection func pingNewConnection(ctx context.Context) error { - conn, err := GetConnectionFromContext(ctx) + client, err := GetClient(ctx) if err != nil { return err } // the ping endpoint sits at / in this case - response, err := conn.DoRequest(nil, http.MethodGet, "../../../_ping", nil) + response, err := client.DoRequest(nil, http.MethodGet, "../../../_ping", nil) if err != nil { return err } @@ -85,26 +151,58 @@ func pingNewConnection(ctx context.Context) error { return errors.Errorf("ping response was %q", response.StatusCode) } -// newConnection takes a scheme and address and creates a connection from it -func newConnection(scheme, address string) Connection { - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial(scheme, address) +func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error) { + auth, err := publicKey(identity) + if err != nil { + return nil, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity) + } + + callback := ssh.InsecureIgnoreHostKey() + if secure { + key := hostKey(_url.Hostname()) + if key != nil { + callback = ssh.FixedHostKey(key) + } + } + + bastion, err := ssh.Dial("tcp", + net.JoinHostPort(_url.Hostname(), _url.Port()), + &ssh.ClientConfig{ + User: _url.User.Username(), + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: callback, + HostKeyAlgorithms: []string{ + ssh.KeyAlgoRSA, + ssh.KeyAlgoDSA, + ssh.KeyAlgoECDSA256, + ssh.KeyAlgoECDSA384, + ssh.KeyAlgoECDSA521, + ssh.KeyAlgoED25519, }, + Timeout: 5 * time.Second, }, + ) + if err != nil { + return nil, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String()) } - newConn := Connection{ - client: &client, - address: address, - scheme: scheme, - } - return newConn + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return bastion.Dial("unix", _url.Path) + }, + }}, nil } -func (c *Connection) makeEndpoint(u string) string { - // The d character in the url is discarded and is meaningless - return fmt.Sprintf("http://d/%s%s", defaultConnectionPath, u) +func unixClient(_url *url.URL) (*http.Client, error) { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "unix", _url.Path) + }, + DisableCompression: true, + }, + }, nil } // DoRequest assembles the http request and returns the response @@ -121,7 +219,7 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, // Lets eventually use URL for this which might lead to safer // usage safeEndpoint := fmt.Sprintf(endpoint, safePathValues...) - e := c.makeEndpoint(safeEndpoint) + e := basePath.String() + safeEndpoint req, err := http.NewRequest(httpMethod, e, httpBody) if err != nil { return nil, err @@ -140,21 +238,11 @@ func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, if err == nil { break } + time.Sleep(time.Duration(i*100) * time.Millisecond) } return &APIResponse{response, req}, err } -// GetConnectionFromContext returns a bindings connection from the context -// being passed into each method. -func GetConnectionFromContext(ctx context.Context) (*Connection, error) { - c := ctx.Value("conn") - if c == nil { - return nil, errors.New("unable to get connection from context") - } - conn := c.(*Connection) - return conn, nil -} - // FiltersToString converts our typical filter format of a // map[string][]string to a query/html safe string. func FiltersToString(filters map[string][]string) (string, error) { @@ -189,3 +277,45 @@ func (h *APIResponse) IsClientError() bool { func (h *APIResponse) IsServerError() bool { return h.Response.StatusCode/100 == 5 } + +func publicKey(path string) (ssh.AuthMethod, error) { + key, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, err + } + + return ssh.PublicKeys(signer), nil +} + +func hostKey(host string) ssh.PublicKey { + // parse OpenSSH known_hosts file + // ssh or use ssh-keyscan to get initial key + known_hosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts") + fd, err := os.Open(known_hosts) + if err != nil { + logrus.Error(err) + return nil + } + + scanner := bufio.NewScanner(fd) + for scanner.Scan() { + _, hosts, key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes()) + if err != nil { + logrus.Errorf("Failed to parse known_hosts: %s", scanner.Text()) + continue + } + + for _, h := range hosts { + if h == host { + return key + } + } + } + + return nil +} |