
# generate_pheno_plink_fast.py  (explicit DO/NON-DO via panel_type)
from __future__ import annotations
import os
import io
import math
import logging
from typing import Dict, List, Union

import numpy as np
import pandas as pd

from plinkformatter.plink_utils import (
    generate_bed_bim_fam,
    calculate_kinship_from_pedmap,
    rewrite_pheno_ids_from_fam,
)
from plinkformatter.generate_pheno_plink import extract_pheno_measure


def _norm_id(x) -> str:
    """
    Normalize IDs used to join DO PED V1 with pheno['animal_id'].

    - Strip whitespace.
    - If numeric, collapse "123", "123.0", "123.000" → "123".
    - Leave non-numeric IDs (e.g. "DO-123") untouched apart from stripping.
    """
    s = str(x).strip()
    if s == "":
        return s

    # Try to canonicalize numeric IDs first
    try:
        f = float(s)
        if f.is_integer():
            return str(int(f))
        # Non-integer numeric IDs: avoid scientific notation
        return ("%.10g" % f).rstrip()
    except Exception:
        # Not numeric: fall back to simple cleanup
        if s.endswith(".0"):
            s = s[:-2]
        return s


# ----------------------------- NON-DO PATH ----------------------------- #
def generate_pheno_plink_fast_non_do(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
) -> pd.DataFrame:
    """
    NON-DO behavior (matches Hao; this is your validated path):
      - replicate-level rows, keyed by STRAIN
      - FID = IID = STRAIN
      - PID/MID = 0; SEX=2 if 'f' else 1; PHE=zscore or -9
      - PHENO: 'FID IID zscore value'
      - MAP: '.' rsids -> 'chr_bp'
    """
    os.makedirs(outdir, exist_ok=True)
    if pheno is None or pheno.empty:
        return pd.DataFrame()

    need = ("strain", "sex", "measnum", "value")
    missing = [c for c in need if c not in pheno.columns]
    if missing:
        raise ValueError(f"pheno missing required columns: {missing} (need {list(need)})")

    ph = pheno.copy()
    ph["strain"] = ph["strain"].astype(str).str.replace(" ", "", regex=False)
    ph = ph[ph["sex"].isin(["f", "m"])].copy()
    if ph.empty:
        return ph

    # MAP sanitize
    map_df = pd.read_csv(map_file, header=None, sep=r"\s+", engine="python")
    map_df[1] = np.where(
        map_df[1].astype(str) == ".",
        map_df[0].astype(str) + "_" + map_df[3].astype(str),
        map_df[1].astype(str),
    )

    # Ensure zscore column
    if "zscore" not in ph.columns:
        logging.info("[NON-DO] 'zscore' missing; filling NaN (becomes -9).")
        ph["zscore"] = np.nan

    # Build strain -> byte offset index from reference PED
    ped_offsets: Dict[str, int] = {}
    with open(ped_file, "rb") as f:
        while True:
            pos = f.tell()
            line = f.readline()
            if not line:
                break
            first_tab = line.find(b"\t")
            fid_bytes = (line.strip().split()[0] if first_tab <= 0 else line[:first_tab])
            name = fid_bytes.decode(errors="replace").replace("?", "").replace(" ", "")
            if name and name not in ped_offsets:
                ped_offsets[name] = pos

    ped_strains = set(ped_offsets.keys())
    ph = ph[ph["strain"].isin(ped_strains)].reset_index(drop=True)
    if ph.empty:
        return ph

    # Conservative de-duplication (kept from the working version)
    dedup_keys = [c for c in ["strain", "sex", "measnum", "animal_id"] if c in ph.columns]
    if dedup_keys:
        ph = ph.drop_duplicates(subset=dedup_keys, keep="first")
    else:
        sig_cols = [c for c in ["strain", "sex", "measnum", "zscore", "value"] if c in ph.columns]
        ph = ph.drop_duplicates(subset=sig_cols, keep="first")

    for (measnum, sex), df in ph.groupby(["measnum", "sex"], sort=False):
        measnum = int(measnum)
        sex = str(sex)

        # MAP
        map_out = os.path.join(outdir, f"{measnum}.{sex}.map")
        map_df.to_csv(map_out, sep="\t", index=False, header=False)

        ped_out = os.path.join(outdir, f"{measnum}.{sex}.ped")
        phe_out = os.path.join(outdir, f"{measnum}.{sex}.pheno")

        df = df.sort_values(["strain"], kind="stable").reset_index(drop=True)

        with open(ped_out, "w", encoding="utf-8") as f_ped, open(phe_out, "w", encoding="utf-8") as f_ph:
            for strain, sdf in df.groupby("strain", sort=False):
                with open(ped_file, "rb") as fp:
                    fp.seek(ped_offsets[strain])
                    raw = fp.readline().decode(errors="replace").rstrip("\n")

                parts = raw.split("\t")
                if len(parts) <= 6:
                    parts = raw.split()
                if len(parts) < 7:
                    raise ValueError("Malformed PED: need >=7 columns (6 meta + genotypes)")

                parts[0] = parts[0].replace("?", "").replace(" ", "")
                parts[1] = parts[1].replace("?", "").replace(" ", "")

                for _, r in sdf.iterrows():
                    z = r.get("zscore", np.nan)
                    v = r.get("value", np.nan)
                    try:
                        z = float(z)
                    except Exception:
                        z = np.nan
                    try:
                        v = float(v)
                    except Exception:
                        v = np.nan

                    meta = parts[:6]
                    meta[0] = strain
                    meta[1] = strain
                    meta[2] = "0"
                    meta[3] = "0"
                    meta[4] = "2" if sex == "f" else "1"
                    meta[5] = f"{z}" if math.isfinite(z) else "-9"

                    out = io.StringIO()
                    out.write(" ".join(meta))
                    for gp in parts[6:]:
                        a_b = gp.split(" ")
                        if len(a_b) != 2:
                            a_b = gp.split()
                            if len(a_b) != 2:
                                raise ValueError(f"Genotype pair not splitable into two alleles: {gp!r}")
                        out.write(f" {a_b[0]} {a_b[1]}")
                    f_ped.write(out.getvalue() + "\n")

                    f_ph.write(
                        f"{strain} {strain} "
                        f"{(z if math.isfinite(z) else -9)} "
                        f"{(v if math.isfinite(v) else -9)}\n"
                    )

        logging.info(f"[generate_pheno_plink_fast:NON-DO] wrote {ped_out}, {map_out}, {phe_out}")

    return ph


