Skip to content
Snippets Groups Projects
__init__.py 7.94 KiB
Newer Older
This module persistently stores informion on tunnels in an in memory structure.
"""
import datetime
import yaml

class SshAgentException(Exception):
    pass

class SSHSession:
    """Interfaces for working with processes forked from flask
    in particular, we fork processes for ssh-agent and ssh tunnels and execution
    def __init__(self,**kwargs):
        self.last = datetime.datetime.now()
        self.socket = None
        self.token = None
        self.key = ''
        self.cert = ''
        self.pids = []
        self.authtok = None
        self.__dict__.update(kwargs)
        self.sshadd = '/usr/bin/ssh-add'
        self.ctrl_processes = {}
    def start_agent(self):
        import subprocess
        from .. import app
        import logging
        import os
        logger = logging.getLogger()
        logger.debug('starting agent')
        if app.config['ENABLELAUNCH'] and os.environ['SSH_AUTH_SOCK']:
            logger.debug('using existing agent')
            self.socket = os.environ['SSH_AUTH_SOCK']
            return
        p = subprocess.Popen([self.sshagent],stdout=subprocess.PIPE,stderr=subprocess.PIPE)
        (stdout,stderr) = p.communicate()
        for l in stdout.decode().split(';'):
            if 'SSH_AUTH_SOCK=' in l:
                socket = l.split('=')[1]
                self.socket = socket
            if 'SSH_AGENT_PID=' in l:
                pid = l.split('=')[1]
                self.pids.append(pid)
    def add_keycert(self,key,cert):
        import tempfile
        import os
        import subprocess
        import logging
        logger = logging.getLogger()
        if self.socket is None:
            self.start_agent()
        keyf = tempfile.NamedTemporaryFile(mode='w',delete=False)
        keyname = keyf.name
        keyf.write(key)
        keyf.close()
        certf = open(keyname+'-cert.pub',mode='w')
        certf.write(cert)
        certf.close()
        p = subprocess.Popen([self.sshkeygen,'-L','-f','-'],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
Chris Hines's avatar
Chris Hines committed
        keygenout,keygenerr = p.communicate(cert.encode())
        # Examine the cert to determine its expiry. Use the -t flag to automatically remove from the ssh-agent when the cert expires
        certcontents = SSHSession.parse_cert_contents(keygenout.decode().splitlines())
        endtime = datetime.datetime.strptime(certcontents['Valid'][0].split()[3],"%Y-%m-%dT%H:%M:%S")
        delta = endtime - datetime.datetime.now() # I *think* the output of ssh-keygen -L is in the current timezone even though I assume the certs validity is in UTC
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = self.socket
Chris Hines's avatar
Chris Hines committed
        cmd = [self.sshadd,'-t',"{}".format(int(delta.total_seconds()))]
        cmd.append(keyname)
        p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
        (stdout,stderr) = p.communicate()
        if p.returncode != 0:
            logger.error("Couldn't add key and cert")
            logger.error(stdout)
            logger.error(stderr)
            raise SshAgentException()
        os.unlink(keyname+'-cert.pub')
        os.unlink(keyname)
        import os
        import subprocess
        import logging
        logger=logging.getLogger()
        if self.socket is None:
            return res
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = self.socket
        cmd = [self.sshadd,'-L']
        p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
        (stdout,stderr) = p.communicate()
        for l in stdout.splitlines():
            if b'cert' in l:
                p = subprocess.Popen([self.sshkeygen,'-L','-f','-'],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                keygenout,keygenerr = p.communicate(l)
                certcontents = SSHSession.parse_cert_contents(keygenout.decode().splitlines())
                res.append(certcontents)
        return res

    @staticmethod
    def parse_cert_contents(lines):
        key = None
        values = []
        res = {}
        for l in lines:
            if ':' in l:
                if key is not None:
                    res[key] = values
                values = []
                (key,v) = l.split(':',1)
Chris Hines's avatar
Chris Hines committed
                if v != '':
Chris Hines's avatar
Chris Hines committed
                if l != '':
                    values.append(l)
    def refresh(self):
        import datetime
        self.last = datetime.datetime.now()
    def addkey(self,key,cert):
        pass

    def kill(self):
        import os
        import signal
        import logging
Chris Hines's avatar
Chris Hines committed
        import time
        logger=logging.getLogger()
        logger.debug("shuting down ssh session for {} last seen at {}".format(self.authtok,self.last))
        for pid in self.pids:
            try:
Chris Hines's avatar
Chris Hines committed
                os.killpg(int(pid), signal.SIGTERM) # Sometimes this fails and I don't know why
                try:
                    os.kill(int(pid), 0) # If the first kill worked, this will raise a ProcessLookupError
                    time.sleep(2)
                    os.killpg(int(pid),signal.SIGKILL)
                    logger.error('resorting to sigkill for pid {}'.format(pid))
                except ProcessLookupError:
                    pass
                logger.debug("killed {}".format(pid))
            except ProcessLookupError as e:
Ubuntu's avatar
Ubuntu committed
                logger.debug("process {} not found".format(pid))
        for ctrl in self.ctrl_processes.items():
            try:
                ctrl[1].kill()
            except:
                pass
            try:
                ctrl[1].wait(5)
            except:
                pass
            try:
                os.unlink(ctrl[0])
            except: 
                pass
    @staticmethod
    def test_sshsession(sess):
        import os
        import subprocess
        import logging
        logger=logging.getLogger()
        env = os.environ.copy()
        if sess.socket is None:
            sess.start_agent()
        env['SSH_AUTH_SOCK'] = sess.socket
        cmd = [sess.sshadd,'-l']
        p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
        (stdout,stderr) = p.communicate()
        if p.returncode != 0:
            """
            A non-zero return code can occur if the agent is running
            but there are no keys loaded
            This is actually not an error condition
            """
            if b'The agent has no identities' in stdout:
                return
            logger.error("Couldn't communicate with the ssh agent")
            logger.error(stdout)
            logger.error(stderr)
            raise SshAgentException()

    def get_sshsession():
        import random
        import string
        from .. import sshsessions
        from flask import session
        sshsessid = session.get('sshsessid', None)
        N = 8
        while sshsessid is None:
            key = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(N))
            if key not in session:
                sshsessid = key
                session['sshsessid'] = sshsessid
        if sshsessid not in sshsessions:
            sshsessions[sshsessid] = SSHSession()
        session = sshsessions[sshsessid]
        try:
            SSHSession.test_sshsession(session)
        except SshAgentException:
            session.kill()
            sshsessions[sshsessid] = SSHSession()
            session = sshsessions[sshsessid]
            SSHSession.test_sshsession(session)

        return session
    @staticmethod
    def remove_sshsession():
        import random
        import string
        from .. import sshsessions
        from flask import session
        sshsessid = session.get('sshsessid', None)
        del sshsessions[sshsessid]