"""
All the API endpoints for controling processes and tunnels via SSH
"""
import json
from flask import session, redirect, request, Response, make_response
from flask_restful import Resource
from flask import render_template
import flask_restful

from . import api, app
from .sshwrapper import Ssh, SshAgentException, SftpPermissionException, SftpException
from .tunnelstat import SSHSession


class AppParamsException(Exception):
    pass

class GetCert(Resource):
    """
    This class is necessary because I'm not reconfiguring
    SSHAuthZ to support CORS, but the TES does support CORS
    """

    def post(self):
        """
        takes a public key, returns to the SPA a certificate
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering GetCert.post')
        try:
            data = request.get_json()
            response = {'cert':GetCert.get_cert(data['token'], data['pubkey'], data['signing_url'])}
            return response
        except:
            import traceback
            logger.error('Failed to get certificate')
            logger.error(traceback.format_exc())
            flask_restful.abort(500,message="an error occured generating your certificate")

    @staticmethod
    def get_cert(access_token, pub_key, url):
        """
        Sign a pub key into a cert
        """
        import requests
        import logging
        logger = logging.getLogger()
        sess = requests.Session()
        headers = {"Authorization":"Bearer %s"%access_token}
        data = {"public_key":pub_key}
        resp = sess.post(url, json=data, headers=headers, verify=False)
        data = resp.json()
        return data['certificate']

class TestAuth(Resource):
    """
    Tests whether the backend can login to the selected compute Resource
    """
    def get(self):
        """
        tell the SPA if the TES is logged in
        """
        return 'token' in session


class SSHAgent(Resource):
    def post(self):
        import logging
        logger = logging.getLogger()
        logger.debug('entering SSHAgent.post')
        session.permanent = True
        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        sshsess.add_keycert(key=data['key'],cert=data['cert'])
        sshsess.refresh()
        logger.debug('leaving SSHAgent.post')
        return "OK"

    def get(self):
        from .tunnelstat import SSHSession
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            if sshsess.socket == None:
                logger.debug('trying to get the agent contents, but the agent isn\'t started yet')
                return [] # The agent hasn't even been started yet
            #return sshsess.get_cert_contents()
            return sshsess.get_cert_contents()
        except Exception as e:
            import traceback
            logger.debug('SSHAgent.get: Exception {}'.format(e))
            logger.debug(traceback.format_exc())
            flask_restful.abort(500)


    def delete(self):
        from .tunnelstat import SSHSession
        sshsess = SSHSession.get_sshsession()
        sshsess.kill()
        SSHSession.remove_sshsession()
        return []

def get_conn_params():
    """
    Return parameters relating to the backend compute service
    Retrieve them from the query string
    """
    import logging
    logger = logging.getLogger()
    logger.debug('entering get_conn_params')
    identitystr = request.args.get('identity')
    identityparams = json.loads(identitystr)
    try:
        interfacestr = request.args.get('interface')
        interfaceparams = json.loads(interfacestr)
    except:
        interfaceparams = {}
        pass
    try:
        pathstr = request.args.get('path')
        pathparams = json.loads(pathstr)
        cdstr = request.args.get('cd')
        cdparams = json.loads(cdstr)
    except:
        pathparams = None
        cdparams = None
    try:
        appstr = request.args.get('app')
        appparams = json.loads(appstr)
    except:
        appparams = {}
    try:
        appinstancestr = request.args.get('appinstance')
        appinstanceparams = json.loads(appinstancestr)
    except:
        appinstanceparams = {}
    params = {}

    params['identity'] = identityparams
    params['interface'] = interfaceparams
    params['app'] = appparams
    params['appinstance'] = appinstanceparams


    params.update(interfaceparams)
    params['user'] = identityparams['username']
    params['host'] = identityparams['site']['host']
    params['path'] = pathparams
    params['cd'] = cdparams

    logger.debug('leaving get_conn_params')

    return params



class TunnelstatEP(Resource):
    """
    Endpoints used by the WS proxy
    """
    # def put(self, authtok):
    #     """
    #     update the last used time on a tunnel
    #     """
    #     Tunnelstat.update(authtok)

    def get(self, authtok):
        """
        given an authtoken, return the port of the tunnels
        """
        import logging
        logger = logging.getLogger()
        logger.debug('entering TunnelstatEP.get')
        from . import sshsessions
        from flask import session
        logger.debug('TunnelstatEP.get: iterating sshsessions {}'.format(authtok))
        port = None

        try:
            for (sessid,sshsess) in sshsessions.items():
                logger.debug('sshsession id {}'.format(sessid))
                for (tok,port) in sshsess.port.items():
                    logger.debug('token {}'.format(tok))
                    if tok == authtok:
                        logger.debug("found port {} for authtok {}".format(port,tok))
                        logger.debug('leaving TunnelstatEP.get')
                        return port
        except:
            logger.error("exception in TunnelstatEP.get")
            import traceback
            logger.error(traceback.format_exc())
        logger.debug("No ports found for authtok {} {}".format(port,authtok))
        logger.debug('leaving TunnelstatEP.get')
        return None

