Skip to content
Snippets Groups Projects
apiendpoints.py 15 KiB
Newer Older
"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
from flask import render_template
from . import api, app
from .sshwrapper import Ssh, SshAgentException
from .tunnelstat import SSHSession
class AppParamsException(Exception):
    pass

class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering GetCert.post')
        data = request.get_json()
        try:
            response = {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}
        except:
            import traceback
            logger.error('Failed to get certificate')
            logger.error(traceback.format_exc())
            response = ("Unable to get a signed certificate",500)
        return response
    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        import logging
        logger = logging.getLogger()
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        data = resp.json()
        return data['certificate']

class TestAuth(Resource):
    """
    Tests whether the backend can login to the selected compute Resource
    """
    def get(self):
        """
        tell the SPA if the TES is logged in
        """
        return 'token' in session

    def post(self):
        import logging
        logger = logging.getLogger()
        logger.debug('entering SSHAgent.post')
        session.permanent = True
        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        sshsess.add_keycert(key=data['key'],cert=data['cert'])
        sshsess.refresh()
        logger.debug('leaving SSHAgent.post')
        from .tunnelstat import SSHSession
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            if sshsess.socket == None:
                return [] # The agent hasn't even been started yet
            return sshsess.get_cert_contents()
        except Exception as e:
            import traceback
            logger.debug('SSHAgent.get: Exception {}'.format(e))
            logger.debug(traceback.format_exc())
            flask_restful.abort(500)

        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        sshsess.kill()
        SSHSession.remove_sshsession()
        return []
def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the query string
    import logging
    logger = logging.getLogger()
    logger.debug('entering get_conn_params')
    identitystr = request.args.get('identity')
    identityparams = json.loads(identitystr)
    interfacestr = request.args.get('interface')
    interfaceparams = json.loads(interfacestr)
    try:
        appstr = request.args.get('app')
        appparams = json.loads(appstr)
    except:
        import traceback
        logger.error('exception retrieving app params')
        logger.error(traceback.format_exc())

    params['identity'] = identityparams
    params['interface'] = interfaceparams
    params['app'] = appparams

    params.update(interfaceparams)
    params['user'] = identityparams['username']
    params['host'] = identityparams['site']['host']

    logger.debug('leaving get_conn_params')


def get_app_params():
    """
    Return the parameters for the application retrieved from the Session
    """
    keys = ['startscript', 'paramscmd', 'client','localbind']
    appstr = request.args.get('app')
    returnvalue = json.loads(appstr)
    return returnvalue

class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    """
    # def put(self, authtok):
    #     """
    #     update the last used time on a tunnel
    #     """
    #     Tunnelstat.update(authtok)

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering TunnelstatEP.get')
        from . import sshsessions
        logger.debug('TunnelstatEP.get: iterating sshsessions {}'.format(authtok))
        port = None
        try:
            for (sessid,sshsess) in sshsessions.items():
                logger.debug('sshsession id {}'.format(sessid))
                for (tok,port) in sshsess.port.items():
                    logger.debug('token {}'.format(tok))
                    if tok == authtok:
                        logger.debug("found port {} for authtok {}".format(port,tok))
                        logger.debug('leaving TunnelstatEP.get')
                        return port
        except:
            logger.error("exception in TunnelstatEP.get")
            import traceback
            logger.error(traceback.format_exc())
        logger.debug("No ports found for authtok {} {}".format(port,authtok))
        logger.debug('leaving TunnelstatEP.get')
        return None
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        sshsess.refresh()
            host = params['identity']['site']['host']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of login host incomplete")
        try:
            user = params['identity']['username']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of username incomplete")
        try:
            cmd = params['interface']['statcmd']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of batch interface incomplete")
        try:
            res = Ssh.execute(sshsess, host=host, user=user, cmd=cmd)
        except SshAgentException as e:
            logger.error(e)
            flask_restful.abort(400, message="Identity error {}".format(e))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.error(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        jobs = json.loads(res['stdout'].decode())

        # Attach the identity information to the job before returning it
        for j in jobs:
            j['identity'] = params['identity']

        return jobs

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['cancelcmd'].format(jobid=jobid))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()


