from typing import Callable

from phylogenie.msa import MSA, Sequence
from phylogenie.tree import Tree


def _parse_newick(newick: str) -> Tree:
    newick = newick.strip()
    stack: list[list[Tree]] = []
    current_children: list[Tree] = []
    current_nodes: list[Tree] = []
    i = 0
    while i < len(newick):

        def _parse_chars(stoppers: list[str]) -> str:
            nonlocal i
            chars = ""
            while newick[i] not in stoppers:
                chars += newick[i]
                i += 1
            return chars

        if newick[i] == "(":
            stack.append(current_nodes)
            current_nodes = []
        else:
            id = _parse_chars([":", ",", ")", ";"])
            branch_length = None
            if newick[i] == ":":
                i += 1
                branch_length = _parse_chars([",", ")", ";"])

            current_node = Tree(
                id=id,
                branch_length=(None if branch_length is None else float(branch_length)),
            )
            for node in current_children:
                current_node.add_child(node)
                current_children = []
            current_nodes.append(current_node)

            if newick[i] == ")":
                current_children = current_nodes
                current_nodes = stack.pop()
            elif newick[i] == ";":
                return current_node

        i += 1

    raise ValueError("Newick string does not end with a semicolon.")


def load_newick(filepath: str) -> Tree | list[Tree]:
    with open(filepath, "r") as file:
        trees = [_parse_newick(newick) for newick in file]
    return trees[0] if len(trees) == 1 else trees


def _to_newick(tree: Tree) -> str:
    children_newick = ",".join([_to_newick(child) for child in tree.children])
    newick = tree.id
    if children_newick:
        newick = f"({children_newick}){newick}"
    if tree.branch_length is not None:
        newick += f":{tree.branch_length}"
    return newick


def dump_newick(tree: Tree, filepath: str) -> None:
    with open(filepath, "w") as file:
        file.write(_to_newick(tree) + ";")


def load_fasta(
    fasta_file: str, extract_time_from_id: Callable[[str], float] | None = None
) -> MSA:
    sequences: list[Sequence] = []
    with open(fasta_file, "r") as f:
        for line in f:
            if not line.startswith(">"):
                raise ValueError(f"Invalid FASTA format: expected '>', got '{line[0]}'")
            id = line[1:].strip()
            if extract_time_from_id is not None:
                time = extract_time_from_id(id)
            else:
                try:
                    time = float(id.split("|")[-1])
                except:
                    time = None
            chars = next(f).strip()
            sequences.append(Sequence(id, chars, time))
    return MSA(sequences)
