Skip to content
Snippets Groups Projects
apiendpoints.py 13.2 KiB
Newer Older
"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
from flask import render_template
from . import api, app
from .sshwrapper import Ssh, SshAgentException
from .tunnelstat import SSHSession
class AppParamsException(Exception):
    pass

class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        print("in GetCert.post")
        data = request.get_json()
        print(data)
        return {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}

    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        print("accss_token {}".format(access_token))
        print("pub_key {}".format(pub_key))
        print("url {}".format(url))
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        print("get_cert returned from its external call")
        data = resp.json()
        return data['certificate']

class TestAuth(Resource):
    """
    Tests whether the backend can login to the selected compute Resource
    """
    def get(self):
        """
        tell the SPA if the TES is logged in
        """
        return 'token' in session

    def post(self):
        session.permanent = True
        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        import logging
        logger = logging.getLogger()
        logger.debug('adding a key to the session')
        sshsess.add_keycert(key=data['key'],cert=data['cert'])
        logger.debug('started an agent and added the key')
        sshsess.refresh()
        return "OK"

        from .tunnelstat import SSHSession
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            if sshsess.socket == None:
                return [] # The agent hasn't even been started yet
            return sshsess.get_cert_contents()
        except Exception as e:
            import traceback
            logger.debug('SSHAgent.get: Exception {}'.format(e))
            logger.debug(traceback.format_exc())
            flask_restful.abort(500)

        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        sshsess.kill()
        SSHSession.remove_sshsession()
        return []
def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the query string
    identitystr = request.args.get('identity')
    identityparams = json.loads(identitystr)
    interfacestr = request.args.get('interface')
    interfaceparams = json.loads(interfacestr)
    try:
        appstr = request.args.get('app')
        appparams = json.loads(appstr)
    except:
        appparams = {}

    params['identity'] = identityparams
    params['interface'] = interfaceparams
    params['app'] = appparams

    params.update(interfaceparams)
    params['user'] = identityparams['username']
    params['host'] = identityparams['site']['host']


def get_app_params():
    """
    Return the parameters for the application retrieved from the Session
    """
    keys = ['startscript', 'paramscmd', 'client','localbind']
    appstr = request.args.get('app')
    returnvalue = json.loads(appstr)
    return returnvalue

class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    """
    # def put(self, authtok):
    #     """
    #     update the last used time on a tunnel
    #     """
    #     Tunnelstat.update(authtok)

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        import logging
        logger = logging.getLogger()
        from . import sshsessions
        from flask import session

        for (id,sshsess) in sshsessions.items():
            for (tok,port) in sshsess.port.items():
                if tok == authtok:
                    print("found port {} for authtok {}".format(port,tok))
                    logger.debug("found port {} for authtok {}".format(port,tok))
                    return port
        return None
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        logger.info('/stat endpoint entered')
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        sshsess.refresh()
        logger.info('/stat endpoint, all parameters collected')
            host = params['identity']['site']['host']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of login host incomplete")
        try:
            user = params['identity']['username']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of username incomplete")
        try:
            cmd = params['interface']['statcmd']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of batch interface incomplete")
        try:
            res = Ssh.execute(sshsess, host=host, user=user, cmd=cmd)
        except SshAgentException as e:
            logger.error(e)
            flask_restful.abort(400, message="Identity error {}".format(e))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.error(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        jobs = json.loads(res['stdout'].decode())

        # Attach the identity information to the job before returning it
        for j in jobs:
            j['identity'] = params['identity']

        return jobs

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['cancelcmd'].format(jobid=jobid))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()


class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['submitcmd'], stdin=params['app']['startscript'])
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            print(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

class JobConnect(Resource):
    """
    endpoints for connecting to an existing JobCancel
    """

    def create_tunnel(self, username, loginhost, appparams, batchhost, firewall, data):
        sshsess = SSHSession.get_sshsession()

        if 'paramscmd' in appparams and appparams['paramscmd'] is not None:
            connectparams['batchhost'] = batchhost
            paramcmd = 'ssh -o StrictHostKeyChecking=no {batchhost} '.format(batchhost=batchhost) + appparams['paramscmd']
            res = Ssh.execute(sshsess, host=loginhost, user=username, cmd=paramcmd.format(data))
            try:
                data = json.loads(res['stdout'])
            except json.decoder.JSONDecodeError as e:
                raise AppParamsException(res['stderr']+res['stdout'])
            if len(res['stderr']) > 0:
                raise AppParamsException(res['stderr'])
            if 'error' in data:
                raise AppParamsException(data['error'])
            try:
                connectparams.update(json.loads(res['stdout']))
            except json.decoder.JSONDecodeError as e:
                print(res['stdout'])
                print(res['stderr'])
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                flask_restful.abort(400, message=res['stderr'].decode())

            if self.validate_connect_params(connectparams, username, loginhost):
                authtok = gen_authtok()
                tunnelport, pids = Ssh.tunnel(sshsess, port=connectparams['port'],
                                              batchhost=connectparams['batchhost'],
                                              user=username, host=loginhost,
                                              internalfirewall=firewall,
                                              localbind=appparams['localbind'], authtok=authtok)

                connectparams['localtunnelport'] = tunnelport
                connectparams['authtok'] = authtok
            else:
                raise AppParamsException("connection parameters invalid {} {} {}".format(connectparams,username,loginhost))
    def validate_connect_params(self, connectparams, username, host):
        if not 'port' in connectparams:
            return False
        if not 'batchhost' in connectparams:
            return False
        try:
            intport = int(connectparams['port'])
        except Exception as e:
            return False
        if ' ' in username or '\n' in username: # This really needs more validation
            return False
        if ' ' in host or '\n' in host: # This really needs more validation
            return False
        return True

    def get(self, jobid, batchhost):
        """
        Connecting to a job is a get operation (i.e. it does not make modifications)
        """
        params = get_conn_params()
        appparams = get_app_params()
        data = request.get_json()
        try:
            connectparams = self.create_tunnel(params['identity']['username'],params['identity']['site']['host'],
                                            appparams, batchhost, params['interface']['internalfirewall'],
                                            data)
        except AppParamsException as e:
            return make_response(render_template('appparams.html.j2',data = "{}".format(e)))
        return self.connect(appparams, connectparams)

    def connect(self, appparams, connectparams):
        """
        perform the connection either by forking a local client or returning a redirect
        """
        import subprocess
        if 'cmd' in appparams['client'] and appparams['client']['cmd'] is not None:
            # We need for fork a local process such as vncviewer or a terminal
            # We may need a wrapper for local processes to find the correct
            # process on all OS
            cmdlist = []
            for cmdarg in appparams['client']['cmd']:
                cmdlist.append(cmdarg.format(**connectparams))
            app_process = subprocess.Popen(cmdlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#            stdout, stderr = app_process.communicate()
#            if stderr is not "":
#                return "connected with cmd {} but got error {}".format(cmdlist,stderr)
        elif 'redir' in appparams['client'] and appparams['client']['redir'] is not None:
            template_response = Response()
            template_response.set_cookie(key='twsproxyauth', value=connectparams['authtok'])
            twsproxy = app.config['TWSPROXY']
            response = make_response(redirect(twsproxy+appparams['client']['redir'].
                                              format(**connectparams)))
            response.set_cookie('twsproxyauth', connectparams['authtok'])
            return response
        return "Connecting with cmd {}".format(cmdlist)


api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(JobStat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(JobConnect, '/connect/<int:jobid>/<string:batchhost>')
# api.add_resource(SessionTest,'/sesstest')
#api.add_resource(StartAgent,'/startagent')
api.add_resource(SSHAgent,'/sshagent')