#!/usr/bin/env python

"""
Simple script for downloading trained Gaussian Processes
for use with the cross-section evaluation code.
Compatible with Python 2 and 3.

Usage:
    xsec-download-gprocs --gp_dir PATH --process TYPE

If no argument PATH is given, the current working directory is used.
If no argument TYPES is given, data will be downloaded for all
supported final-state types.
Use 'xsec-download-gprocs --help' for more details.
"""


from __future__ import print_function

import os
import sys
import time
import errno
import tarfile
import argparse


try:  # if python3
    from urllib.request import urlretrieve
except ImportError:  # if python2
    from urllib import urlretrieve

# Import xsec, first assume we have pip installed it
try:
    import xsec
# Someone is insisting on using our fine programme without pip
# installing
except ImportError:
    # Our current absolute directory
    abs_dir = os.path.dirname(os.path.abspath(__file__))
    # Parent directory containing xsec
    parent_dir = os.path.dirname(abs_dir)
    sys.path.append(parent_dir)
    import xsec

# Find current version number for correct download
VERSION = xsec.__version__

BASE_URL = "https://github.com/jeriek/xsec/releases/download/{v}/".format(
    v=VERSION
)
PROCESS_TYPES = ["gg", "sg", "ss", "sb", "tb"]
ALL_PROCESS_TYPES = PROCESS_TYPES + ["all"]


# Example URL:
# https://github.com/jeriek/xsec/releases/download/1.0.0/xsec_1.0.0_gprocs_gg.tar.gz

# Each file is specified as a {key: filename} pair in this dictionary
FILE_DICT = {
    process_type: "xsec_{version}_gprocs_{process_type}.tar.gz".format(
        version=VERSION, process_type=process_type
    )
    for process_type in PROCESS_TYPES
}
# Each url is specified as a {key: url} pair in this dictionary
DOWNLOAD_DICT = {
    process_type: BASE_URL + FILE_DICT[process_type]
    for process_type in FILE_DICT.keys()
}


def main():
    """
    Download the requested files in DOWNLOAD_LIST.
    """

    # Get the script name
    prog_name = "xsec-download-gprocs"

    # Set up the argument parser
    parser = argparse.ArgumentParser(
        prog=prog_name,
        description="Tool for downloading Gaussian process files for xsec.",
    )

    # Take the download directory as an optional argument
    parser.add_argument(
        "-g",
        "--gp_dir",
        nargs="?",
        metavar="PATH",
        type=str,
        action="store",
        default=os.path.join(os.getcwd(), "gprocs"),
        help="set the path where the downloaded files are stored. "
        "The default path is %(default)s",
    )
    parser.add_argument(
        "-t",
        "--process_type",
        nargs="?",
        metavar="TYPE",
        type=str,
        choices=ALL_PROCESS_TYPES,
        action="store",
        default="all",
        help="choose the final-state type for which GP data will be "
        "downloaded. The default choice is '%(default)s'. Other "
        "options: "
        "  'gg' (gluino pair production), "
        "  'sg' (1st/2nd gen. squark--gluino pair production), "
        "  'ss' (1st/2nd gen. squark pair production), "
        "  'sb' (1st/2nd gen. squark--anti-squark pair production), "
        "  'tb' (3rd gen. squark--anti-squark pair production).",
    )

    # Parse the arguments
    args = parser.parse_args()

    data_init_filepath = os.path.join(args.gp_dir, "__init__.py")

    # Create data dir and/or init file if not existing
    if not os.path.exists(args.gp_dir):
        try:
            # Create gprocs directory and empty init file
            mkdir_p(args.gp_dir)
            open(data_init_filepath, "a").close()
        except OSError:
            raise
    elif not os.path.exists(data_init_filepath):
        try:
            # Create empty init file
            open(data_init_filepath, "a").close()
        except OSError:
            raise
        except IOError as exc:
            if exc.errno == errno.EACCES:
                print(
                    "Permission denied when trying to write the "
                    "file {f}. Do you have write permission for the "
                    "directory {dir}?".format(
                        f=data_init_filepath, dir=args.gp_dir
                    )
                )
                print()
                raise
            else:
                raise

    # Print a new line
    print()

    if args.process_type == "all":
        loop_list = PROCESS_TYPES
    else:
        loop_list = [args.process_type]

    for process in loop_list:
        url = DOWNLOAD_DICT[process]
        print("-- Downloading file", FILE_DICT[process], "from\n", url, ":\n")
        # Force writing to terminal; otherwise stdout is buffered first
        sys.stdout.flush()

        # Download compressed file to a temporary location (since
        # filename=None)
        try:
            tmp_filename, _ = urlretrieve(
                url, filename=None, reporthook=download_progress_hook
            )
        except IOError:
            print(
                "Download failed. Are you sure you're connected to "
                "the internet?"
            )
            print()
            raise

        print("\n\n  ... download completed")
        print("  ... starting file extraction (", tmp_filename, ")")
        sys.stdout.flush()
        tar = tarfile.open(tmp_filename, "r:gz")
        tar.extractall(args.gp_dir)
        tar.close()
        print("  ... file extraction completed.")
        sys.stdout.flush()

    print()
    print("All downloads completed. Data files stored in", args.gp_dir, "\b.")
    print()


def mkdir_p(path):
    """
    Safely make new directories recursively, as with "mkdir -p".
    """
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


def download_progress_hook(count, block_size, total_size):
    """
    Visualise download progress.
    Used with the urllib.urlretrieve reporthook.
    """
    # Compute time since start of download
    global start_time
    if count == 0:
        start_time = time.time()
        return
    duration = time.time() - start_time
    # Compute combined size of downloaded chunks, given in MB
    progress_size = float(count * block_size)
    # Compute download speed in KB/s
    speed = int(progress_size / (1024.0 ** 2 * duration))

    # When the HTTP response has 'Content-Length' in the header, display
    # progress bar and percentage of download completion. Else, when
    # 'Transfer-Encoding: chunked' is used in the header, no information
    # about total size is available, so only display accumulated
    # download size.
    if total_size > 0:
        percent = min(int(count * block_size * 100.0 / total_size), 100)
        # Setup toolbar
        toolbar_width = 18
        progress_symbol = "/"
        n_progress_lines = int(percent * toolbar_width / 100.0)

        # Output to screen: progress bar, some useful info
        # (\r returns cursor to start of line to overwrite old info,
        # \b backspace, \t tab)
        sys.stdout.write(
            "\r[%s] \t\b\b %s%%  --  %.1f MB of %.1f MB, %.1f s, %d MB/s"
            % (
                progress_symbol * n_progress_lines
                + " " * (toolbar_width - n_progress_lines),
                percent,
                progress_size / (1024.0 ** 2),
                total_size / (1024.0 ** 2),
                duration,
                speed,
            )
        )
        sys.stdout.flush()

    else:
        sys.stdout.write(
            "\r  ... downloaded %.1f MB in %.1f s, avg speed: %d MB/s"
            % (progress_size / (1024.0 ** 2), duration, speed)
        )
        sys.stdout.flush()


# When the code is executed as a script, run the following.
if __name__ == "__main__":
    main()
