"""

"""


#  Copyright (c) 2022-2022 Theodor Möser
#  .
#  Licensed under the EUPL-1.2-or-later (the "Licence");
#  .
#  You may not use this work except in compliance with the Licence.
#  You may obtain a copy of the Licence at:
#  .
#  https://joinup.ec.europa.eu/software/page/eupl
#  .
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the Licence is distributed on an "AS IS" basis,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  .
#  See the Licence for the specific language governing
#  permissions and limitations under the Licence.

from typing import Sequence, List, Tuple

from formgram.formgram_procedures.transformations.context_free import to_chomsky_normal_form


def cyk_recognize(grammar: dict, word: tuple) -> bool:
    """Check if the given word can be generated by given context free grammar

    :param grammar:
    :param word:
    :return:
    """
    parse_table, _ = create_cyk_tables(grammar, word)
    return grammar["starting_symbol"] in parse_table[-1][0]


def create_cyk_forest(grammar: dict, word: Sequence) -> List[dict]:
    """Create a list of all possible parse trees

    The individual parse trees are dictionaries with keys
    * symbol: which is the terminal or nonterminal symbol which is used or respectively produced
    * children: a Tuple of tree dictionaries

    This is compatible with the anytree package.

    :param grammar:
    :param word:
    :return:
    """
    _, backpointer_table = create_cyk_tables(grammar, word)
    if not backpointer_table[len(word) - 1][0]:
        return None
    return create_tree_by_recursion(
        backpointer_table, grammar["starting_symbol"], word, row=len(word) - 1, col=0
    )


def create_tree_by_recursion(backpointer_table, current, word, row, col) -> List[dict]:
    """Reconstruct all subtrees from the backpointer table starting at the specified cell with the specified nonterminal

    :param backpointer_table:
    :param row:
    :param col:
    :return:
    """
    if row == 0:
        return [{"name": current, "children": ({"name": word[col], "children": ()},)}]
    if row > 0:
        return [
            {"name": current, "children": (left_subtree, right_subtree)}
            for (nonterminal, left_child, right_child, split) in backpointer_table[row][
                col
            ]
            if nonterminal == current
            for left_subtree in create_tree_by_recursion(
                backpointer_table, left_child, word, split, col
            )
            for right_subtree in create_tree_by_recursion(
                backpointer_table, right_child, word, row - split - 1, col + split + 1
            )
        ]


def create_cyk_tables(
    grammar: dict, word: Sequence
) -> Tuple[List[List[set]], List[List[set]]]:
    """Create a parse table and backpointer table using the CYK algorithm



    :param grammar:
    :param word:
    :return:
    """
    word_length = len(word)
    grammar = to_chomsky_normal_form(grammar)
    parse_table = [
        [set() for column in range(word_length - row)] for row in range(word_length)
    ]
    backpointer_table = [
        [set() for column in range(word_length - row)] for row in range(word_length)
    ]
    terminal_productions = {
        production
        for production in grammar["productions"]
        if len(right_hand_side := production[1]) == 1
    }
    bifurcating_productions = {
        production
        for production in grammar["productions"]
        if len(right_hand_side := production[1]) == 2
    }

    # Fill bottom row with nonterminals which can produce the word symbols directly
    for i, symbol in enumerate(word):
        parse_table[0][i] = {
            nonterminal
            for (nonterminal,), (terminal,) in terminal_productions
            if terminal == symbol
        }

    # Fill rest of rows by checking for all possible places where one could split the word if those substrings can be
    # created by the corresponding cells in the already filled rows
    for row in range(1, word_length):
        for column in range(0, word_length - row):
            for split in range(0, row):
                for (left_nonterminal,), (
                    first_right_nonterminal,
                    second_right_nonterminal,
                ) in bifurcating_productions:
                    if (
                        first_right_nonterminal in parse_table[split][column]
                        and second_right_nonterminal
                        in parse_table[row - split - 1][column + split + 1]
                    ):
                        parse_table[row][column].add(left_nonterminal)
                        backpointer_table[row][column].add(
                            (
                                left_nonterminal,
                                first_right_nonterminal,
                                second_right_nonterminal,
                                split,
                            )
                        )

    return parse_table, backpointer_table
