import pandas as pd
import networkx as nx
from collections import Counter, defaultdict
from pathlib import Path
from configs import *

def parse_files(x):
    # Robust against NaN, None, and string "nan"/"none"
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s or s.lower() in {"nan", "none", "<na>"}:
        return []
    r = []
    for t in s.split('|'):
        t = t.strip()
        if not t:
            continue
        base = t.rsplit(':', 1)[0].strip()
        if base:
            r.append(base)
    return r

def analyze_dag(df):
    df = df[['task_id', 'input_files', 'output_files']].copy()
    df['in_list'] = df['input_files'].apply(parse_files)
    df['out_list'] = df['output_files'].apply(parse_files)

    # Canonicalize tasks by (inputs, outputs) signature
    sig_to_canon = {}
    for r in df.itertuples(index=False):
        sig = (tuple(sorted(set(r.in_list))), tuple(sorted(set(r.out_list))))
        if sig not in sig_to_canon or int(r.task_id) < sig_to_canon[sig]:
            sig_to_canon[sig] = int(r.task_id)

    tid_to_canon = {}
    for r in df.itertuples(index=False):
        sig = (tuple(sorted(set(r.in_list))), tuple(sorted(set(r.out_list))))
        tid_to_canon[int(r.task_id)] = sig_to_canon[sig]

    files_prod = defaultdict(set)
    files_cons = defaultdict(set)
    canon_tasks = set(tid_to_canon.values())

    for r in df.itertuples(index=False):
        ctid = tid_to_canon[int(r.task_id)]
        for f in r.out_list:
            files_prod[f].add(ctid)
        for f in r.in_list:
            files_cons[f].add(ctid)

    G = nx.DiGraph()
    G.add_nodes_from(canon_tasks)

    # Do NOT fail on external inputs; treat them as coming from outside the subgraph
    for f, consumers in files_cons.items():
        producers = files_prod.get(f, set())
        if not producers:
            continue  # external input
        for p in producers:
            for c in consumers:
                if p != c:
                    G.add_edge(p, c)

    if not nx.is_directed_acyclic_graph(G):
        raise ValueError('graph contains a cycle')

    nodes = G.number_of_nodes()
    edges = G.number_of_edges()
    depth = 0 if edges == 0 else nx.dag_longest_path_length(G)

    levels = {}
    for n in nx.topological_sort(G):
        levels[n] = 0 if G.in_degree(n) == 0 else max(levels[p] + 1 for p in G.predecessors(n))
    width = max(Counter(levels.values()).values()) if levels else 0

    indeg = [d for _, d in G.in_degree()]
    outdeg = [d for _, d in G.out_degree()]
    max_in = max(indeg) if indeg else 0
    max_out = max(outdeg) if outdeg else 0
    sources = sum(1 for d in indeg if d == 0)
    sinks = sum(1 for d in outdeg if d == 0)
    components = nx.number_weakly_connected_components(G)

    return {
        "Nodes": nodes,
        "Edges": edges,
        "Depth": depth,
        "Width": width,
        "Max Indegree": max_in,
        "Max Outdegree": max_out,
        "Sources": sources,
        "Sinks": sinks,
        "Components": components
    }

def run_many(paths):
    rows = []
    for name, p in paths.items():
        df = pd.read_csv(
            p,
            usecols=['task_id', 'input_files', 'output_files'],
            dtype={'task_id': 'int64'},   # do NOT force str on file cols
            low_memory=False
        )
        m = analyze_dag(df)
        m['Workflow'] = name
        rows.append(m)
    cols = ["Workflow", "Nodes", "Edges", "Depth", "Width",
            "Max Indegree", "Max Outdegree", "Sources", "Sinks", "Components"]
    return pd.DataFrame(rows)[cols].set_index("Workflow")

if __name__ == "__main__":
    PATHS = {}
    for workflow in WORKFLOWS:
        PATHS[workflow] = Path(LOGS_DIR) / "fault_prevention" / Path(workflow) / "repeat1" / "baseline" / "csv-files" / "task_subgraphs.csv"
    print(run_many(PATHS))
