# from fastq_downloader.helper import construct_gsm_dict, download_and_verify, rename_files_by_run, merge_files
from fastq_downloader.helper.construct_gsm_dict import infotsv_to_dict, parse_acc_type
from fastq_downloader.helper.download_and_verify import (
    download_and_verify__,
    check_file,
)
from fastq_downloader.helper.rename_files_by_run import (
    write_filename_to_dict,
    actual_rename,
)
from fastq_downloader.helper.merge_files import merge_files
from fastq_downloader.helper.accession_to_links import gsm2link_md5, srx2link_md5
from pathlib import Path
import json
from glob import glob
import os
import sys
import subprocess
import re
import shutil
from hashlib import md5

# please invoke this script with --config set
# infotsv, a tsv file specified accession number and description
# output_dir, a directory to store the downloaded files
infotsv = config["infotsv"]
output_dir = config.get("output_dir", ".")
refresh_acc = config.get("refresh_acc", False)
retry = config.get("retry", 3)
tmp_dir = Path(output_dir, ".tmp")
tmp_dir = str(tmp_dir)

ascp_bin = subprocess.check_output("which ascp", shell=True).decode("utf-8").strip()
ascp_privkey = Path(ascp_bin).parent.parent / "etc" / "asperaweb_id_dsa.openssh"
if not ascp_privkey.exists():
    raise Exception("ascp private key not found: {}".format(ascp_privkey))


infodict = infotsv_to_dict(infotsv)
infodict = parse_acc_type(infodict)
acc_list = list(infodict.keys())
if (not Path(tmp_dir).exists()) or refresh_acc:
    Path(tmp_dir).mkdir(exist_ok=True, parents=True)
    for acc, info in infodict.items():
        Path(tmp_dir, "acc_json").mkdir(exist_ok=True, parents=True)
        # check md5 of the file, if not equal, overwrite it
        with open(Path(tmp_dir, "acc_json", f"{acc}.json"), "r") as f:
            file_content = json.load(f)
            # calculate md5 of the file
            file_md5 = md5(file_content.encode("utf-8")).hexdigest()
            # calculate md5 of the json
            json_md5 = md5(json.dumps(info).encode("utf-8")).hexdigest()
        with open(Path(tmp_dir, "acc_json", f"{acc}.json"), "w") as f:
            # if not equal:
            if file_md5 != json_md5:
                # overwrite the file
                json.dump(info, f)
            else:
                # if equal, do nothing
                # but alert the user
                print(f"{acc} is already up to date", file=sys.stderr)
                pass


rule all:
    input:
        expand(tmp_dir + "/acc_json/{acc}.json", acc=acc_list),
        expand(tmp_dir + "/acc_json_with_ascp/{acc}.json", acc=acc_list),
        expand(tmp_dir + "/acc_json_with_rename/{acc}.json", acc=acc_list),
        expand(output_dir + "/merged_fastq/{acc}/", acc=acc_list),
        expand(tmp_dir + "/merged_lock/{acc}.lock", acc=acc_list),


rule get_dict_with_ascp:
    input:
        tmp_dir + "/acc_json/{acc}.json",
    output:
        tmp_dir + "/acc_json_with_ascp/{acc}.json",
    run:
        with open(input[0]) as f:
            info = json.load(f)
        if info["type"] == "srx":
            info["ascp"] = srx2link_md5(wildcards.acc)
        elif info["type"] == "gsm":
            info["ascp"] = gsm2link_md5(wildcards.acc)
        else:
            ValueError(
                "accession type is not gsm or srx (please note the letter should be lower case)"
            )
        with open(output[0], "w") as f:
            json.dump(info, f)


rule get_ascp_filename:
    input:
        rules.get_dict_with_ascp.output[0],
    output:
        tmp_dir + "/acc_json_with_rename/{acc}.json",
    run:
        with open(input[0]) as f:
            info = json.load(f)
        info = write_filename_to_dict(info)
        with open(output[0], "w") as f:
            json.dump(info, f)


checkpoint split_ascp_dict:
    input:
        rules.get_ascp_filename.output[0],
    output:
        directory(tmp_dir + "/ascp_split/{acc}"),
    run:
        info = json.load(open(input[0]))
        for link, ascpinfo in info["ascp"].items():
            ascpinfo["link"] = link
            orig_name = ascpinfo["orig_name"]
            if "_1" in orig_name:
                ascpinfo["mate"] = "R1"
            elif "_2" in orig_name:
                ascpinfo["mate"] = "R2"
            elif "_" not in orig_name:
                ascpinfo["mate"] = "R_"
            else:
                ValueError("mate is not R1 or R2, and not single end read")
            Path(tmp_dir + f"/ascp_split/{wildcards.acc}/{ascpinfo['mate']}").mkdir(
                exist_ok=True, parents=True
            )
            with open(f"{output[0]}/{ascpinfo['mate']}/{orig_name}.json", "w") as f:
                json.dump(ascpinfo, f)


