#!/usr/bin/env python

import sys
import os
import pandas as pd
import numpy as np
from argparse import ArgumentParser, RawTextHelpFormatter

from zpca.zpca import counts2tpm, \
                      counts2tpm_many, \
                      determine_number_of_components, \
                      perform_pca, \
                      generate_scree_plot, \
                      generate_pc1_pc2_plot, \
                      generate_pc1_pc3_plot, \
                      generate_pc2_pc3_plot, \
                      generate_pca_3d

def main():
    """ Main function """

    __doc__ = "Perform PCA based on an expression matrix (rows are genes/transcripts, columns are samples)."

    parser = ArgumentParser(
        description=__doc__,
        formatter_class=RawTextHelpFormatter
    )

    parser.add_argument(
        "--counts",
        dest="counts",
        help="Counts table (tsv). The first column should contain the gene/transcript id. The other columns should contain the counts for each sample.",
        required=True,
        metavar="FILE"
    )

    parser.add_argument(
        "--lengths",
        dest="lengths",
        help="Table of feature lengths (tsv). " + os.linesep +
             "The file can have two types of formats." + os.linesep +
             "First option: The first column should contain the gene/transcript id." + os.linesep +
             "The second column should contain the corresponding lengths" + os.linesep +
             "Second option: The first column should contain the gene/transcript id." + os.linesep +
             "The rest of the columns should contain the gene/transcript lengths for each of the samples" + os.linesep +
             "Note that the sample names should be the same the sample names of the counts."
             ,
        required=True,
        metavar="FILE"
    )

    parser.add_argument(
        "--pseudocount",
        dest="pseudocount",
        default=1,
        help="Pseudocount to add in the count table. Default: 1",
        required=False,
    )

    parser.add_argument(
        "--filter-not-expressed",
        dest="filter_not_expressed",
        default=False, 
        action='store_true',
        help="Filter not expressed genes/transcripts (0 counts for all samples).",
        required=False,
    )

    parser.add_argument(
        "--out",
        dest="out",
        help="Output directory",
        required=True,
        metavar="DIRECTORY"
    )

    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        dest="verbose",
        default=False,
        required=False,
        help="Verbose"
    )

    try:
        options = parser.parse_args()
    except(Exception):
        parser.print_help()

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    if options.verbose:
        sys.stdout.write(f"Creating output directory: {options.out} {os.linesep}")

    if not os.path.exists(options.out):
        os.makedirs(options.out)

    if options.verbose:
        sys.stdout.write(f"Reading: {options.counts} {os.linesep}")
    df = pd.read_csv(options.counts, header=0, index_col=0, sep="\t")

    if options.verbose:
        sys.stdout.write(f"Reading: {options.lengths} {os.linesep}")
    lengths =  pd.read_csv(options.lengths, header=0, index_col=0, sep="\t")

    mode_lengths = "one"
    if int(lengths.shape[1])>2:
        mode_lengths = "many"

    if options.filter_not_expressed:
        if options.verbose:
            sys.stdout.write(f"Filtering not expressed genes/transcripts {os.linesep}")
        df = df[df.sum(axis=1)>0]
        
    if options.verbose:
        sys.stdout.write(f"Adding a pseudocount of {options.pseudocount} {os.linesep}")
    df = df + float(options.pseudocount)

    if options.verbose:
        sys.stdout.write(f"Converting the data to TPM {os.linesep}")
    if mode_lengths == "one":
        tpm = counts2tpm(df, lengths)
    elif mode_lengths == "many":
        tpm = counts2tpm_many(df, lengths)
    
    if options.verbose:
        sys.stdout.write(f"Log2 transform the data {os.linesep}")
    tpm = np.log2(tpm)

    if options.verbose:
        sys.stdout.write(f"Running PCA {os.linesep}")
        
    n_components = determine_number_of_components(df)
    
    scaled_data, pca_data, per_var, labels, = perform_pca(df, n_components)
    
    if options.verbose:
        sys.stdout.write(f"Generating scree plot in: {options.out} {os.linesep}")
    generate_scree_plot(options.out, per_var, labels)
    
    if options.verbose:
        sys.stdout.write(f"Writing PCA data \"PCA.tsv\" and \"scree.tsv\" in: {options.out} {os.linesep}")
    pca_df = pd.DataFrame(pca_data, index=[*df], columns=labels)
    explained_variance_df = pd.DataFrame(per_var, columns=["Percentage of Explained Variance"], index=labels).T
    pca_df.to_csv(os.path.join(options.out, "PCA.tsv"), header=True, index=True, sep="\t")
    explained_variance_df.to_csv(os.path.join(options.out, "scree.tsv"), header=True, index=True, sep="\t")

    if options.verbose:
        sys.stdout.write(f"Generating plots in: {options.out} {os.linesep}")

    generate_pc1_pc2_plot(options.out, pca_df, per_var)
    if n_components>2:
        generate_pc2_pc3_plot(options.out, pca_df, per_var)
        generate_pc1_pc3_plot(options.out, pca_df, per_var)
        generate_pca_3d(options.out, pca_df, per_var)

    if options.verbose:
        sys.stdout.write(f"Done {os.linesep}")

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt!" + os.linesep)
        sys.exit(0)