Skip to content
Snippets Groups Projects
apiendpoints.py 16.2 KiB
Newer Older
"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
from flask import render_template
from . import api, app
from .sshwrapper import Ssh, SshAgentException
from .tunnelstat import SSHSession
class AppParamsException(Exception):
    pass

class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering GetCert.post')
        data = request.get_json()
        try:
            response = {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}
        except:
            import traceback
            logger.error('Failed to get certificate')
            logger.error(traceback.format_exc())
            response = ("Unable to get a signed certificate",500)
        return response
    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        import logging
        logger = logging.getLogger()
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        data = resp.json()
        return data['certificate']

class TestAuth(Resource):
    """
    Tests whether the backend can login to the selected compute Resource
    """
    def get(self):
        """
        tell the SPA if the TES is logged in
        """
        return 'token' in session

    def post(self):
        import logging
        logger = logging.getLogger()
        logger.debug('entering SSHAgent.post')
        session.permanent = True
        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        sshsess.add_keycert(key=data['key'],cert=data['cert'])
        sshsess.refresh()
        logger.debug('leaving SSHAgent.post')
        from .tunnelstat import SSHSession
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            if sshsess.socket == None:
                logger.debug('trying to get the agent contents, but the agent isn\'t started yet')
                return [] # The agent hasn't even been started yet
            return sshsess.get_cert_contents()
        except Exception as e:
            import traceback
            logger.debug('SSHAgent.get: Exception {}'.format(e))
            logger.debug(traceback.format_exc())
            flask_restful.abort(500)

        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        sshsess.kill()
        SSHSession.remove_sshsession()
        return []
def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the query string
    import logging
    logger = logging.getLogger()
    logger.debug('entering get_conn_params')
    identitystr = request.args.get('identity')
    identityparams = json.loads(identitystr)
    try:
        interfacestr = request.args.get('interface')
        interfaceparams = json.loads(interfacestr)
    except:
        interfaceparams = {}
        pass
    try:
        pathstr = request.args.get('path')
        pathparams = json.loads(pathstr)
        cdstr = request.args.get('cd')
        cdparams = json.loads(cdstr)
    except:
        pathparams = None
        cdparams = None
    try:
        appstr = request.args.get('app')
        appparams = json.loads(appstr)
    except:
        appparams = {}
    try:
        appinstancestr = request.args.get('appinstance')
        appinstanceparams = json.loads(appinstancestr)
    except:
        appinstanceparams = {}

    params['identity'] = identityparams
    params['interface'] = interfaceparams
    params['app'] = appparams
    params['appinstance'] = appinstanceparams


    params.update(interfaceparams)
    params['user'] = identityparams['username']
    params['host'] = identityparams['site']['host']
    params['path'] = pathparams
    params['cd'] = cdparams
    logger.debug('leaving get_conn_params')



class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    """
    # def put(self, authtok):
    #     """
    #     update the last used time on a tunnel
    #     """
    #     Tunnelstat.update(authtok)

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering TunnelstatEP.get')
        from . import sshsessions
        logger.debug('TunnelstatEP.get: iterating sshsessions {}'.format(authtok))
        port = None
        try:
            for (sessid,sshsess) in sshsessions.items():
                logger.debug('sshsession id {}'.format(sessid))
                for (tok,port) in sshsess.port.items():
                    logger.debug('token {}'.format(tok))
                    if tok == authtok:
                        logger.debug("found port {} for authtok {}".format(port,tok))
                        logger.debug('leaving TunnelstatEP.get')
                        return port
        except:
            logger.error("exception in TunnelstatEP.get")
            import traceback
            logger.error(traceback.format_exc())
        logger.debug("No ports found for authtok {} {}".format(port,authtok))
        logger.debug('leaving TunnelstatEP.get')
        return None
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        logger.debug('enteringing JobStat')
        try:
            params = get_conn_params()
        except:
            flask_restful.abort(400, "connection parameters not correctly defined")
        try:
            sshsess = SSHSession.get_sshsession()
            sshsess.refresh()
        except:
            flask_restful.abort(500, "Error relating to the ssh sessions")
            host = params['identity']['site']['host']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of login host incomplete")
        try:
            user = params['identity']['username']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of username incomplete")
        try:
            cmd = params['interface']['statcmd']
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="stat: definition of batch interface incomplete")
        logger.debug('ssh sess socket is {}'.format(sshsess.socket))

            logger.debug('attempting ssh execute {} {} {}'.format(host,user,cmd))
            res = Ssh.execute(sshsess, host=host, user=user, cmd=cmd)
        except SshAgentException as e:
            logger.error(e)
            flask_restful.abort(400, message="Identity error {}".format(e))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.error(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        try:
            jobs = json.loads(res['stdout'].decode())
            # Attach the identity information to the job before returning it
            for j in jobs:
                j['identity'] = params['identity']
            logger.debug('leaving jobstat gracefully')
            return jobs
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            flask_resful.abort(400, message=e)

class MkDir(Resource):
    def post(self):
        import logging
        logger = logging.getLogger()
        data = request.get_json()
        logger.debug('mkdir data')
        logger.debug(data)
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        Ssh.sftpmkdir(sshsess, host=params['identity']['site']['host'],
                           user=params['identity']['username'], path=params['path'],name=data['name'])

        return

class DirList(Resource):
    def get(self):
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        site = params['identity']['site']
        if 'dtnport' in site:
            sshport = site['dtnport']
        else:
            sshport = "22"
        path = params['path']
        cd = params['cd']
        if path == "":
            path = "."
        if cd == "":
            cd = "."
        print('sshport is {}'.format(sshport))
        if 'lscmd' in site and site['lscmd'] is not None and site['lscmd'] is not "":
            res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                                sshport=sshport,
                                cmd="{} {} {}".format(site['lscmd'],path,cd))
            dirls = json.loads(res['stdout'].decode())
            print(dirls)
        else:

            dirls = Ssh.sftpls(sshsess, host=params['identity']['site']['host'],
                               user=params['identity']['username'], path=params['path'],changepath=params['cd'], sshport=sshport)
        return dirls

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['cancelcmd'].format(jobid=jobid))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()


