-
Chris Hines authoredChris Hines authored
__init__.py 8.14 KiB
"""
This module persistently stores informion on tunnels in an in memory structure.
"""
import datetime
import yaml
import threading
class SshAgentException(Exception):
pass
class SSHSession:
"""Interfaces for working with processes forked from flask
in particular, we fork processes for ssh-agent and ssh tunnels and execution
"""
def __init__(self,**kwargs):
self.last = datetime.datetime.now()
self.socket = None
self.token = None
self.port = {}
self.key = ''
self.cert = ''
self.pids = []
self.authtok = None
self.__dict__.update(kwargs)
self.sshagent = 'ssh-agent'
self.sshadd = '/usr/bin/ssh-add'
self.sshkeygen = 'ssh-keygen'
self.ctrl_processes = {}
self.lock = threading.Lock()
def start_agent(self):
import subprocess
from .. import app
import logging
import os
logger = logging.getLogger()
logger.debug('starting agent')
if app.config['ENABLELAUNCH'] and 'SSH_AUTH_SOCK' in os.environ and os.environ['SSH_AUTH_SOCK']:
logger.debug('using existing agent')
self.socket = os.environ['SSH_AUTH_SOCK']
return
p = subprocess.Popen([self.sshagent],stdout=subprocess.PIPE,stderr=subprocess.PIPE)
(stdout,stderr) = p.communicate()
for l in stdout.decode().split(';'):
if 'SSH_AUTH_SOCK=' in l:
socket = l.split('=')[1]
self.socket = socket
if 'SSH_AGENT_PID=' in l:
pid = l.split('=')[1]
self.pids.append(pid)
def add_keycert(self,key,cert):
import tempfile
import os
import subprocess
import logging
logger = logging.getLogger()
if self.socket is None:
self.start_agent()
keyf = tempfile.NamedTemporaryFile(mode='w',delete=False)
keyname = keyf.name
keyf.write(key)
keyf.close()
certf = open(keyname+'-cert.pub',mode='w')
certf.write(cert)
certf.close()
p = subprocess.Popen([self.sshkeygen,'-L','-f','-'],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
keygenout,keygenerr = p.communicate(cert.encode())
# Examine the cert to determine its expiry. Use the -t flag to automatically remove from the ssh-agent when the cert expires
certcontents = SSHSession.parse_cert_contents(keygenout.decode().splitlines())
endtime = datetime.datetime.strptime(certcontents['Valid'][0].split()[3],"%Y-%m-%dT%H:%M:%S")
delta = endtime - datetime.datetime.now() # I *think* the output of ssh-keygen -L is in the current timezone even though I assume the certs validity is in UTC
env = os.environ.copy()
env['SSH_AUTH_SOCK'] = self.socket
cmd = [self.sshadd,'-t',"{}".format(int(delta.total_seconds()))]
cmd.append(keyname)
p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
(stdout,stderr) = p.communicate()
if p.returncode != 0:
logger.error("Couldn't add key and cert")
logger.error(stdout)
logger.error(stderr)
raise SshAgentException()
os.unlink(keyname+'-cert.pub')
os.unlink(keyname)
def get_cert_contents(self):
import os
import subprocess
import logging
logger=logging.getLogger()
res=[]
if self.socket is None:
return res
env = os.environ.copy()
env['SSH_AUTH_SOCK'] = self.socket
cmd = [self.sshadd,'-L']
p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
(stdout,stderr) = p.communicate()
for l in stdout.splitlines():
if b'cert' in l:
p = subprocess.Popen([self.sshkeygen,'-L','-f','-'],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
keygenout,keygenerr = p.communicate(l)
certcontents = SSHSession.parse_cert_contents(keygenout.decode().splitlines())
res.append(certcontents)
return res
@staticmethod
def parse_cert_contents(lines):
key = None
values = []
res = {}
for l in lines:
l = l.rstrip().lstrip()
if ':' in l:
if key is not None:
res[key] = values
values = []
(key,v) = l.split(':',1)
v = v.lstrip().rstrip()
if v != '':
values = [v]
else:
if l != '':
values.append(l)
return res
def refresh(self):
import datetime
self.last = datetime.datetime.now()
def addkey(self,key,cert):
pass
def kill(self):
import os
import signal
import logging
import time
logger=logging.getLogger()
logger.debug("shuting down ssh session for {} last seen at {}".format(self.authtok,self.last))
for pid in self.pids:
logger.debug("killing pid {}".format(pid))
try:
os.killpg(int(pid), signal.SIGTERM) # Sometimes this fails and I don't know why
try:
os.kill(int(pid), 0) # If the first kill worked, this will raise a ProcessLookupError
time.sleep(2)
os.killpg(int(pid),signal.SIGKILL)
logger.error('resorting to sigkill for pid {}'.format(pid))
except ProcessLookupError:
pass
logger.debug("killed {}".format(pid))
except ProcessLookupError as e:
logger.debug("process {} not found".format(pid))
for ctrl in self.ctrl_processes.items():
logger.debug("killing ctrl pid {}".format(ctrl[1]))
try:
ctrl[1].kill()
except:
pass
try:
ctrl[1].wait(5)
except:
pass
try:
os.unlink(ctrl[0])
except:
pass
@staticmethod
def test_sshsession(sess):
import os
import subprocess
import logging
logger=logging.getLogger()
env = os.environ.copy()
if sess.socket is None:
sess.start_agent()
env['SSH_AUTH_SOCK'] = sess.socket
cmd = [sess.sshadd,'-l']
p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE,env=env)
(stdout,stderr) = p.communicate()
if p.returncode != 0:
"""
A non-zero return code can occur if the agent is running
but there are no keys loaded
This is actually not an error condition
"""
if b'The agent has no identities' in stdout:
return
logger.error("Couldn't communicate with the ssh agent")
logger.error(stdout)
logger.error(stderr)
raise SshAgentException()
@staticmethod
def get_sshsession():
import random
import string
from .. import sshsessions
from flask import session
sshsessid = session.get('sshsessid', None)
N = 8
while sshsessid is None:
key = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(N))
if key not in session:
sshsessid = key
session['sshsessid'] = sshsessid
if sshsessid not in sshsessions:
sshsessions[sshsessid] = SSHSession()
session = sshsessions[sshsessid]
try:
SSHSession.test_sshsession(session)
except SshAgentException:
session.kill()
sshsessions[sshsessid] = SSHSession()
session = sshsessions[sshsessid]
SSHSession.test_sshsession(session)
return session
@staticmethod
def remove_sshsession():
import random
import string
from .. import sshsessions
from flask import session
sshsessid = session.get('sshsessid', None)
del sshsessions[sshsessid]