"""
A script to transfer a tree of data files from a remote/source server to a
local/destination computer. This runs on a local Linux machine or the eResearch dtn, on
which the tape archive system is mounted; in our case, this is a machine at Monash.
Prior to running this an ssh key pair must be shared between the systems. See
https://confluence.apps.monash.edu/display/XI/ssh+between+MASSIVE+filesystem+and+ASCI
for details on how to do this between a Monash Linux machine and ASCI
(Australian Synchrotron Compute Infrastructure). Requires Python 3.7 or higher
and uses the fabric module.

Authors:
gary.ruben@monash.edu
michelle.croughan@monash.edu
linda.croton@monash.edu

Note that current version creates two files in the same directory as this script
1. A .log file named based on the start-time timestamp which is a capture of all
   stdout activity.
2. A Python pickle file named tree_state.pickle that contains the transfer state
   from which failed transfers can be restarted by setting the resume
   file to True.

Known issues
------------
Note: Some versions of fabric generate a harmless warning, which can be ignored. This
      issue is discussed here: https://github.com/paramiko/paramiko/issues/1369

Notes
-----
This is a possible option for checksumming:
https://stackoverflow.com/q/45819356/
KERNEL_CHECKSUM=$(cpio --to-stdout -i kernel.fat16 < archive.cpio  | sha256sum | awk "{print $1}")

We used the following command to check whether a transfer was successful
immediately prior to a failure of the ASCI filesystem.
The command to count the number of files in a tarball
$ tar -tf Lamb_Lung_Microfil_CT_18011B_right_CT.tar | wc -l
75920

"""
import os
import re
import sys
import warnings
from dataclasses import dataclass
import pathlib
import subprocess
import pickle
import pprint
import time
import click
import textwrap
from fabric import Connection


def escape_parens(path):
    """ Explicitly escape parentheses. This is required to work around a bug in Fabric's
    Invoke module. See my question on Stackoverflow:
    https://stackoverflow.com/q/63225018/607587
    The recommended workaround, until Fabric fixes the bug, is to just "manually escape
    the parentheses"

    I used this method: https://stackoverflow.com/a/23563806/607587

    """
    replacements = {"(":"\(", ")":"\)"}
    escaped_path = "".join([replacements.get(c, c) for c in path])
    return escaped_path


def escape_path(path):
    """ Explicitly escape parentheses AND spaces.
    I used this method: https://stackoverflow.com/a/23563806/607587

    """
    # kludge; first (un)escape any already escaped characters
    unreplacements = {"\(":"(", "\)":")", "\ ":" "}
    unescaped_path = "".join([unreplacements.get(c, c) for c in path])
    # Now escape unescaped spaces, plus any unescaped parens
    replacements = {"(":"\(", ")":"\)", " ":"\ "}
    escaped_path = "".join([replacements.get(c, c) for c in unescaped_path])
    return escaped_path


@dataclass
class Node:
    """A directory tree node"""
    src: str                    # source tree node path
    dest: str                   # destination tree node path
    count: int = None           # number of files at the node
    processed: bool = False     # True iff a node transfer completes


