import pandas as pd
import json
from tqdm import tqdm
from os import path
import random
from nltk.tokenize import word_tokenize
import numpy as np
import csv
import swifter
import string


def parse_behaviors(source, target, user2int_path, negative_sampling_ratio):
    """
    Parse behaviors file in training set.

    Args:
        source (str): Source behaviors file
        target (str): Target behaviors file
        user2int_path (str): Path for saving user2int file

    Returns:
        Number of users
    """
    print(f"Parse {source}")

    behaviors = pd.read_table(
        source,
        header=None,
        names=["impression_id", "user", "time", "clicked_news", "impressions"],
    )
    behaviors.clicked_news.fillna(" ", inplace=True)
    behaviors.impressions = behaviors.impressions.str.split()

    user2int = {}
    for row in behaviors.itertuples(index=False):
        if row.user not in user2int:
            user2int[row.user] = len(user2int) + 1

    pd.DataFrame(user2int.items(), columns=["user", "int"]).to_csv(
        user2int_path, sep="\t", index=False
    )
    num_users = 1 + len(user2int)

    for row in behaviors.itertuples():
        behaviors.at[row.Index, "user"] = user2int[row.user]

    for row in tqdm(behaviors.itertuples(), desc="Balancing data"):
        positive = iter([x for x in row.impressions if x.endswith("1")])
        negative = [x for x in row.impressions if x.endswith("0")]
        random.shuffle(negative)
        negative = iter(negative)
        pairs = []
        try:
            while True:
                pair = [next(positive)]
                for _ in range(negative_sampling_ratio):
                    pair.append(next(negative))
                pairs.append(pair)
        except StopIteration:
            pass
        behaviors.at[row.Index, "impressions"] = pairs

    behaviors = (
        behaviors.explode("impressions")
        .dropna(subset=["impressions"])
        .reset_index(drop=True)
    )
    behaviors[["candidate_news", "clicked"]] = pd.DataFrame(
        behaviors.impressions.map(
            lambda x: (
                " ".join([e.split("-")[0] for e in x]),
                " ".join([e.split("-")[1] for e in x]),
            )
        ).tolist()
    )
    behaviors.to_csv(
        target,
        sep="\t",
        index=False,
        columns=["user", "clicked_news", "candidate_news", "clicked"],
    )

    return num_users


