#!/usr/bin/env python
#-*- coding: utf-8 -*-
import warnings
warnings.filterwarnings("ignore")
import sys
import argparse
import datetime
import getpass
import os
from dna_features_viewer import GraphicFeature, GraphicRecord
import pandas as pd
import uuid
import subprocess
import matplotlib.pyplot as plt
npg_colors = ["#E64B35","#4DBBD5","#00A087","#3C5488","#F39B7F","#464d4f"]
my_colors = {}
my_colors['sgRNA'] = npg_colors[0]
my_colors['PBS'] = npg_colors[1]
my_colors['RTT'] = npg_colors[2]
my_colors['ngRNA'] = npg_colors[3]
my_colors['variant'] = "#e6fc3f"
"""

Output
--------

The output folder will contain:
1. all pegRNA + ngRNA combination for the input vcf file
2. top1 pegRNA + ngRNA combination for each variant
3. visualization of the top1s [TODO]
4. a summary file of each variant

"""

def my_args():
	mainParser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,description="easy_prime for pegRNA design")
	username = getpass.getuser()
	
	mainParser.add_argument('-f','--rawX',  help="input rawX format, last column with predicted efficiency",required=True)
	mainParser.add_argument('-s','--genome_fasta',  help="A YAML file specifying parameters",required=True)
	mainParser.add_argument('-o','--output',  help="output dir",default="easy_prime_vis_%s_%s"%(username,str(datetime.date.today())))
	
	##------- add parameters above ---------------------
	args = mainParser.parse_args()	
	return args

def write_file(file_name,message):
	out = open(file_name,"wt")
	out.write(message)
	out.close()
def get_fasta_single(chr,start,end,genome_fasta=None):
	out_bed = str(uuid.uuid4()).split("-")[-1]
	out_fa = str(uuid.uuid4()).split("-")[-1]
	write_file(out_bed,"%s\t%s\t%s"%(chr,start,end))
	command = "bedtools getfasta -fi %s -bed %s -fo %s -tab"%(genome_fasta,out_bed,out_fa)
	subprocess.call(command,shell=True)
	lines = open(out_fa).readlines()[0]
	seq = lines.split()[-1]
	subprocess.call("rm %s;rm %s"%(out_bed,out_fa),shell=True)
	return seq
def get_strand(x):
	if x == "+":
		return 1
	else:
		return -1

def plot_main(df,output=None,genome_fasta=None,**kwargs):
	"""Given one instance of easy-prime prediction (rawX format), generate DNA visualization
	
	Input
	--------
	the data frame contains 4 rows: RTT, PBS, sgRNA, ngRNA
	
	"""
	pegRNA_id = df.index.tolist()[0]
	variant_id = pegRNA_id.split("_")[0]
	chr = df['CHROM'][0]
	start = df['start'].min()
	start -= start%10
	start -= 1
	end = df['end'].max()
	end -= end%10
	end += 10

	variant_pos = df.POS.min()
	ref = df.REF[0]
	alt = df.ALT[0]
	predicted_efficiency = df.predicted_efficiency[0]*100
	pos = variant_pos-start
	sequence = get_fasta_single(chr,start,end,genome_fasta).upper()
	
	feature_list = []
	for s,r in df.iterrows():
		r_start = r.start-start
		r_end = r_start+(r.end-r.start)
		r_strand = get_strand(r.strand)
		gf = GraphicFeature(start=r_start, end=r_end, strand=r_strand, 
			color=my_colors[r.type],label=r.type)
		feature_list.append(gf)
	record = GraphicRecord(sequence=sequence, features=feature_list)

	ax, _ = record.plot(figure_width=int(len(sequence)/5))
	record.plot_sequence(ax)
	ax.fill_between((pos-1.5, pos-0.5), +1000, -1000, alpha=0.5,color=my_colors['variant'])
	locs, labels = plt.xticks()
	new_labels = []
	flag = True
	for i in locs:
		if flag:
			new_labels.append("%s %s"%(chr,int(start+i+1)))
			flag=False
		else:
			new_labels.append(int(start+i+1))
	plt.xticks(locs,new_labels)
	plt.title("ID: %s, CHR: %s, POS: %s, REF: %s, ALT: %s \n Predicted efficiency: %.1f"%(variant_id,chr,variant_pos,ref,alt,predicted_efficiency)+"%")
	ax.figure.savefig(f'{output}/{pegRNA_id}.pdf', bbox_inches='tight')


def main():

	args = my_args()
	if not os.path.isfile(args.genome_fasta):
		print (f"genome fasta NOT FOUND: {args.genome_fasta}")
	df = pd.read_csv(args.rawX,index_col=0)
	subprocess.call(f"mkdir -p {args.output}",shell=True)
	for i in df.index.unique().tolist():
		print (f"Processing {i}...")
		plot_main(df.loc[i],**vars(args))


if __name__ == "__main__":
	main()





















