from typing import Any, Dict, List, Optional, Tuple, Set

from alnlp.metrics.span_utils import enumerate_spans


def make_coref_instance(
        sentences: List[List[str]],
        max_span_width: int,
        gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
        max_sentences: int = None,
        remove_singleton_clusters: bool = True,
) -> dict:
    """
    # Parameters

    sentences : `List[List[str]]`, required.
        A list of lists representing the tokenised words and sentences in the document.
    token_indexers : `Dict[str, TokenIndexer]`
        This is used to index the words in the document.  See :class:`TokenIndexer`.
    max_span_width : `int`, required.
        The maximum width of candidate spans to consider.
    gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None)
        A list of all clusters in the document, represented as word spans with absolute indices
        in the entire document. Each cluster contains some number of spans, which can be nested
        and overlap. If there are exact matches between clusters, they will be resolved
        using `_canonicalize_clusters`.
    wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = None)
        If not None, this dataset reader does subword tokenization using the supplied tokenizer
        and distribute the labels to the resulting wordpieces. All the modeling will be based on
        wordpieces. If this is set to `False` (default), the user is expected to use
        `PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`,
        and the modeling will be on the word-level.
    max_sentences: int, optional (default = None)
        The maximum number of sentences in each document to keep. By default keeps all sentences.
    remove_singleton_clusters : `bool`, optional (default = True)
        Some datasets contain clusters that are singletons (i.e. no coreferents). This option allows
        the removal of them.

    # Returns

    An `Instance` containing the following `Fields`:
        text : `TextField`
            The text of the full document.
        spans : `ListField[SpanField]`
            A ListField containing the spans represented as `SpanFields`
            with respect to the document text.
        span_labels : `SequenceLabelField`, optional
            The id of the cluster which each possible span belongs to, or -1 if it does
                not belong to a cluster. As these labels have variable length (it depends on
                how many spans we are considering), we represent this a as a `SequenceLabelField`
                with respect to the spans `ListField`.
    """
    if max_sentences is not None and len(sentences) > max_sentences:
        sentences = sentences[:max_sentences]
        total_length = sum(len(sentence) for sentence in sentences)

        if gold_clusters is not None:
            new_gold_clusters = []

            for cluster in gold_clusters:
                new_cluster = []
                for mention in cluster:
                    if mention[1] < total_length:
                        new_cluster.append(mention)
                if new_cluster:
                    new_gold_clusters.append(new_cluster)

            gold_clusters = new_gold_clusters

    flattened_sentences = [_normalize_word(word) for sentence in sentences for word in sentence]
    flat_sentences_tokens = [word for word in flattened_sentences]

    text_field = flat_sentences_tokens

    cluster_dict = {}
    if gold_clusters is not None:
        gold_clusters = _canonicalize_clusters(gold_clusters)
        if remove_singleton_clusters:
            gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 1]

        for cluster_id, cluster in enumerate(gold_clusters):
            for mention in cluster:
                cluster_dict[tuple(mention)] = cluster_id

    spans: List = []
    span_labels: Optional[List[int]] = [] if gold_clusters is not None else None

    sentence_offset = 0
    for sentence in sentences:
        for start, end in enumerate_spans(
                sentence, offset=sentence_offset, max_span_width=max_span_width
        ):

            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append((start, end))
        sentence_offset += len(sentence)

    span_field = spans

    # metadata: Dict[str, Any] = {"original_text": flattened_sentences}
    # if gold_clusters is not None:
    #     metadata["clusters"] = gold_clusters
    # metadata_field = MetadataField(metadata)

    fields: Dict[str, List] = {
        "text": text_field,
        "spans": span_field,
        'clusters': gold_clusters,
        # "metadata": metadata_field,
    }
    if span_labels is not None:
        fields["span_labels"] = span_labels

    return fields


def _normalize_word(word):
    if word in ("/.", "/?"):
        return word[1:]
    else:
        return word


def _canonicalize_clusters(clusters: List[List[Tuple[int, int]]]) -> List[List[Tuple[int, int]]]:
    """
    The data might include 2 annotated spans which are identical,
    but have different ids. This checks all clusters for spans which are
    identical, and if it finds any, merges the clusters containing the
    identical spans.
    """
    merged_clusters: List[Set[Tuple[int, int]]] = []
    for cluster in clusters:
        cluster_with_overlapping_mention = None
        for mention in cluster:
            # Look at clusters we have already processed to
            # see if they contain a mention in the current
            # cluster for comparison.
            for cluster2 in merged_clusters:
                if mention in cluster2:
                    # first cluster in merged clusters
                    # which contains this mention.
                    cluster_with_overlapping_mention = cluster2
                    break
            # Already encountered overlap - no need to keep looking.
            if cluster_with_overlapping_mention is not None:
                break
        if cluster_with_overlapping_mention is not None:
            # Merge cluster we are currently processing into
            # the cluster in the processed list.
            cluster_with_overlapping_mention.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    return [list(c) for c in merged_clusters]
