import click


@click.group()
def tasks():
    pass


@tasks.command(help="Parse the tsv file generated by command 'search'. "
                    "If you want to output the result into a specified file, you may "
                    "use the redirection symbol \"> path\".")
@click.option("-i", "--input", type=str, help="tsv file path.", default=None)
def parse(input):
    from .Results import format as fmt
    
    if not input:
        click.echo("Error: No tsv file input, please input a tsv file.")
        return
    
    fmt.parse_tsv(input)


@tasks.command(help="Convert tsv files to MSAs.")
@click.option("-i", "--input", type=str, help="tsv file path.", default=None)
@click.option("-o", "--output", type=str, help="the directory for storing results.", default=None)
@click.option("--align", type=bool, help="whether align sequences or not. default: 1(True)", default=True)
@click.option("--outfmt", type=click.Choice(['fasta', 'a3m']), help="output result in specified format. "
                                                                    "default: a3m", default='a3m')
def tsv2msa(input, output, align, outfmt):
    from .Results import format as fmt
    if not input:
        click.echo("Error: No tsv file input, please input a tsv file.")
        return
    
    fmt.tsv_to_msa(input, output, align, outfmt)


@tasks.command(help="Search query sequences in standard UniRef30 database. The database is divided into several parts "
                    "based on sequence length because of memory limitation. The search will be executed on every "
                    "sub-database and all results will finally be aggregated into one.")
@click.option("--q_vec_path", type=str, help="npy path of query sequences.", default=None)
@click.option("--q_info_path", type=str, help="information of query sequences.", default=None)
@click.option("--db_path", type=str, help="database path", default=None)
@click.option("-o", "--output", type=str, help="output result to specified file. input related path will be used "
                                               "by default.", default=None)
@click.option("-t", "--threshold", type=int, help="threshold that the model use to detect homologies. default: 1",
              default=1)
@click.option("--nprobe", type=int, help="search the first n clusters. f you want to search all clusters, "
                                         "set this to 0. default: 1", default=1)
@click.option("--max_num", type=int, help="max num limitation of homologies for one query sequence. if you want no "
                                          "limitation, set this to 0. default: 500", default=500)
@click.option("-v", "--verbose", type=bool, help="display details. default: 1(True)", default=True)
def standard_search(q_vec_path, q_info_path, db_path, output, threshold, nprobe, max_num,
                    verbose):
    import os
    import pandas as pd
    import click
    from .utils import TimeCounter
    from .search import pipeline
    
    var_dict = {'q_vec_path': q_vec_path,
                'q_info_path': q_info_path,
                'db_index_path': db_path}
    
    for k, v in var_dict.items():
        if not v:
            click.echo(f"Error: No {k} input, please input {k}.")
            return
    
    output = q_vec_path.replace(".npy", "_result.tsv") if not output else output + '_result.tsv'
    
    db_name = [f"{i}_{i + 100}" for i in range(0, 1100, 100)]
    aggregation_list = []
    for name in db_name:
        db_index_path = f"{db_path}_{name}.index"
        db_info_path = f"{db_path}_{name}_info.tsv"
        temp_output = output + name
        
        if not os.path.exists(db_index_path):
            click.echo(f"Warning: index '{db_index_path}' doesn't exist, so ignored. It means that the model will not "
                       f"search sequences of length {name}, which may affect the accuracy of final result.")
            continue
        
        pipeline(q_vec_path, db_index_path, q_info_path, db_info_path, temp_output,
                 t=threshold, nprobe=nprobe, max_num=max_num, verbose=verbose)
        aggregation_list.append(temp_output)
    
    query_dict = {}
    with TimeCounter("Aggregating results...", verbose):
        for temp_output in aggregation_list:
            if not os.path.exists(temp_output):
                continue
            
            res = pd.read_csv(temp_output, sep='\t')
            
            for key in res['query_id'].value_counts().keys():
                if key not in query_dict.keys():
                    query_dict[key] = res[res['query_id'] == key]
                
                else:
                    query_dict[key] = pd.concat([query_dict[key], res[res['query_id'] == key]], axis=0).sort_values(
                        by='distance')
            
            os.remove(temp_output)
        
        # output result
        res_list = []
        for query_id in pd.read_csv(q_info_path, sep='\t')['id']:
            if query_id in query_dict.keys():
                res_list.append(query_dict[query_id])
        
        pd.concat(res_list, axis=0).to_csv(output, sep='\t', index=False)