rule download_for_r1:
    input:
        tmp_dir + "/ascp_split/{acc}/R1/{orig_name}.json",
    output:
        temp(tmp_dir + "/unmerged_fastq/{acc}/R1/{orig_name}"),
        touch(tmp_dir + "/md5_lock/{acc}/R1/{orig_name}.lock"),
    run:
        passed = False
        for i in range(retry):
            try:
                with open(input[0]) as f:
                    info = json.load(f)
                    link = info["link"]
                    orig_name = info["orig_name"]
                    outdir = str(Path(output[0]).parent)
                    download_and_verify__(link, info, ascp_privkey, outdir)
                    passed = True
                    break
            except Exception as e:
                print(e, file = sys.stderr)
                print(f"retry {i+1}/{retry}", file = sys.stderr)
                continue
        if passed:
            print(f"{orig_name} downloaded successfully", file = sys.stderr)
        else:
            raise Exception(f"{orig_name} failed to download after {retry} retries")


use rule download_for_r1 as download_for_r2 with:
    input:
        tmp_dir + "/ascp_split/{acc}/R2/{orig_name}.json",
    output:
        temp(tmp_dir + "/unmerged_fastq/{acc}/R2/{orig_name}"),
        touch(tmp_dir + "/md5_lock/{acc}/R2/{orig_name}.lock"),


use rule download_for_r1 as download_for_single with:
    input:
        tmp_dir + "/ascp_split/{acc}/R_/{orig_name}.json",
    output:
        temp(tmp_dir + "/unmerged_fastq/{acc}/R_/{orig_name}"),
        touch(tmp_dir + "/md5_lock/{acc}/R_/{orig_name}.lock"),


def get_downloaded_files(wildcards):
    ascp_checkpoint_dir = checkpoints.split_ascp_dict.get(**wildcards).output[0]
    mate = glob(f"{ascp_checkpoint_dir}/*")
    def get_output_path(outpath_list):
        print(outpath_list, file = sys.stderr)
        result = []
        for outpath in outpath_list:
            mate = re.search("R[\d_]", outpath)[0]
            result += expand(
                tmp_dir + outpath,
                acc=wildcards.acc,
                orig_name=glob_wildcards(
                    os.path.join(
                        tmp_dir, "ascp_split", wildcards.acc, mate, "{orig_name}.json"
                    )
                ).orig_name,
            )
        return result

    if len(mate) == 2:
        print(f"{wildcards.acc} is paired end, start downloading", file = sys.stderr)
        outputs = get_output_path([
            "/unmerged_fastq/{acc}/R1/{orig_name}",
            "/unmerged_fastq/{acc}/R2/{orig_name}",
            "/md5_lock/{acc}/R1/{orig_name}.lock",
            "/md5_lock/{acc}/R2/{orig_name}.lock"
        ])
    if len(mate) == 1:
        print(f"{wildcards.acc} is single end", file = sys.stderr)
        outputs = get_output_path([
            "/unmerged_fastq/{acc}/R_/{orig_name}",
            "/md5_lock/{acc}/R_/{orig_name}.lock"
        ])
        print(f"{outputs} \n\nOUTPUT", file = sys.stderr)
    return outputs


rule merge_files:
    input:
        [get_downloaded_files],
        accdict=tmp_dir + "/acc_json/{acc}.json",
    output:
        directory(output_dir + "/merged_fastq/{acc}"),
        touch(tmp_dir + "/merged_lock/{acc}.lock")
    run:
        # tell if all of the files are of the same type
        input_filename = input[:-1]
        input_filename = [file for file in input_filename if "md5_lock" not in file]
        filename_list = [Path(i).name for i in input_filename]
        if_underscore_list = ["_" in Path(i).name for i in filename_list]
        if len(set(if_underscore_list)) != 1:
            raise Exception(
                f"accession {wildcards.acc} the files are not of the same read type\n"
                + f"filename_list: {filename_list},\n if_underscore_list: {if_underscore_list}"
            )
        # determinter if they are single end or paired end
        with open(input.accdict) as f:
            accdict = json.load(f)
        desp = accdict["desp"]
        r1_newname = desp + "_R1.fastq.gz"
        r2_newname = desp + "_R2.fastq.gz"
        r__newname = desp + ".fastq.gz"


        def merge_(subdict):
            with open(output[0] + "/"+ subdict["newname"], "wb") as wfd:
                for f in subdict["oldnames"]:
                    with open(f, "rb") as fd:
                        shutil.copyfileobj(fd, wfd)


        Path(output[0]).mkdir(exist_ok=True, parents=True)
        if all(if_underscore_list):
            r1_list = [i for i in input_filename if "_1" in Path(i).name]
            r2_list = [i for i in input_filename if "_2" in Path(i).name]
            r1_list.sort()
            r2_list.sort()
            mate_dict = {
                "R1": {"oldnames": r1_list, "newname": r1_newname},
                "R2": {"oldnames": r2_list, "newname": r2_newname}
            }
            if len(r1_list) != len(r2_list):
                raise Exception(
            f"accession {wildcards.acc} r1 and r2 file numbers are not equal"
            )
            # if only one run, only move, no need to merge
            if len(r1_list) == 1:
                for mate, subdict in mate_dict.items():
                    shutil.move(
                        subdict["oldnames"][0], output[0] + "/" + subdict["newname"]
                    )
            else:
                for mate, info in mate_dict.items():
                    merge_(info)
        else:
            r__list = [i for i in input_filename]
            r__list.sort()
            mate_dict = {"R_": {"oldnames": r__list, "newname": r__newname}}
            if len(r__list) == 1:
                shutil.move(r__list[0], output[0] + "/" + r__newname)
            else:
                merge_(mate_dict["R_"])