# ------------------------------- DO PATH ------------------------------- #
def generate_pheno_plink_fast_do(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
) -> pd.DataFrame:
    """
    DO behavior (matches Hao):
      - animal-level rows keyed by animal_id (no genotype duplication)
      - FID=V1 and IID=V2 from the DO PED, SEX set from pheno, PHE=zscore
      - PHENO: FID=V1, IID=V2, zscore, value
      - Order follows DO PED order

    IMPORTANT:
      - We do *not* pandas-read the PED; we stream it line-by-line.
      - We fail loudly if there is ZERO overlap between pheno animal_id
        (normalized) and PED V1 (normalized).
    """
    os.makedirs(outdir, exist_ok=True)
    if pheno is None or pheno.empty:
        return pd.DataFrame()

    need = ("strain", "sex", "measnum", "value", "animal_id")
    missing = [c for c in need if c not in pheno.columns]
    if missing:
        raise ValueError(f"[generate_pheno_plink_fast_do] pheno missing required columns: {missing} (need {list(need)})")

    ph = pheno.copy()

    # Normalize strain a bit (Hao just checks for "J:DO" membership)
    ph["strain"] = ph["strain"].astype(str).str.replace(" ", "", regex=False)

    # Keep only usable rows (sex f/m, non-empty animal_id)
    ph = ph[ph["sex"].isin(["f", "m"])].copy()
    ph["animal_id_norm"] = ph["animal_id"].map(_norm_id)
    ph = ph[ph["animal_id_norm"].notna() & (ph["animal_id_norm"] != "")].copy()

    if ph.empty:
        logging.info("[generate_pheno_plink_fast_do] No usable rows after filtering sex/animal_id; nothing to write.")
        return ph

    # MAP sanitize (same as NON-DO)
    map_df = pd.read_csv(map_file, header=None, sep=r"\s+", engine="python")
    map_df[1] = np.where(
        map_df[1].astype(str) == ".",
        map_df[0].astype(str) + "_" + map_df[3].astype(str),
        map_df[1].astype(str),
    )

    # zscore presence
    if "zscore" not in ph.columns:
        logging.info("[generate_pheno_plink_fast_do] 'zscore' missing; filling NaN (becomes -9).")
        ph["zscore"] = np.nan

    # Build per-(measnum,sex): animal_id_norm -> (zscore, value, sex)
    per_group_maps: Dict[tuple, Dict[str, tuple]] = {}
    global_ids: set[str] = set()

    for (meas, sex), g in ph.groupby(["measnum", "sex"], sort=False):
        meas = int(meas)
        sex = str(sex)

        m: Dict[str, tuple] = {}
        seen: set[str] = set()
        for _, r in g.iterrows():
            aid = r["animal_id_norm"]
            if aid in seen:
                # Exactly like Hao's slice(match(...)): keep first record per animal
                continue
            seen.add(aid)

            z = r.get("zscore", np.nan)
            v = r.get("value", np.nan)
            try:
                z = float(z)
            except Exception:
                z = np.nan
            try:
                v = float(v)
            except Exception:
                v = np.nan

            m[aid] = (z, v, sex)
            global_ids.add(aid)

        if m:
            per_group_maps[(meas, sex)] = m

    if not per_group_maps:
        logging.info("[generate_pheno_plink_fast_do] No per-(measnum,sex) groups after dedup; nothing to write.")
        return ph

    # Quick sanity log
    logging.info(
        "[generate_pheno_plink_fast_do] groups=%d, total unique animals=%d",
        len(per_group_maps),
        len(global_ids),
    )

    # Prepare output handles per group
    group_handles: Dict[tuple, dict] = {}
    for (meas, sex), aid_map in per_group_maps.items():
        map_out = os.path.join(outdir, f"{meas}.{sex}.map")
        ped_out = os.path.join(outdir, f"{meas}.{sex}.ped")
        phe_out = os.path.join(outdir, f"{meas}.{sex}.pheno")

        # Write MAP once per group
        map_df.to_csv(map_out, sep="\t", index=False, header=False)

        group_handles[(meas, sex)] = {
            "aid_map": aid_map,
            "ped_path": ped_out,
            "phe_path": phe_out,
            "ped_file": open(ped_out, "w", encoding="utf-8"),
            "phe_file": open(phe_out, "w", encoding="utf-8"),
            "wrote_any": False,
        }

    # Single streaming pass per group would re-scan PED for each group.
    # Instead, we do ONE pass over PED and fan out to groups that contain this animal.
    # This keeps order == PED order, like Hao's ped.overlap.id.
    try:
        with open(ped_file, "r", encoding="utf-8", errors="replace") as fped:
            for raw in fped:
                if not raw.strip():
                    continue
                parts = raw.rstrip("\n").split()
                if len(parts) < 7:
                    # malformed / header-like, skip
                    continue

                V1, V2 = parts[0], parts[1]
                V1n = _norm_id(V1)

                if V1n not in global_ids:
                    # This animal never appears in any (meas, sex) group
                    continue

                # For each group that has this animal, write appropriate lines
                for key, info in group_handles.items():
                    aid_map = info["aid_map"]
                    if V1n not in aid_map:
                        continue

                    z, v, sx = aid_map[V1n]

                    # Copy meta so we don't mutate 'parts' between groups
                    meta = parts[:6]
                    # SEX: 2=f, 1=m
                    meta[4] = "2" if sx == "f" else "1"
                    # PHE: zscore or -9
                    meta[5] = f"{z}" if math.isfinite(z) else "-9"

                    # Reconstruct PED line: meta + genotype pairs as-is
                    out_line = " ".join(meta + parts[6:])
                    info["ped_file"].write(out_line + "\n")

                    # PHENO: FID V1, IID V2, z, value
                    info["phe_file"].write(
                        f"{V1} {V2} "
                        f"{(z if math.isfinite(z) else -9)} "
                        f"{(v if math.isfinite(v) else -9)}\n"
                    )
                    info["wrote_any"] = True
    finally:
        # Close all files
        for info in group_handles.values():
            info["ped_file"].close()
            info["phe_file"].close()

    # Log + sanity: did we write anything at all?
    any_written = any(info["wrote_any"] for info in group_handles.values())

    for (meas, sex), info in group_handles.items():
        if info["wrote_any"]:
            logging.info(
                "[generate_pheno_plink_fast_do] wrote %s, %s, %s",
                info["ped_path"],
                os.path.join(outdir, f"{meas}.{sex}.map"),
                info["phe_path"],
            )
        else:
            # remove empty files to avoid confusing downstream PLINK2
            try:
                if os.path.getsize(info["ped_path"]) == 0:
                    os.remove(info["ped_path"])
                if os.path.getsize(info["phe_path"]) == 0:
                    os.remove(info["phe_path"])
            except OSError:
                pass

    if not any_written:
        raise ValueError(
            "[generate_pheno_plink_fast_do] No DO PED/PHENO files were written because there is no overlap "
            "between pheno['animal_id'] (normalized) and PED V1 (normalized). "
            "Check that the DO PED file actually uses the same animal IDs as the measure CSV."
        )

    return ph


