Skip to content
Snippets Groups Projects
__init__.py 7.45 KiB
Newer Older
"""
This module persistently stores informion on tunnels in a yaml file
it probably shouldn't be used on a server handling multiple requests
due to the overhead of opening reading writing locking etc but its
probably OK for a single user computer
"""
import datetime
import yaml

class SSHSession:
    """Interfaces for working with processes forked from flask
    in particular, we fork processes for ssh-agent and ssh tunnels and execution
    def __init__(self,**kwargs):
        self.last = datetime.datetime.now()
        self.socket = None
        self.token = None
        self.port = None
        self.key = ''
        self.cert = ''
        self.pids = []
        self.authtok = None
        self.__dict__.update(kwargs)
        self.sshadd = '/usr/bin/ssh-add'
    def start_agent(self):
        import subprocess
        p = subprocess.Popen([self.sshagent],stdout=subprocess.PIPE,stderr=subprocess.PIPE)
        (stdout,stderr) = p.communicate()
        for l in stdout.decode().split(';'):
            if 'SSH_AUTH_SOCK=' in l:
                socket = l.split('=')[1]
                self.socket = socket
            if 'SSH_AGENT_PID=' in l:
                pid = l.split('=')[1]
                self.pids.append(pid)
    def add_keycert(self,key,cert):
        import tempfile
        import os
        import subprocess
        import logging
        logger = logging.getLogger()
        if self.socket is None:
            self.start_agent()
        keyf = tempfile.NamedTemporaryFile(mode='w',delete=False)
        keyname = keyf.name
        keyf.write(key)
        keyf.close()
        certf = open(keyname+'-cert.pub',mode='w')
        certf.write(cert)
        certf.close()
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = self.socket
        cmd.append(keyname)
        p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
        (stdout,stderr) = p.communicate()
        if p.returncode != 0:
            logger.error("Couldn't add key and cert")
            logger.debug(stdout)
            logger.debug(stderr)
            print(stderr)
            print(stdout)
        os.unlink(keyname+'-cert.pub')
        os.unlink(keyname)
        import os
        import subprocess
        import logging
        logger=logging.getLogger()
        if self.socket is None:
            return res
        env = os.environ.copy()
        env['SSH_AUTH_SOCK'] = self.socket
        cmd = [self.sshadd,'-L']
        p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
        (stdout,stderr) = p.communicate()
        if stderr is not None:
            logger.debug('called sshadd, got stderr {}'.format(stderr))
        if stdout is not None:
            logger.debug('called sshadd, got stdout {}'.format(stdout))
        for l in stdout.splitlines():
            logger.debug('is {} a cert?'.format(l))
            if b'cert' in l:
                logger.debug('decoding {}'.format(l))
                p = subprocess.Popen([self.sshkeygen,'-L','-f','-'],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                keygenout,keygenerr = p.communicate(l)
                certcontents = SSHSession.parse_cert_contents(keygenout.decode().splitlines())
                res.append(certcontents)
        return res

    @staticmethod
    def parse_cert_contents(lines):
        key = None
        values = []
        res = {}
        for l in lines:
            if ':' in l:
                if key is not None:
                    res[key] = values
                values = []
                (key,v) = l.split(':',1)
                if v is not '':
                if l is not '':
                    values.append(l)
    def refresh(self):
        import datetime
        self.last = datetime.datetime.now()
        import logging
        logger = logging.getLogger()
        logger.debug("updated datetime to {} for {}".format(self.last,self.authtok))
        print("updated datetime to {} for {}".format(self.last,self.authtok))
    def set_authtok(self,authtok):
        self.authtok = authtok
    def addkey(self,key,cert):
        pass

    def kill(self):
        import os
        import signal
        import logging
        logger=logging.getLogger()
        print("killing all processes associated with this sshsession")
        logger.debug("shuting down ssh session for {} last seen at {}".format(self.authtok,self.last))
        print("shuting down ssh session for {} last seen at {}".format(self.authtok,self.last))
        print(self.pids)
        for pid in self.pids:
            try:
                os.kill(int(pid), signal.SIGTERM)
            except ProcessLookupError as e:
                print("process {} not found".format(pid))

    def get_port(self, authtok):
        if self.authtok is not None and self.authtok == authtok:
            return self.port
    def get_sshsession():
        import random
        import string
        from .. import sshsessions
        from flask import session
        sshsessid = session.get('sshsessid', None)
        N = 8
        while sshsessid is None:
            key = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(N))
            if key not in session:
                sshsessid = key
                session['sshsessid'] = sshsessid
        if sshsessid not in sshsessions:
            sshsessions[sshsessid] = SSHSession()
        return sshsessions[sshsessid]

    @staticmethod
    def remove_sshsession():
        import random
        import string
        from .. import sshsessions
        from flask import session
        sshsessid = session.get('sshsessid', None)
        N = 8
        del sshsessions[sshsessid]

# class Tunnelstat:
#     """
#     class docstring same as module docstring
#     """
#     @staticmethod
#     def loaddata():
#         """load data from a file"""
#         with open('tunnels.yml') as fptr:
#             data = yaml.load(fptr.read())
#         if data is None:
#             data = {}
#         return data
#
#     @staticmethod
#     def savedata(data):
#         """save data to a file"""
#         with open('tunnels.yml', 'w') as fptr:
#             fptr.write(yaml.dump(data))
#
#     @staticmethod
#     def update(authtok):
#         """update the last used time for a tunnel"""
#         data = Tunnelstat.loaddata()
#         for key, value in data.items():
#             if key == authtok:
#                 value['last'] = datetime.datetime.now()
#         Tunnelstat.savedata(data)
#
#
#
#     @staticmethod
#     def getport(authtok):
#         """given authtok, retrieve the port of the corresponding ssh tunnel"""
#         data = Tunnelstat.loaddata()
#         if authtok in data:
#             return data[authtok]['port']
#         return None
#
#     @staticmethod
#     def newtunnel(authtok, port, pids):
#         """a new tunnel has just been created, save the data info so we can
#          a) reap it latter
#          b) allow the transparent web socket proxy to connect
#         """
#         newtunnel = {'port': port, 'last':datetime.datetime.now(), 'pids':pids}
#         data = Tunnelstat.loaddata()
#         data[authtok] = newtunnel
#         Tunnelstat.savedata(data)