class Logger(object):
    def __init__(self, log_filename):
        self.terminal = sys.stdout
        self.log = open(log_filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        self.terminal.flush()
        self.log.flush()


def send_directory(node, remote_login, src_path):
    """Sends all files in the node.src directory to the node.dest directory
    across an ssh connection.

    Different methods are used for single versus multiple files. For single
    files, scp is used. For multiple files cpio is used to tar the files into a
    single tarball. The destination tarball is named after the directories
    trailing src_path. Permissions are set to r_x for group and owner.

    Args:
        node: Node object
            Contains source and destination directory information as follows:
            src: full path to a remote node
                 e.g. /data/13660a/asci/input
            dest: full path to a destination node
                  e.g. /home/grub0002/bapcxi/vault/imbl2018
            count: number of files at the remote node
        remote_login: str
            remote login username@url
        src_path: str
            asci src top-level directory

    """
    # Check if there are any files in the node.
    with Connection(remote_login) as c:
        with c.cd(escape_parens(node.src)):
            result = c.run(r"nice find -maxdepth 1 -type f -printf '%f\n'", echo=True)

    files = result.stdout.strip()
    node.count = len(files.splitlines())

    print(f"Node:{node.src}, file count:{node.count}")
    if node.count == 0:
        # No files at this node, just return
        print("No files to transfer")
    else:
        # At least one file. Transfer all files to a tarball.
        if node.src == src_path:
            filename = os.path.basename(node.src)
        else:
            filename = node.src.replace(src_path + "/", "").replace("/", "_")

        cmd_src = escape_path(node.src)
        cmd_dest = escape_path(node.dest)
        cmd_filename = escape_path(filename)
        output = subprocess.run(
            f'ssh {remote_login} "cd {cmd_src};'
            f'nice find -maxdepth 1 -type f -printf \'%f\\0\' |'
            f'xargs -0 tar -cf - " | cat > {cmd_dest}/{cmd_filename}.tar',
            shell=True,
            check=True
        )
        print("stdout:", output.stdout)
        print("stderr:", output.stderr)

        # os.chmod(f"{node.dest}/{filename}.tar", 0o550)
        print(f"Transferred {node.count} files {node.src} -> {node.dest}")

    node.processed = True


@click.command()
@click.argument("remote_login")
@click.argument("experiment_name")
@click.argument("src_path", type=click.Path())
@click.argument("dest_path", type=click.Path())
@click.option("-p","pickle_filename", help="Pickle filename, e.g. 'foo.pickle' (default = experiment_name.pickle")
@click.option("-r","resume",is_flag=True, help="If True, continue from current pickle state")
@click.option("-d","display_pickle_file",is_flag=True, help="If True, just show the pickle file state")
def main(
    remote_login,
    experiment_name,
    src_path,
    dest_path,
    pickle_filename,
    resume,
    display_pickle_file
):
    """
    \b
    Example
    -------
    $ python asci_to_vault.py gary.ruben@monash.edu@sftp1.synchrotron.org.au 15223 /data/15223/asci/input /home/gruben/vault/vault/IMBL/IMBL_2019_Nov_Croton/input

    A script to transfer a tree of data files from a remote/source server to a
    local/destination computer. This runs on a local Linux machine or the eResearch dtn, on
    which the tape archive system is mounted; in our case, this is a machine at Monash.
    Prior to running this an ssh key pair must be shared between the systems. See
    https://confluence.apps.monash.edu/display/XI/Australian+Synchrotron
    for details on how to do this between a Monash Linux machine and ASCI
    (Australian Synchrotron Compute Infrastructure). Requires Python 3.7 or higher
    and uses the fabric module.

    Note that current version creates two files in the same directory as this script
    1. A .log file named based on the start-time timestamp which is a capture of all
    stdout activity.
    2. A Python pickle file named tree_state.pickle that contains the transfer state
    from which failed transfers can be restarted by setting the resume
    file to True.

    """
    assert 4 <= len(experiment_name) <= 6
    if pickle_filename is None:
        pickle_filename = experiment_name+".pickle"

    path, base = os.path.split(pickle_filename)

    if path == "":
        pickle_filename = os.path.join(os.path.dirname(__file__), pickle_filename)

    timestamp = time.strftime("%Y-%m-%d-%H%M%S")

    log_filename = os.path.join(
        os.path.dirname(__file__),
        f"{experiment_name}-{timestamp}.log"
    )

    if re.fullmatch(r"[a-zA-z0-9_\-\.@]+@[a-zA-Z0-9_\-\.]+", remote_login) is None:
        raise Exception("Invalid form for login address")

    """
    Possible file name formats:
    /data/<experiment number>/asci/input
    /data/<experiment number>/asci/output
    input
    output
    output/username/working/
    output/username/working
    """

    src_file_path = src_path.split("/")[:5]

    if src_file_path[0] != "":
        src_path = os.path.join(f"/data/{experiment_name}/asci/", *src_file_path)

    sys.stdout = Logger(log_filename)       # Log all stdout to a log file

    print(textwrap.dedent(f"""
        remote_login = {remote_login}
        experiment_name = {experiment_name}
        src_path = {src_path}
        dest_path = {dest_path}
        pickle_filename = {pickle_filename}
        resume = {resume}
        display_pickle_file = {display_pickle_file}
    """))

    # If the resume flag is set, resume the transfer.
    if resume or display_pickle_file:
        # Read the saved transfer state from the locally pickled tree object.
        with open(pickle_filename, "rb") as f: 
            tree = pickle.load(f)
        print("tree:")
        pprint.pprint(tree)

        if display_pickle_file:
            sys.exit()

        if resume:
            # Reset nodes at the end of the list with count==0 to unprocessed
            # This is done because we observed a failure that mistakenly reported
            # source tree nodes to have 0 files, so force a recheck of those.
            for node in reversed(tree):
                if node.count == 0:
                    node.processed = False
                else:
                    break
    else:
        # Get the directory tree from the remote server as a list.
        with Connection(remote_login) as c:
            result = c.run(f"find {src_path} -type d")
        remote_dirs = result.stdout.strip().splitlines()

        # Create a tree data structure that represents both source and
        # destination tree paths.
        tree = []
        for src in remote_dirs:
            dest = src.replace(src_path, dest_path)
            tree.append(Node(src, dest))

    # Transfer all directory tree nodes.
    for i, node in enumerate(tree):
        if not node.processed:
            pathlib.Path(node.dest).mkdir(mode=0o770, parents=True, exist_ok=True)
            # os.chmod(node.dest, 0o770)
            send_directory(node, remote_login, src_path)

        # pickle the tree to keep a record of the processed state.
        with open(pickle_filename, "wb") as f:
            pickle.dump(tree, f)

        print(f"Processed {i + 1} of {len(tree)} directory tree nodes")


if __name__ == "__main__":
    main()