Skip to content
Snippets Groups Projects
apiendpoints.py 25.96 KiB
"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
from flask import render_template
import flask_restful

from . import api, app
from .sshwrapper import Ssh, SshAgentException, SftpPermissionException, SftpException, SshCtrlException, SshExecException
from .tunnelstat import SSHSession

class AppParamsException(Exception):
    pass

class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        import logging
        logger = logging.getLogger()
        try:
            data = request.get_json()
            response = {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}
            return response
        except:
            import traceback
            logger.error('Failed to get certificate')
            logger.error(traceback.format_exc())
            flask_restful.abort(400,message="an error occured generating your certificate")

    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        import logging
        logger = logging.getLogger()
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        data = resp.json()
        return data['certificate']


class SSHAgent(Resource):
    def post(self):
        import logging
        logger = logging.getLogger()
        try:
            session.permanent = True
            from .tunnelstat import SSHSession
            sshsess = SSHSession.get_sshsession()
            data = request.get_json()
            sshsess.add_keycert(key=data['key'],cert=data['cert'])
            sshsess.refresh()
            return "OK"
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            logger.error('failed to add ssh key to the agent')
            flask_restful.abort(500,message="failed to add the ssh key to the agent")

    def get(self):
        from .tunnelstat import SSHSession
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            if sshsess.socket == None:
                return [] # The agent hasn't even been started yet
            return sshsess.get_cert_contents()
        except Exception as e:
            import traceback
            logger.error('SSHAgent.get: Exception {}'.format(e))
            logger.error(traceback.format_exc())
            flask_restful.abort(500,message="failed to query the ssh-agent")


    def delete(self):
        from .tunnelstat import SSHSession
        try:
            sshsess = SSHSession.get_sshsession()
            sshsess.kill()
            SSHSession.remove_sshsession()
            return []
        except Exception as e:
            import traceback
            logger.debug('SSHAgent.delete: Exception {}'.format(e))
            logger.error(traceback.format_exc())
            flask_restful.abort(500,message="failed to shut down ssh-agent")

def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the query string
    """
    import logging
    logger = logging.getLogger()
    identitystr = request.args.get('identity')
    identityparams = json.loads(identitystr)
    try:
        interfacestr = request.args.get('interface')
        interfaceparams = json.loads(interfacestr)
    except:
        interfaceparams = {}
        pass
    try:
        pathstr = request.args.get('path')
        pathparams = json.loads(pathstr)
        cdstr = request.args.get('cd')
        cdparams = json.loads(cdstr)
    except:
        pathparams = None
        cdparams = None
    try:
        appstr = request.args.get('app')
        appparams = json.loads(appstr)
    except:
        appparams = {}
    try:
        appinstancestr = request.args.get('appinstance')
        appinstanceparams = json.loads(appinstancestr)
    except:
        appinstanceparams = {}
    params = {}

    params['identity'] = identityparams
    params['interface'] = interfaceparams
    params['app'] = appparams
    params['appinstance'] = appinstanceparams


    params.update(interfaceparams)
    params['user'] = identityparams['username']
    params['host'] = identityparams['site']['host']
    params['path'] = pathparams
    params['cd'] = cdparams

    return params



class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    The Frontend should never call this
    """

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        import logging
        logger = logging.getLogger()
        from . import sshsessions
        from flask import session
        port = None

        try:
            for (sessid,sshsess) in sshsessions.items():
                for (tok,port) in sshsess.port.items():
                    if tok == authtok:
                        return port
        except:
            logger.error("exception in TunnelstatEP.get")
            import traceback
            logger.error(traceback.format_exc())
        return None

class ContactUs(Resource):
    def post(self):
        import tempfile
        print(request.get_json())
        data = request.get_json()
        f = tempfile.NamedTemporaryFile(mode='w+b',dir=app.config['MESSAGES'],delete=False)
        f.write(json.dumps(data).encode())
        f.close()
        return

