import os
import re
from typing import Dict, List, Optional, Tuple

import pandas as pd
import spacy
from PyPDF2 import PdfReader
from spacy.tokens import Doc, Token


def process_tokens(
        doc: pd.Series, nlp: spacy.language.Language, stop_words: List[str]
) -> List[str]:
    """Processes the tokens in a document. Removes stop words, punctuation and non-alphabetic tokens."""
    spacy_text = nlp(doc)
    return [
        token
        for token in spacy_text
        if not any([token.is_stop, token.is_punct, token.lemma_.lower() in stop_words, not token.is_alpha])
    ]

def get_filtered_tokens(spacy_text: Doc, stop_words: List[str]) -> List[Token]:
    """Processes the tokens in a document. Removes stop words, punctuation and non-alphabetic tokens."""
    return [
        token
        for token in spacy_text
        if not any([token.is_stop, token.is_punct, token.lemma_.lower() in stop_words, not token.is_alpha])
    ]


def process_lemmas(doc: pd.Series) -> List[str]:
    """Makes tokens lemma lower case."""
    return [token.lemma_.lower() for token in doc]


def _multiply_ngrams(tokens: List[str]):
    """Generator that yields tokens, one time for standard token, three times for ngram. Used to multiply ngrams."""
    for token in tokens:
        if " " in token:
            yield token
            yield token
        yield token


def get_table_of_contents(path: str, toc: str = "Table of Contents") -> Tuple[str, int]:
    """Returns the table of contents of a pdf file and the page number of the table of contents."""
    file = open(path, "rb")
    fileReader = PdfReader(file)
    text = ""
    toc_page = 0
    while not toc in text:
        try:
            pageObj = fileReader.pages[toc_page]
            # pageObj = fileReader.getPage(toc_page)
        except:
            return "", -1
        text = pageObj.extract_text()
        toc_page += 1
    file.close()
    return text, toc_page


def get_paragraphs_df(
        toc: str, pages_shift: int, paragraphs_names: Dict[str, List[str]], end_paragraph: str
) -> pd.DataFrame:
    """Based on the table of contents, returns a dataframe with the paragraphs and the pages they start and end on.

    Args:
        toc (str): table of contents of the pdf file.
        pages_shift (int): number of pages to shift the page numbers.
        paragraphs_names (Dict[str, List[str]]): dictionary with the paragraphs names.
        end_paragraph (str): name of the last paragraph.

    Returns:
        pd.DataFrame: dataframe with the paragraphs and the pages they start and end on.

    """
    lines = toc.split("\n")
    rows = {"paragraph": [], "start_page": [], "end_page": [], "start_text": [], "end_text": []}
    for key, paragraphs in paragraphs_names.items():
        for paragraph in paragraphs:
            paragaph_without_spaces = paragraph.replace(" ", "")
            paragraph_line = [line.replace(" ", "") for line in lines if paragaph_without_spaces in line.replace(" ", "")]
            if len(paragraph_line) == 0:
                continue
            paragraph_line = paragraph_line[0]
            paragraph_line_without_spaces = paragraph_line.replace(" ", "")
            paragaph_without_spaces = paragraph.replace(" ", "")
            try:
                start_page = (
                        int(
                            re.sub(
                                "[^0-9]+",
                                "",
                                paragraph_line[paragraph_line.find(paragraph) + len(paragraph):],
                            )
                        )
                        + pages_shift
                )
            except Exception as e:
                continue
            if len(rows["start_page"]) > 0:
                rows["end_page"].append(start_page)
                rows["end_text"].append(paragraph if start_page!=999 else None)
            if key != end_paragraph:
                rows["paragraph"].append(key)
                rows["start_page"].append(start_page)
                rows["start_text"].append(paragraph)
            else:
                break
    if len({len(i) for i in rows.values()}) != 1:
        rows["end_page"].append(999)
        rows["end_text"].append(None)
    return pd.DataFrame(rows)


def read_pages_from_pdf(path: str, start_page: int, end_page: int) -> str:
    """Reads the text from a pdf file from start_page to end_page."""
    file = open(path, "rb")
    fileReader = PdfReader(file)
    text = ""
    count = start_page - 1
    while count < end_page:
        try:
            pageObj = fileReader.pages[count]
            count += 1
            text += pageObj.extract_text().replace("\n", "")
        except IndexError:
            break
    return text


