Skip to content
Snippets Groups Projects
apiendpoints.py 11.4 KiB
Newer Older
"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
import flask_restful
from . import api, islocal

    from .localssh import Ssh, SshAgentException
    from .localtunnelstat import SSHSession
    # from .localtunnelstat import Tunnelstat


class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        print("in GetCert.post")
        data = request.get_json()
        print(data)
        return {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}

    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        print("accss_token {}".format(access_token))
        print("pub_key {}".format(pub_key))
        print("url {}".format(url))
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        print("get_cert returned from its external call")
        data = resp.json()
        return data['certificate']

class TestAuth(Resource):
    """
    Tests whether the backend can login to the selected compute Resource
    """
    def get(self):
        """
        tell the SPA if the TES is logged in
        """
        return 'token' in session

class StartAgent(Resource):
    def get(self):
        from .localtunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        sshsess.start_agent()
        sshsess.refresh()
        return "{}".format(sshsess.socket)

class AddKey(Resource):
    def post(self):
        session.permanent = True
        from .localtunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        import logging
        logger = logging.getLogger()
        logger.debug('adding a key to the session')
        sshsess.add_keycert(key=data['key'],cert=data['cert'])
        logger.debug('started an agent and added the key')
        sshsess.refresh()
        return "OK"

    def get(self):
        from .localtunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        return sshsess.get_principals_and_hosts()

def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the session (ideally)
    """
    params =  get_m3_params()
    # Default assume there is only one certificate with one principal available
    sshsess = SSHSession.get_sshsession()
    certs = sshsess.get_cert_contents()
    import logging
    logger = logging.getLogger()
    logger.debug("certs loaded: {}".format(certs))
    try: 
        params['user'] = certs[0]['Principals'][0]
    except:
        params['user']=None

def get_m3_params():
    """
    Hard code the parameters for M3. This will be removed latter
    factored into a site config file for each compute backend
    parsed and set by the frontend.
    """
    params = {}
    params['host'] = 'm3.massive.org.au'
    params['cancelcmd'] = 'scancel {jobid}'
    params['statcmd'] = '/home/chines/jsonstat.py'
    params['submitcmd'] = 'sbatch --partition=m3f'
    params['internalfirewall'] = False
    return params

def get_app_params():
    """
    Return the parameters for the application retrieved from the Session
    """
    keys = ['startscript', 'paramscmd', 'client','localbind']
    returnvalue = {}
    for k in keys:
        try:
            returnvalue[k] = session.get(k)
        except: # This should be a key exception, i.e. if the key doesn't exist on the session
            pass
    print("got app params",returnvalue)
    return returnvalue

class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    """
    def put(self, authtok):
        """
        update the last used time on a tunnel
        """
        Tunnelstat.update(authtok)

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        from . import sshsessions
        for sshsess in sshsessions.values():
            if sshsess.authtok == authtok:
                return sshsess.port
        return None
class Stat(Resource):
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        logger.info('/stat endpoint entered')
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        sshsess.refresh()
        logger.info('/stat endpoint, all parameters collected')
        try:
            res = Ssh.execute(sshsess, host=params['host'], user=params['user'], cmd=params['statcmd'])
        except SshAgentException as e:
            logger.error(e)
            return flask_restful.abort(404,message="{}".format(e))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.error(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return json.loads(res['stdout'].decode())

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        print("in jobcancle jobid is {}".format(jobid))
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['host'], user=params['user'],
                          cmd=params['cancelcmd'].format(jobid=jobid))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

class AppSetup(Resource):
    """
    configure the session for the app the user wants
    """
    def post(self):
        """
        post details of the app to be run or connected to
        """
        data = request.get_json()
        keys = ['startscript', 'paramscmd', 'client','localbind']
        for key in keys:
            session[key] = data[key]
            print(data[key])


class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
        params = get_conn_params()
        appparams = get_app_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['host'], user=params['user'],
                          cmd=params['submitcmd'], stdin=appparams['startscript'])
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            print(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

class JobConnect(Resource):
    """
    endpoints for connecting to an existing JobCancel
    """

    def create_tunnel(self, params, appparams, batchhost, data):
        connectparams = {}
        sshsess = SSHSession.get_sshsession()

        if 'paramscmd' in appparams and appparams['paramscmd'] is not None:
            connectparams['batchhost'] = batchhost
            paramcmd = 'ssh {batchhost} '.format(batchhost=batchhost) + appparams['paramscmd']
            res = Ssh.execute(sshsess, host=params['host'], user=params['user'], cmd=paramcmd.format(data))
            try:
                connectparams.update(json.loads(res['stdout']))
            except json.decoder.JSONDecodeError as e:
                print(res['stdout'])
                print(res['stderr'])
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                flask_restful.abort(400, message=res['stderr'].decode())

            if 'port' in connectparams and 'batchhost' in connectparams:
                tunnelport, pids = Ssh.tunnel(sshsess, port=connectparams['port'],
                                              batchhost=connectparams['batchhost'],
                                              user=params['user'], host=params['host'],
                                              internalfirewall=params['internalfirewall'],
                                              localbind=appparams['localbind'])
                authtok = gen_authtok()
                sshsess.set_authtok(authtok)
                connectparams['localtunnelport'] = tunnelport
            connectparams['authtok'] = authtok
        return connectparams

    def get(self, jobid, batchhost):
        """
        Connecting to a job is a get operation (i.e. it does not make modifications)
        """
        params = get_conn_params()
        appparams = get_app_params()
        data = request.get_json()
        connectparams = self.create_tunnel(params, appparams, batchhost, data)
        return self.connect(appparams, connectparams)

    def connect(self, appparams, connectparams):
        """
        perform the connection either by forking a local client or returning a redirect
        """
        import subprocess
        if 'cmd' in appparams['client'] and appparams['client']['cmd'] is not None:
            # We need for fork a local process such as vncviewer or a terminal
            # We may need a wrapper for local processes to find the correct
            # process on all OS
            cmdlist = []
            for cmdarg in appparams['client']['cmd']:
                cmdlist.append(cmdarg.format(**connectparams))
            app_process = subprocess.Popen(cmdlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#            stdout, stderr = app_process.communicate()
#            if stderr is not "":
#                return "connected with cmd {} but got error {}".format(cmdlist,stderr)
        elif 'redir' in appparams['client'] and appparams['client']['redir'] is not None:
            template_response = Response()
            template_response.set_cookie(key='twsproxyauth', value=connectparams['authtok'])
            twsproxy = 'http://localhost:4000/'
            twsproxy = 'https://vm-118-138-240-255.erc.monash.edu.au/'
            response = make_response(redirect(twsproxy+appparams['client']['redir'].
                                              format(**connectparams)))
            response.set_cookie('twsproxyauth', connectparams['authtok'])
            return response
        return "Connecting with cmd {}".format(cmdlist)


api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(Stat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(JobConnect, '/connect/<int:jobid>/<string:batchhost>')
api.add_resource(AppSetup, '/appsetup')
# api.add_resource(SessionTest,'/sesstest')
api.add_resource(StartAgent,'/startagent')
api.add_resource(AddKey,'/addkey')