@tasks.command(help="Search query sequences in user customized database")
@click.option("--q_vec_path", type=str, help="npy path of query sequences.", default=None)
@click.option("--q_info_path", type=str, help="information of query sequences.", default=None)
@click.option("--db_index_path", type=str, help="faiss index of database", default=None)
@click.option("--db_info_path", type=str, help="information of sequences in database.", default=None)
@click.option("-o", "--output", type=str, help="output result to specified file. input related path will be used "
                                               "by default.", default=None)
@click.option("-t", "--threshold", type=int, help="threshold that the model use to detect homologies. default: 1",
              default=1)
@click.option("--nprobe", type=int, help="search the first n clusters. f you want to search all clusters, "
                                         "set this to 0. default: 1", default=1)
@click.option("--max_num", type=int, help="max num limitation of homologies for one query sequence. if you want no "
                                          "limitation, set this to 0. default: 500", default=500)
@click.option("-v", "--verbose", type=bool, help="display details. default: 1(True)", default=True)
def customized_search(q_vec_path, q_info_path, db_index_path, db_info_path, output, threshold, nprobe, max_num,
                      verbose):
    from .search import pipeline
    
    var_dict = {'q_vec_path': q_vec_path,
                'q_info_path': q_info_path,
                'db_index_path': db_index_path,
                'db_info_path': db_info_path}
    for k, v in var_dict.items():
        if not v:
            click.echo(f"Error: No {k} input, please input {k}.")
            return
    
    output = q_vec_path.replace(".npy", "_result.tsv") if not output else output + '_result.tsv'
    
    pipeline(q_vec_path, db_index_path, q_info_path, db_info_path, output,
             t=threshold, nprobe=nprobe, max_num=max_num, verbose=verbose)


@tasks.command(help="Convert sequences into vectors")
@click.option("-i", "--input", type=str, help="fasta path.", default=None)
@click.option("-o", "--output", type=str, help="output result to specified file. input related path will be used "
                                               "by default.", default=None)
@click.option("--devices", type=str, help="choose which GPUs will be used. e.g. '0,1,2,3'. if not specified, all GPUs "
                                          "will be used by default. if no GPU available, please set this to 'cpu'",
              default=None)
@click.option("--batch", type=int, help="batch size adopted. default: 128", default=128)
@click.option("--model", type=str, help="model parameter file. e.g. 'your_path/model_name.pt'", default=None)
@click.option("--cover", type=bool, help="whether cover the output file when it is already existed. default: 1(True)",
              default=True)
@click.option("-v", "--verbose", type=bool, help="display details. default: 1(True)", default=True)
def build(input, output, devices, model, batch, cover, verbose):
    import click
    import torch
    from .vector_construction import fasta2vec
    
    if not input:
        click.echo("Error: No fasta file input, please input a fasta file.")
        return
    
    if not model:
        click.echo("Error: No model path input, please input a model path.")
        return
    
    if not devices:
        cnt = torch.cuda.device_count()
        devices = list(range(cnt)) if cnt != 0 else ['cpu']
    elif devices == 'cpu':
        devices = ['cpu']
    else:
        devices = [int(i) for i in devices.split(',')]
    
    if not output:
        output = input
    
    fasta2vec(devices, model, input, output, batch_size=batch, cover=cover, verbose=verbose)


@tasks.command(help="Construct indices of given vectors.")
@click.option("-i", "--input", type=str, help="npy path.", default=None)
@click.option("-o", "--output", type=str, help="specified index path. input related path will be used "
                                               "by default.", default=None)
@click.option("--param", type=str, help="faiss index parameters. for more information, please check "
                                        "https://github.com/facebookresearch/faiss/wiki/Faiss-indexes."
                                        "default: IVFx, SQ4", default=None)
@click.option("-v", "--verbose", type=bool, help="display details. default: 1(True)", default=True)
def vec2index(input, output, param, verbose):
    import click
    import faiss
    import numpy as np
    from math import ceil
    from .utils import TimeCounter
    from .Faiss.index_construction import contruct_faiss_index
    
    if not input:
        click.echo("Error: No fasta file input, please input a fasta file.")
        return
    
    output = input.replace(".npy", ".index") if not output else output + '.index'
    
    vectors = np.load(input)
    dim, measure = 1280, faiss.METRIC_L2
    if not param:
        n = ceil(vectors.shape[0] / 256)
        param = f'IVF{n}, SQ4'
    
    with TimeCounter("Constructing index...", verbose):
        contruct_faiss_index(vectors, output, dim, measure, param)


if __name__ == '__main__':
    tasks()