def read_paragraphs(
        df: pd.DataFrame,
        id_column: str,
        path: str,
        id: str,
        root: str = ""
) -> pd.DataFrame:
    """Reads paragraphs from a pdf file and saves them as txt files."""
    result_dict = {"paragraph": [], id_column: [], "text_path": []}
    for i, row in df.iterrows():
        file_name = row.paragraph.replace(":","").replace(" ","_").replace(
            ",","").replace("/","_").replace("(","").replace(")","").replace(
            "&","").replace("-","_").replace("__","_").lower()
        txt_destination = f"{root}{id}_{file_name}.txt"
        if row.start_page is None:
            text = ""
        else:
            start_page = row.start_page
            end_page = row.end_page
            for i in [0, -1, 1, 2]:
                start_page = row.start_page + i
                text = read_pages_from_pdf(path, start_page, end_page)
                if text != "":
                    break
            if row.start_text is not None:
                try:
                    text = row.start_text + text.split(row.start_text, 1)[1]
                except:
                    pass
            if row.end_text is not None:
                text = text.split(row.end_text, 1)[0]
        text = text_cleaning(text)
        text_file = open(txt_destination, "w+", encoding="utf-8")
        n = text_file.write(text)
        text_file.close()
        result_dict["paragraph"].append(row.paragraph)
        result_dict[id_column].append(id)
        result_dict["text_path"].append(txt_destination)
    return pd.DataFrame(result_dict)


def process_all_documents(
        directory_path: str,
        id_column: str,
        paragraphs_names: Dict[str, List[str]],
        save_txt: str,
        end_paragraph: str,
        toc_str: str = "Table of Contents",
        pages_shift: Optional[int] = None,
) -> pd.DataFrame:
    """Process documents from directory_path with
    table of contents with paragraph names and pages

    Args:
        directory_path (str): directory with documents to process
        id_column (str): name of the id column
        paragraphs_names (Dict[str, List[str]]): key - name of pargraph that should
        be displayed in the final df, value - list of possible names of this paragraph in toc
        save_txt (str): path to directory where txt files should be saved
        end_paragraph (str): last paragraph of the text that should not be present in final df
        toc_str (str, optional): name of table of contents in documents. Defaults to "Table of Contents".
        pages_shift (int, optional): difference between page number in table of contents and in pdf file.
            Defaults to None which will be interpreted as pages_shift = page of toc

    Returns:
        pd.DataFrame: data frame with desired format
    """
    dir_list = os.listdir(directory_path)
    dir_list = [file for file in dir_list if file[-3:] == "pdf"]
    df = pd.DataFrame({"paragraph": [], id_column: [], "text_path": []})
    for doc in dir_list:
        toc, toc_page = get_table_of_contents(directory_path + doc, toc_str)
        paragraphs_df = get_paragraphs_df(
            toc, pages_shift or toc_page, paragraphs_names, end_paragraph
        )
        doc_df = read_paragraphs(paragraphs_df, id_column, directory_path + "/" + doc, doc[:-4], save_txt)
        df = pd.concat([df, doc_df], ignore_index=True)
    return df


def text_cleaning(text):
    # deleting URLs
    text = re.sub(r'\w+:\/{2}[\d\w-]+(\.[\d\w-]+)*(?:(?:\/[^\s/]*))*', '', text, flags=re.MULTILINE)
    # deleting headlines
    text = re.sub(r'((\d\.)+\d) +[A-Z]([a-z]|\s|,)+', '', text)
    # deleting random numbers
    text = re.sub(r'((\d)+ +)+\(\d+\)', '', text)
    # deleting picture descriptions
    text = re.sub(r' \d+ [A-Z](\w|\s|,)+.', '', text)
    # deleting tables
    sentences = text.split('. ')
    to_delete = False
    sentences_copy = sentences.copy()
    for i, sentence in enumerate(sentences):
        if to_delete:
            to_delete = False
            sentences_copy[i] = ''
        if re.match(r'\s+Table \d+', sentence):
            to_delete = True
            sentences_copy[i] = ''
    text = '. '.join(sentences_copy)
    # deleting multiple spaces
    text = re.sub(r'\s{2,}', ' ', text)
    text = re.sub(r' . ', '', text)
    return text
