Skip to content
Snippets Groups Projects
__init__.py 5.34 KiB
Newer Older
"""
This module handles SSH Connections
"""
import subprocess

class SshAgentException(Exception):
    pass

class Ssh(object):
    """
    Ssh class can execute or create tunnelstat
    """
    @staticmethod
    def execute(sess, host, user, cmd, stdin=None):
        """
        execute the command cmd on the host via ssh
        # assume the environment is already setup with an
        # SSH_AUTH_SOCK that allows login
        """
        import os
        env = os.environ.copy()
        if sess.socket is None:
            raise SshAgentException("No ssh-agent yet")
        env['SSH_AUTH_SOCK'] = sess.socket
        exec_p = subprocess.Popen(['ssh', '-A', '-o', 'Stricthostkeychecking=no', '-l',
                                   user, host, cmd],
                                  stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                  stdin=subprocess.PIPE, env=env)
        if stdin is not None:
            (stdout, stderr) = exec_p.communicate(stdin.encode())
        else:
            (stdout, stderr) = exec_p.communicate()
        return {'stdout':stdout, 'stderr':stderr}

    @staticmethod
    def tunnel(sess, port, batchhost, user, host, internalfirewall = True, localbind = True):
        """
        the double tunnel is used if the server we run on the batch host is only
        addressiable on localhost
        e.g. jupyter in its default config runs on localhost:8888 so a web
        browser on the login node can not connect to it
        """
        print("localbind is",localbind,internalfirewall)
        if port == 22:
            print("port is 22, using single tunnel")
            return Ssh.singletunnel(sess, port, batchhost, user, host)
        if not internalfirewall and not localbind:
            return Ssh.singletunnel(sess, port, batchhost, user, host)
        print("using doubletunnel")
        return Ssh.doubletunnel(sess, port, batchhost, user, host)

    @staticmethod
    def addkey(sess,key,cert):
        pass
    def singletunnel(sess, port, batchhost, user, host):
        """
        # fork a daemon process to hold a tunnel open on a given port like
        # ssh -l user -L tunnel host -N
        """
        import os
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = sess.socket
        localport = Ssh.get_free_port()
        tunnel_p = subprocess.Popen(['ssh', '-N', '-o', 'Stricthostkeychecking=no', '-L',
                                     '{localport}:{batchhost}:{port}'.format(port=port,
                                                                             localport=localport,
                                                                             batchhost=batchhost),
                                     '-l', user, host], stdin=subprocess.PIPE, env=env)
        Ssh.wait_for_tunnel(localport)
        sess.port = localport
        sess.pids.append(tunnel_p.pid)
        return localport, [tunnel_p.pid]

    @staticmethod
    def doubletunnel(sess, port, batchhost, user, host):
        """
        # fork a daemon process to hold a tunnel open on a given port like
        # ssh -l user -L tunnel host -N
        """
        import os
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = sess.socket
        pids = []
        localport1 = Ssh.get_free_port()
        tunnel_p = subprocess.Popen(['ssh', '-N', '-o', 'Stricthostkeychecking=no', '-L',
                                     '{localport1}:{batchhost}:22'.format(localport1=localport1,
                                                                          batchhost=batchhost),
                                     '-l', user, host], stdin=subprocess.PIPE, env=env)
        pids.append(tunnel_p.pid)
        Ssh.wait_for_tunnel(localport1)
        localport2 = Ssh.get_free_port()
        tunnel_p = subprocess.Popen(['ssh', '-N', '-o', 'Stricthostkeychecking=no', '-L',
                                     '{localport2}:localhost:{port}'.format(port=port,
                                                                            localport2=localport2),
                                     '-l', user, 'localhost', '-p',
                                     '{}'.format(localport1)], stdin=subprocess.PIPE, env=env)
        pids.append(tunnel_p.pid)
        Ssh.wait_for_tunnel(localport2)
        sess.port = localport2
        sess.pids.extend(pids)
        return localport2, pids

    @staticmethod
    def wait_for_tunnel(localport):
        """
        # In order to avoid a race condition, we wait for the tunnel to be established
        """
        import socket
        notopen = True
        while notopen:
            ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ssock.setblocking(True)
            try:
                ssock.connect(('127.0.0.1', localport))
                notopen = False
                ssock.close()
            except ConnectionRefusedError:
                ssock.close()
        return

    @staticmethod
    def get_free_port():
        """
        # Finds a port which the local server can listen on.
        """
        import socket
        serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        for testport in range(1025, 65500):
            try:
                serversocket.bind(('127.0.0.1', testport))
                port = testport
                serversocket.close()
                return port
            except OSError:
                pass