# ------------------------------- WRAPPER ------------------------------- #
def generate_pheno_plink_fast(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
    ncore: int = 1,
    *,
    panel_type: str = "NON_DO",   # <-- explicit control
) -> pd.DataFrame:
    """
    Wrapper that dispatches to DO or NON-DO implementation based on explicit panel_type.
      panel_type ∈ {"DO","NON_DO"}  (default NON_DO to preserve current behavior)
    """
    if pheno is None or pheno.empty:
        os.makedirs(outdir, exist_ok=True)
        return pd.DataFrame()

    pt = (panel_type or "NON_DO").upper()
    if pt == "DO":
        logging.info("[generate_pheno_plink_fast] using DO panel_type")
        return generate_pheno_plink_fast_do(ped_file, map_file, pheno, outdir)

    # Everything else is NON-DO
    logging.info(
        "[generate_pheno_plink_fast] panel_type=%r => NON-DO path",
        panel_type,
    )
    return generate_pheno_plink_fast_non_do(ped_file, map_file, pheno, outdir)


# ----------------------------- Orchestrator ---------------------------- #
def fast_prepare_pylmm_inputs(
    ped_file: str,
    map_file: str,
    measure_id_directory: str,
    measure_ids: List,
    outdir: str,
    ncore: int,
    plink2_path: str,
    *,
    panel_type: str = "NON_DO",
    ped_pheno_field: str = "zscore",
    maf_threshold: Union[float, None] = None,
) -> None:
    """
    Orchestrate extraction and PLINK file generation (public API + panel_type):

      1) Extract phenotype rows for requested measure_ids.
      2) generate_pheno_plink_fast(..., panel_type=...) -> writes <meas>.<sex>.ped/.map/.pheno
      3) PLINK2 from --pedmap -> BED/BIM/FAM   (geno 0.1, mind 0.1)
      4) Rewrite PHENO IIDs from .fam (exact FID/IID order + suffixes)
      5) Kinship from --pedmap (square .rel)
    """
    os.makedirs(outdir, exist_ok=True)

    pheno = extract_pheno_measure(measure_id_directory, measure_ids)
    if pheno is None or pheno.empty:
        logging.info("[fast_prepare_pylmm_inputs] no phenotype rows extracted; nothing to do.")
        return

    used = generate_pheno_plink_fast(
        ped_file=ped_file,
        map_file=map_file,
        pheno=pheno,
        outdir=outdir,
        ncore=ncore,
        panel_type=panel_type,
    )
    if used is None or used.empty:
        logging.info("[fast_prepare_pylmm_inputs] no usable phenotypes after PED/MAP intersection; nothing to do.")
        return

    for measure_id in measure_ids:
        base_id = str(measure_id).split("_", 1)[0]
        for sex in ("f", "m"):
            ped_path   = os.path.join(outdir, f"{base_id}.{sex}.ped")
            map_path   = os.path.join(outdir, f"{base_id}.{sex}.map")
            out_prefix = os.path.join(outdir, f"{base_id}.{sex}")

            if not (os.path.exists(ped_path) and os.path.exists(map_path)):
                continue

            logging.info(f"[fast_prepare_pylmm_inputs] make BED/BIM/FAM for {base_id}.{sex}")
            generate_bed_bim_fam(
                plink2_path=plink2_path,
                ped_file=ped_path,
                map_file=map_path,
                output_prefix=out_prefix,
                relax_mind_threshold=False,
                maf_threshold=maf_threshold,
                sample_keep_path=None,
                autosomes_only=False,
            )

            # Align PHENO IIDs to FAM IIDs (strict 1:1; no dedup)
            fam_path   = f"{out_prefix}.fam"
            pheno_path = os.path.join(outdir, f"{base_id}.{sex}.pheno")
            rewrite_pheno_ids_from_fam(pheno_path, fam_path, pheno_path)

            logging.info(f"[fast_prepare_pylmm_inputs] compute kinship for {base_id}.{sex} (from --pedmap)")
            calculate_kinship_from_pedmap(
                plink2_path=plink2_path,
                pedmap_prefix=out_prefix,
                kin_prefix=os.path.join(outdir, f"{base_id}.{sex}.kin"),
            )


