import pandas as pd
import numpy as np

# Build the matrix exactly as the table design
def build_matrix(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(
        csv_path,
        usecols=["metric", "workflow", "template", "mean"],  # ignore other columns like 'std'
        low_memory=False,
    )

    # Map CSV metrics -> table sections
    metrics_map = {
        "recovery-task-count": "Recovery Task Count",
        "graph-makespan": "Makespan [s]",
        "storage-consumption-peak": "Storage Peak [GB]",  # CSV in MB -> convert to GB
        "pruning-overhead": "Pruning Overhead [s]",
    }

    workflows = ["TopEFT", "RSTriPhoton", "DV5"]

    # Define all rows in the exact visual order
    row_specs = []
    row_specs.append(("Baseline", "", "baseline"))
    for d in [2, 3, 4, 5, 6]:
        row_specs.append(("Prune Depth", str(d), f"prune-depth-{d}"))
    for r in [2, 3, 4, 5, 6]:
        row_specs.append(("Replica Count", str(r), f"replica-count-{r}"))
    for p in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]:
        row_specs.append(("Checkpoint Percentage", f"{int(p*100)}%", f"checkpoint-percentage-{p:.1f}"))

    # Explicit hybrid combinations required in output
    row_specs.append(("hybrid", "PD2-RC2-CP0.1", "PD2-RC2-CP0.1"))
    row_specs.append(("hybrid", "PD3-RC3-CP0.2", "PD3-RC3-CP0.2"))
    row_specs.append(("hybrid", "PD4-RC4-CP0.3", "PD4-RC4-CP0.3"))

    cols = pd.MultiIndex.from_product(
        [list(metrics_map.values()), workflows],
        names=["Metric", "Workflow"],
    )
    idx = pd.MultiIndex.from_tuples([(a, b) for a, b, _ in row_specs], names=["Row", "Value"])
    out = pd.DataFrame(index=idx, columns=cols, dtype="float64")

    mdf = df[df["metric"].isin(metrics_map.keys())].copy()

    # Convert storage peak from MB to GB
    is_storage = mdf["metric"] == "storage-consumption-peak"
    mdf.loc[is_storage, "mean"] = mdf.loc[is_storage, "mean"] / 1024.0

    # Fill the matrix with rounded (0 decimals) means
    for row_group, row_label, tmpl in row_specs:
        sub = mdf[mdf["template"] == tmpl]
        if sub.empty:
            continue
        for m_key, m_name in metrics_map.items():
            ss = sub[sub["metric"] == m_key]
            if ss.empty:
                continue
            for w in workflows:
                v = ss.loc[ss["workflow"] == w, "mean"]
                if not v.empty:
                    out.loc[(row_group, row_label), (m_name, w)] = round(float(v.iloc[0]))

    # Cast to nullable integers to keep blanks as NA while showing 0 decimals
    for m in metrics_map.values():
        for w in workflows:
            out[(m, w)] = out[(m, w)].astype("Int64")

    return out

if __name__ == "__main__":
    csv_path = "results.csv"  # <-- put your CSV path here
    matrix = build_matrix(csv_path)
    print(matrix)
    matrix.to_csv("end_to_end_matrix.csv", na_rep="")  # empty string for missing cells
