#!/usr/bin/env python
#-*- coding: utf-8 -*-
import warnings
warnings.filterwarnings("ignore")
import sys
import argparse
import datetime
import getpass
import os
from dna_features_viewer import GraphicFeature, GraphicRecord
import pandas as pd
import uuid
import subprocess
import matplotlib.pyplot as plt
npg_colors = ["#E64B35","#4DBBD5","#00A087","#3C5488","#F39B7F","#464d4f"]
my_colors = {}
my_colors['sgRNA'] = npg_colors[0]
my_colors['PBS'] = npg_colors[1]
my_colors['RTT'] = npg_colors[2]
my_colors['ngRNA'] = npg_colors[3]
my_colors['variant'] = "#e6fc3f"
"""

Output
--------

The output folder will contain:
1. all pegRNA + ngRNA combination for the input vcf file
2. top1 pegRNA + ngRNA combination for each variant
3. visualization of the top1s [TODO]
4. a summary file of each variant

"""

def my_args():
	mainParser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,description="easy_prime for pegRNA design")
	username = getpass.getuser()
	
	mainParser.add_argument('-f','--rawX',  help="input rawX format, last column with predicted efficiency",required=True)
	mainParser.add_argument('-s','--genome_fasta',  help="input genome sequences",required=True)
	mainParser.add_argument('-t','--figure_type',  help="png or pdf",default="png")
	mainParser.add_argument('--sample_id',  help="only plot this sample id",default=None)
	mainParser.add_argument('--output_file_name',  help="use only when plot one pegRNA, output_file_name.figure_type, will overwrite -t -o options.",default=None)
	mainParser.add_argument('-o','--output',  help="output dir",default="easy_prime_vis_%s_%s"%(username,str(datetime.date.today())))
	
	##------- add parameters above ---------------------
	args = mainParser.parse_args()	
	return args

def write_file(file_name,message):
	out = open(file_name,"wt")
	out.write(message)
	out.close()
def get_fasta_single(chr,start,end,genome_fasta=None):
	out_bed = str(uuid.uuid4()).split("-")[-1]
	out_fa = str(uuid.uuid4()).split("-")[-1]
	write_file(out_bed,"%s\t%s\t%s"%(chr,start,end))
	command = "bedtools getfasta -fi %s -bed %s -fo %s -tab"%(genome_fasta,out_bed,out_fa)
	subprocess.call(command,shell=True)
	lines = open(out_fa).readlines()[0]
	seq = lines.split()[-1]
	subprocess.call("rm %s;rm %s"%(out_bed,out_fa),shell=True)
	return seq
def get_strand(x):
	if x == "+":
		return 1
	else:
		return -1
def find_mutation_pos(pos,ref,alt):
	"""solve the problem when user specification of ref, alt contain redundancy
	ATTTT-> ATTT, should be T -> ""
	G - > GC will  be "" C
	"""
	count=0
	for i in range(min(len(ref),len(alt))):
		x=ref[i]
		y=alt[i]
		if x != y:
			return pos,ref[i:],alt[i:]
		else:
			pos+=1
			count+=1
	return pos,ref[count:],alt[count:]

def plot_main(df,output=None,genome_fasta=None,figure_type="png",output_file_name=None,**kwargs):
	"""Given one instance of easy-prime prediction (rawX format), generate DNA visualization
	
	Input
	--------
	the data frame contains 4 rows: RTT, PBS, sgRNA, ngRNA
	
	"""
	pegRNA_id = df.index.tolist()[0]
	variant_id = pegRNA_id.split("_")[0]
	chr = df['CHROM'][0]
	start = df['start'].min()
	start -= start%10
	start -= 1
	start = max(start,0)
	end = df['end'].max()
	end -= end%10
	end += 10

	variant_pos = df.POS.min()
	ref = df.REF[0]
	alt = df.ALT[0]
	correct_pos,new_ref,new_alt = find_mutation_pos(variant_pos,ref,alt)
	predicted_efficiency = df.predicted_efficiency[0]*100
	pos = variant_pos-start
	pos_rel_correct = correct_pos-start
	sequence = get_fasta_single(chr,start,end,genome_fasta).upper()
	
	feature_list = []
	for s,r in df.iterrows():
		r_start = r.start-start
		r_end = r_start+(r.end-r.start)
		r_strand = get_strand(r.strand)
		gf = GraphicFeature(start=r_start, end=r_end, strand=r_strand, 
			color=my_colors[r.type],label=r.type)
		feature_list.append(gf)
	record = GraphicRecord(sequence=sequence, features=feature_list)

	ax, _ = record.plot(figure_width=int(len(sequence)/5))
	record.plot_sequence(ax)
	if len(new_ref) > 0:
		for xxx in range(len(new_ref)):
			ax.fill_between((pos_rel_correct-1.5+xxx, pos_rel_correct-0.5+xxx), +1000, -1000, alpha=0.5,color=my_colors['variant'])
	else:
		ax.fill_between((pos-1.5, pos-0.5), +1000, -1000, alpha=0.5,color=my_colors['variant'])
	locs, labels = plt.xticks()
	new_labels = []
	flag = True
	for i in locs:
		if flag:
			new_labels.append("%s %s"%(chr,int(start+i+1)))
			flag=False
		else:
			new_labels.append(int(start+i+1))
	plt.xticks(locs,new_labels)
	dPAM=""
	if "dPAM" in pegRNA_id:
		dPAM = "PAM disruption"
	if "PE3b" in pegRNA_id:
		dPAM += ", PE3b"
	if dPAM!="":
		dPAM += "\n"
	myTitle = "ID: %s, CHR: %s, POS: %s, REF: %s, ALT: %s \n %s Predicted efficiency: %.1f"%(variant_id,chr,variant_pos,ref,alt,dPAM,predicted_efficiency)+"%"
	

	plt.title(myTitle)
	if output_file_name!= None:
		ax.figure.savefig(output_file_name, bbox_inches='tight')
	else:
		ax.figure.savefig(f'{output}/{pegRNA_id}.{figure_type}', bbox_inches='tight')


def main():

	args = my_args()
	if not os.path.isfile(args.genome_fasta):
		print (f"genome fasta NOT FOUND: {args.genome_fasta}")
	df = pd.read_csv(args.rawX,index_col=0)
	subprocess.call(f"mkdir -p {args.output}",shell=True)
	for i in df.index.unique().tolist():
		if args.sample_id != None:
			if i != args.sample_id:
				continue
		print (f"Processing {i}...")
		plot_main(df.loc[i],**vars(args))


if __name__ == "__main__":
	main()





















