from Bio import SeqIO
from random import choice
from pathlib import Path
import functools
import polars as pl
from collections import defaultdict

import shutil
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from tf_bind_transformer.protein_utils import parse_gene_name
from enformer_pytorch import FastaInterval

from pyfaidx import Fasta
import pybedtools

def exists(val):
    return val is not None

def cast_list(val = None):
    if not exists(val):
        return []
    return [val] if not isinstance(val, (tuple, list)) else val

def read_bed(path):
    return pl.read_csv(path, sep = '\t', has_headers = False)

# fetch protein sequences by gene name and uniprot id

class FactorProteinDatasetByUniprotID(Dataset):
    def __init__(self, folder):
        super().__init__()
        fasta_paths = [*Path(folder).glob('*.fasta')]
        assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
        self.paths = fasta_paths
        self.index_by_id = dict()

        for path in fasta_paths:
            gene, uniprotid, *_ = path.stem.split('.')
            self.index_by_id[uniprotid] = path

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, uid):
        index = self.index_by_id

        if uid not in index:
            return None

        entry = index[uid]
        fasta = SeqIO.read(entry, 'fasta')
        return str(fasta.seq)

# fetch

class FactorProteinDataset(Dataset):
    def __init__(self, folder, return_tuple_only = False):
        super().__init__()
        fasta_paths = [*Path(folder).glob('*.fasta')]
        assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
        self.paths = fasta_paths
        self.index_by_gene = defaultdict(list)
        self.return_tuple_only = return_tuple_only # whether to return tuple even if there is only one subunit

        for path in fasta_paths:
            gene, uniprotid, *_ = path.stem.split('.')
            self.index_by_gene[gene].append(path)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, unparsed_gene_name):
        index = self.index_by_gene

        genes = parse_gene_name(unparsed_gene_name)
        seqs = []

        for gene in genes:
            entry = index[gene]

            if len(entry) == 0:
                print(f'no entries for {gene}')
                continue

            path = choice(entry) if isinstance(entry, list) else entry

            fasta = SeqIO.read(path, 'fasta')
            seqs.append(str(fasta.seq))

        seqs = tuple(seqs)

        if len(seqs) == 1 and not self.return_tuple_only:
            return seqs[0]

        return seqs

# remap dataframe functions

def get_chr_names(ids):
    return set(map(lambda t: f'chr{t}', ids))

CHR_IDS = set([*range(1, 23), 'X'])
CHR_NAMES = get_chr_names(CHR_IDS)

def remap_df_add_experiment_target_cell(df, col = 'column_4'):
    df = df.clone()

    exp_id = df.select([pl.col(col).str.extract(r"^([\w\-]+)\.*")])
    exp_id = exp_id.rename({col: 'experiment'}).to_series(0)
    df.insert_at_idx(3, exp_id)

    targets = df.select([pl.col(col).str.extract(r"[\w\-]+\.([\w\-]+)\.[\w\-]+")])
    targets = targets.rename({col: 'target'}).to_series(0)
    df.insert_at_idx(3, targets)

    cell_type = df.select([pl.col(col).str.extract(r"^.*\.([\w\-]+)$")])
    cell_type = cell_type.rename({col: 'cell_type'}).to_series(0)
    df.insert_at_idx(3, cell_type)

    return df

def pl_isin(col, arr):
    equalities = list(map(lambda t: pl.col(col) == t, arr))
    return functools.reduce(lambda a, b: a | b, equalities)

def pl_notin(col, arr):
    equalities = list(map(lambda t: pl.col(col) != t, arr))
    return functools.reduce(lambda a, b: a & b, equalities)

def filter_bed_file_by_(bed_file_1, bed_file_2, output_file):
    # generated by OpenAI Codex

    bed_file_1_bedtool = pybedtools.BedTool(bed_file_1)
    bed_file_2_bedtool = pybedtools.BedTool(bed_file_2)
    bed_file_1_bedtool_intersect_bed_file_2_bedtool = bed_file_1_bedtool.intersect(bed_file_2_bedtool, v = True)
    bed_file_1_bedtool_intersect_bed_file_2_bedtool.saveas(output_file)

