# Copyright (c) 2015. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function, division, absolute_import
from collections import Counter, OrderedDict

from typechecks import require_iterable_of

from .collection import Collection
from .common import memoize
from .effects import MutationEffect, NonsilentCodingMutation
from .effect_ordering import (
    effect_priority,
    effect_sort_key,
    top_priority_effect,
    transcript_effect_priority_dict
)

class EffectCollection(Collection):
    """
    Collection of MutationEffect objects and helpers for grouping or filtering
    them.
    """

    def __init__(
            self,
            effects,
            path=None,
            distinct=False,
            sort_key=None):
        """Construct an EffectCollection from a sequence of MutationEffects.

        Parameters
        ----------
        effects : iterable
            MutationEffect objects

        path : str, optional
            File path from which we loaded variants which gave rise to these
            effects.

        distinct : bool
            Don't keep repeated effects

        sort_key : callable
        """
        require_iterable_of(effects, MutationEffect)
        Collection.__init__(
            self,
            elements=effects,
            path=path,
            distinct=distinct,
            sort_key=sort_key)

    def groupby_variant(self):
        return self.groupby(key_fn=lambda effect: effect.variant)

    def groupby_gene(self):
        return self.groupby(key_fn=lambda effect: effect.gene)

    def groupby_gene_name(self):
        return self.groupby(key_fn=lambda effect: effect.gene_name)

    def groupby_gene_id(self):
        return self.groupby(key_fn=lambda effect: effect.gene_id)

    def groupby_transcript(self):
        return self.groupby(key_fn=lambda effect: effect.transcript)

    def groupby_transcript_name(self):
        return self.groupby(key_fn=lambda effect: effect.transcript_name)

    def groupby_transcript_id(self):
        return self.groupby(key_fn=lambda effect: effect.transcript_id)

    def filter_by_transcript_expression(
            self,
            transcript_expression_dict,
            min_expression_value=0.0):
        """
        Filters effects to those which have an associated transcript whose
        expression value in the transcript_expression_dict argument is greater
        than min_expression_value.

        Parameters
        ----------
        transcript_expression_dict : dict
            Dictionary mapping Ensembl transcript IDs to expression estimates
            (either FPKM or TPM)

        min_expression_value : float
            Threshold above which we'll keep an effect in the result collection
        """
        return self.filter_above_threshold(
            key_fn=lambda effect: effect.transcript_id,
            value_dict=transcript_expression_dict,
            threshold=min_expression_value)

    def filter_by_gene_expression(
            self,
            gene_expression_dict,
            min_expression_value=0.0):
        """
        Filters effects to those which have an associated gene whose
        expression value in the gene_expression_dict argument is greater
        than min_expression_value.

        Parameters
        ----------
        gene_expression_dict : dict
            Dictionary mapping Ensembl gene IDs to expression estimates
            (either FPKM or TPM)

        min_expression_value : float
            Threshold above which we'll keep an effect in the result collection
        """
        return self.filter_above_threshold(
            key_fn=lambda effect: effect.gene_id,
            value_dict=gene_expression_dict,
            threshold=min_expression_value)

    def filter_by_effect_priority(self, min_priority_class):
        """
        Create a new EffectCollection containing only effects whose priority
        falls below the given class.
        """
        min_priority = transcript_effect_priority_dict[min_priority_class]
        return self.filter(
            lambda effect: effect_priority(effect) >= min_priority)

    def drop_silent_and_noncoding(self):
        """
        Create a new EffectCollection containing only non-silent coding effects
        """
        return self.filter(
            lambda effect: isinstance(effect, NonsilentCodingMutation))

    def detailed_string(self):
        """
        Create a long string with all transcript effects for each mutation,
        grouped by gene (if a mutation affects multiple genes).
        """
        lines = []
        # TODO: annoying to always write `groupby_result.items()`,
        # consider makings a GroupBy class which iterates over pairs
        # and also common helper methods like `map_values`.
        for variant, variant_effects in self.groupby_variant().items():
            lines.append("\n%s" % variant)

            gene_effects_groups = variant_effects.groupby_gene_id()
            for (gene_id, gene_effects) in gene_effects_groups.items():
                if gene_id:
                    gene_name = variant.ensembl.gene_name_of_gene_id(gene_id)
                    lines.append("  Gene: %s (%s)" % (gene_name, gene_id))
                # place transcript effects with more significant impact
                # on top (e.g. FrameShift should go before NoncodingTranscript)
                for effect in sorted(
                        gene_effects,
                        key=effect_priority,
                        reverse=True):
                    lines.append("  -- %s" % effect)

            # if we only printed one effect for this gene then
            # it's redundant to print it again as the highest priority effect
            if len(variant_effects) > 1:
                best = variant_effects.top_priority_effect()
                lines.append("  Highest Priority Effect: %s" % best)
        return "\n".join(lines)

    @memoize
    def top_priority_effect(self):
        """Highest priority MutationEffect of all genes/transcripts overlapped
        by this variant. If this variant doesn't overlap anything, then this
        this method will return an Intergenic effect.

        If multiple effects have the same priority, then return the one
        which is associated with the longest transcript.
        """
        return top_priority_effect(self.elements)

    # TODO: find a way to express these kinds of methods without
    # duplicating every single groupby_* method
    @memoize
    def top_priority_effect_per_variant(self):
        """Highest priority effect for each unique variant"""
        return OrderedDict(
            (variant, top_priority_effect(variant_effects))
            for (variant, variant_effects)
            in self.groupby_variant().items())

    @memoize
    def top_priority_effect_per_transcript_id(self):
        """Highest priority effect for each unique transcript ID"""
        return OrderedDict(
            (transcript_id, top_priority_effect(variant_effects))
            for (transcript_id, variant_effects)
            in self.groupby_transcript_id().items())

    @memoize
    def top_priority_effect_per_gene_id(self):
        """Highest priority effect for each unique gene ID"""
        return OrderedDict(
            (gene_id, top_priority_effect(variant_effects))
            for (gene_id, variant_effects)
            in self.groupby_gene_id().items())

    def effect_expression(self, expression_levels):
        """
        Parameters
        ----------
        expression_levels : dict
            Dictionary mapping transcript IDs to length-normalized expression
            levels (either FPKM or TPM)

        Returns dictionary mapping each transcript effect to an expression
        quantity. Effects that don't have an associated transcript
        (e.g. Intergenic) will not be included.
        """
        return OrderedDict(
            (effect, expression_levels.get(effect.transcript.id, 0.0))
            for effect in self
            if effect.transcript is not None)

    def top_expression_effect(self, expression_levels):
        """
        Return effect whose transcript has the highest expression level.
        If none of the effects are expressed or have associated transcripts,
        then return None. In case of ties, add lexicographical sorting by
        effect priority and transcript length.
        """
        effect_expression_dict = self.effect_expression(expression_levels)
        if len(effect_expression_dict) == 0:
            return None

        def key_fn(effect_fpkm_pair):
            """
            Sort effects primarily by their expression level
            and secondarily by the priority logic used in
            `top_priority_effect`.
            """
            (effect, fpkm) = effect_fpkm_pair
            return (fpkm, effect_sort_key(effect))

        top_pair = max(
            effect_expression_dict.items(),
            key=key_fn)

        return top_pair[0]

    @memoize
    def gene_counts(self):
        counter = Counter()
        for effect in self:
            counter[effect.gene_name] += 1
        return counter
