from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List

from unstructured.documents.elements import Element, NarrativeText


class Document(ABC):
    """The base class for all document types. A document consists of an ordered list of pages."""

    def __init__(self, filename: str):
        self.filename: str = filename
        self.pages: List[Page] = list()
        self._elements: List[Element] = list()

    def __str__(self):
        return "\n\n".join([str(page) for page in self.pages])

    @abstractmethod
    def read(self, inplace: bool = True):  # pragma: no cover
        pass

    def get_narrative(self) -> List[NarrativeText]:
        """Pulls out all of the narrative text sections from the document."""
        narrative: List[NarrativeText] = list()
        for page in self.pages:
            for element in page.elements:
                if isinstance(element, NarrativeText):
                    narrative.append(element)
        return narrative

    @property
    def elements(self) -> List[Element]:
        """Gets all elements from pages in sequential order."""
        if not self._elements:
            self._elements = [el for page in self.pages for el in page.elements]
        return self._elements

    def after_element(self, element: Element) -> Document:
        """Returns a single page document containing all the elements after the given element"""
        elements = self.elements
        element_ids = [el.id for el in elements]
        start_idx = element_ids.index(element.id) + 1
        return self.__class__.from_elements(elements[start_idx:])

    def before_element(self, element: Element) -> Document:
        """Returns a single page document containing all the elements before the given element"""
        elements = self.elements
        element_ids = [el.id for el in elements]
        end_idx = element_ids.index(element.id)
        return self.__class__.from_elements(elements[:end_idx])

    def print_narrative(self):
        """Prints the narrative text sections of the document."""
        print("\n\n".join([str(el) for el in self.get_narrative()]))

    @classmethod
    def from_elements(cls, elements: List[Element]) -> Document:
        doc = cls(filename="")
        page = Page(number=1)
        page.elements = elements
        doc.pages = [page] if elements else []
        return doc


class Page(ABC):
    """A page consists of an ordered set of elements. The intent of the ordering is to align
    with the order in which a person would read the document."""

    def __init__(self, number: int):
        self.number: int = number
        self.elements: List[Element] = list()

    def __str__(self):
        return "\n\n".join([str(element) for element in self.elements])