def filter_df_by_tfactor_fastas(df, folder, derive_target_col = False):
    if derive_target_col:
        df = remap_df_add_experiment_target_cell(df)

    files = [*Path(folder).glob('**/*.fasta')]
    present_target_names = set([f.stem.split('.')[0] for f in files])
    all_df_targets = df.get_column('target').unique().to_list()

    all_df_targets_with_parsed_name = [(target, parse_gene_name(target)) for target in all_df_targets]
    unknown_targets = [target for target, parsed_target_name in all_df_targets_with_parsed_name for parsed_target_name_sub_el in parsed_target_name if parsed_target_name_sub_el not in present_target_names]

    if len(unknown_targets) > 0:
        df = df.filter(pl_notin('target', unknown_targets))
    return df

def generate_random_ranges_from_fasta(
    fasta_file,
    *,
    output_filename = 'random-ranges.bed',
    context_length,
    filter_bed_files = [],
    num_entries_per_key = 10,
    keys = None,
):
    fasta = Fasta(fasta_file)
    tmp_file = f'/tmp/{output_filename}'

    with open(tmp_file, 'w') as f:
        for chr_name in sorted(CHR_NAMES):
            print(f'generating ranges for {chr_name}')

            if chr_name not in fasta:
                print(f'{chr_name} not found in fasta file')
                continue

            chromosome = fasta[chr_name]
            chromosome_length = len(chromosome)

            start = np.random.randint(0, chromosome_length - context_length, (num_entries_per_key,))
            end = start + context_length
            start_and_end = np.stack((start, end), axis = -1)

            for row in start_and_end.tolist():
                start, end = row
                f.write('\t'.join((chr_name, str(start), str(end))) + '\n')

    for file in filter_bed_files:
        filter_bed_file_by_(tmp_file, file, tmp_file)

    shutil.move(tmp_file, f'./{output_filename}')
    print('success')

# dataset for remap data - all peaks

class RemapAllPeakDataset(Dataset):
    def __init__(
        self,
        *,
        factor_fasta_folder,
        bed_file = None,
        remap_df = None,
        filter_chromosome_ids = None,
        exclude_targets = None,
        include_targets = None,
        exclude_cell_types = None,
        include_cell_types = None,
        remap_df_frac = 1.,
        **kwargs
    ):
        super().__init__()
        assert exists(remap_df) ^ exists(bed_file), 'either remap bed file or remap dataframe must be passed in'

        if not exists(remap_df):
            remap_df = read_bed(bed_file)

        if remap_df_frac < 1:
            remap_df = remap_df.sample(frac = remap_df_frac)

        dataset_chr_ids = CHR_IDS

        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))

        remap_df = remap_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))
        remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder, derive_target_col = True)

        self.factor_ds = FactorProteinDataset(factor_fasta_folder)

        # filter dataset by inclusion and exclusion list of targets
        # (<all available targets> intersect <include targets>) subtract <exclude targets>

        include_targets = cast_list(include_targets)
        exclude_targets = cast_list(exclude_targets)

        if include_targets:
            remap_df = remap_df.filter(pl_isin('target', include_targets))

        if exclude_targets:
            remap_df = remap_df.filter(pl_notin('target', exclude_targets))

        # filter dataset by inclusion and exclusion list of cell types
        # same logic as for targets

        include_cell_types = cast_list(include_cell_types)
        exclude_cell_types = cast_list(exclude_cell_types)

        if include_cell_types:
            remap_df = remap_df.filter(pl_isin('cell_type', include_cell_types))

        if exclude_cell_types:
            remap_df = remap_df.filter(pl_notin('cell_type', exclude_cell_types))

        assert len(remap_df) > 0, 'dataset is empty by filter criteria'

        self.df = remap_df
        self.fasta = FastaInterval(**kwargs)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, ind):
        chr_name, begin, end, _, _, _, experiment_target_cell_type, reading, *_ = self.df.row(ind)

        experiment, target, *cell_type = experiment_target_cell_type.split('.')
        cell_type = '.'.join(cell_type) # handle edge case where cell type contains periods

        seq = self.fasta(chr_name, begin, end)
        aa_seq = self.factor_ds[target]
        context_str = f'{cell_type} | {experiment}'

        value = torch.Tensor([reading])
        label = torch.Tensor([1.])

        return seq, aa_seq, context_str, value, label

