"""
Metamist wrapper to get input samples.
"""

import logging
from collections import defaultdict
from itertools import groupby

from cpg_utils.config import get_config

from .metamist import get_metamist, Sequence, AnalysisType, MetamistError
from .targets import Cohort, Sex, PedigreeInfo


_cohort: Cohort | None = None


def get_cohort() -> Cohort:
    """Return the cohort object"""
    global _cohort
    if not _cohort:
        _cohort = create_cohort()
    return _cohort


def create_cohort() -> Cohort:
    """
    Add datasets in the cohort. There exists only one cohort for the workflow run.
    """
    analysis_dataset_name = get_config()['workflow']['dataset']
    dataset_names = get_config()['workflow'].get('datasets', [analysis_dataset_name])
    skip_datasets = get_config()['workflow'].get('skip_datasets', [])
    dataset_names = [d for d in dataset_names if d not in skip_datasets]

    skip_samples = get_config()['workflow'].get('skip_samples', [])
    only_samples = get_config()['workflow'].get('only_samples', [])

    cohort = Cohort()
    for dataset_name in dataset_names:
        dataset = cohort.create_dataset(dataset_name)
        sample_entries = get_metamist().sapi.get_samples(
            body_get_samples={'project_ids': [dataset_name]}
        )
        sample_entries = _filter_samples(
            sample_entries,
            dataset_name,
            skip_samples,
            only_samples,
        )
        for entry in sample_entries:
            dataset.add_sample(
                id=str(entry['id']),
                external_id=str(entry['external_id']),
                meta=entry.get('meta', {}),
            )

    if not cohort.get_datasets():
        msg = 'No datasets populated'
        if skip_samples or only_samples or skip_datasets:
            msg += ' (after skipping/picking samples)'
        logging.warning(msg)
        return cohort

    if sequencing_type := get_config()['workflow'].get('sequencing_type'):
        _populate_alignment_inputs(cohort, sequencing_type)
        _filter_sequencing_type(cohort, sequencing_type)
    _populate_analysis(cohort)
    _populate_participants(cohort)
    _populate_pedigree(cohort)
    return cohort


def _filter_sequencing_type(cohort: Cohort, sequencing_type: str):
    """
    Filtering to the samples with only requested sequencing types.
    """
    for s in cohort.get_samples():
        if not s.seq_by_type:
            logging.warning(f'{s}: skipping because no sequencing inputs found')
            s.active = False
            continue

        if s.alignment_input_by_seq_type:
            avail_types = list(s.seq_by_type.keys())
            s.alignment_input_by_seq_type = {
                k: v
                for k, v in s.alignment_input_by_seq_type.items()
                if k == sequencing_type
            }
            if not bool(s.alignment_input_by_seq_type):
                logging.warning(
                    f'{s}: skipping because no inputs with data type '
                    f'"{sequencing_type}" found in {avail_types}'
                )
                s.active = False


def _filter_samples(
    entries: list[dict[str, str]],
    dataset_name: str,
    skip_samples: list[str] | None = None,
    only_samples: list[str] | None = None,
) -> list[dict]:
    """
    Apply the only_samples and skip_samples filters.
    """

    filtered_entries = []
    for entry in entries:
        cpgid = entry['id']
        extid = entry['external_id']
        if only_samples:
            if cpgid in only_samples or extid in only_samples:
                logging.info(f'Picking sample: {dataset_name}|{cpgid}|{extid}')
            else:
                continue
        if skip_samples:
            if cpgid in skip_samples or extid in skip_samples:
                logging.info(f'Skipping sample: {dataset_name}|{cpgid}|{extid}')
                continue
        filtered_entries.append(entry)
    return filtered_entries


