aboutsummaryrefslogtreecommitdiff
path: root/contrib/python/podman
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/python/podman')
-rw-r--r--contrib/python/podman/client.py148
-rw-r--r--contrib/python/podman/libs/tunnel.py131
2 files changed, 256 insertions, 23 deletions
diff --git a/contrib/python/podman/client.py b/contrib/python/podman/client.py
index 89fcf5c15..c62c1430a 100644
--- a/contrib/python/podman/client.py
+++ b/contrib/python/podman/client.py
@@ -1,6 +1,6 @@
"""A client for communicating with a Podman varlink service."""
-import contextlib
-import functools
+import os
+from urllib.parse import urlparse
from varlink import Client as VarlinkClient
from varlink import VarlinkError
@@ -10,6 +10,119 @@ from .libs.containers import Containers
from .libs.errors import error_factory
from .libs.images import Images
from .libs.system import System
+from .libs.tunnel import Context, Portal, Tunnel
+
+
+class BaseClient(object):
+ """Context manager for API workers to access varlink."""
+
+ def __call__(self):
+ """Support being called for old API."""
+ return self
+
+ @classmethod
+ def factory(cls,
+ uri=None,
+ interface='io.projectatomic.podman',
+ *args,
+ **kwargs):
+ """Construct a Client based on input."""
+ if uri is None:
+ raise ValueError('uri is required and cannot be None')
+ if interface is None:
+ raise ValueError('interface is required and cannot be None')
+
+ local_path = urlparse(uri).path
+ if local_path == '':
+ raise ValueError('path is required for uri, format'
+ ' "unix://path_to_socket"')
+
+ if kwargs.get('remote_uri') or kwargs.get('identity_file'):
+ # Remote access requires the full tuple of information
+ if kwargs.get('remote_uri') is None:
+ raise ValueError('remote is required, format'
+ ' "ssh://user@hostname/path_to_socket".')
+ remote = urlparse(kwargs['remote_uri'])
+ if remote.username is None:
+ raise ValueError('username is required for remote_uri, format'
+ ' "ssh://user@hostname/path_to_socket".')
+ if remote.path == '':
+ raise ValueError('path is required for remote_uri, format'
+ ' "ssh://user@hostname/path_to_socket".')
+ if remote.hostname is None:
+ raise ValueError('hostname is required for remote_uri, format'
+ ' "ssh://user@hostname/path_to_socket".')
+
+ if kwargs.get('identity_file') is None:
+ raise ValueError('identity_file is required.')
+
+ if not os.path.isfile(kwargs['identity_file']):
+ raise ValueError('identity_file "{}" not found.'.format(
+ kwargs['identity_file']))
+ return RemoteClient(
+ Context(uri, interface, local_path, remote.path,
+ remote.username, remote.hostname,
+ kwargs['identity_file']))
+ else:
+ return LocalClient(
+ Context(uri, interface, None, None, None, None, None))
+
+
+class LocalClient(BaseClient):
+ """Context manager for API workers to access varlink."""
+
+ def __init__(self, context):
+ """Construct LocalClient."""
+ self._context = context
+
+ def __enter__(self):
+ """Enter context for LocalClient."""
+ self._client = VarlinkClient(address=self._context.uri)
+ self._iface = self._client.open(self._context.interface)
+ return self._iface
+
+ def __exit__(self, e_type, e, e_traceback):
+ """Cleanup context for LocalClient."""
+ if hasattr(self._client, 'close'):
+ self._client.close()
+ self._iface.close()
+
+ if isinstance(e, VarlinkError):
+ raise error_factory(e)
+
+
+class RemoteClient(BaseClient):
+ """Context manager for API workers to access remote varlink."""
+
+ def __init__(self, context):
+ """Construct RemoteCLient."""
+ self._context = context
+ self._portal = Portal()
+
+ def __enter__(self):
+ """Context manager for API workers to access varlink."""
+ tunnel = self._portal.get(self._context.uri)
+ if tunnel is None:
+ tunnel = Tunnel(self._context).bore(self._context.uri)
+ self._portal[self._context.uri] = tunnel
+
+ try:
+ self._client = VarlinkClient(address=self._context.uri)
+ self._iface = self._client.open(self._context.interface)
+ return self._iface
+ except Exception:
+ self._close_tunnel(self._context.uri)
+ raise
+
+ def __exit__(self, e_type, e, e_traceback):
+ """Cleanup context for RemoteClient."""
+ if hasattr(self._client, 'close'):
+ self._client.close()
+ self._iface.close()
+
+ # set timer to shutdown ssh tunnel
+ if isinstance(e, VarlinkError):
+ raise error_factory(e)
class Client(object):
@@ -20,37 +133,26 @@ class Client(object):
>>> import podman
>>> c = podman.Client()
>>> c.system.versions
- """
- # TODO: Port to contextlib.AbstractContextManager once
- # Python >=3.6 required
+ Example remote podman:
+
+ >>> import podman
+ >>> c = podman.Client(uri='unix:/tmp/podman.sock',
+ remote_uri='ssh://user@host/run/podman/io.projectatomic.podman',
+ identity_file='~/.ssh/id_rsa')
+ """
def __init__(self,
uri='unix:/run/podman/io.projectatomic.podman',
- interface='io.projectatomic.podman'):
+ interface='io.projectatomic.podman',
+ **kwargs):
"""Construct a podman varlink Client.
uri from default systemd unit file.
interface from io.projectatomic.podman.varlink, do not change unless
you are a varlink guru.
"""
- self._podman = None
-
- @contextlib.contextmanager
- def _podman(uri, interface):
- """Context manager for API workers to access varlink."""
- client = VarlinkClient(address=uri)
- try:
- iface = client.open(interface)
- yield iface
- except VarlinkError as e:
- raise error_factory(e) from e
- finally:
- if hasattr(client, 'close'):
- client.close()
- iface.close()
-
- self._client = functools.partial(_podman, uri, interface)
+ self._client = BaseClient.factory(uri, interface, **kwargs)
# Quick validation of connection data provided
try:
diff --git a/contrib/python/podman/libs/tunnel.py b/contrib/python/podman/libs/tunnel.py
new file mode 100644
index 000000000..2cb178644
--- /dev/null
+++ b/contrib/python/podman/libs/tunnel.py
@@ -0,0 +1,131 @@
+"""Cache for SSH tunnels."""
+import collections
+import os
+import subprocess
+import threading
+import time
+import weakref
+
+Context = collections.namedtuple('Context', (
+ 'uri',
+ 'interface',
+ 'local_socket',
+ 'remote_socket',
+ 'username',
+ 'hostname',
+ 'identity_file',
+))
+
+
+class Portal(collections.MutableMapping):
+ """Expiring container for tunnels."""
+
+ def __init__(self, sweap=25):
+ """Construct portal, reap tunnels every sweap seconds."""
+ self.data = collections.OrderedDict()
+ self.sweap = sweap
+ self.ttl = sweap * 2
+ self.lock = threading.RLock()
+
+ def __getitem__(self, key):
+ """Given uri return tunnel and update TTL."""
+ with self.lock:
+ value, _ = self.data[key]
+ self.data[key] = (value, time.time() + self.ttl)
+ self.data.move_to_end(key)
+ return value
+
+ def __setitem__(self, key, value):
+ """Store given tunnel keyed with uri."""
+ if not isinstance(value, Tunnel):
+ raise ValueError('Portals only support Tunnels.')
+
+ with self.lock:
+ self.data[key] = (value, time.time() + self.ttl)
+ self.data.move_to_end(key)
+
+ def __delitem__(self, key):
+ """Remove and close tunnel from portal."""
+ with self.lock:
+ value, _ = self.data[key]
+ del self.data[key]
+ value.close(key)
+ del value
+
+ def __iter__(self):
+ """Iterate tunnels."""
+ with self.lock:
+ values = self.data.values()
+
+ for tunnel, _ in values:
+ yield tunnel
+
+ def __len__(self):
+ """Return number of tunnels in portal."""
+ with self.lock:
+ return len(self.data)
+
+ def _schedule_reaper(self):
+ timer = threading.Timer(interval=self.sweap, function=self.reap)
+ timer.setName('PortalReaper')
+ timer.setDaemon(True)
+ timer.start()
+
+ def reap(self):
+ """Remove tunnels who's TTL has expired."""
+ with self.lock:
+ now = time.time()
+ for entry, timeout in self.data:
+ if timeout < now:
+ self.__delitem__(entry)
+ else:
+ # StopIteration as soon as possible
+ break
+ self._schedule_reaper()
+
+
+class Tunnel(object):
+ """SSH tunnel."""
+
+ def __init__(self, context):
+ """Construct Tunnel."""
+ self.context = context
+ self._tunnel = None
+
+ def bore(self, id):
+ """Create SSH tunnel from given context."""
+ cmd = [
+ 'ssh',
+ '-nNT',
+ '-L',
+ '{}:{}'.format(self.context.local_socket,
+ self.context.remote_socket),
+ '-i',
+ self.context.identity_file,
+ 'ssh://{}@{}'.format(self.context.username, self.context.hostname),
+ ]
+
+ if os.environ.get('PODMAN_DEBUG'):
+ cmd.append('-vvv')
+
+ self._tunnel = subprocess.Popen(cmd, close_fds=True)
+ for i in range(5):
+ if os.path.exists(self.context.local_socket):
+ break
+ time.sleep(1)
+ else:
+ raise TimeoutError('Failed to create tunnel using: {}'.format(
+ ' '.join(cmd)))
+ weakref.finalize(self, self.close, id)
+ return self
+
+ def close(self, id):
+ """Close SSH tunnel."""
+ print('Tunnel collapsed!')
+ if self._tunnel is None:
+ return
+
+ self._tunnel.kill()
+ self._tunnel.wait(300)
+ os.remove(self.context.local_socket)
+ self._tunnel = None