class NegativePeakDataset(Dataset):
    def __init__(
        self,
        *,
        factor_fasta_folder,
        negative_bed_file = None,
        remap_bed_file = None,
        remap_df = None,
        negative_df = None,
        filter_chromosome_ids = None,
        exclude_targets = None,
        include_targets = None,
        exclude_cell_types = None,
        include_cell_types = None,
        **kwargs
    ):
        super().__init__()
        assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in'
        assert exists(negative_df) ^ exists(negative_bed_file), 'either negative bed file or negative dataframe must be passed in'

        # instantiate dataframes if not passed in

        if not exists(remap_df):
            remap_df = read_bed(remap_bed_file)

        neg_df = negative_df
        if not exists(negative_df):
            neg_df = read_bed(negative_bed_file)

        # filter remap dataframe

        remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder, derive_target_col = True)

        dataset_chr_ids = CHR_IDS

        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))


        neg_df = neg_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))

        assert len(neg_df) > 0, 'dataset is empty by filter criteria'

        self.neg_df = neg_df
        self.experiments = remap_df['experiment'].unique().to_list()

        self.cell_types = remap_df['cell_type'].unique().to_list()

        self.targets = remap_df['target'].unique().to_list()

        include_targets = cast_list(include_targets)
        exclude_targets = cast_list(exclude_targets)

        if include_targets:
            self.targets = list(set(self.targets).intersection(set(include_targets)))

        if exclude_targets:
            self.targets = list(set(self.targets) - set(exclude_targets))

        assert self.targets, 'targets cannot be empty for negative set'

        include_cell_types = cast_list(include_cell_types)
        exclude_cell_types = cast_list(exclude_cell_types)

        if include_cell_types:
            self.cell_types = list(set(self.cell_types).intersection(set(include_cell_types)))

        if exclude_cell_types:
            self.cell_types = list(set(self.cell_types) - set(exclude_cell_types))

        self.factor_ds = FactorProteinDataset(factor_fasta_folder)
        self.fasta = FastaInterval(**kwargs)

    def __len__(self):
        return len(self.neg_df)

    def __getitem__(self, ind):
        chr_name, begin, end = self.neg_df.row(ind)

        experiment = choice(self.experiments)
        target = choice(self.targets)
        cell_type = choice(self.cell_types)

        seq = self.fasta(chr_name, begin, end)
        aa_seq = self.factor_ds[target]
        context_str = f'{cell_type} | {experiment}'

        value = torch.Tensor([0.])
        label = torch.Tensor([0.])

        return seq, aa_seq, context_str, value, label

# dataloader related functions

def collate_fn(data):
    seq, aa_seq, context_str, values, labels = list(zip(*data))
    return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(values, dim = 0), torch.cat(labels, dim = 0)

def collate_dl_outputs(*dl_outputs):
    outputs = list(zip(*dl_outputs))
    ret = []
    for entry in outputs:
        if isinstance(entry[0], torch.Tensor):
            entry = torch.cat(entry, dim = 0)
        else:
            entry = (sub_el for el in entry for sub_el in el)
        ret.append(entry)
    return tuple(ret)

def cycle(loader):
    while True:
        for data in loader:
            yield data

def get_dataloader(ds, cycle_iter = False, **kwargs):
    dl = DataLoader(ds, collate_fn = collate_fn, **kwargs)
    wrapper = cycle if cycle_iter else iter
    return wrapper(dl)