class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
Ubuntu's avatar
Ubuntu committed
        import logging
        logger=logging.getLogger()
        params = get_conn_params()
        logger.debug('entering JobSubmit.post {}'.format(params))
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        logger.debug(data)
        try:
            script = data['app']['startscript'].format(**data)
        except:
            flask_restful.abort(400, message='Incomplete job information was passed to the backend.')
        logger.debug('script formated to {}'.format(script))

        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['submitcmd'], stdin=script)
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.debug('failed to submit job')
Ubuntu's avatar
Ubuntu committed
            logger.debug(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    import logging
    logger=logging.getLogger()
    logger.debug('generating new authtok')
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

        import logging
        logger = logging.getLogger()
        appdef = json.loads(request.args.get('app'))
        logger.debug('appdef {}'.format(appdef))
        inst = json.loads(request.args.get('appinst'))
        logger.debug('appinst {}'.format(inst))
        url = "{}{}".format(app.config['TWSPROXY'],appdef['client']['redir'].format(**inst))
        return url

class AppInstance(Resource):
    def get(self, username, loginhost, batchhost):
        """Run a command to get things like password and port number
        command is passed as a query string"""
        sshsess = SSHSession.get_sshsession()
        paramscmd = request.args.get('cmd')
        import logging
        logger = logging.getLogger()
        logger.debug('getting appinstance {} {} {}'.format(username,loginhost,batchhost))
        logger.debug('ssh sess socket is {}'.format(sshsess.socket))
        cmd = 'ssh -o StrictHostKeyChecking=no -o CheckHostIP=no {batchhost} '.format(batchhost=batchhost) + paramscmd
        res = Ssh.execute(sshsess, host=loginhost, user=username, cmd=cmd)
        try:
            data = json.loads(res['stdout'].decode())
        except json.decoder.JSONDecodeError as e:
            raise AppParamsException(res['stderr']+res['stdout'])
        if len(res['stderr']) > 0:
            raise AppParamsException(res['stderr'])
        if 'error' in data:
            raise AppParamsException(data['error'])
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return data
class CreateTunnel(Resource):
    @staticmethod
    def validate_connect_params(port, username, host, batchhost):
        except Exception as e:
            return False
        if ' ' in username or '\n' in username: # This really needs more validation
            return False
        if ' ' in host or '\n' in host: # This really needs more validation
            return False
        return True

    def post(self,username,loginhost,batchhost):
        Create a tunnel using established keys
        parameters for the tunnel (host username port etc)
        will be passed in the body
        import logging
        logger = logging.getLogger()
        logger.debug("Createing tunnel")
        logger.debug("recieved data {}".format(request.data))
        data = request.get_json()
            port = data['port']
        except KeyError as missingdata:
            raise AppParamsException("missing value for {}".format(missingdata))
        if not CreateTunnel.validate_connect_params(port, username, loginhost, batchhost):
            raise AppParamsException("Invalid value: {} {} {} {}".
                                     format(username, loginhost, batchhost, port))
        if 'internalfirewall' in data:
            firewall = data['internalfirewall']
        else:
            firewall = False
        if 'localbind' in data:
            localbind = data['localbind']
        else:
            localbind = True
        sshsess = SSHSession.get_sshsession()
        authtok = gen_authtok()
        # logger.debug('JobCreate.create_tunnel: creating a tunnel for authtok {}'.format(authtok))
        Ssh.tunnel(sshsess, port=port, batchhost=batchhost,
                   user=username, host=loginhost,
                   internalfirewall=firewall,
                   localbind=localbind, authtok=authtok)
        response = make_response("")
        response.mime_type = 'application/json'
        response.set_cookie('twsproxyauth', authtok)
        logger.debug('JobConnect.connect: connecting via redirect with cookie authtok set to  {}'.format(authtok))
        return response



api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(JobStat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(CreateTunnel, '/createtunnel/<string:username>/<string:loginhost>/<string:batchhost>')
api.add_resource(AppInstance, '/appinstance/<string:username>/<string:loginhost>/<string:batchhost>')
api.add_resource(AppUrl, '/appurl')
api.add_resource(SSHAgent,'/sshagent')
api.add_resource(DirList,'/ls')