import os
import sys
import math
import gzip
import shutil
from collections import defaultdict, OrderedDict

import concurrent.futures
from Bio import SeqIO
from Bio.Seq import Seq

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def classify_peptides(reference,
                      peptide_fasta,
                      output_dir,
                      vcf_path=None,
                      database_path=None,
                      num_threads=1):
    """
    Classify peptides into categories based on canonical proteome, alternative splicing,
    mutated antigens (mutanome), alternative ORFs, amino acid misincorporations, and unaligned.

    Required arguments
    ------------------
    reference : str
        Reference name (hg19, hg38, mm10, mm39). Case-insensitive.
    peptide_fasta : str
        Path to peptide FASTA file.
    output_dir : str
        Output directory path.

    Optional arguments
    ------------------
    vcf_path : str or None
        Path to VCF or VCF.GZ file. If None or file does not exist, no SNVs are used.
    database_path : str or None
        Path to a directory containing precomputed database FASTA files
        (canonicalProteome.fa, alternativeSplicing.fa, mutanome.fa,
         mutatedCanonicalTranscriptome.fa, mutatedAlternativeTranslatome.fa,
         mutatedAlternativeORFeome.fa). If None, a new database directory
        called "database" will be created under output_dir and populated.
        If the provided directory is invalid or missing any required file,
        it is ignored and a new database is built from scratch.
    num_threads : int
        Number of threads to use for amino acid misincorporation search (step 7).
        If <= 1, runs single-threaded.
    """

    # -------------------------- simple step progress bar --------------------------
    steps = [
        "Filter VCF to exome",
        "Setup and load transcriptome/CDS/knownCanonical",
        "Build canonical / non-canonical transcript sets",
        "Generate canonical proteome and classify canonical peptides",
        "Generate alternative splicing proteome and classify peptides",
        "Apply SNVs, generate mutanome and classify neoantigens",
        "Generate alternative ORFs and classify peptides",
        "Identify amino acid misincorporations",
        "Write unaligned peptides and pie chart",
        "Finalize",
    ]
    total_steps = len(steps)
    current_step = 0

    def report_step_done():
        # use nonlocal to modify current_step inside nested function
        nonlocal current_step
        current_step += 1
        bar_len = 40
        filled = int(bar_len * float(current_step) / float(total_steps))
        bar = "#" * filled + "-" * (bar_len - filled)
        msg = "[{bar}] {i}/{total} - {desc}".format(
            bar=bar,
            i=current_step,
            total=total_steps,
            desc=steps[current_step - 1],
        )
        print(msg, file=sys.stderr)
        sys.stderr.flush()

    # -------------------------- helpers / normalizers ----------------------------

    def safe_translate_nt(nt_seq):
        """
        Translate nucleotide sequence safely:
        - Trim last 1 or 2 nt if length not multiple of 3 (to avoid partial codon warning).
        - Return protein string (may be empty).
        """
        if nt_seq is None:
            return ""
        nt_seq = str(nt_seq).upper()
        if not nt_seq:
            return ""
        remainder = len(nt_seq) % 3
        if remainder != 0:
            nt_seq = nt_seq[:-remainder]
        if len(nt_seq) < 3:
            return ""
        return str(Seq(nt_seq).translate(to_stop=False))

    def normalize_chrom(chrom):
        """
        Normalize chromosome names so VCF and GFF match.
        Example: 'chr1' -> '1', '1' -> '1'
        """
        chrom = str(chrom)
        if chrom.startswith("chr"):
            chrom = chrom[3:]
        return chrom

    def normalize_gff_tx_id(tx_id):
        """
        Normalize GFF transcript IDs to match FASTA IDs.
        Example: 'transcript:ENST00000335137.4' -> 'ENST00000335137.4'
        """
        tx_id = tx_id.split()[0]
        if ":" in tx_id:
            tx_id = tx_id.split(":", 1)[1]
        return tx_id

    # ----------------------------- basic paths & setup -----------------------------
    reference = str(reference).lower()
    if reference not in ("hg19", "hg38", "mm10", "mm39"):
        raise ValueError("Unsupported reference: {} (expected hg19/hg38/mm10/mm39)".format(reference))

    output_dir = os.path.abspath(output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Decide on database directory and whether to build it
    required_db_files = [
        "alternativeSplicing.fa",
        "mutatedAlternativeORFeome.fa",
        "canonicalProteome.fa",
        "mutatedAlternativeTranslatome.fa",
        "mutanome.fa",
        "mutatedCanonicalTranscriptome.fa",
    ]

    build_database = True

    if database_path is not None:
        candidate = os.path.abspath(database_path)
        if os.path.isdir(candidate):
            missing = [f for f in required_db_files if not os.path.exists(os.path.join(candidate, f))]
            if not missing:
                database_dir = candidate
                build_database = False
            else:
                print(
                    "[WARN] Provided database_path '{}' is missing required files: {}. "
                    "Ignoring it and rebuilding database from scratch.".format(
                        candidate, ", ".join(missing)
                    ),
                    file=sys.stderr,
                )
                database_dir = os.path.join(output_dir, "database")
                os.makedirs(database_dir, exist_ok=True)
        else:
            print(
                "[WARN] Provided database_path '{}' is not an existing directory. "
                "Ignoring it and rebuilding database from scratch.".format(candidate),
                file=sys.stderr,
            )
            database_dir = os.path.join(output_dir, "database")
            os.makedirs(database_dir, exist_ok=True)
    else:
        database_dir = os.path.join(output_dir, "database")
        if not os.path.exists(database_dir):
            os.makedirs(database_dir)

    # Resolve genome directory relative to this file so installed package works
    here = os.path.dirname(os.path.abspath(__file__))
    genome_root = os.path.join(here, "genome", reference)

    # helper to find a single file by glob-like pattern
    def find_single_file(prefix, suffix):
        """
        Find file with given prefix & suffix inside genome_root.
        Examples:
          prefix="gencode", suffix=".{}.gff".format(reference)
          prefix="knownCanonical", suffix=".{}.list".format(reference)
        """
        candidates = []
        for fname in os.listdir(genome_root):
            if fname.startswith(prefix) and fname.endswith(suffix):
                candidates.append(os.path.join(genome_root, fname))
        if not candidates:
            raise IOError("Could not find file {}*{} in {}".format(prefix, suffix, genome_root))
        if len(candidates) > 1:
            # arbitrarily choose the first, but deterministic order
            candidates.sort()
        return candidates[0]

    # Files described in the prompt
    transcriptome_fa = os.path.join(genome_root, "transcriptome.{}.fa".format(reference))
    cds_bed = os.path.join(genome_root, "transcriptome.{}.cds.bed".format(reference))
    known_canonical_list = os.path.join(genome_root, "knownCanonical.{}.list".format(reference))
    if not os.path.exists(known_canonical_list):
        # allow versioned variants like knownCanonical.vX.hg38.list
        known_canonical_list = find_single_file("knownCanonical", ".{}.list".format(reference))

    gff_path = find_single_file("gencode", ".{}.gff".format(reference))

    # ----------------------------- tiny helpers -----------------------------------

    def read_canonical_ids(path):
        ids = set()
        with open(path) as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                ids.add(line.split()[0])
        return ids

    canonical_ids = read_canonical_ids(known_canonical_list)

    def fasta_id_core(rec_id):
        """
        Take first whitespace-separated token and then strip version suffix (.1, .2, etc).
        """
        core = rec_id.split()[0]
        return core.split(".")[0]

    def read_peptide_fasta(path):
        peptides = OrderedDict()
        for rec in SeqIO.parse(path, "fasta"):
            seq = str(rec.seq).strip().upper()
            if not seq:
                continue
            peptides[rec.id] = seq
        return peptides

    def write_fasta_single_line(records, path):
        """
        records: 
          - iterable of (id, seq) 
          - or dict id -> seq
          - or dict id -> (seq, source_id)   # e.g. seq + transcript/protein ID

        If records contain (seq, source_id), the header becomes:
          >peptideID | sourceID
        """
        with open(path, "w") as out:
            if isinstance(records, dict):
                iterator = records.items()
            else:
                iterator = records
            for rid, val in iterator:
                # Allow either seq or (seq, source_id)
                if isinstance(val, tuple) and len(val) == 2:
                    seq, source_id = val
                    header_id = f"{rid} | {source_id}"
                else:
                    seq = val
                    header_id = rid
                out.write(f">{header_id}\n{str(seq).strip()}\n")

    def load_transcriptome(path):
        """
        Return dict: transcript_id -> nucleotide sequence (str, upper-case)
        """
        tx = {}
        for rec in SeqIO.parse(path, "fasta"):
            tx_id = rec.id.split()[0]
            tx[tx_id] = str(rec.seq).upper()
        return tx

    def load_cds_bed(path):
        """
        Return dict: transcript_id -> list of (start, end)
        BED coordinates assumed 0-based, half-open [start,end).
        """
        cds = defaultdict(list)
        with open(path) as fh:
            for line in fh:
                if not line.strip() or line.startswith("#"):
                    continue
                fields = line.rstrip("\n").split("\t")
                if len(fields) < 3:
                    continue
                tx_id = fields[0]
                try:
                    start = int(fields[1])-1
                    end = int(fields[2])
                except ValueError:
                    continue
                cds[tx_id].append((start, end))
        # sort segments by start
        for tx_id in list(cds.keys()):
            cds[tx_id].sort(key=lambda x: x[0])
        return cds

    def translate_cds(transcripts, cds_map, tx_ids_subset=None):
        """
        Translate CDS regions into protein sequences using safe_translate_nt.
        transcripts: dict transcript_id -> nt sequence
        cds_map: dict transcript_id -> list of (start,end) (0-based, half-open)
        tx_ids_subset: optional set of transcript_ids to restrict to
        Returns dict: protein_id -> aa sequence
        """
        proteins = {}
        for tx_id, seq in transcripts.items():
            if tx_ids_subset is not None and tx_id not in tx_ids_subset:
                continue
            if tx_id not in cds_map:
                continue
            cds_seq = []
            for start, end in cds_map[tx_id]:
                if start < 0 or end > len(seq) or start >= end:
                    continue
                cds_seq.append(seq[start:end])
            if not cds_seq:
                continue
            nt_seq = "".join(cds_seq)
            aa_seq = safe_translate_nt(nt_seq)
            if not aa_seq:
                continue
            proteins[tx_id] = aa_seq
        return proteins

    def find_exact_matches(peptides, proteome_fasta_path, step_label=None):
        """
        peptides: dict pep_id -> seq
        proteome_fasta_path: FASTA with protein sequences

        Return: (matches_dict, remaining_peptides_dict)
        where:
          matches_dict: pep_id -> (pep_seq, ref_entry_id)
          remaining_peptides_dict: pep_id -> pep_seq
        """
        # Load all proteins as (normalized_id, seq),
        # where normalized_id = first substring before '_'
        proteins = []
        for rec in SeqIO.parse(proteome_fasta_path, "fasta"):
            full_id = rec.id.split()[0]   # drop whitespace stuff
            norm_id = full_id.split("_", 1)[0]  # take first chunk before '_'
            proteins.append((norm_id, str(rec.seq).upper()))

        matched = OrderedDict()
        remaining = OrderedDict()
        total = float(len(peptides)) if peptides else 1.0
        idx = 0

        for pid, pep_seq in peptides.items():
            idx += 1
            found = False
            found_ref_id = None
            for prot_id, prot_seq in proteins:
                if pep_seq in prot_seq:
                    found = True
                    found_ref_id = prot_id
                    break
            if found:
                matched[pid] = (pep_seq, found_ref_id)
            else:
                remaining[pid] = pep_seq

            if step_label is not None and idx % 100 == 0:
                frac = idx / total
                bar_len = 30
                filled = int(bar_len * frac)
                bar = "#" * filled + "-" * (bar_len - filled)
                msg = "[{bar}] {done}/{tot} peptides - {label}".format(
                    bar=bar,
                    done=idx,
                    tot=int(total),
                    label=step_label,
                )
                print(msg, file=sys.stderr)
                sys.stderr.flush()

        return matched, remaining

    def hamming_distance(s1, s2):
        if len(s1) != len(s2):
            return None
        d = 0
        for a, b in zip(s1, s2):
            if a != b:
                d += 1
                if d > 1:
                    break
        return d

    def find_hamming_leq1_matches(peptides,
                                  proteome_fasta_path,
                                  step_label=None,
                                  num_threads=1):
        """
        peptides: dict pep_id -> seq

        Return: (matches_dict, remaining_peptides_dict)
        where match if any substring of protein has Hamming distance <=1

        Strategy:
          - For each peptide, generate all sequences at Hamming distance <=1
            (including the original).
          - Use Python's fast substring search (in) over protein sequences.
          - Parallelize over peptides.
        """

        # Load all proteins once as (normalized_id, seq)
        proteins = []
        for rec in SeqIO.parse(proteome_fasta_path, "fasta"):
            full_id = rec.id.split()[0]
            norm_id = full_id.split("_", 1)[0]
            proteins.append((norm_id, str(rec.seq).upper()))

        # Amino acid alphabet – adjust if you have non-standard residues
        AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWY*"

        def generate_hamming_leq1_variants(seq):
            """
            Generate all sequences with Hamming distance <=1 from seq,
            including seq itself.
            """
            variants = set()
            L = len(seq)
            # distance 0
            variants.add(seq)
            # distance 1
            for i in range(L):
                orig = seq[i]
                for aa in AA_ALPHABET:
                    if aa == orig:
                        continue
                    variants.add(seq[:i] + aa + seq[i+1:])
            return variants

        matched = OrderedDict()
        remaining = OrderedDict()
        total = float(len(peptides)) if peptides else 1.0

        def worker(item):
            """
            Check a single peptide for Hamming distance <=1 to any window
            in any protein, using variant generation + substring search.

            Returns (pid, pep_seq, found_bool, ref_entry_id_or_None).
            """
            pid, pep_seq = item
            pep_len = len(pep_seq)
            if pep_len == 0:
                return pid, pep_seq, False, None

            variants = generate_hamming_leq1_variants(pep_seq)

            # Try to find any variant in any protein
            for prot_id, prot_seq in proteins:
                # quick length check – skip too-short proteins
                if pep_len > len(prot_seq):
                    continue
                for v in variants:
                    if v in prot_seq:
                        return pid, pep_seq, True, prot_id
            return pid, pep_seq, False, None

        items = list(peptides.items())
        if num_threads is None or num_threads <= 1:
            iterator = map(worker, items)
        else:
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
                iterator = executor.map(worker, items, chunksize=1)

        idx = 0
        for pid, pep_seq, found, ref_entry_id in iterator:
            idx += 1
            if found:
                matched[pid] = (pep_seq, ref_entry_id)
            else:
                remaining[pid] = pep_seq

            if step_label is not None and idx % 100 == 0:
                frac = idx / total
                bar_len = 30
                filled = int(bar_len * frac)
                bar = "#" * filled + "-" * (bar_len - filled)
                msg = "[{bar}] {done}/{tot} peptides - {label}".format(
                    bar=bar,
                    done=idx,
                    tot=int(total),
                    label=step_label,
                )
                print(msg, file=sys.stderr)
                sys.stderr.flush()

        return matched, remaining
    
    # mark: step 1 done after loading transcriptome/CDS/knownCanonical etc
    # (we finish this after VCF loading + transcriptome/CDS load below)

    # ----------------------------- VCF handling + exome filter -------------------
    # Only needed if we are actually building the database. If we are reusing
    # an existing database, we skip reading the VCF entirely.
    if build_database:
        snvs = []  # list of (chrom, pos_int, ref, alt)

        exome_bed = os.path.join(genome_root, f"exome.{reference}.bed")

        def load_exome_bed(path):
            """
            Load exome BED (1-based, inclusive) into a per-chrom index suitable
            for fast position lookup.
            Returns: dict chrom -> (starts_list, ends_list), both sorted.
            """
            from bisect import bisect_right  # not actually used here but fine
            exome_raw = defaultdict(list)
            with open(path) as fh:
                for line in fh:
                    line = line.strip()
                    if not line or line.startswith("#"):
                        continue
                    fields = line.split("\t")
                    if len(fields) < 3:
                        continue
                    chrom = normalize_chrom(fields[0])
                    try:
                        start = int(fields[1])
                        end = int(fields[2])
                    except ValueError:
                        continue
                    if start <= end:
                        exome_raw[chrom].append((start, end))

            exome_index = {}
            for chrom, intervals in exome_raw.items():
                intervals.sort(key=lambda x: x[0])
                starts = [s for s, e in intervals]
                ends = [e for s, e in intervals]
                exome_index[chrom] = (starts, ends)
            return exome_index

        def pos_in_exome(chrom, pos, exome_index):
            """
            Check if a 1-based position (pos) on chrom lies inside any exome interval.
            """
            from bisect import bisect_right
            if exome_index is None:
                return True  # no exome file → accept all
            if chrom not in exome_index:
                return False
            starts, ends = exome_index[chrom]
            idx = bisect_right(starts, pos) - 1
            if idx < 0:
                return False
            return pos <= ends[idx]

        # Try to load exome index if file exists
        exome_index = None
        if os.path.exists(exome_bed):
            exome_index = load_exome_bed(exome_bed)

        def parse_vcf_line(line):
            """
            Parse a single VCF line into a list of SNV tuples.
            Returns: list[(chrom, pos_int, ref, alt)]
            """
            line = line.strip()
            if not line or line.startswith("#"):
                return []
            fields = line.split("\t")
            if len(fields) < 5:
                return []
            chrom = normalize_chrom(fields[0])
            try:
                pos = int(fields[1])  # VCF POS is 1-based
            except ValueError:
                return []
            ref = fields[3].upper()
            alts = fields[4].split(",")
            out = []
            for alt in alts:
                alt = alt.upper()
                if len(ref) == 1 and len(alt) == 1:
                    out.append((chrom, pos, ref, alt))
            return out

        def vcf_line_iterator(path):
            """
            Stream lines from VCF or VCF.GZ.
            """
            if path.endswith(".gz"):
                fh = gzip.open(path, "rt")
            else:
                fh = open(path, "r")
            with fh:
                for line in fh:
                    if not line.strip() or line.startswith("#"):
                        continue
                    yield line

        if vcf_path is not None and os.path.exists(vcf_path):
            vcf_path_abs = os.path.abspath(vcf_path)

            # Single-threaded streaming parse + exome filter
            for line in vcf_line_iterator(vcf_path_abs):
                for chrom, pos, ref, alt in parse_vcf_line(line):
                    if pos_in_exome(chrom, pos, exome_index):
                        snvs.append((chrom, pos, ref, alt))
        else:
            snvs = []
    else:
        # Not building database: we won't use SNVs at all.
        snvs = []

    report_step_done()  # step 1 done
    
    # ----------------------------- transcriptome and CDS -------------------------

    # Load whole transcriptome once
    transcriptome = load_transcriptome(transcriptome_fa)
    cds_map = load_cds_bed(cds_bed)

    report_step_done()  # step 2 done

    # Build canonical / non-canonical transcript ID sets (full IDs as in FASTA)
    canonical_tx_ids = set()
    noncanonical_tx_ids = set()
    for tx_id in transcriptome.keys():
        core = fasta_id_core(tx_id)
        if core in canonical_ids:
            canonical_tx_ids.add(tx_id)
        else:
            noncanonical_tx_ids.add(tx_id)

    report_step_done()  # step 3 done

    # ----------------------------- canonical proteome ----------------------------

    canonical_proteome_fa_tmp = os.path.join(database_dir, "canonicalProteome.fa")
    if build_database:
        canonical_proteins = translate_cds(transcriptome, cds_map, tx_ids_subset=canonical_tx_ids)
        write_fasta_single_line(canonical_proteins, canonical_proteome_fa_tmp)

    # Load peptides
    peptides_all = read_peptide_fasta(peptide_fasta)

    # Find canonical matches
    canonical_hits, peptides_remaining = find_exact_matches(
        peptides_all,
        canonical_proteome_fa_tmp,
        step_label="canonical classification"
    )
    canonical_out_fa = os.path.join(output_dir, "canonicalProteome.fa")
    if canonical_hits:
        write_fasta_single_line(canonical_hits, canonical_out_fa)

    report_step_done()  # step 4 done

    # -------------------------- alternative splicing proteome --------------------

    alt_splicing_proteome_fa_tmp = os.path.join(database_dir, "alternativeSplicing.fa")
    if build_database:
        alt_splicing_proteins = translate_cds(transcriptome, cds_map, tx_ids_subset=noncanonical_tx_ids)
        write_fasta_single_line(alt_splicing_proteins, alt_splicing_proteome_fa_tmp)

    alt_splicing_hits, peptides_remaining = find_exact_matches(
        peptides_remaining,
        alt_splicing_proteome_fa_tmp,
        step_label="alternative splicing classification"
    )
    alt_splicing_out_fa = os.path.join(output_dir, "alternativeSplicing.fa")
    if alt_splicing_hits:
        write_fasta_single_line(alt_splicing_hits, alt_splicing_out_fa)

    report_step_done()  # step 5 done

    # ----------------------------- mutated antigens (mutanome) -------------------

    mutated_canonical_tx_fa_tmp = os.path.join(database_dir, "mutatedCanonicalTranscriptome.fa")
    mutanome_fa_tmp = os.path.join(database_dir, "mutanome.fa")

    if build_database:
        # Parse GFF to map genomic SNVs to transcript coordinates
        # We build transcript -> chrom, strand, exon intervals (genomic coordinates, 1-based closed)
        transcript_exons = defaultdict(list)
        transcript_strand = {}
        transcript_chrom = {}

        def parse_gff_attributes(attr_str):
            attrs = {}
            for item in attr_str.strip().split(";"):
                item = item.strip()
                if not item:
                    continue
                if "=" in item:
                    key, val = item.split("=", 1)
                    val = val.strip().strip('"')
                else:
                    parts = item.split()
                    if len(parts) >= 2:
                        key = parts[0]
                        val = parts[1].strip('"')
                    else:
                        continue
                attrs[key] = val
            return attrs

        with open(gff_path) as fh:
            for line in fh:
                if not line.strip() or line.startswith("#"):
                    continue
                fields = line.rstrip("\n").split("\t")
                if len(fields) < 9:
                    continue
                chrom, source, feature, start, end, score, strand, frame, attrs_str = fields

                # Use only exon features to avoid double counting exon+CDS
                if feature.lower() != "exon":
                    continue

                try:
                    start = int(start)
                    end = int(end)
                except ValueError:
                    continue

                attrs = parse_gff_attributes(attrs_str)

                # Be strict: use only transcript_id / transcriptId, to match gff_to_bed
                tx_id = attrs.get("transcript_id") or attrs.get("transcriptId")
                if tx_id is None:
                    # No transcript-level ID → skip; prevents grouping by exon ID
                    continue

                tx_id = normalize_gff_tx_id(tx_id)
                chrom_norm = normalize_chrom(chrom)

                transcript_exons[tx_id].append((start, end))
                transcript_strand[tx_id] = strand
                transcript_chrom[tx_id] = chrom_norm

        # Sort exons
        for tx_id in list(transcript_exons.keys()):
            transcript_exons[tx_id].sort(key=lambda x: x[0])

        # reverse complement helper (unchanged)
        complement_map = {
            "A": "T",
            "T": "A",
            "C": "G",
            "G": "C",
            "a": "t",
            "t": "a",
            "c": "g",
            "g": "c"
        }

        def complement_base(b):
            return complement_map.get(b, b)

        # Precompute exon ordering & lengths per transcript for transcript coord mapping
        exon_order_cache = {}  # tx_id -> (ordered_exons, total_len, ordered_exons_desc)
        for tx_id, exons in transcript_exons.items():
            exons_sorted = sorted(exons, key=lambda x: x[0])
            total_len = 0
            for s, e in exons_sorted:
                total_len += (e - s + 1)
            exons_desc = list(reversed(exons_sorted))
            exon_order_cache[tx_id] = (exons_sorted, total_len, exons_desc)

        # NEW: index SNVs by chromosome
        snvs_by_chrom = defaultdict(list)
        for chrom, pos, ref, alt in snvs:
            snvs_by_chrom[chrom].append((pos, ref, alt))
        for chrom in snvs_by_chrom:
            snvs_by_chrom[chrom].sort(key=lambda x: x[0])

        def apply_snvs_to_transcript(tx_id):
            """
            Apply all SNVs to a single canonical transcript (if applicable).
            Returns: (tx_id, list_of_chars_sequence)
            """
            # Only canonical transcripts that exist in transcriptome are mutated
            if tx_id not in transcriptome:
                return tx_id, []

            seq_list = list(transcriptome[tx_id])
            chrom = transcript_chrom.get(tx_id)
            if chrom is None or chrom not in snvs_by_chrom:
                # no SNVs on this chromosome → return original seq
                return tx_id, seq_list

            if tx_id not in exon_order_cache:
                # no exon info for this transcript
                return tx_id, seq_list

            exons_sorted, total_len, exons_desc = exon_order_cache[tx_id]
            strand = transcript_strand.get(tx_id, "+")

            # SNVs already filtered to exome + indexed by chrom
            for pos, ref, alt in snvs_by_chrom[chrom]:
                if strand == "+":
                    offset = 0
                    within = False
                    for s, e in exons_sorted:
                        if pos < s:
                            break
                        if pos > e:
                            offset += (e - s + 1)
                        else:
                            offset += (pos - s)
                            within = True
                            break
                    if not within:
                        continue
                    tx_index = offset  # 0-based index in transcript sequence
                    alt_base = alt
                    expected_ref = ref
                else:
                    # minus strand: transcript 5'->3' is reverse complement,
                    # so exons in descending order
                    offset_from_5prime = 0
                    within = False
                    for s, e in exons_desc:
                        if pos > e:
                            # position is more 5' than this exon on minus strand
                            continue
                        if pos < s:
                            offset_from_5prime += (e - s + 1)
                        else:
                            offset_from_5prime += (e - pos)
                            within = True
                            break
                    if not within:
                        continue
                    tx_index = offset_from_5prime
                    alt_base = complement_base(alt)
                    expected_ref = complement_base(ref)

                if 0 <= tx_index < len(seq_list):
                    current_ref = seq_list[tx_index].upper()
                    if expected_ref and current_ref != expected_ref:
                        # Warn but do not mutate
                        print(
                            f"[WARN] Ref base mismatch for {tx_id} at transcript index {tx_index}: "
                            f"expected {expected_ref}, saw {current_ref} (chrom {chrom}, pos {pos})",
                            file=sys.stderr
                        )
                        continue
                    seq_list[tx_index] = alt_base

            return tx_id, seq_list

        # NEW: parallel apply SNVs per transcript (canonical only)
        mutated_transcripts = {}

        canonical_tx_ids_list = [tx for tx in canonical_tx_ids if tx in transcriptome]

        if num_threads is not None and num_threads > 1:
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
                for tx_id, seq_list in executor.map(apply_snvs_to_transcript, canonical_tx_ids_list, chunksize=50):
                    if seq_list:
                        mutated_transcripts[tx_id] = seq_list
        else:
            # single-threaded fallback
            for tx_id in canonical_tx_ids_list:
                tx_id_res, seq_list = apply_snvs_to_transcript(tx_id)
                if seq_list:
                    mutated_transcripts[tx_id_res] = seq_list

        # Convert mutated_transcripts back to strings and write mutatedCanonicalTranscriptome.fa
        mutated_canonical_tx_dict = {}
        for tx_id, seq_list in mutated_transcripts.items():
            mutated_canonical_tx_dict[tx_id] = "".join(seq_list)
        write_fasta_single_line(mutated_canonical_tx_dict, mutated_canonical_tx_fa_tmp)

        # Reload mutated canonical transcriptome into dict for translation
        # IMPORTANT: mutanome = canonicalProteome with SNVs applied,
        # so we must translate ONLY CDS regions, just like canonicalProteome.
        mutated_canonical_transcripts = load_transcriptome(mutated_canonical_tx_fa_tmp)

        # Translate CDS of mutated canonical transcripts
        # If there are no SNVs, mutated_canonical_transcripts == canonical transcripts,
        # so mutanome_proteins == canonical_proteins (as desired).
        mutanome_proteins = translate_cds(
            mutated_canonical_transcripts,
            cds_map,
            tx_ids_subset=None  # translated for all mutated canonical transcripts present in cds_map
        )

        write_fasta_single_line(mutanome_proteins, mutanome_fa_tmp)
    else:
        # Load precomputed mutated canonical transcriptome from database
        mutated_canonical_transcripts = load_transcriptome(mutated_canonical_tx_fa_tmp)

    neoantigen_hits, peptides_remaining = find_exact_matches(
        peptides_remaining,
        mutanome_fa_tmp,
        step_label="mutanome classification"
    )
    neoantigen_out_fa = os.path.join(output_dir, "neoantigen.fa")
    if neoantigen_hits:
        write_fasta_single_line(neoantigen_hits, neoantigen_out_fa)

    report_step_done()  # step 6 done

    # -------------------------- alternative ORFs (3 reading frames) --------------

    alt_orf_fa_translatome_tmp = os.path.join(database_dir, "mutatedAlternativeTranslatome.fa")
    alt_orf_fa_orfeome_tmp = os.path.join(database_dir, "mutatedAlternativeORFeome.fa")

    if build_database:
        # Translate all three frames of mutatedCanonicalTranscriptome
        alt_orf_records = {}
        for tx_id, nt_seq in mutated_canonical_transcripts.items():
            for frame in (0, 1, 2):
                sub_seq = nt_seq[frame:]
                aa_seq = safe_translate_nt(sub_seq)
                if not aa_seq:
                    continue
                rid = "{}_frame{}".format(tx_id, frame)
                alt_orf_records[rid] = aa_seq

        write_fasta_single_line(alt_orf_records, alt_orf_fa_translatome_tmp)
        # keep both filenames as per prompt terminology
        write_fasta_single_line(alt_orf_records, alt_orf_fa_orfeome_tmp)

    alternative_orf_hits, peptides_remaining = find_exact_matches(
        peptides_remaining,
        alt_orf_fa_translatome_tmp,
        step_label="alternative ORF classification"
    )
    alternative_orf_out_fa = os.path.join(output_dir, "alternativeReadingFrame.fa")
    if alternative_orf_hits:
        write_fasta_single_line(alternative_orf_hits, alternative_orf_out_fa)

    report_step_done()  # step 7 done

    # -------------------------- amino acid misincorporations ----------------------

    misincorporation_hits, peptides_remaining = find_hamming_leq1_matches(
        peptides_remaining,
        alt_orf_fa_orfeome_tmp,
        step_label="amino acid misincorporation search",
        num_threads=num_threads,
    )
    misincorporation_out_fa = os.path.join(output_dir, "aminoAcidMisincorporation.fa")
    if misincorporation_hits:
        write_fasta_single_line(misincorporation_hits, misincorporation_out_fa)

    report_step_done()  # step 8 done

    # -------------------------- unaligned peptides --------------------------------

    unaligned_out_fa = os.path.join(output_dir, "unknown.fa")
    if peptides_remaining:
        write_fasta_single_line(peptides_remaining, unaligned_out_fa)

    # -------------------------- pie chart of category counts ----------------------

    counts = OrderedDict()
    counts["canonical"] = len(canonical_hits)
    counts["alternativeSplicing"] = len(alt_splicing_hits)
    counts["neoantigen"] = len(neoantigen_hits)
    counts["alternativeReadingFrame"] = len(alternative_orf_hits)
    counts["aminoAcidMisincorporation"] = len(misincorporation_hits)
    counts["unknown"] = len(peptides_remaining)

    # -------------------------- write pieChart.tsv ----------------------
    tsv_path = os.path.join(output_dir, "pieChart.tsv")
    with open(tsv_path, "w") as tsv:
        tsv.write("Category\tCount\n")
        for cat, cnt in counts.items():
            tsv.write(f"{cat}\t{cnt}\n")

    # -------------------------- produce pie chart (hex colors) --------------------
    category_keys = [
        "canonical",
        "alternativeSplicing",
        "neoantigen",
        "alternativeReadingFrame",
        "aminoAcidMisincorporation",
        "unknown",
    ]

    legend_labels = [
        "canonical proteome",
        "alternative splicing",
        "neoantigen",
        "alternative reading frame",
        "amino acid misincorporation",
        "unknown",
    ]

    # Hexadecimal colors
    colors = [
        "#263b81",  
        "#0578a6",  
        "#64cdf6",  
        "#d71f26",  
        "#f493a9",  
        "#e5e5e5",  
    ]

    # Sizes in fixed order
    sizes_all = [counts[k] for k in category_keys]

    nonzero_indices = [i for i, s in enumerate(sizes_all) if s > 0]

    if nonzero_indices:
        pie_sizes = [sizes_all[i] for i in nonzero_indices]
        pie_colors = [colors[i] for i in nonzero_indices]

        matplotlib.rcParams['font.family'] = 'Arial'
        matplotlib.rcParams['font.size'] = 14

        max_label_len = max(len(lbl) for lbl in legend_labels)
        legend_width = max(2.0, 0.10 * max_label_len)

        pie_width = 4.0
        pie_height = 4.0
        fig_width = pie_width + legend_width
        fig_height = pie_height

        fig = plt.figure(figsize=(fig_width, fig_height))
        pie_ax_width_fraction = pie_width / fig_width
        ax = fig.add_axes([0.0, 0.0, pie_ax_width_fraction, 1.0])

        wedges, _ = ax.pie(
            pie_sizes,
            colors=pie_colors,
            startangle=90,
            counterclock=False
        )
        ax.axis('equal')

        legend_handles = [
            Patch(facecolor=colors[i], edgecolor='none')
            for i in range(len(category_keys))
        ]

        ax.legend(
            legend_handles,
            legend_labels,
            loc='center left',
            bbox_to_anchor=(1.02, 0.5),
            fontsize=14,
            frameon=False
        )

        pie_path = os.path.join(output_dir, "pieChart.pdf")
        fig.savefig(pie_path, format="pdf", dpi=1200, bbox_inches='tight')
        plt.close(fig)

    report_step_done()  # step 9 done

    # -------------------------- finalize -----------------------------------------

    report_step_done()  # step 10 done