def parse_news(
    source,
    target,
    category2int_path,
    word2int_path,
    entity2int_path,
    mode,
    num_words_title,
    num_words_abstract,
    entity_confidence_threshold,
    word_freq_threshold,
    entity_freq_threshold,
):
    """
    Parse news for training set and test set

    Args:
        source (str): Source news file
        target  (str): Target news file
        category2int_path (str): Path to category2int file. If mode == 'train': Path to save. If mode == 'test': Path to load from.
        word2int_path (str): Path to word2int file. If mode == 'train': Path to save. If mode == 'test': Path to load from.
        entity2int_path (str): Path to entity2int file. If mode == 'train': Path to save. If mode == 'test': Path to load from.
        mode (str): Either 'train' or 'test'
        num_words_title (long): number of words in title
        num_words_abstract (long): number of words in abstract
        entity_confidence_threshold ():
        word_freq_threshold (float): Threshold for word frequency
        entity_freq_threshold (float): Threshold for entity frequency

    """
    print(f"Parse {source}")
    news = pd.read_table(
        source,
        header=None,
        usecols=[0, 1, 2, 3, 4, 6, 7],
        quoting=csv.QUOTE_NONE,
        names=[
            "id",
            "category",
            "subcategory",
            "title",
            "abstract",
            "title_entities",
            "abstract_entities",
        ],
    )  # TODO try to avoid csv.QUOTE_NONE
    news.title_entities.fillna("[]", inplace=True)
    news.abstract_entities.fillna("[]", inplace=True)
    news.fillna(" ", inplace=True)

    def parse_row(row):
        new_row = [
            row.id,
            category2int[row.category] if row.category in category2int else 0,
            category2int[row.subcategory] if row.subcategory in category2int else 0,
            [0] * num_words_title,
            [0] * num_words_abstract,
            [0] * num_words_title,
            [0] * num_words_abstract,
        ]

        # Calculate local entity map (map lower single word to entity)
        local_entity_map = {}
        for e in json.loads(row.title_entities):
            if (
                e["Confidence"] > entity_confidence_threshold
                and e["WikidataId"] in entity2int
            ):
                for x in " ".join(e["SurfaceForms"]).lower().split():
                    local_entity_map[x] = entity2int[e["WikidataId"]]
        for e in json.loads(row.abstract_entities):
            if (
                e["Confidence"] > entity_confidence_threshold
                and e["WikidataId"] in entity2int
            ):
                for x in " ".join(e["SurfaceForms"]).lower().split():
                    local_entity_map[x] = entity2int[e["WikidataId"]]

        try:
            for i, w in enumerate(word_tokenize(row.title.lower())):
                if w in word2int:
                    new_row[3][i] = word2int[w]
                    if w in local_entity_map:
                        new_row[5][i] = local_entity_map[w]
        except IndexError:
            pass

        try:
            for i, w in enumerate(word_tokenize(row.abstract.lower())):
                if w in word2int:
                    new_row[4][i] = word2int[w]
                    if w in local_entity_map:
                        new_row[6][i] = local_entity_map[w]
        except IndexError:
            pass

        return pd.Series(
            new_row,
            index=[
                "id",
                "category",
                "subcategory",
                "title",
                "abstract",
                "title_entities",
                "abstract_entities",
            ],
        )

    if mode == "train":
        category2int = {}
        word2int = {}
        word2freq = {}
        entity2int = {}
        entity2freq = {}

        for row in news.itertuples(index=False):
            if row.category not in category2int:
                category2int[row.category] = len(category2int) + 1
            if row.subcategory not in category2int:
                category2int[row.subcategory] = len(category2int) + 1

            for w in word_tokenize(row.title.lower()):
                if w not in word2freq:
                    word2freq[w] = 1
                else:
                    word2freq[w] += 1
            for w in word_tokenize(row.abstract.lower()):
                if w not in word2freq:
                    word2freq[w] = 1
                else:
                    word2freq[w] += 1

            for e in json.loads(row.title_entities):
                times = len(e["OccurrenceOffsets"]) * e["Confidence"]
                if times > 0:
                    if e["WikidataId"] not in entity2freq:
                        entity2freq[e["WikidataId"]] = times
                    else:
                        entity2freq[e["WikidataId"]] += times

            for e in json.loads(row.abstract_entities):
                times = len(e["OccurrenceOffsets"]) * e["Confidence"]
                if times > 0:
                    if e["WikidataId"] not in entity2freq:
                        entity2freq[e["WikidataId"]] = times
                    else:
                        entity2freq[e["WikidataId"]] += times

        for k, v in word2freq.items():
            if v >= word_freq_threshold:
                word2int[k] = len(word2int) + 1

        for k, v in entity2freq.items():
            if v >= entity_freq_threshold:
                entity2int[k] = len(entity2int) + 1

        parsed_news = news.swifter.apply(parse_row, axis=1)
        parsed_news.to_csv(target, sep="\t", index=False)

        pd.DataFrame(category2int.items(), columns=["category", "int"]).to_csv(
            category2int_path, sep="\t", index=False
        )
        num_categories = 1 + len(category2int)

        pd.DataFrame(word2int.items(), columns=["word", "int"]).to_csv(
            word2int_path, sep="\t", index=False
        )
        num_words = 1 + len(word2int)

        pd.DataFrame(entity2int.items(), columns=["entity", "int"]).to_csv(
            entity2int_path, sep="\t", index=False
        )
        num_entities = 1 + len(entity2int)

        return num_categories, num_words, num_entities

    elif mode == "test":
        category2int = dict(pd.read_table(category2int_path).values.tolist())
        # na_filter=False is needed since nan is also a valid word
        word2int = dict(pd.read_table(word2int_path, na_filter=False).values.tolist())
        entity2int = dict(pd.read_table(entity2int_path).values.tolist())

        parsed_news = news.swifter.apply(parse_row, axis=1)
        parsed_news.to_csv(target, sep="\t", index=False)

    else:
        print("Wrong mode!")


def generate_word_embedding(source, target, word2int_path, word_embedding_dim):
    """Generate from pretrained word embedding file
    If a word not in embedding file, initial its embedding by N(0, 1)

    Args:
        source (str): Path of pretrained word embedding file, e.g. glove.840B.300d.txt
        target (str): Path for saving word embedding
        word2int_path (str): Path to vocabulary file
    """
    # na_filter=False is needed since nan is also a valid word
    # word, int
    word2int = pd.read_table(word2int_path, na_filter=False, index_col="word")
    source_embedding = pd.read_table(
        source,
        index_col=0,
        sep=" ",
        header=None,
        quoting=csv.QUOTE_NONE,
        names=range(word_embedding_dim),
    )
    # word, vector
    source_embedding.index.rename("word", inplace=True)
    # word, int, vector
    merged = word2int.merge(
        source_embedding, how="inner", left_index=True, right_index=True
    )
    merged.set_index("int", inplace=True)

    missed_index = np.setdiff1d(np.arange(len(word2int) + 1), merged.index.values)
    missed_embedding = pd.DataFrame(
        data=np.random.normal(size=(len(missed_index), word_embedding_dim))
    )
    missed_embedding["int"] = missed_index
    missed_embedding.set_index("int", inplace=True)

    final_embedding = pd.concat([merged, missed_embedding]).sort_index()
    np.save(target, final_embedding.values)

    print(
        f"Rate of word missed in pretrained embedding: {(len(missed_index)-1)/len(word2int):.4f}"
    )