class JobStat(Resource):
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            sshsess.refresh()
        except:
            #flask_restful.abort(500, message="Error relating to the ssh sessions")
            return apiabort(500, message="Error relating to the ssh sessions")
        try:
            cmd = json.loads(request.args.get('statcmd'))
            host = json.loads(request.args.get('host'))
            user = json.loads(request.args.get('username'))
        except (TypeError, KeyError) as e:
            #flask_restful.abort(400, message="Missing required parameter {}".format(e))
            import traceback
            return apiabort(400, message="Missing required parameter {}\n{}".format(e,traceback.format_exc()))

        try:
            logger.debug('attempting ssh execute {} {} {}'.format(host,user,cmd))
            res = Ssh.execute(sshsess, host=host, user=user, cmd=cmd)
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                logger.error(res['stderr'])
                #flask_restful.abort(400, message=res['stderr'].decode())
                return apiabort(400, message=res['stderr'].decode())
            try:
                jobs = json.loads(res['stdout'].decode())
                return jobs
            except Exception as e:
                import traceback
                logger.error(e)
                logger.error(traceback.format_exc())
                #flask_restful.abort(400, message=e)
                return apiabort(400, message=e)
        except SshAgentException as e:
            logger.error(e)
            #flask_restful.abort(401, message="Identity error {}".format(e))
            return apiabort(401, message="Identity error {}".format(e))
        except SshCtrlException as e:
            #flask_restful.abort(400,message="We're having difficultly contacting {}. We failed with the message: {}".format(host,e))
            return apiabort(400,message="We're having difficultly contacting {}. We failed with the message: {}".format(host,e))
        except SshExecException as e:
            return apiabort(400,message="{}".format(e))
        except Exception as e:
            import traceback
            logger.error('JobStat.get: Exception {}'.format(e))
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="SSH failed in an unexpected way")
            return apiabort(500,message="SSH failed in an unexpected way")

class MkDir(Resource):
    def post(self):
        import logging
        logger = logging.getLogger()
        try:
            data = request.get_json()
            logger.debug('mkdir data')
            logger.debug(data)
            params = get_conn_params()
            sshsess = SSHSession.get_sshsession()
            logger.debug('try to call mkdir')
            site = params['identity']['site']
            if 'dtnport' in site:
                sshport = site['dtnport']
            else:
                sshport = "22"
            try:
                Ssh.sftpmkdir(sshsess, host=params['identity']['site']['host'],
                               user=params['identity']['username'], path=params['path'],name=data['name'], sshport=sshport)
            except SftpPermissionException as e:
                #return json.dumps({'message':"You don't have permission to make a directory there"}), 403
                #flask_restful.abort(403,message="You don't have permission to make a directory there")
                return apiabort(403,message="You don't have permission to make a directory there")
            except SftpException as e:
                #return json.dumps({'message':"Something went wrong making that directory"}), 500
                #flask_restful.abort(500,message="Something went wrong making that directory")
                return apiabort(500,message="Something went wrong making that directory")
            except Exception as e:
                import traceback
                logger.error(traceback.format_exc())
                #flask_restful.abort(500,message="Something went wrong creating that directory, probably a bug")
                return apiabort(500,message="Something went wrong creating that directory, probably a bug")

            return
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="mkdir failed in some unexpected way")
            return apiabort(500,message="mkdir failed in some unexpected way")

class DirList(Resource):
    def get(self):
        import logging
        logger = logging.getLogger()
        try:
            params = get_conn_params()
            sshsess = SSHSession.get_sshsession()
            site = params['identity']['site']
            if 'dtnport' in site:
                sshport = site['dtnport']
            else:
                sshport = "22"
            path = params['path']
            cd = params['cd']
            if path == "":
                path = "."
            if cd == "":
                cd = "."
            try:
                if 'lscmd' in site and site['lscmd'] is not None and site['lscmd'] is not "":
                    logger.debug('using ssh.execute with lscmd')
                    res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                                        sshport=sshport,
                                        cmd="{} {} {}".format(site['lscmd'],path,cd))
                    try:
                        dirls = json.loads(res['stdout'].decode())
                    except:
                        #return json.dumps({'message':"You don't have permission to view that directory"}), 401
                        #flask_restful.abort(401,message="You don't have permission to view that directory")
                        return apiabort(401,message="You don't have permission to view that directory")
                else:
                    logger.debug('using ssh.sftpls')

                    dirls = Ssh.sftpls(sshsess, host=params['identity']['site']['host'],
                                       user=params['identity']['username'], path=params['path'],changepath=params['cd'], sshport=sshport)
                return dirls
            except SshCtrlException as e:
                #flask_restful.abort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
                return apiabort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="dirlist failed in some unexpected way")
            return apiabort(500,message="dirlist failed in some unexpected way")

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        try:
            params = get_conn_params()
            sshsess = SSHSession.get_sshsession()
            try:
                res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                                  cmd=params['interface']['cancelcmd'].format(jobid=jobid))
                if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                    #return json.dumps({'message':res['stderr'].decode()}), 400
                    #flask_restful.abort(400, message=res['stderr'].decode())
                    return apiabort(400, message=res['stderr'].decode())
                return res['stdout'].decode()
            except SshCtrlException as e:
                #return json.dumps({'message':"We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e)}), 400
                #flask_restful.abort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
                return apiabort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #return json.dumps({'message':"jobcancel failed in some unexpected way"}), 500
            #flask_restful.abort(500,message="jobcancel failed in some unexpected way")
            return apiabort(500,message="jobcancel failed in some unexpected way")



