#!/usr/bin/env python

# Argument handling

import argparse
import pkg_resources
import sys

parser = argparse.ArgumentParser(description=(
    "Fast protein structure searching using structure graph embeddings. "
    "See https://github.com/jgreener64/progres for documentation and citation information. "
    f"This is version {pkg_resources.get_distribution('progres').version} of the software."
))
subparsers = parser.add_subparsers(dest="mode",
    help="the mode to run progres in, run \"progres {mode} -h\" to see help for each")

parser_search = subparsers.add_parser("search",
    description="Search one or more queries against a pre-embedded database.",
    help="search one or more queries against a pre-embedded database")
parser_search.add_argument("-q", "--querystructure",
    help="query structure file in PDB/mmCIF/MMTF/coordinate format")
parser_search.add_argument("-l", "--querylist",
    help="text file with one query file path per line")
parser_search.add_argument("-t", "--targetdb", required=True,
    help=("pre-embedded database to search against, either "
          "\"scope95\", \"scope40\", \"cath40\", \"ecod70\" or a file path"))
parser_search.add_argument("-f", "--fileformat",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the query structure(s), by default guessed from the file extension")
parser_search.add_argument("-s", "--minsimilarity", type=float, default=0.8,
    help="similarity threshold above which to return hits, default 0.8")
parser_search.add_argument("-m", "--maxhits", type=int, default=100,
    help="maximum number of hits to return, default 100")
parser_search.add_argument("-d", "--device", default="cpu",
    help="device to run on, default is \"cpu\"")

parser_embed = subparsers.add_parser("embed",
    description="Embed a dataset of structures to allow it to be searched against.",
    help="embed a dataset of structures to allow it to be searched against")
parser_embed.add_argument("-l", "--structurelist", required=True,
    help="text file with file path, domain name and optional note per line")
parser_embed.add_argument("-o", "--outputfile", required=True,
    help="output file path for the PyTorch file containing the embeddings")
parser_embed.add_argument("-f", "--fileformat",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the structures, by default guessed from the file extension")
parser_embed.add_argument("-d", "--device", default="cpu",
    help="device to run on, default is \"cpu\"")

args = parser.parse_args()

if args.mode == "search":
    from progres import progres_search_print
    if args.minsimilarity < 0 or args.minsimilarity > 1:
        raise argparse.ArgumentTypeError("minsimilarity must be between 0 and 1")
    if args.maxhits < 1:
        raise argparse.ArgumentTypeError("maxhits must be a positive integer")
    if args.querystructure:
        progres_search_print(querystructure=args.querystructure, targetdb=args.targetdb,
                             fileformat=args.fileformat, minsimilarity=args.minsimilarity,
                             maxhits=args.maxhits, device=args.device)
    elif args.querylist:
        progres_search_print(querylist=args.querylist, targetdb=args.targetdb,
                             fileformat=args.fileformat, minsimilarity=args.minsimilarity,
                             maxhits=args.maxhits, device=args.device)
    else:
        print("One of -q and -l must be given for structural searching", file=sys.stderr)
elif args.mode == "embed":
    from progres import progres_embed
    progres_embed(structurelist=args.structurelist, outputfile=args.outputfile,
                  fileformat=args.fileformat, device=args.device)
else:
    print("No mode selected, run \"progres -h\" to see help", file=sys.stderr)
