#!/usr/bin/python

__author__		= "Sander Granneman"
__copyright__	= "Copyright 2019"
__version__		= "0.6.9"
__credits__		= ["Sander Granneman"]
__maintainer__	= "Sander Granneman"
__email__		= "sgrannem@staffmail.ed.ac.uk"
__status__		= "Production"

##################################################################################
#
#	Coverage
#
#
#	Copyright (c) Sander Granneman 2019
#
#	Permission is hereby granted, free of charge, to any person obtaining a copy
#	of this software and associated documentation files (the "Software"), to deal
#	in the Software without restriction, including without limitation the rights
#	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#	copies of the Software, and to permit persons to whom the Software is
#	furnished to do so, subject to the following conditions:
#
#	The above copyright notice and this permission notice shall be included in
#	all copies or substantial portions of the Software.
#
#	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
#	THE SOFTWARE.
#
##################################################################################

import sys
import time
import numpy as np
from pyCRAC.Classes.Exceptions import *
from pyCRAC.Methods import getfilename,numpy_overlap,reverse_strand
from pyCRAC.Parsers.ParseAlignments import ParseCountersOutput
from pyCRAC.Classes.NGSFormatWriters import NGSFileWriter
from collections import defaultdict

class BinCollector():
	def __init__(self):
		""" File input should be the gtf file generated by PyCounters with gene_name attributes in it!"""
		np.seterr(all='ignore')					# ignore zero division errors
		self.dataarray	  = defaultdict(list)
		self.numberofbins = int()
		self.sequence	  = str()
		self.annotation	  = str()
		self.filename	  = str()
		self.range		  = int()
		self.cumulativedata = None
		self.__feature_a  = ["coding","genomic","5UTR","3UTR","TSS","5end","3end","CDSstart","CDSend"]
		self.__feature_b  = ["intron","CDS","exon","5ss","3ss"]
		self.__types	  = ["intervals","deletions","substitutions"]
		self.__orientations = ["sense","anti_sense"]
		self.orientation  = "sense"
		if self.sequence == "CDSstart" or self.sequence == "CDSend":
			assert self.annotation == "protein_coding", "You cannot used CDSstart and CDSend without setting the -a flag to protein_coding"

	def findBinOverlap(self,datafile,gtf,printbins=[],sequence="genomic",orientation="sense",minfeatlength=100000000,maxfeatlength=0,ignorestrand=False,out_file=None,ranges=0,overlap=1):
		""" returns those lines that overlap with specific bin numbers. Has a bedtools-like function.
		Allows you to find those intervals that overlap with, for example, 5' or 3' ends of genomic features."""
		assert orientation in self.__orientations, "The orientation %s is not supported, please choose from the following:\t%s\n" % (orientation,", ".join(self.__orientations))
		assert sequence in self.__feature_a or sequence in self.__feature_b, "Can not determine what sequence coordinates you want to use.\nOptions are:\t%s\tand\t%s\n.\nPlease check your input\n" % (", ".join(self.__feature_a),", ".join(self.__feature_b))
		assert self.numberofbins, "No bin number was provided\n"
		self.sequence	  = sequence
		self.filename	  = datafile
		self.range		  = ranges
		data = ParseCountersOutput(datafile)
		outfile = NGSFileWriter()
		if out_file:
			outfile = NGSFileWriter(out_file)
		outfile.write("##gff-version 2\n# generated by BinCollector, %s\n# %s\n" % (time.ctime(),' '.join(sys.argv)))

		list_of_tuples = defaultdict(list)

		#### reading the GTF file line by line ####

		while data.readLineByLine(numpy=True,collectmutations=False):							# reads line by line over the GTF data file or stdin and returns the coordinates as a numpy array
			if ignorestrand:
				strand = None
			elif orientation == "sense":
				strand = data.strand
			else:
				strand = reverse_strand(data.strand)
			search_results = []
			if (data.chromosome,strand) not in list_of_tuples:		# list_of_tuples is as it says...the tuples contain gene names and chromosomal positions.
				try:
					list_of_tuples[(data.chromosome,strand)] = gtf.chromosomeGeneCoordIterator(data.chromosome,annotation=self.annotation,numpy=True,strand=strand,ranges=ranges,sequence=sequence)
				except AssertionError:
					sys.stderr.write("Can't get chromosomeGeneCoordinates for chromosome %s and sequence %s" % (data.chromosome,sequence))

			if len(list_of_tuples[(data.chromosome,strand)]) > 0:	# if there are any genes on the chromosome...
				search_results = numpy_overlap(list_of_tuples[(data.chromosome,strand)],data.read_start,data.read_end,overlap=overlap)
			else:
				continue
			if search_results:
				for gene in search_results:
					if sequence in self.__feature_a:
						itercoordinates = [gtf.geneIterCoordinates(gene,coordinates=sequence,ranges=ranges)]
					elif sequence in self.__feature_b:
						itercoordinates = gtf.geneIterCoordinates(gene,coordinates=sequence,ranges=ranges)
					else:
						raise LookupError("Could not generate iteration coordinates, please check your input\n")
					for coordinates in itercoordinates:
						if type(coordinates) is np.ndarray and coordinates.any():
							feature_length = len(coordinates)
							if feature_length >= minfeatlength and feature_length <= maxfeatlength and feature_length >= self.numberofbins:
								if gtf.strand(gene) == "-":
									coordinates = coordinates[::-1]		# reverse the itercoordinates if the feature is on the minus strand
								hits = np.array_split((coordinates < data.read_end) & (coordinates >= data.read_start),self.numberofbins)
								if True in [True in i for i in hits[printbins[0]:printbins[-1]]]:	   # if there are any hits in the selected bins: print the line!
									gene_name = ",".join(search_results)
									gene_id = ",".join([gtf.gene2orf(gene) for gene in search_results])
									outfile.writeGTF(data.chromosome,data.source,data.type,data.read_start,data.read_end,data.score,data.strand,gene_name=gene_name,gene_id=gene_id,comments=data.comments)

	def findOverlap(self,datafile,gtf,sequence="genomic",orientation="sense",minfeatlength=0,maxfeatlength=100000000,ignorestrand=False,out_file=None,ranges=0,overlap=1):
		""" returns those lines that overlap with specific features such as exons, coding sequences, introns, UTRs, etc. Has a bedtools-like function. """
		assert orientation in self.__orientations, "Error! Orientation %s is not recognized. Please choose %s" % (orientation, " or ".join(self.__orientations))
		assert sequence in self.__feature_a or sequence in self.__feature_b, "Can not determine what sequence coordinates you want to use.\nOptions are:\t%s\tand\t%s\n.\nPlease check your input\n" % (", ".join(self.__feature_a),", ".join(self.__feature_b))
		self.sequence = sequence
		self.filename = datafile
		self.range	  = ranges
		data = ParseCountersOutput(datafile)
		outfile = NGSFileWriter()
		if out_file:
			outfile = NGSFileWriter(out_file)
		outfile.write("##gff-version 2\n# generated by BinCollector, %s\n# %s\n" % (time.ctime(),' '.join(sys.argv)))

		list_of_tuples = defaultdict(list)

		#### reading the GTF file line by line ####
		while data.readLineByLine(numpy=True,collectmutations=False):	# reads line by line over the GTF data file or stdin and returns the coordinates as a numpy array
			if ignorestrand:
				strand = None
			elif orientation == "sense":
				strand = data.strand
			else:
				strand = reverse_strand(data.strand)
			search_results = []
			if (data.chromosome,strand) not in list_of_tuples:		    # list_of_tuples is as it says...the tuples contain gene names and chromosomal positions.
				try:
					list_of_tuples[(data.chromosome,strand)] = gtf.chromosomeGeneCoordIterator(data.chromosome,annotation=self.annotation,numpy=True,strand=strand,ranges=ranges,sequence=sequence)
				except AssertionError:
					sys.stderr.write("Can't get chromosomeGeneCoordinates for chromosome %s and sequence %s" % (data.chromosome,sequence))

			if len(list_of_tuples[(data.chromosome,strand)]) > 0:	    # if there are any genes on the chromosome...
				search_results = numpy_overlap(list_of_tuples[(data.chromosome,strand)],data.read_start,data.read_end,overlap=overlap)
			else:
				continue
			if search_results:
				gene_name = ",".join(search_results)
				gene_id = ",".join([gtf.gene2orf(gene) for gene in search_results])
				outfile.writeGTF(data.chromosome,data.source,data.type,data.read_start,data.read_end,data.score,data.strand,gene_name=gene_name,gene_id=gene_id,comments=data.comments)

	def countFeatureDensities(self,datafile,gtf,sequence="genomic",feature="intervals",minfeatlength=0,maxfeatlength=100000000,orientation="sense",ignorestrand=False,out_file=None,ranges=0,overlap=1,unique=False):
		""" counts the gene hits without dividing the gene lengths into bins"""
		assert feature in self.__types, "The feature %s is not supported, please choose from the following:\t%s\n" % (feature,", ".join(self.__types))
		assert orientation in self.__orientations, "The orientation %s is not supported, please choose from the following:\t%s\n" % (orientation,", ".join(self.__orientations))
		assert sequence in self.__feature_a or sequence in self.__feature_b, "Can not determine what sequence coordinates you want to use.\nOptions are:\t%s\tand\t%s\n.\nPlease check your input\n" % (", ".join(self.__feature_a),", ".join(self.__feature_b))
		self.sequence	 = sequence
		self.filename	 = datafile
		self.range		 = ranges
		self.type		 = feature
		self.orientation = orientation
		collectmuts = True
		if feature == "intervals":
			collectmuts = False

		data = ParseCountersOutput(datafile)
		list_of_tuples = defaultdict(list)

		#### reading the GTF file line by line. This program assumes the intervals are sorted by chromosome! ####

		while data.readLineByLine(numpy=True,collectmutations=collectmuts):					# reads line by line over the GTF data file or stdin and returns the coordinates as a numpy array
			if ignorestrand:
				strand = None
			elif orientation == "sense":
				strand = data.strand
			else:
				strand = reverse_strand(data.strand)
			search_results = []
			if (data.chromosome,strand) not in list_of_tuples:		# list_of_tuples is as it says...the tuples contain gene names and chromosomal positions.
				try:
					list_of_tuples[(data.chromosome,strand)] = gtf.chromosomeGeneCoordIterator(data.chromosome,annotation=self.annotation,numpy=True,strand=strand,ranges=ranges,sequence=sequence)
				except AssertionError:
					sys.stderr.write("Can't get chromosomeGeneCoordinates for chromosome %s and sequence %s" % (data.chromosome,sequence))
			if len(list_of_tuples[(data.chromosome,strand)]) > 0:	# if there are any genes on the chromosome...
				search_results = set(numpy_overlap(list_of_tuples[(data.chromosome,strand)],data.read_start,data.read_end,overlap=overlap))
			else:
				continue
			if search_results:
				for gene in search_results:
					if sequence in self.__feature_a:
						itercoordinates = [gtf.geneIterCoordinates(gene,coordinates=sequence,ranges=ranges)]
					elif sequence in self.__feature_b:
						itercoordinates = gtf.geneIterCoordinates(gene,coordinates=sequence,ranges=ranges)
					else:
						raise LookupError("Could not generate iteration coordinates, please check your input\n")
					for nr,coordinates in enumerate(itercoordinates):
						if type(coordinates) is np.ndarray and coordinates.any():
							feature_length = len(coordinates)
							if feature_length >= minfeatlength and feature_length <= maxfeatlength and feature_length >= self.numberofbins:
								if (gene,nr) not in self.dataarray:
									self.dataarray[(gene,nr)] = np.zeros(len(coordinates))
								if gtf.strand(gene) == "-":
									coordinates = coordinates[::-1]			# reverse the itercoordinates if the feature is on the minus strand
								if unique:
									data.number_of_reads = 1
								if feature == "intervals":
									self.dataarray[(gene,nr)][(coordinates < data.read_end) & (coordinates >= data.read_start)] += data.number_of_reads
								elif feature == "substitutions":
									if data.substitutions.any():
										self.dataarray[(gene,nr)][np.in1d(coordinates,data.substitutions)] += data.number_of_reads
								elif feature == "deletions":
									if data.deletions.any():
										self.dataarray[(gene,nr)][np.in1d(coordinates,data.deletions)] += data.number_of_reads

	def binFeatureDensities(self,numberofbins=20):
		""" divides the features in an equal number of bins """
		assert self.dataarray, "No data to be binned!\n"
		if not self.numberofbins:
			self.numberofbins = numberofbins
		for (gene,nr) in sorted(self.dataarray):
			if len(self.dataarray[(gene,nr)]) >= self.numberofbins:
				self.dataarray[(gene,nr)] = np.array([i.sum() for i in np.array_split(self.dataarray[(gene,nr)],self.numberofbins)])
			else:
				del(self.dataarray[(gene,nr)])

	def __createFileHandles(self,file_type,file_extension="txt"):
		""" does what it says """
		filename   = str()
		annotation = "all"
		if self.annotation:
			annotation = self.annotation
		filename = "%s_%s_%s_%s_%s" % (getfilename(self.filename),file_type,annotation,self.sequence,file_extension)
		return filename

	def printBinCountingResults(self,out_file=None,cumulative=False,normalize=False,permillion=False):
		""" prints the binCounter results in a pileup or table format. This method can be called repeatedly to append data to an existing file if the same output file name is used """
		assert self.dataarray, "\n\nNo data to be printed. Did you run the binCounter method?\nDid you set the maximum sequence length too high?\nPlease double check your settings\n"
		normvalue = 1.0
		if permillion:
		    assert self.data.mapped_reads, "Could not calculate hits permillion as I could not tell how many reads were mapped to the genome\n"
		if cumulative:
			assert self.numberofbins, "The data has not been divided into bins yet\n"
			file_type = "cumulative_densities"
			if out_file:
				filename = out_file
			else:
				filename = "%s_%s.pileup" % (getfilename(self.filename),file_type)
			outputfile = open(filename,"w")
			cumulativedata = np.zeros(self.numberofbins)
			sumofall = np.sum(list(self.dataarray.values()))
			for (gene,nr) in self.dataarray:
				if type(self.dataarray[(gene,nr)]) is np.ndarray and self.dataarray[(gene,nr)].any():
					if normalize:
						data = self.dataarray[(gene,nr)]/sumofall
						data[np.isnan(data)] = 0.0
					elif permillion:
					    data = self.dataarray[(gene,nr)]/(float(self.data.mapped_reads)/1000000.0)
					else:
						data = self.dataarray[(gene,nr)]
					cumulativedata += data
			outputfile.write("%s\n" % "\t".join(["%.6f" % i for i in cumulativedata]))
			outputfile.close()
		else:
			filename = str()
			if out_file:
				filename = out_file
			else:
				filename = self.__createFileHandles("%s_%s" % (self.orientation,self.type),file_extension="txt")
			outputfile = open(filename,"w")
			for (gene,nr) in self.dataarray:
				if type(self.dataarray[(gene,nr)]) is np.ndarray and self.dataarray[(gene,nr)].any():
					if normalize:
						data = self.dataarray[(gene,nr)]/sum(self.dataarray[(gene,nr)])
					elif permillion:
					    data = self.dataarray[(gene,nr)]/float(self.data.mapped_reads)
					else:
						data = self.dataarray[(gene,nr)]
					data[np.isnan(data)] = 0.0
					if data.any():
						string = "\t".join(["%.8f" % i for i in data])
						outputfile.write("%s\t%s\n" % (gene,string))
		outputfile.close()