class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
        import logging
        logger=logging.getLogger()
        try:
            params = get_conn_params()
            logger.debug('entering JobSubmit.post {}'.format(params))
            sshsess = SSHSession.get_sshsession()
            data = request.get_json()
            logger.debug(data)
            try:
                script = data['app']['startscript'].format(**data)
            except Exception as e:
                import traceback
                logger.error(e)
                logger.error(traceback.format_exc())
                logger.error('formating data')
                logger.error(data)
                logger.error('end formating data')
                logger.error('body')
                logger.error(request.data)
                logger.error('end body')
                #return json.dumps({'message':'Incomplete job information was passed to the backend.'}), 400
                #flask_restful.abort(400, message='Incomplete job information was passed to the backend.')
                return apiabort(400, message='Incomplete job information was passed to the backend.')

            try:
                res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                              cmd=params['interface']['submitcmd'], stdin=script)
            except SshCtrlException as e:
                #return json.dumps({'message':"We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e)}), 400
                #flask_restful.abort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
                return apiabort(400,message="We're having difficultly contacting {}. We failed with the message {}".format(params['identity']['site']['host'],e))
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                logger.debug('failed to submit job')
                logger.debug(res['stderr'])
                #return json.dumps({'message':res['stderr'].decode()}), 400
                #flask_restful.abort(400, message=res['stderr'].decode())
                return apiabort(400, message=res['stderr'].decode())
            return res['stdout'].decode()
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #return json.dumps({'message':"JobSubmission failed in some unexpected way"}), 500
            #flask_restful.abort(500,message="JobSubmission failed in some unexpected way")
            return apiabort(500,message="JobSubmission failed in some unexpected way")

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    import logging
    logger=logging.getLogger()
    logger.debug('generating new authtok')
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

def apiabort(code, message):
    return {'message':message}, code

class AppUrl(Resource):
    def get(self):
        import logging
        try:
            logger = logging.getLogger()
            appdef = json.loads(request.args.get('app'))
            logger.debug('appdef {}'.format(appdef))
            inst = json.loads(request.args.get('appinst'))
            inst['twsproxy']='{twsproxy}'
            logger.debug('appinst {}'.format(inst))
            url = "{}/{}".format("{twsproxy}",appdef['client']['redir'].format(**inst))
            return url
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="AppUrl failed in some unexpected way")
            return apiabort(500,message="AppUrl failed in some unexpected way")

class AppLaunch(Resource):
    def get(self):
        import logging
        logger = logging.getLogger()
        logger.debug('in AppLaunch')
        try:
            appdef = json.loads(request.args.get('app'))
            inst = json.loads(request.args.get('appinst'))
            cmd = "{}".format(appdef['client']['cmd'].format(**inst)).split()
            import subprocess
            logger.debug('run cmd {}'.format(cmd))
            try:
                p = subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                (stdout,stderr) = p.communicate()
                if p.returncode != 0:
                    if stderr != "":
                        msg = stderr
                    else:
                        msg = "Unable to start the vncviewer"
                    return apiabort(500,message=msg)
            except FileNotFoundError:
                return apiabort(500,message="Unable to find a vncviewer")
                pass

            return None
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="AppUrl failed in some unexpected way")
            return apiabort(500,message="AppUrl failed in some unexpected way")