class JobStat(Resource):
    """
    endpoints to return info on jobs on the backend compute Resource
    """
    def get(self):
        """
        get info on the job from the backend
        """
        import logging
        logger = logging.getLogger()
        try:
            sshsess = SSHSession.get_sshsession()
            sshsess.refresh()
        except:
            flask_restful.abort(500, message="Error relating to the ssh sessions")
        try:
            cmd = json.loads(request.args.get('statcmd'))
            host = json.loads(request.args.get('host'))
            user = json.loads(request.args.get('username'))
        except (TypeError, KeyError) as e:
            flask_restful.abort(400, message="Missing required parameter {}".format(e))

        try:
            logger.debug('attempting ssh execute {} {} {}'.format(host,user,cmd))
            res = Ssh.execute(sshsess, host=host, user=user, cmd=cmd)
        except SshAgentException as e:
            logger.error(e)
            flask_restful.abort(400, message="Identity error {}".format(e))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.error(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        try:
            jobs = json.loads(res['stdout'].decode())

            return jobs
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            flask_restful.abort(400, message=e)

class MkDir(Resource):
    def post(self):
        import logging
        logger = logging.getLogger()
        data = request.get_json()
        logger.debug('mkdir data')
        logger.debug(data)
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        logger.debug('try to call mkdir')
        site = params['identity']['site']
        if 'dtnport' in site:
            sshport = site['dtnport']
        else:
            sshport = "22"
        try:
            Ssh.sftpmkdir(sshsess, host=params['identity']['site']['host'],
                           user=params['identity']['username'], path=params['path'],name=data['name'], sshport=sshport)
        except SftpPermissionException as e:
            flask_restful.abort(400,message="You don't have permission to make a directory there")
        except SftpException as e:
            flask_restful.abort(400,message="Something went wrong making that directory")
        except Exception as e:
            import traceback
            logger.error(traceback.format_exc())
            flask_restful.abort(500,message="Something went wrong creating that directory, probably a bug")

        return

class DirList(Resource):
    def get(self):
        import logging
        logger = logging.getLogger()

        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        site = params['identity']['site']
        if 'dtnport' in site:
            sshport = site['dtnport']
        else:
            sshport = "22"
        path = params['path']
        cd = params['cd']
        if path == "":
            path = "."
        if cd == "":
            cd = "."
        if 'lscmd' in site and site['lscmd'] is not None and site['lscmd'] is not "":
            logger.debug('using ssh.execute with lscmd')
            res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                                sshport=sshport,
                                cmd="{} {} {}".format(site['lscmd'],path,cd))
            try:
                dirls = json.loads(res['stdout'].decode())
            except:
                flask_restful.abort(404,message="You don't have permission to view that directory")
        else:
            logger.debug('using ssh.sftpls')

            dirls = Ssh.sftpls(sshsess, host=params['identity']['site']['host'],
                               user=params['identity']['username'], path=params['path'],changepath=params['cd'], sshport=sshport)
        return dirls

class JobCancel(Resource):
    """
    Terminate a job on the compute backend
    """
    def delete(self, jobid):
        """
        Terminate a job on the backend
        """
        params = get_conn_params()
        sshsess = SSHSession.get_sshsession()
        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['cancelcmd'].format(jobid=jobid))
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()


class JobSubmit(Resource):
    """
    Class dealing the starting a new job on the compute backend
    """
    def post(self):
        """starting a job is a post, since it changes the state of the backend"""
        import logging
        logger=logging.getLogger()
        params = get_conn_params()
        logger.debug('entering JobSubmit.post {}'.format(params))
        sshsess = SSHSession.get_sshsession()
        data = request.get_json()
        logger.debug(data)
        try:
            script = data['app']['startscript'].format(**data)
        except Exception as e:
            import traceback
            logger.error(e)
            logger.error(traceback.format_exc())
            logger.error('formating data')
            logger.error(data)
            logger.error('end formating data')
            logger.error('body')
            logger.error(request.data)
            logger.error('end body')
            flask_restful.abort(400, message='Incomplete job information was passed to the backend.')
        logger.debug('script formated to {}'.format(script))

        res = Ssh.execute(sshsess, host=params['identity']['site']['host'], user=params['identity']['username'],
                          cmd=params['interface']['submitcmd'], stdin=script)
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            logger.debug('failed to submit job')
            logger.debug(res['stderr'])
            flask_restful.abort(400, message=res['stderr'].decode())
        return res['stdout'].decode()

def gen_authtok():
    """
    generate a random string suitable for an auth token stored in a cookie
    """
    import random
    import string
    import logging
    logger=logging.getLogger()
    logger.debug('generating new authtok')
    return ''.join(random.SystemRandom().choice(string.ascii_uppercase +
                                                string.digits) for _ in range(16))