def _populate_alignment_inputs(
    cohort: Cohort,
    sequencing_type: str,
    check_existence: bool = False,
) -> None:
    """
    Populate sequencing inputs for samples.
    """
    assert cohort.get_sample_ids()
    found_seqs: list[dict] = get_metamist().seqapi.get_sequences_by_sample_ids(
        cohort.get_sample_ids(), get_latest_sequence_only=False
    )
    found_seqs = [seq for seq in found_seqs if str(seq['type']) == sequencing_type]
    found_seqs_by_sid = defaultdict(list)
    for found_seq in found_seqs:
        found_seqs_by_sid[found_seq['sample_id']].append(found_seq)

    # Log sequences without samples, this is a pretty common thing,
    # but useful to log to easier track down samples not processed
    if sample_wo_seq := [
        s for s in cohort.get_samples() if s.id not in found_seqs_by_sid
    ]:
        msg = f'No {sequencing_type} sequencing data found for samples:\n'
        ds_sample_count = {
            ds_name: len(list(ds_samples))
            for ds_name, ds_samples in groupby(
                cohort.get_samples(), key=lambda s: s.dataset.name
            )
        }
        for ds, samples in groupby(sample_wo_seq, key=lambda s: s.dataset.name):
            msg += (
                f'\t{ds}, {len(list(samples))}/{ds_sample_count.get(ds)} samples: '
                f'{", ".join([s.id for s in samples])}\n'
            )
        logging.info(msg)

    for sample in cohort.get_samples():
        for d in found_seqs_by_sid.get(sample.id, []):
            seq = Sequence.parse(d, check_existence=check_existence)
            sample.seq_by_type[seq.sequencing_type] = seq
            if seq.alignment_input:
                if seq.sequencing_type in sample.alignment_input_by_seq_type:
                    raise MetamistError(
                        f'{sample}: found more than 1 alignment input with '
                        f'sequencing type: {seq.sequencing_type}. Check your '
                        f'input provider to make sure there is only one data source '
                        f'of sequencing type per sample.'
                    )
                sample.alignment_input_by_seq_type[
                    seq.sequencing_type
                ] = seq.alignment_input


def _populate_analysis(cohort: Cohort) -> None:
    """
    Populate Analysis entries.
    """
    for dataset in cohort.get_datasets():
        gvcf_by_sid = get_metamist().find_analyses_by_sid(
            dataset.get_sample_ids(),
            analysis_type=AnalysisType.GVCF,
            dataset=dataset.name,
        )
        cram_by_sid = get_metamist().find_analyses_by_sid(
            dataset.get_sample_ids(),
            analysis_type=AnalysisType.CRAM,
            dataset=dataset.name,
        )
        for sample in dataset.get_samples():
            if (analysis := gvcf_by_sid.get(sample.id)) and analysis.output:
                assert analysis.output == sample.make_gvcf_path().path, (
                    analysis.output,
                    sample.make_gvcf_path().path,
                )
            if (analysis := cram_by_sid.get(sample.id)) and analysis.output:
                assert analysis.output == sample.make_cram_path().path, analysis.output


def _populate_participants(cohort: Cohort) -> None:
    """
    Populate Participant entries.
    """
    for dataset in cohort.get_datasets():
        pid_sid_multi = (
            get_metamist().papi.get_external_participant_id_to_internal_sample_id(
                dataset.name
            )
        )
        participant_by_sid = {}
        for group in pid_sid_multi:
            pid = group[0]
            for sid in group[1:]:
                participant_by_sid[sid] = pid.strip()

        for sample in dataset.get_samples():
            if pid := participant_by_sid.get(sample.id):
                sample.participant_id = pid


def _populate_pedigree(cohort: Cohort) -> None:
    """
    Populate pedigree data for samples.
    """
    sample_by_participant_id = dict()
    for s in cohort.get_samples():
        sample_by_participant_id[s.participant_id] = s

    for dataset in cohort.get_datasets():
        ped_entries = get_metamist().get_ped_entries(dataset=dataset.name)
        ped_entry_by_participant_id = {}
        for ped_entry in ped_entries:
            part_id = str(ped_entry['individual_id'])
            ped_entry_by_participant_id[part_id] = ped_entry

        for sample in dataset.get_samples():
            if sample.participant_id not in ped_entry_by_participant_id:
                logging.warning(
                    f'No pedigree data for participant {sample.participant_id}'
                )
                continue

            ped_entry = ped_entry_by_participant_id[sample.participant_id]
            maternal_sample = sample_by_participant_id.get(
                str(ped_entry['maternal_id'])
            )
            paternal_sample = sample_by_participant_id.get(
                str(ped_entry['paternal_id'])
            )
            sample.pedigree = PedigreeInfo(
                sample=sample,
                fam_id=ped_entry['family_id'],
                mom=maternal_sample,
                dad=paternal_sample,
                sex=Sex.parse(str(ped_entry['sex'])),
                phenotype=ped_entry['affected'] or '0',
            )

    for dataset in cohort.get_datasets():
        samples_with_ped = [s for s in dataset.get_samples() if s.pedigree]
        logging.info(
            f'{dataset.name}: found pedigree info for {len(samples_with_ped)} '
            f'samples out of {len(dataset.get_samples())}'
        )