def transform_entity_embedding(source, target, entity2int_path, entity_embedding_dim):
    """Transform entity embedding

    Args:
        source (str): Path of embedding file
        target (str): Path of transformed embedding file in numpy format
        entity2int_path (str): Path to entity ids file
    """
    entity_embedding = pd.read_table(source, header=None)
    entity_embedding["vector"] = entity_embedding.iloc[:, 1:101].values.tolist()
    entity_embedding = entity_embedding[[0, "vector"]].rename(columns={0: "entity"})

    entity2int = pd.read_table(entity2int_path)
    merged_df = pd.merge(entity_embedding, entity2int, on="entity").sort_values("int")
    entity_embedding_transformed = np.random.normal(
        size=(len(entity2int) + 1, entity_embedding_dim)
    )
    for row in merged_df.itertuples(index=False):
        entity_embedding_transformed[row.int] = row.vector
    np.save(target, entity_embedding_transformed)


def parse_mind(
    train_dir,
    val_dir,
    test_dir,
    glove_dir,
    glove_size,
    negative_sampling_ratio,
    num_words_title,
    num_words_abstract,
    entity_confidence_threshold,
    word_freq_threshold,
    entity_freq_threshold,
    word_embedding_dim,
    entity_embedding_dim,
):
    """Parse MIND dataset
    Args:
        train_dir (str): Path to train directory
        val_dir (str): Path to validation directory
        test_dir (str): Path to test directory
        glove_dir (str): Path to glove directory
        glove_size (): Glove size
        negative_sampling_ratio ():
        num_words_title (long): number of words in title
        num_words_abstract (long): number of words in abstract
        entity_confidence_threshold ():
        word_freq_threshold (float): Threshold for word frequency
        entity_freq_threshold (float): Threshold for entity frequency
        word_embedding_dim (): Word embedding dimension
        entity_embedding_dim (): Entity embedding dimension

    Returns:
        Tuple (num_users, num_categories, num_words, num_entities) containing number of users, number of categories, number of words, number of entities
    """

    # Process data for training
    num_users = parse_behaviors(
        path.join(train_dir, "behaviors.tsv"),
        path.join(train_dir, "behaviors_parsed.tsv"),
        path.join(train_dir, "user2int.tsv"),
        negative_sampling_ratio,
    )
    num_categories, num_words, num_entities = parse_news(
        path.join(train_dir, "news.tsv"),
        path.join(train_dir, "news_parsed.tsv"),
        path.join(train_dir, "category2int.tsv"),
        path.join(train_dir, "word2int.tsv"),
        path.join(train_dir, "entity2int.tsv"),
        "train",
        num_words_title,
        num_words_abstract,
        entity_confidence_threshold,
        word_freq_threshold,
        entity_freq_threshold,
    )
    generate_word_embedding(
        path.join(glove_dir, f"glove.{glove_size}B.{word_embedding_dim}d.txt"),
        path.join(train_dir, "pretrained_word_embedding.npy"),
        path.join(train_dir, "word2int.tsv"),
        word_embedding_dim,
    )
    transform_entity_embedding(
        path.join(train_dir, "entity_embedding.vec"),
        path.join(train_dir, "pretrained_entity_embedding.npy"),
        path.join(train_dir, "entity2int.tsv"),
        entity_embedding_dim,
    )

    # Process data for validation
    parse_news(
        path.join(val_dir, "news.tsv"),
        path.join(val_dir, "news_parsed.tsv"),
        path.join(train_dir, "category2int.tsv"),
        path.join(train_dir, "word2int.tsv"),
        path.join(train_dir, "entity2int.tsv"),
        "test",
        num_words_title,
        num_words_abstract,
        entity_confidence_threshold,
        word_freq_threshold,
        entity_freq_threshold,
    )

    # Process data for test
    if test_dir is not None:
        parse_news(
            path.join(test_dir, "news.tsv"),
            path.join(test_dir, "news_parsed.tsv"),
            path.join(train_dir, "category2int.tsv"),
            path.join(train_dir, "word2int.tsv"),
            path.join(train_dir, "entity2int.tsv"),
            "test",
            num_words_title,
            num_words_abstract,
            entity_confidence_threshold,
            word_freq_threshold,
            entity_freq_threshold,
        )

    print(
        f"Set num_words = {num_words} and num_categories = {num_categories} in model config"
    )

    return num_users, num_categories, num_words, num_entities

