#!/usr/bin/env/ python3

#Python Modules
import numpy as np
from subprocess import run, PIPE
import os
import sys
from collections import defaultdict
import csv

#Modules from this package
from stepRNA.general import mini_maxi, replace_ext, check_dir
from stepRNA.processing import make_unique, rm_ref_matches
from stepRNA.commands import left_overhang, right_overhang
from stepRNA.output import make_csv, make_type_csv, write_to_bam, print_hist, refs_counts, Logger

#Scripts to be used...
import stepRNA.remove_exact as remove_exact
import stepRNA.make_unique as make_unique
import stepRNA.stepRNA_run_bowtie as run_bowtie
import stepRNA.index_bowtie as index_bowtie
import stepRNA.stepRNA_cigar_process as cigar_process

#Modules that need to be installed
try:
    from Bio import SeqIO
except ImportError:
    print('Error: Biopython not found, can be installed with\npip3 install biopython', file=sys.stderr)
    sys.exit(1)

try:
    import pysam
except ImportError:
    print('Error: Pysam not found, can be installed with\npip3 install pysam', file=sys.stderr)
    sys.exit(1)

#Set-up arguments...
from argparse import ArgumentParser, SUPPRESS

parser = ArgumentParser(description='Align an reference RNA file to read sequences.\n Output will be a set of CSV files containing information about the length of the reads, number of reads aligned to a reference sequence and the length of overhangs of the alignment. \n Reference RNA file will be automatically indexed', add_help=False)

optional = parser.add_argument_group('Optional Arguments')
required = parser.add_argument_group('Required Arguments')
flags = parser.add_argument_group('Flags')

#Add back help...
optional.add_argument(
    '-h',
    '--help',
    action='help',
    default=SUPPRESS,
    help='show this help message and exit'
)

required.add_argument('-r', '--reference', help='Path to the reference seqeunces', required=True)
required.add_argument('-q', '--reads', help='Path to the read sequences', required=True)
optional.add_argument('-n', '--name',  help='Prefix for the output files')
optional.add_argument('-d', '--directory', default = os.curdir, help='Directory to store the output files')
optional.add_argument('-m', '--min_score', default=-1, type=int, help='Minimum score to accept, default is the shortest read length')
flags.add_argument('-e', '--remove_exact', action='store_true', help='Remove exact read matches to the reference sequence')
flags.add_argument('-u', '--make_unique', action='store_true', help='Make FASTA headers unique in reference and reads i.e. >Read_1 >Read_2')
flags.add_argument('-j', '--write_json', action='store_true', help='Write count dictionaries to a JSON file')
flags.add_argument('-V', '--version', action='version', version='stepRNA v1.0.0', help='Print version number then exit.')

#parser._action_groups.append(optional)
#parser._action_groups.append(flags)

args = parser.parse_args()

# Parse arguments...
ref = args.reference
reads = args.reads
min_score = args.min_score
outdir = check_dir(args.directory)
if args.name is None:
    filename = os.path.splitext(reads)[0]
else:
    filename = args.name

#Join together output directory and filename to make a prefix...
prefix = os.path.join(outdir, os.path.basename(filename))
logger = Logger(prefix + '.log')
logger.write('Output to: {}'.format(outdir))

#Remove exact matches to reference if set...
if args.remove_exact:
    logger.write('Removing exact matches to read FASTA')
    reads = remove_exact.main(ref, reads)
    logger.write('Exact matches removed')

#Make unique headers if set...
if args.make_unique:
    logger.write('Making read headers unique...')
    reads = make_unique.main(reads, 'fasta', name = 'Read')
    logger.write('Unique read headers complete')
    logger.write('Making reference headers unique...')
    ref = make_unique.main(ref, 'fasta', name = 'Ref')
    logger.write('Unquire reference headers complete')
    logger.write('Unique FASTA headers made')


#Build a reference (suppress verbosity)...
logger.write('Building index...')
ref_base = index_bowtie.main(ref)
logger.write('Bowtie index built')

# Run bowtie alignment...
logger.write('Aligning...')
sorted_bam = run_bowtie.main(ref_base, reads, prefix, min_score, logger)
logger.write('Alignment completed')

#Cigar process...
fpath = os.path.join(outdir, prefix + '_AlignmentFiles')
if os.path.isdir(fpath):
    logger.write('Removing contents in {}'.format(fpath))
    for f in os.listdir(fpath):
        try:
            os.remove(os.path.join(fpath, f))
        except:
            logger.log('Could not remove {}'.format(f))

logger.write('Processing Cigar strings...')
right_dic, left_dic, type_dic, read_len_dic, refs_read_dic = cigar_process.main(sorted_bam, prefix, args.write_json)
logger.write('Cigar strings processed')

# Count unique references
right_unique_dic = defaultdict(lambda:0) 
left_unique_dic = defaultdict(lambda:0)
fpath = os.path.join(outdir, prefix + '_AlignmentFiles')
for f in os.listdir(fpath):
    if 'passed' not in f:
        key = int(f.split('_')[-2])
        print(key)
        if '5prime' in f.split('_')[-3]:
            left_unique_dic[key] = refs_counts(os.path.join(fpath, f), unique = True)
        if '3prime' in f.split('_')[-3]:
            right_unique_dic[key] = refs_counts(os.path.join(fpath, f), unique = True)



#Put overhangs infomation into a csv and print to terminal...
logger.write('\n## Overhang counts ##')
make_csv([right_dic, left_dic], prefix + '_overhang.csv', ['Overhang','5prime','3prime'])
logger.write('\n## Unique overhang counts ##')
make_csv([right_unique_dic, left_unique_dic], prefix + '_unique_overhang.csv', ['Overhang','3prime','5prime'])
logger.write('\n## Overhang types ##')
make_type_csv(type_dic, prefix + '_overhang_type.csv', ['Classification', 'count'])
logger.write('\n## Read lengths ##')
make_type_csv(read_len_dic, prefix + '_passenger_length.csv', ['sRNA_read', 'passenger_count'], sort=True)
make_type_csv(refs_read_dic, prefix + '_passenger_number.csv', ['Passenger_length', 'number'], show=False)
print()

def make_hist(csv_in):
    with open(csv_in) as summary:
        left_dens = []
        right_dens = []
        left_tot = 0
        right_tot = 0
        keys = []
        csv_reader = csv.reader(summary, delimiter=',')
        head = next(csv_reader)
        for line in csv_reader:
            keys.append(line[0])
            left_dens.append(int(line[1]))
            left_tot += int(line[1])
            right_dens.append(int(line[2]))
            right_tot += int(line[2])


    for key in range(len(keys)):
        left_dens[key] = 100 * left_dens[key] / left_tot
        right_dens[key] = 100 * right_dens[key] / right_tot
        
    #Print histogram of overhangs to terminal...
    logger.write('LHS Overhang Histogram')
    print_hist(left_dens, keys)

    logger.write('RHS Overhang Histogram')
    print_hist(right_dens, keys)

logger.write('\nAll aligned reads')
make_hist(prefix + '_overhang.csv')
logger.write('\nUnique aligned reads')
make_hist(prefix + '_unique_overhang.csv')