class AppUrl(Resource):
    def get(self):
        import logging
        logger = logging.getLogger()
        appdef = json.loads(request.args.get('app'))
        logger.debug('appdef {}'.format(appdef))
        inst = json.loads(request.args.get('appinst'))
        inst['twsproxy']='{twsproxy}'
        logger.debug('appinst {}'.format(inst))
        url = "{}/{}".format("{twsproxy}",appdef['client']['redir'].format(**inst))
        return url

class AppInstance(Resource):
    def get(self, username, loginhost, batchhost):
        """Run a command to get things like password and port number
        command is passed as a query string"""
        sshsess = SSHSession.get_sshsession()
        paramscmd = request.args.get('cmd')
        import logging
        logger = logging.getLogger()
        logger.debug('getting appinstance {} {} {}'.format(username,loginhost,batchhost))
        logger.debug('ssh sess socket is {}'.format(sshsess.socket))
        #cmd = 'ssh -o StrictHostKeyChecking=no -o CheckHostIP=no {batchhost} '.format(batchhost=batchhost) + paramscmd
        try:
            res = Ssh.execute(sshsess, host=batchhost, bastion=loginhost, user=username, cmd=paramscmd)
        except:
            message = "The server couldn't execute to {} to get the necessary info".format(batchhost)
            flask_restful.abort(500, message=message)
            import traceback
            logger.error(traceback.format_exc())
        try:
            data = json.loads(res['stdout'].decode())
            return data
        except json.decoder.JSONDecodeError as e:
            logger.error(res['stderr']+res['stdout'])
            message="I'm having trouble using ssh to find out about that application"
            flask_restful.abort(500, message=message)
            #raise AppParamsException(res['stderr']+res['stdout'])
        if len(res['stderr']) > 0:
            logger.error(res['stderr']+res['stdout'])
            flask_restful.abort(500, message="The command {} on {} didn't work".format(paramscmd,batchhost))
            #raise AppParamsException(res['stderr'])
        if 'error' in data:
            raise AppParamsException(data['error'])
        if not (res['stderr'] == '' or res['stderr'] is None or res['stderr'] == b''):
            flask_restful.abort(400, message=res['stderr'].decode())
        return data

class CreateTunnel(Resource):
    @staticmethod
    def validate_connect_params(port, username, host, batchhost):
        try:
            intport = int(port)
        except Exception as e:
            return False
        if ' ' in username or '\n' in username: # This really needs more validation
            return False
        if ' ' in host or '\n' in host: # This really needs more validation
            return False
        return True

    def post(self,username,loginhost,batchhost):
        """
        Create a tunnel using established keys
        parameters for the tunnel (host username port etc)
        will be passed in the body
        """
        import logging
        logger = logging.getLogger()
        logger.debug("Createing tunnel")
        logger.debug("recieved data {}".format(request.data))
        data = request.get_json()
        try:
            port = data['port']
        except KeyError as missingdata:
            raise AppParamsException("missing value for {}".format(missingdata))
        if not CreateTunnel.validate_connect_params(port, username, loginhost, batchhost):
            raise AppParamsException("Invalid value: {} {} {} {}".
                                     format(username, loginhost, batchhost, port))
        if 'internalfirewall' in data:
            firewall = data['internalfirewall']
        else:
            firewall = False
        if 'localbind' in data:
            localbind = data['localbind']
        else:
            localbind = True
        sshsess = SSHSession.get_sshsession()
        authtok = gen_authtok()
        # logger.debug('JobCreate.create_tunnel: creating a tunnel for authtok {}'.format(authtok))
        Ssh.tunnel(sshsess, port=port, batchhost=batchhost,
                   user=username, host=loginhost,
                   internalfirewall=firewall,
                   localbind=localbind, authtok=authtok)
        response = make_response("")
        response.mime_type = 'application/json'
        response.set_cookie('twsproxyauth', authtok)
        logger.debug('JobConnect.connect: connecting via redirect with cookie authtok set to  {}'.format(authtok))
        return response



api.add_resource(TunnelstatEP, '/tunnelstat/<string:authtok>')
api.add_resource(GetCert, '/getcert')
api.add_resource(JobStat, '/stat')
api.add_resource(JobCancel, '/cancel/<int:jobid>')
api.add_resource(JobSubmit, '/submit')
api.add_resource(CreateTunnel, '/createtunnel/<string:username>/<string:loginhost>/<string:batchhost>')
api.add_resource(AppInstance, '/appinstance/<string:username>/<string:loginhost>/<string:batchhost>')
api.add_resource(AppUrl, '/appurl')
api.add_resource(SSHAgent,'/sshagent')
api.add_resource(DirList,'/ls')
api.add_resource(MkDir,'/mkdir')