#!/usr/bin/env python

import sys
import os
import pandas as pd
import numpy as np
from argparse import ArgumentParser, RawTextHelpFormatter

from zpca.zpca import counts2tpm, \
                      counts2tpm_many, \
                      determine_number_of_components, \
                      perform_pca, \
                      generate_scree_plot, \
                      generate_pc1_pc2_plot, \
                      generate_pc1_pc3_plot, \
                      generate_pc2_pc3_plot, \
                      generate_pca_3d

def main():
    """ Main function """

    __doc__ = "Perform PCA based on a TPM expression matrix (rows are genes/transcripts, columns are samples)."

    parser = ArgumentParser(
        description=__doc__,
        formatter_class=RawTextHelpFormatter
    )

    parser.add_argument(
        "--tpm",
        dest="tpm",
        help="TPM table (tsv).",
        required=True,
        metavar="FILE"
    )
    
    parser.add_argument(
        "--tpm-filter",
        dest="tpm_filter",
        default=0,
        help="Filter genes/transcripts with mean expression less than the provided filter. Default: 0",
        required=False,
    )
    
    parser.add_argument(
        "--tpm-pseudocount",
        dest="tpm_pseudocount",
        default=1,
        help="Pseudocount to add in the tpm table. Default: 1",
        required=False,
    )

    parser.add_argument(
        "--out",
        dest="out",
        help="Output directory",
        required=True,
        metavar="FILE"
    )

    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        dest="verbose",
        default=False,
        required=False,
        help="Verbose"
    )

    try:
        options = parser.parse_args()
    except(Exception):
        parser.print_help()

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    if options.verbose:
        sys.stdout.write(f"Creating output directory: {options.out} {os.linesep}")

    if not os.path.exists(options.out):
        os.makedirs(options.out)

    if options.verbose:
        sys.stdout.write(f"Reading: {options.tpm} {os.linesep}")
    df = pd.read_csv(options.tpm, header=0, index_col=0, sep="\t")
    
    if options.verbose:
        sys.stdout.write(f"Filtering transcripts/genes with mean expression less or equal to {options.tpm_filter} {os.linesep}")
    df = df[df.mean(axis=1)>float(options.tpm_filter)]
    
    if options.verbose:
        sys.stdout.write(f"Adding TPM pseudocount of {options.tpm_pseudocount} {os.linesep}")
    df = df + float(options.tpm_pseudocount)
    
    if options.verbose:
        sys.stdout.write(f"Log2 transform the data {os.linesep}")
    df = np.log2(df)

    n_components = determine_number_of_components(df)
    
    if n_components == 1:
        w = open(os.path.join(options.out, "PCA.tsv"), 'w')
        w.write(f"")
        w.close()
        w = open(os.path.join(options.out, "scree.tsv"), 'w')
        w.write(f"")
        w.close()
        sys.stderr.write(f"Too few samples for PCA {os.linesep}")
        sys.stderr.write(f"Exiting {os.linesep}")
        sys.exit(0)
    
    scaled_data, pca_data, per_var, labels, = perform_pca(df, n_components)
    
    if options.verbose:
        sys.stdout.write(f"Generating scree plot in: {options.out} {os.linesep}")
    generate_scree_plot(options.out, per_var, labels)
    
    if options.verbose:
        sys.stdout.write(f"Writing PCA data \"PCA.tsv\" and \"scree.tsv\" in: {options.out} {os.linesep}")
    pca_df = pd.DataFrame(pca_data, index=[*df], columns=labels)
    explained_variance_df = pd.DataFrame(per_var, columns=["Percentage of Explained Variance"], index=labels).T
    pca_df.to_csv(os.path.join(options.out, "PCA.tsv"), header=True, index=True, sep="\t")
    explained_variance_df.to_csv(os.path.join(options.out, "scree.tsv"), header=True, index=True, sep="\t")

    if options.verbose:
        sys.stdout.write(f"Generating plots in: {options.out} {os.linesep}")

    generate_pc1_pc2_plot(options.out, pca_df, per_var)
    if n_components>2:
        generate_pc2_pc3_plot(options.out, pca_df, per_var)
        generate_pc1_pc3_plot(options.out, pca_df, per_var)
        generate_pca_3d(options.out, pca_df, per_var)

    if options.verbose:
        sys.stdout.write(f"Done {os.linesep}")

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt!" + os.linesep)
        sys.exit(0)