import copy
import numpy as np
import matplotlib.pyplot as plt

from .Logical import LogicalCircuit, LogicalStatevector, logical_state_fidelity
from .Execution import execute_circuits

from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, DensityMatrix, state_fidelity
from qiskit_addon_utils.slicing import slice_by_depth

from typing import TYPE_CHECKING
from typing import Iterable

def create_qec_schedules(circuit, schedule_type="optimized", interval=5, constraint_model=None):
    slices = slice_by_depth(circuit, 1)
    circuit_depth = len(slices)

    # Get depth of each instruction
    depths = []
    for d, slice in enumerate(slices):
        depths.extend([d]*len(slice.data))

    # Count non-trivial operations
    non_trivial_ops = [inst for inst in circuit.data 
                       if inst.operation.name not in ['barrier', 'measure']]
    circuit_length = len(non_trivial_ops)

    if schedule_type == "optimized":
        # Return None - will use the actual optimize_qec_cycle_indices method on LogicalCircuit
        return None

    elif schedule_type == "random":
        num_qec = max(1, circuit_depth // 4)
        if circuit_length > 1:
            indices = sorted([int(i) for i in np.random.choice(range(1, circuit_length), size=min(num_qec, circuit_length-1), replace=False)])
        else:
            indices = []
        return {0: indices}

    elif schedule_type == "fixed":
        indices = [int(i) for i in range(interval, circuit_length, interval)]
        return {0: indices}

    else:
        raise ValueError(f"Unknown schedule_type: {schedule_type}")

def compare_qec_schedules(benchmark_circuits : Iterable[QuantumCircuit | LogicalCircuit], qecc, constraint_model, fixed_intervals=[3, 5, 7], backend="aer_simulator", hardware_model=None, coupling_map=None, method="density_matrix", shots=1024):
    results = {
        "circuits": {},
        "summary": {}
    }

    for circuit_name, circuit in benchmark_circuits.items():
        print(f"\nProcessing {circuit_name}...")

        # Remove measurements for exact state calculation
        circuit_no_meas = circuit.remove_final_measurements(inplace=False)

        slices = slice_by_depth(circuit_no_meas, 1)
        circuit_depth = len(slices)
        print(f"  Circuit depth: {circuit_depth}")

        if method == "density_matrix":
            exact_state = DensityMatrix(circuit_no_meas)
        else:
            exact_state = Statevector(circuit_no_meas)

        circuit_results = {
            "exact_state": exact_state,
            "circuit_depth": circuit_depth,
            "methods": {}
        }

        # 0. NO QEC BASELINE
        print("  Running no-QEC baseline (physical circuit)...")
        circuit_phys = copy.deepcopy(circuit)
        if "measure" not in circuit_phys.count_ops():
            circuit_phys.measure_all()

        result_no_qec = execute_circuits(circuit_phys, backend=backend, hardware_model=hardware_model, coupling_map=None, method=method, shots=shots)[0]

        fidelity_no_qec = calculate_fidelity(result_no_qec, exact_state, method)

        circuit_results["methods"]["no_qec"] = {
            "fidelity": fidelity_no_qec,
            "num_qec": 0
        }

        # 1. OPTIMIZED SCHEDULING
        print("  Running optimized scheduling...")
        lqc_opt = LogicalCircuit.from_physical_circuit(circuit_no_meas, **qecc)
        # Use the actual optimize_qec_cycle_indices method
        qec_indices_opt = lqc_opt.optimize_qec_cycle_indices(constraint_model=constraint_model)
        print(qec_indices_opt)
        lqc_opt.insert_qec_cycles(qec_cycle_indices=qec_indices_opt)
        lqc_opt.measure_all()

        result_opt = execute_circuits(lqc_opt, backend=backend, hardware_model=hardware_model, coupling_map=None, method=method, shots=shots)[0]

        # logical_state_opt = LogicalStatevector.from_counts(result_opt.get_counts(), circuit.num_qubits, **qecc)
        # fidelity_opt = logical_state_fidelity(exact_state, logical_state_opt)
        logical_counts = lqc_opt.get_logical_counts(result_opt.get_counts())
        logical_state_opt = Statevector(np.sqrt(np.array([logical_counts.get("0", 0.0), logical_counts.get("1", 0.0)])/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))))
        fidelity_opt = logical_counts.get("0", 0.0)/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))

        circuit_results["methods"]["optimized"] = {
            "fidelity": fidelity_opt,
            "num_qec": len(qec_indices_opt.get(0, []))
        }

        # 2. RANDOM SCHEDULING
        print("  Running random scheduling...")
        lqc_random = LogicalCircuit.from_physical_circuit(circuit_no_meas, **qecc)
        qec_indices_random = create_qec_schedules(circuit_no_meas, "random", constraint_model=constraint_model)
        print(qec_indices_random)
        lqc_random.insert_qec_cycles(qec_cycle_indices=qec_indices_random)
        lqc_random.measure_all()

        result_random = execute_circuits(lqc_random, backend=backend, hardware_model=hardware_model, coupling_map=None, method=method, shots=shots)[0]

        # logical_state_random = LogicalStatevector.from_counts(result_random.get_counts(), circuit.num_qubits, **qecc)
        # fidelity_random = logical_state_fidelity(exact_state, logical_state_random)
        logical_counts = lqc_random.get_logical_counts(result_random.get_counts())
        logical_state_random = Statevector(np.sqrt(np.array([logical_counts.get("0", 0.0), logical_counts.get("1", 0.0)])/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))))
        fidelity_random = logical_counts.get("0", 0.0)/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))

        circuit_results["methods"]["random"] = {
            "fidelity": fidelity_random,
            "num_qec": len(qec_indices_random.get(0, []))
        }

        # 3. FIXED INTERVAL SCHEDULING
        for interval in fixed_intervals:
            print(f"  Running fixed-{interval} scheduling...")
            lqc_fixed = LogicalCircuit.from_physical_circuit(circuit_no_meas, **qecc)
            qec_indices_fixed = create_qec_schedules(circuit_no_meas, "fixed", interval=interval)
            print(qec_indices_fixed)
            lqc_fixed.insert_qec_cycles(qec_cycle_indices=qec_indices_fixed)
            lqc_fixed.measure_all()

            result_fixed = execute_circuits(lqc_fixed, backend=backend, hardware_model=hardware_model, coupling_map=None, method=method, shots=shots)[0]

            # logical_state_fixed = LogicalStatevector.from_counts(result_fixed.get_counts(), circuit.num_qubits, **qecc)
            # fidelity_fixed = logical_state_fidelity(exact_state, logical_state_fixed)
            logical_counts = lqc_fixed.get_logical_counts(result_fixed.get_counts())
            logical_state_fixed = Statevector(np.sqrt(np.array([logical_counts.get("0", 0.0), logical_counts.get("1", 0.0)])/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))))
            fidelity_fixed = logical_counts.get("0", 0.0)/(logical_counts.get("0", 0.0) + logical_counts.get("1", 0.0))

            circuit_results["methods"][f"fixed_{interval}"] = {
                "fidelity": fidelity_fixed,
                "num_qec": len(qec_indices_fixed.get(0, []))
            }

        # Calculate improvement ratios (compared to no-QEC baseline)
        if fidelity_no_qec and fidelity_no_qec > 0:
            circuit_results["opt_vs_no_qec"] = fidelity_opt / fidelity_no_qec
            circuit_results["random_vs_no_qec"] = fidelity_random / fidelity_no_qec
            for interval in fixed_intervals:
                fid_fixed = circuit_results["methods"][f"fixed_{interval}"]["fidelity"]
                if fid_fixed:
                    circuit_results[f"fixed_{interval}_vs_no_qec"] = fid_fixed / fidelity_no_qec

        # Also calculate optimized vs others
        if fidelity_random and fidelity_random > 0:
            circuit_results["opt_vs_random"] = fidelity_opt / fidelity_random

        for interval in fixed_intervals:
            fid_fixed = circuit_results["methods"][f"fixed_{interval}"]["fidelity"]
            if fid_fixed and fid_fixed > 0:
                circuit_results[f"opt_vs_fixed_{interval}"] = fidelity_opt / fid_fixed

        results["circuits"][circuit_name] = circuit_results

    return results