class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
Ubuntu's avatar
Ubuntu committed
        import logging
        logger=logging.getLogger()
        params = get_conn_params()
        logger.debug('entering JobSubmit.post {}'.format(params))
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['submitcmd'], stdin=params['app']['startscript'])
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.debug('failed to submit job')
Ubuntu's avatar
Ubuntu committed
            logger.debug(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    import logging
    logger=logging.getLogger()
    logger.debug('generating new authtok')
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

class JobConnect(Resource):
    """
    endpoints for connecting to an existing JobCancel
    """

    def create_tunnel(self, username, loginhost, appparams, batchhost, firewall, data):
        import logging
        logger=logging.getLogger()
        logger.debug('entering JobConnect.create_tunnel {} {}'.format(username,batchhost))
        sshsess = SSHSession.get_sshsession()

        if 'paramscmd' in appparams and appparams['paramscmd'] is not None:
            connectparams['batchhost'] = batchhost
            paramcmd = 'ssh -o StrictHostKeyChecking=no -o CheckHostIP=no {batchhost} '.format(batchhost=batchhost) + appparams['paramscmd']
            logger.debug('JobCreate.create_tunnel: using ssh to extract connection parameters')
            res = Ssh.execute(sshsess, host=loginhost, user=username, cmd=paramcmd.format(data))
            try:
                data = json.loads(res['stdout'])
            except json.decoder.JSONDecodeError as e:
                raise AppParamsException(res['stderr']+res['stdout'])
            if len(res['stderr']) > 0:
                raise AppParamsException(res['stderr'])
            if 'error' in data:
                raise AppParamsException(data['error'])
            try:
                connectparams.update(json.loads(res['stdout']))
            except json.decoder.JSONDecodeError as e:
                logger.error(res['stdout'])
                logger.error(res['stderr'])
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                flask_restful.abort(400, message=res['stderr'].decode())

            if self.validate_connect_params(connectparams, username, loginhost):
                authtok = gen_authtok()
                logger.debug('JobCreate.create_tunnel: creating a tunnel for authtok {}'.format(authtok))
                tunnelport, pids = Ssh.tunnel(sshsess, port=connectparams['port'],
                                              batchhost=connectparams['batchhost'],
                                              user=username, host=loginhost,
                                              internalfirewall=firewall,
                                              localbind=appparams['localbind'], authtok=authtok)

                connectparams['localtunnelport'] = tunnelport
                logger.debug('JobCreate.create_tunnel: created a tunnel for authtok {} port {}'.format(authtok,tunnelport))
            else:
                raise AppParamsException("connection parameters invalid {} {} {}".format(connectparams,username,loginhost))
    def validate_connect_params(self, connectparams, username, host):
        if not 'port' in connectparams:
            return False
        if not 'batchhost' in connectparams:
            return False
        try:
            intport = int(connectparams['port'])
        except Exception as e:
            return False
        if ' ' in username or '\n' in username: # This really needs more validation
            return False
        if ' ' in host or '\n' in host: # This really needs more validation
            return False
        return True

    def get(self, jobid, batchhost):
        """
        Connecting to a job is a get operation (i.e. it does not make modifications)
        """
        import logging
        logger=logging.getLogger()
        logger.debug('entering JobConnect.get for jobid {} {}'.format(jobid,batchhost))
        params = get_conn_params()
        appparams = get_app_params()
        data = request.get_json()
        try:
            connectparams = self.create_tunnel(params['identity']['username'],params['identity']['site']['host'],
                                            appparams, batchhost, params['interface']['internalfirewall'],
                                            data)
        except AppParamsException as e:
            return make_response(render_template('appparams.html.j2',data = "{}".format(e)))
        logger.debug('JobConnect.get tunnels created, moving to redirect'.format(jobid,batchhost))
        return self.connect(appparams, connectparams)

    def connect(self, appparams, connectparams):
        """
        perform the connection either by forking a local client or returning a redirect
        """
        import subprocess
        import logging
        logger=logging.getLogger()
        if 'cmd' in appparams['client'] and appparams['client']['cmd'] is not None:
            # We need for fork a local process such as vncviewer or a terminal
            # We may need a wrapper for local processes to find the correct
            # process on all OS
            cmdlist = []
            for cmdarg in appparams['client']['cmd']:
                cmdlist.append(cmdarg.format(**connectparams))
            app_process = subprocess.Popen(cmdlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        elif 'redir' in appparams['client'] and appparams['client']['redir'] is not None:
            twsproxy = app.config['TWSPROXY']
            data = json.dumps({'location': twsproxy+appparams['client']['redir'].
                                              format(**connectparams) })

            response = make_response(data)
            response.mime_type = 'application/json'
            response.set_cookie('twsproxyauth', connectparams['authtok'])
            logger.debug('JobConnect.connect: connecting via redirect with cookie authtok set to  {}'.format(connectparams['authtok']))
            return response
        return "Connecting with cmd {}".format(cmdlist)


api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(JobStat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(JobConnect, '/connect/<int:jobid>/<string:batchhost>')
# api.add_resource(SessionTest,'/sesstest')
#api.add_resource(StartAgent,'/startagent')
api.add_resource(SSHAgent,'/sshagent')