def parse_news_bert( source, target,column):
    """
    Parse news for using BERT baseline
    Generate BERT embedding for the text in news df
    Args:
        source: source news file
        target: target news file
        column: the text that will represent the news

    Returns:
        DataFrame: news_parsed(news_id, text)

    """
    print(f"Parse {source}")
    news = pd.read_table(
        source,
        header=None,
        usecols=[0, 1, 2, 3, 4, 6, 7],
        # quoting=csv.QUOTE_NONE,
        names=[
            "id",
            "category",
            "subcategory",
            "title",
            "abstract",
            "title_entities",
            "abstract_entities",
        ],
    )
    news.fillna(" ", inplace=True)
    parsed_news = pd.DataFrame()
    parsed_news["news_id"] = news.id
    parsed_news["text"] = news[column]

    def clean(text):
        # Remove punctuation from the text
        remove_punctuation_dict = dict((ord(mark), None) for mark in string.punctuation)
        return text.translate(remove_punctuation_dict)

    parsed_news.text = parsed_news.text.apply(lambda x: clean(x))
    # Print : Parsed_news are saved in : target
    parsed_news.to_csv(target, sep="\t", index=False)
    return parsed_news, set(parsed_news.news_id)

def parse_behaviors_bert( source, target, news_ids_set):
    """
    Parse behaviors for using BERT baseline
    Get all the history(NewsID) for each user
    Get the candidate_news from impressions for each user
    Args:
        source: source news file
        target: target news file

    Returns:
        DataFrame: behaviors_parsed(id, history:<NEWS_IDS>, candidate_news<NEWS_IDS>, labels:<y_true>)

    """
    print(f"Parse {source}")
    behaviors = pd.read_table(
        source,
        header=None,
        names=["impression_id", "user", "time", "clicked_news", "impressions"],
    )
    behaviors.clicked_news.fillna(" ", inplace=True)
    parsed_behaviors = pd.DataFrame()
    parsed_behaviors["id"] = behaviors.impression_id
    parsed_behaviors["history"] = behaviors.clicked_news.str.split()
    behaviors.impressions = behaviors.impressions.str.split()
    # parsed_behaviors["candidate_news"] = behaviors.impressions.apply(
    #     lambda x: [impression.split("-")[0] for impression in x if impression.split("-")[0] in news_ids_set])
    # parsed_behaviors["labels"] = behaviors.impressions.apply(
    #     lambda x: [impression.split("-")[1] for impression in x if impression.split("-")[0] in news_ids_set])
    parsed_behaviors["candidate_news"] = behaviors.impressions.apply(
        lambda x: [impression.split("-")[0] for impression in x ])
    parsed_behaviors["labels"] = behaviors.impressions.apply(
        lambda x: [impression.split("-")[1] for impression in x ])
    # parsed_behaviors.history=parsed_behaviors.history.apply(
    #     lambda x: [news_id for news_id in x if news_id in news_ids_set])

    # Print : Parsed_news are saved in : target
    parsed_behaviors.to_csv(target, sep="\t", index=False)

    return parsed_behaviors

def parse_mind_bert(
    train_dir,
    val_dir,
    test_dir,
    column
):

    # Process data for training
    train_news, news_ids_set_train= parse_news_bert(
        path.join(train_dir, "news.tsv"),
        path.join(train_dir, "news_BERT_parsed.tsv"),column
    )
    test=parse_behaviors_bert(
        path.join(train_dir, "behaviors.tsv"),
        path.join(train_dir, "behaviors_BERT_parsed.tsv"), news_ids_set_train
    )

    # Process data for validation
    val_news, news_ids_set_val= parse_news_bert(
        path.join(val_dir, "news.tsv"),
        path.join(val_dir, "news_BERT_parsed.tsv"), column
    )
    parse_behaviors_bert(
        path.join(val_dir, "behaviors.tsv"),
        path.join(val_dir, "behaviors_BERT_parsed.tsv"), news_ids_set_val
    )
    test_news = val_news
    # Process data for test
    if test_dir is not None:
        test_news, news_ids_set_test =parse_news_bert(
            path.join(test_dir, "news.tsv"),
            path.join(test_dir, "news_BERT_parsed.tsv"), column
        )
        parse_behaviors_bert(
            path.join(test_dir, "behaviors.tsv"),
            path.join(test_dir, "behaviors_BERT_parsed.tsv"), news_ids_set_test
        )
    return train_news, val_news, test_news