class AppInstance(Resource):
    def get(self, username, loginhost, batchhost, jobid):
        """Run a command to get things like password and port number
        command is passed as a query string"""
        import logging
        logger=logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            paramscmd = json.loads(request.args.get('cmd')).format(jobid=jobid)
            import logging
            logger = logging.getLogger()
            logger.debug('getting appinstance {} {} {}'.format(username,loginhost,batchhost))
            logger.debug('ssh sess socket is {}'.format(sshsess.socket))
            logger.debug('paramscmd is {}'.format(paramscmd))
            #cmd = 'ssh -o StrictHostKeyChecking=no -o CheckHostIP=no {batchhost} '.format(batchhost=batchhost) + paramscmd
            try:
                res = Ssh.execute(sshsess, host=batchhost, bastion=loginhost, user=username, cmd=paramscmd)
            except:
                message = "The server couldn't execute to {} to get the necessary info".format(batchhost)
                #flask_restful.abort(400, message=message)
                import traceback
                logger.error(traceback.format_exc())
                return apiabort(400, message=message)
            try:
                data = json.loads(res['stdout'].decode())
                if 'error' in data:
                    return data, 400
                return data
            except json.decoder.JSONDecodeError as e:
                logger.error(res['stderr']+res['stdout'])
                message="I'm having trouble using ssh to find out about that application"
                #flask_restful.abort(400, message=message)
                return apiabort(400, message=message)
                #raise AppParamsException(res['stderr']+res['stdout'])
            if len(res['stderr']) > 0:
                logger.error(res['stderr']+res['stdout'])
                #flask_restful.abort(400, message="The command {} on {} didn't work".format(paramscmd,batchhost))
                return apiabort(400, message="The command {} on {} didn't work".format(paramscmd,batchhost))
                #raise AppParamsException(res['stderr'])
            if 'error' in data:
                raise AppParamsException(data['error'])
            if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
                #flask_restful.abort(400, message=res['stderr'].decode())
                return apiabort(400, message=res['stderr'].decode())
            return data
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="AppUrl failed in some unexpected way")
            return apiabort(500,message="AppUrl failed in some unexpected way")

class CreateTunnel(Resource):
    @staticmethod
    def validate_connect_params(port, username, host, batchhost):
        try:
            intport = int(port)
        except Exception as e:
            return False
        if ' ' in username or '\n' in username: # This really needs more validation
            return False
        if ' ' in host or '\n' in host: # This really needs more validation
            return False
        return True

    def post(self,username,loginhost,batchhost):
        """
        Create a tunnel using established keys
        parameters for the tunnel (host username port etc)
        will be passed in the body
        """
        import logging
        import json
        logger = logging.getLogger()
        try:
            data = request.get_json()
            try:
                port = data['port']
            except KeyError as missingdata:
                raise AppParamsException("missing value for {}".format(missingdata))
            if not CreateTunnel.validate_connect_params(port, username, loginhost, batchhost):
                raise AppParamsException("Invalid value: {} {} {} {}".
                                         format(username, loginhost, batchhost, port))
            if 'internalfirewall' in data:
                firewall = data['internalfirewall']
            else:
                firewall = False
            if 'localbind' in data:
                localbind = data['localbind']
            else:
                localbind = True
            sshsess = SSHSession.get_sshsession()
            authtok = gen_authtok()
            # logger.debug('JobCreate.create_tunnel: creating a tunnel for authtok {}'.format(authtok))
            port,pids = Ssh.tunnel(sshsess, port=port, batchhost=batchhost,
                       user=username, host=loginhost,
                       internalfirewall=firewall,
                       localbind=localbind, authtok=authtok)
            response = make_response(json.dumps({'localport':port}),200)
            response.mime_type = 'application/json'
            response.set_cookie('twsproxyauth', authtok)
            return response
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            #flask_restful.abort(500,message="CreateTunnel failed in some unexpected way")
            return apiabort(500,message="CreateTunnel failed in some unexpected way")



api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(JobStat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(CreateTunnel, '/createtunnel/<string:username>/<string:loginhost>/<string:batchhost>')
api.add_resource(AppInstance, '/appinstance/<string:username>/<string:loginhost>/<string:batchhost>/<int:jobid>')
api.add_resource(AppUrl, '/appurl')
if 'ENABLELAUNCH' in app.config and app.config['ENABLELAUNCH']:
    api.add_resource(AppLaunch, '/applaunch')
api.add_resource(AppLaunch, '/applaunch')
api.add_resource(SSHAgent,'/sshagent')
api.add_resource(DirList,'/ls')
api.add_resource(MkDir,'/mkdir')
api.add_resource(ContactUs,'/contactus')