"""
Plots the distribution of reads in a BAM file around a given set of features
in a BED file. The user can specify which end (5' or 3') of the reads and the
features will be used as reference for the comparison. For example: We assume
that the user selects the 5' end of the reads and the 5' end of the features
as reference. Then a read that maps at position 10 of chr1 will be at a
relative position of -5 nt compared to a feature aligning at position 15 of
chr1. The same concept is applied for all reads against all features and a
distribution of relative positions is constructed.
"""


import pysam
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages


class Feature:
    def __init__(self, chrom, start, end, strand):
        self.chrom = chrom
        self.start = start
        self.end = end
        self.strand = strand

    def __str__(self):
        return "{}:{}-{}-{}".format(
                self.chrom, self.start, self.end, self.strand)


def find_pos(start, end, strand, pos):
    if pos != '5p' and pos != '3p':
        raise ValueError('Incorrectly specified position')
    if strand != '+' and strand != '-':
        raise ValueError('Incorrectly specified strand')
    final_pos = start
    if strand == '-' and pos == '5p' or strand == '+' and pos == '3p':
        final_pos = end
    return final_pos


def relative_pos(feat, read, fpos, rpos):
    """
    rel_pos calculates the relative position between the two reference
    positions of a bed and a reads entry.
    """
    feat_pos = find_pos(feat.start, feat.end, feat.strand, fpos)
    read_strand = '+' if read.is_forward else '-'
    read_pos = find_pos(read.reference_start, read.reference_end - 1, read_strand, rpos)
    rel_pos = read_pos - feat_pos
    if feat.strand == '-':
        rel_pos *= -1
    return rel_pos


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("-m","--bam", required=True,
                    help = "BAM file with reads. Must be indexed.")
    parser.add_argument("-b","--bed", required=True,
                    help = "BED file with features.")
    parser.add_argument("-u","--up", type = int, default = 100,
                    help = "Number of nts to plot upstream of pos. (default: %(default)s)")
    parser.add_argument("-d","--down", type = int, default = 100,
                    help = "Number of nts to plot downstream of pos. (default: %(default)s)")
    parser.add_argument("-f","--fpos", default='5p',
                    help = "Reference point for features; one of 5p or 3p (default: %(default)s)")
    parser.add_argument("-r","--rpos", default='5p',
                    help = "Reference point for reads; one of 5p or 3p (default: %(default)s)")
    parser.add_argument("-o", "--pdf", required=True,
                    help = "Output pdf file with plot")
    args = parser.parse_args()

    # Initialize a histogram
    positions = list(range(-1 * args.up, args.down + 1))
    hist = {i : 0 for i in positions}

    # Open the bam file
    bamfile = pysam.AlignmentFile(args.bam, "rb")

    # Loop on the BED file and query the BAM file to get overlapping reads for
    # each line. Calculate relative positions and add in the histogram.
    feature_count = 0
    with open(args.bed) as bed:
        for line in bed:
            feature_count += 1
            cols = line.strip().split('\t')
            feat = Feature(cols[0], int(cols[1]), int(cols[2]), cols[5])

            if bamfile.header.get_tid(feat.chrom) == -1:
                continue

            reads = bamfile.fetch(feat.chrom, feat.start, feat.end)
            for read in reads:
                rel_pos = relative_pos(feat, read, args.fpos, args.rpos)
                if rel_pos >= -1 * args.up and rel_pos <= args.down:
                    hist[rel_pos] += 1

    # Print a table with the histogram to stdout.
    print("\t".join(["pos", "count"]))
    for i in positions:
        print("\t".join([str(i), str(hist[i])]))

    # Plot the histogram in a pdf.
    values = [hist[i] for i in positions]
    with PdfPages(args.pdf) as pages:
        fig, ax = plt.subplots(layout='constrained')
        plt.bar(positions, values)
        plt.xlabel('Relative position to feature ' + args.fpos + ' end')
        plt.ylabel('Read-feature pairs count')
        pages.savefig(fig)
        plt.close()


if __name__ == "__main__":
    main()
