from mykatlas.typing.typer.base import Typer
from mykatlas.stats import log_lik_R_S_coverage
from mykatlas.typing.typer.base import MIN_LLK

from ga4ghmongo.schema import VariantCall

DEFAULT_ERROR_RATE = 0.05
DEFAULT_MINOR_FREQ = 0.2
from mykatlas.stats import percent_coverage_from_expected_coverage


class VariantTyper(Typer):

    def __init__(self, expected_depths, contamination_depths=[],
                 error_rate=DEFAULT_ERROR_RATE,
                 minor_freq=DEFAULT_MINOR_FREQ,
                 ignore_filtered=False):
        super(
            VariantTyper,
            self).__init__(
            expected_depths,
            contamination_depths,
            error_rate,
            ignore_filtered=ignore_filtered)
        self.method = "MAP"
        self.error_rate = error_rate
        self.minor_freq = minor_freq
        self.ignore_filtered = ignore_filtered

        if len(expected_depths) > 1:
            raise NotImplementedError("Mixed samples not handled yet")

    def type(self, variant_probe_coverages, variant=None):
        """
            Takes a list of VariantProbeCoverages and returns a Call for the Variant.
            Note, in the simplest case the list will be of length one. However, we may be typing the
            Variant on multiple backgrouds leading to multiple VariantProbes for a single Variant.

        """
        if not isinstance(variant_probe_coverages, list):
            variant_probe_coverages = [variant_probe_coverages]
        calls = []
        for variant_probe_coverage in variant_probe_coverages:
            calls.append(
                self._type_variant_probe_coverages(
                    variant_probe_coverage, variant))
        hom_alt_calls = [c for c in calls if sum(c.genotype) > 1]
        het_calls = [c for c in calls if sum(c.genotype) == 1]
        if hom_alt_calls:
            hom_alt_calls.sort(key=lambda x: x.genotype_conf, reverse=True)
            return hom_alt_calls[0]
        elif het_calls:
            het_calls.sort(key=lambda x: x.genotype_conf, reverse=True)
            return het_calls[0]
        else:
            calls.sort(key=lambda x: x.genotype_conf, reverse=True)
            return calls[0]

    def _type_variant_probe_coverages(
            self, variant_probe_coverage, variant=None):
        variant_probe_coverage = self._check_min_coverage(
            variant_probe_coverage)
        hom_ref_likelihood = self._hom_ref_lik(variant_probe_coverage)
        hom_alt_likelihood = self._hom_alt_lik(variant_probe_coverage)
        if not self.has_contamination():
            het_likelihood = self._het_lik(variant_probe_coverage)
        else:
            het_likelihood = MIN_LLK
        likelihoods = [hom_ref_likelihood, het_likelihood, hom_alt_likelihood]
        gt = self.likelihoods_to_genotype(
            likelihoods
        )
        info = {"coverage": variant_probe_coverage.coverage_dict,
                "expected_depths": self.expected_depths,
                "contamination_depths": self.contamination_depths,
                "filter": "PASS"}
        if gt == "-/-" and not self.ignore_filtered:
            if variant_probe_coverage.alternate_percent_coverage > variant_probe_coverage.reference_percent_coverage:
                gt = "1/1"
                info["filter"] = "MISSING_WT"
            else:
                gt = "0/0"
                info["filter"] = "MISSING_WT"
        return VariantCall.create(
            variant=variant,
            genotype=gt,
            genotype_likelihoods=likelihoods,
            info=info)

    def _check_min_coverage(self, variant_probe_coverage):
        if variant_probe_coverage.alternate_min_depth < 0.1 * \
                max(self.expected_depths):
            variant_probe_coverage.alternate_percent_coverage = variant_probe_coverage.alternate_percent_coverage * 0.9
        return variant_probe_coverage

    def _hom_ref_lik(self, variant):
        if variant.reference_percent_coverage < 100 * \
                percent_coverage_from_expected_coverage(max(self.expected_depths)):
            return MIN_LLK
        else:
            hom_ref_likes = []
            # Either alt+cov or alt_covg + contam_covg
            for expected_depth in self.expected_depths:
                hom_ref_likes.append(
                    log_lik_R_S_coverage(
                        variant.reference_median_depth,
                        variant.alternate_median_depth,
                        expected_depth,
                        expected_depth *
                        self.error_rate /
                        3))
                for contamination in self.contamination_depths:
                    hom_ref_likes.append(
                        log_lik_R_S_coverage(
                            variant.reference_median_depth,
                            variant.alternate_median_depth,
                            expected_depth + contamination,
                            (expected_depth + contamination) * self.error_rate / 3))
            return max(hom_ref_likes)

    def _hom_alt_lik(self, variant):
        if variant.alternate_percent_coverage < 100 * \
                percent_coverage_from_expected_coverage(max(self.expected_depths)):
            return MIN_LLK
        else:
            hom_alt_liks = []
            # Either alt+cov or alt_covg + contam_covg
            for expected_depth in self.expected_depths:
                hom_alt_liks.append(
                    log_lik_R_S_coverage(
                        variant.alternate_median_depth,
                        variant.reference_median_depth,
                        expected_depth,
                        expected_depth *
                        self.error_rate /
                        3))
                for contamination in self.contamination_depths:
                    hom_alt_liks.append(
                        log_lik_R_S_coverage(
                            variant.alternate_median_depth,
                            variant.reference_median_depth,
                            expected_depth + contamination,
                            (expected_depth + contamination) * self.error_rate / 3))
            return max(hom_alt_liks)

    def _het_lik(self, variant):
        if variant.alternate_percent_coverage < 100 or variant.reference_percent_coverage < 100:
            return MIN_LLK
        else:
            het_liks = []
            for expected_depth in self.expected_depths:
                het_liks.append(
                    log_lik_R_S_coverage(
                        variant.alternate_median_depth,
                        variant.reference_median_depth,
                        expected_depth * self.minor_freq,
                        expected_depth * (
                            1 - self.minor_freq)))
            return max(het_liks)