def calculate_fidelity(result, exact_state, method):
    if method == "statevector":
        try:
            sv = None
            if hasattr(result, 'get_statevector'):
                sv = result.get_statevector()
            elif hasattr(result, 'results'):
                sv = getattr(result.results[0].data, 'statevector', None)
            if sv is not None:
                noisy_sv = sv if isinstance(sv, Statevector) else Statevector(sv)
                return logical_state_fidelity(exact_state, noisy_sv)
        except Exception:
            # fallback below
            pass

    elif method == "density_matrix":
        try:
            dm = None
            if hasattr(result, 'data') and callable(result.data):
                dm = result.data().get("density_matrix", None)
            elif hasattr(result, 'results'):
                dm = getattr(result.results[0].data, "density_matrix", None)
            if dm is not None:
                noisy_dm = dm if isinstance(dm, DensityMatrix) else DensityMatrix(dm)
                return logical_state_fidelity(exact_state, noisy_dm)
        except Exception:
            # fallback below
            pass

    # reconstruct a statevector from counts
    counts = None
    try:
        if hasattr(result, 'get_counts'):
            counts = result.get_counts()
        elif hasattr(result, 'results'):
            d = result.results[0].data
            if hasattr(d, 'get_counts'):
                counts = d.get_counts()
            elif hasattr(d, 'counts'):
                counts = d.counts
    except Exception:
        counts = None

    if counts:
        if isinstance(counts, list):
            counts = counts[0] if counts else None
        if not counts:
            return None

        try:
            example_key = next(iter(counts.keys()))
        except StopIteration:
            return None

        key_str = str(example_key).replace(" ", "")
        n = len(key_str)
        dim = 1 << n

        probs = np.zeros(dim, dtype=float)
        total = 0.0
        for k, v in counts.items():
            s = str(k).replace(" ", "")
            try:
                idx = int(s, 2)
            except Exception:
                # best-effort hex fallback
                try:
                    idx = int(s, 16)
                except Exception:
                    continue
            probs[idx] += float(v)
            total += float(v)

        if total <= 0:
            return None

        probs /= total
        probs = np.clip(probs, 0.0, 1.0)
        ssum = probs.sum()
        if ssum <= 0:
            return None
        probs /= ssum

        # Reconstruct a statevector
        amps = np.sqrt(probs).astype(complex)
        noisy_sv = Statevector(amps)
        try:
            return logical_state_fidelity(exact_state, noisy_sv)
        except Exception:
            # last resort: plain state_fidelity
            return state_fidelity(exact_state, noisy_sv)

