from AnnotatedSentence.ViewLayerType import ViewLayerType
from AnnotatedTree.TreeBankDrawable cimport TreeBankDrawable
from AnnotatedTree.ParseTreeDrawable cimport ParseTreeDrawable
from AnnotatedSentence.AnnotatedSentence cimport AnnotatedSentence
from AnnotatedSentence.AnnotatedWord cimport AnnotatedWord
from DisambiguationCorpus.DisambiguatedWord cimport DisambiguatedWord
from DisambiguationCorpus.DisambiguationCorpus cimport DisambiguationCorpus


cdef class TreeDisambiguationCorpusGenerator:

    cdef TreeBankDrawable __tree_bank

    def __init__(self,
                 folder: str,
                 pattern: str):
        """
        Constructor for the DisambiguationCorpusGenerator which takes input the data directory and the pattern for the
        training files included. The constructor loads the treebank from the given directory including the given files
        the given pattern.

        PARAMETERS
        ----------
        folder : str
            Directory where the treebank files reside.
        pattern : str
            Pattern of the tree files to be included in the treebank. Use "." for all files.
        """
        self.__tree_bank = TreeBankDrawable(folder, pattern)

    cpdef DisambiguationCorpus generate(self):
        """
        Creates a morphological disambiguation corpus from the treeBank. Calls generateAnnotatedSentence for each parse
        tree in the treebank.

        RETURNS
        -------
        DisambiguationCorpus
            Created disambiguation corpus.
        """
        cdef DisambiguationCorpus corpus
        cdef int i
        cdef ParseTreeDrawable parse_tree
        cdef AnnotatedSentence sentence, disambiguation_sentence
        cdef AnnotatedWord annotated_word
        corpus = DisambiguationCorpus()
        for i in range(self.__tree_bank.size()):
            parse_tree = self.__tree_bank.get(i)
            if parse_tree.layerAll(ViewLayerType.INFLECTIONAL_GROUP):
                sentence = parse_tree.generateAnnotatedSentence()
                disambiguation_sentence = AnnotatedSentence()
                for j in range(sentence.wordCount()):
                    annotated_word = sentence.getWord(j)
                    if isinstance(annotated_word, AnnotatedWord):
                        disambiguation_sentence.addWord(DisambiguatedWord(annotated_word.getName(),
                                                                         annotated_word.getParse()))
                corpus.addSentence(disambiguation_sentence)
        return corpus