def plot_scheduling_comparison(results, save_path="qec_scheduling_comparison.png"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    circuit_names = list(results["circuits"].keys())
    all_methods = set()
    for circuit_data in results["circuits"].values():
        all_methods.update(circuit_data["methods"].keys())

    # Order methods: no_qec first, then optimized, random, fixed
    method_order = ["no_qec", "optimized", "random"]
    for method in sorted(all_methods):
        if method.startswith("fixed_"):
            method_order.append(method)
    methods = [m for m in method_order if m in all_methods]

    # Color scheme
    color_map = {
        "no_qec": '#9E9E9E',
        "optimized": '#2E7D32',
        "random": '#1976D2',
        "fixed_3": '#D32F2F',
        "fixed_5": '#F57C00',
        "fixed_7": '#7B1FA2'
    }

    # Plot A: Absolute Infidelities
    x = np.arange(len(circuit_names))
    width = 0.8 / len(methods)

    for i, method in enumerate(methods):
        fidelities = []
        for circuit in circuit_names:
            fid = results["circuits"][circuit]["methods"].get(method, {}).get("fidelity", 0)
            fidelities.append(fid if fid is not None else 0)

        infidelities = 1.0-np.array(fidelities)

        offset = (i - len(methods)/2 + 0.5) * width
        color = color_map.get(method, '#757575')
        label = "No QEC (Baseline)" if method == "no_qec" else method.replace('_', ' ').title()
        bars = ax1.bar(x + offset, infidelities, width, 
                      label=label,
                      color=color, alpha=0.8)

        # Add value labels
        for bar, val in zip(bars, infidelities):
            if val > 0:
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height,
                        f'{val:.3f}', ha='center', va='bottom', fontsize=7)

    ax1.set_xlabel('Benchmark Circuit', fontsize=12)
    ax1.set_ylabel('Infidelity', fontsize=12)
    ax1.set_title('QEC Scheduling Comparison: Absolute Infidelities', fontsize=14, fontweight='bold')
    ax1.set_xticks(x)
    ax1.set_xticklabels(circuit_names, rotation=45, ha='right')
    ax1.legend(loc='lower left', fontsize=9)
    ax1.grid(True, alpha=0.3, axis='y')
    # ax1.set_ylim(0, 1.1)
    # ax1.set_yscale("log")

    # Plot B: Improvement over No-QEC Baseline
    improvement_methods = [m for m in methods if m != "no_qec"]
    x2 = np.arange(len(circuit_names))
    width2 = 0.8 / len(improvement_methods)

    for i, method in enumerate(improvement_methods):
        improvements = []
        for circuit in circuit_names:
            # Get improvement over no-QEC baseline
            ratio_key = f"{method}_vs_no_qec" if method != "optimized" else "opt_vs_no_qec"
            if method == "random":
                ratio_key = "random_vs_no_qec"
            ratio = results["circuits"][circuit].get(ratio_key, 1.0)
            improvements.append(ratio if ratio is not None else 1.0)

        offset = (i - len(improvement_methods)/2 + 0.5) * width2
        color = color_map.get(method, '#757575')
        label = method.replace('_', ' ').title()
        bars = ax2.bar(x2 + offset, improvements, width2, 
                      label=label,
                      color=color, alpha=0.8)

        # Add value labels
        for bar, val in zip(bars, improvements):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{val:.2f}x', ha='center', va='bottom', fontsize=7)

    # Add horizontal line at y=1 for reference (no improvement)
    ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='No Improvement')

    ax2.set_xlabel('Benchmark Circuit', fontsize=12)
    ax2.set_ylabel('Fidelity Improvement over No-QEC', fontsize=12)
    ax2.set_title('QEC Benefit: Improvement over Physical Circuit', fontsize=14, fontweight='bold')
    ax2.set_xticks(x2)
    ax2.set_xticklabels(circuit_names, rotation=45, ha='right')
    ax2.legend(loc='upper left', fontsize=9)
    ax2.grid(True, alpha=0.3, axis='y')
    # ax2.set_yscale("log")

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\nPlot saved to {save_path}")
    return fig

