import numpy as np
import pytest

from pycircstat2 import Circular, load_data
from pycircstat2.distributions import vonmises
from pycircstat2.hypothesis import (
    V_test,
    angular_randomisation_test,
    batschelet_test,
    binomial_test,
    change_point_test,
    chisquare_test,
    circ_anova,
    circ_range_test,
    common_median_test,
    concentration_test,
    harrison_kanji_test,
    kuiper_test,
    omnibus_test,
    one_sample_test,
    rao_homogeneity_test,
    rao_spacing_test,
    rayleigh_test,
    symmetry_test,
    wallraff_test,
    watson_test,
    watson_u2_test,
    watson_williams_test,
    wheeler_watson_test,
)


def test_rayleigh_test():
    # Ch27 Example 1 (Zar, 2010, P667)
    # Using data from Ch26 Example 2.
    data_zar_ex2_ch26 = load_data("D1", source="zar")
    circ_zar_ex1_ch27 = Circular(data=data_zar_ex2_ch26["θ"].values[:])

    # computed directly from r and n
    result = rayleigh_test(n=circ_zar_ex1_ch27.n, r=circ_zar_ex1_ch27.r)
    np.testing.assert_approx_equal(result.z, 5.448, significant=3)
    assert 0.001 < result.pval < 0.002

    # computed directly from alpha
    result = rayleigh_test(alpha=circ_zar_ex1_ch27.alpha)
    np.testing.assert_approx_equal(result.z, 5.448, significant=3)
    assert 0.001 < result.pval < 0.002


def test_V_test():
    # Ch27 Example 2 (Zar, 2010, P669)
    data_zar_ex2_ch27 = load_data("D7", source="zar")
    circ_zar_ex2_ch27 = Circular(data=data_zar_ex2_ch27["θ"].values[:])

    # computed directly from r and n
    result = V_test(
        angle=np.deg2rad(90),
        mean=circ_zar_ex2_ch27.mean,
        n=circ_zar_ex2_ch27.n,
        r=circ_zar_ex2_ch27.r,
    )

    np.testing.assert_approx_equal(result.V, 9.498, significant=3)
    np.testing.assert_approx_equal(result.u, 4.248, significant=3)
    assert result.pval < 0.0005

    # computed directly from alpha
    result = V_test(
        alpha=circ_zar_ex2_ch27.alpha,
        angle=np.deg2rad(90),
    )

    np.testing.assert_approx_equal(result.V, 9.498, significant=3)
    np.testing.assert_approx_equal(result.u, 4.248, significant=3)
    assert result.pval < 0.0005


def test_one_sample_test():
    # Ch27 Example 3 (Zar, 2010, P669)
    # Using data from Ch27 Example 2
    data_zar_ex2_ch27 = load_data("D7", source="zar")
    circ_zar_ex3_ch27 = Circular(data=data_zar_ex2_ch27["θ"].values[:], unit="degree")

    # computed directly from lb and ub
    result = one_sample_test(
        lb=circ_zar_ex3_ch27.mean_lb,
        ub=circ_zar_ex3_ch27.mean_ub,
        angle=np.deg2rad(90),
    )

    assert result.reject is False

    # computed directly from alpha
    result = one_sample_test(alpha=circ_zar_ex3_ch27.alpha, angle=np.deg2rad(90))

    assert result.reject is False


def test_omnibus_test():
    data_zar_ex4_ch27 = load_data("D8", source="zar")
    circ_zar_ex4_ch27 = Circular(data=data_zar_ex4_ch27["θ"].values[:], unit="degree")

    result = omnibus_test(alpha=circ_zar_ex4_ch27.alpha, scale=1)

    np.testing.assert_approx_equal(result.pval, 0.0043, significant=2)

    # test large sample size
    # (factorial division overflow while computing p-val)
    # fixed in PR 12
    from pycircstat2.distributions import circularuniform, vonmises

    rng = np.random.default_rng(42)
    d0 = vonmises.rvs(mu=0, kappa=1, size=10_000, random_state=rng)
    d1 = circularuniform.rvs(size=10_000, random_state=rng)

    result = omnibus_test(alpha=d0)
    assert result.pval < 0.05, "Expected significant p-value for von Mises distribution"
    result = omnibus_test(alpha=d1)
    assert result.pval > 0.05, (
        "Expected non-significant p-value for uniform distribution"
    )


def test_batschelet_test():
    data_zar_ex5_ch27 = load_data("D8", source="zar")
    circ_zar_ex5_ch27 = Circular(data=data_zar_ex5_ch27["θ"].values[:], unit="degree")

    result = batschelet_test(
        angle=np.deg2rad(45),
        alpha=circ_zar_ex5_ch27.alpha,
    )
    np.testing.assert_equal(result.C, 5)
    np.testing.assert_approx_equal(result.pval, 0.00661, significant=3)


def test_chisquare_test():
    d2 = load_data("D2", source="zar")
    c2 = Circular(data=d2["θ"].values[:], w=d2["w"].values[:])

    result = chisquare_test(c2.w)
    np.testing.assert_approx_equal(result.chi2, 66.543, significant=3)
    assert result.pval < 0.001


def test_symmetry_test():
    data_zar_ex6_ch27 = load_data("D9", source="zar")
    circ_zar_ex6_ch27 = Circular(data=data_zar_ex6_ch27["θ"].values[:], unit="degree")

    result = symmetry_test(
        median=float(circ_zar_ex6_ch27.median), alpha=circ_zar_ex6_ch27.alpha
    )
    assert result.pval > 0.5


def test_watson_williams_test():
    data = load_data("D10", source="zar")
    s1 = Circular(data=data[data["sample"] == 1]["θ"].values[:])
    s2 = Circular(data=data[data["sample"] == 2]["θ"].values[:])
    result = watson_williams_test([s1, s2])

    np.testing.assert_approx_equal(result.F, 1.61, significant=3)
    np.testing.assert_approx_equal(result.pval, 0.22, significant=2)

    # Support plain arrays
    array_result = watson_williams_test([s1.alpha, s2.alpha])
    np.testing.assert_allclose(array_result.F, result.F, rtol=1e-6)
    np.testing.assert_allclose(array_result.pval, result.pval, rtol=1e-6)

    data = load_data("D11", source="zar")
    s1 = Circular(data=data[data["sample"] == 1]["θ"].values[:])
    s2 = Circular(data=data[data["sample"] == 2]["θ"].values[:])
    s3 = Circular(data=data[data["sample"] == 3]["θ"].values[:])

    result = watson_williams_test([s1, s2, s3])

    np.testing.assert_approx_equal(result.F, 1.86, significant=3)
    np.testing.assert_approx_equal(result.pval, 0.19, significant=2)


def test_watson_u2_test():
    d = load_data("D12", source="zar")
    c0 = Circular(data=d[d["sample"] == 1]["θ"].values[:])
    c1 = Circular(data=d[d["sample"] == 2]["θ"].values[:])
    result = watson_u2_test([c0, c1])

    np.testing.assert_approx_equal(result.U2, 0.1458, significant=3)
    assert 0.1 < result.pval < 0.2

    # Array support
    array_result = watson_u2_test([c0.alpha, c1.alpha])
    np.testing.assert_allclose(array_result.U2, result.U2, rtol=1e-6)
    np.testing.assert_allclose(array_result.pval, result.pval, rtol=1e-6)

    d = load_data("D13", source="zar")
    c0 = Circular(
        data=d[d["sample"] == 1]["θ"].values[:], w=d[d["sample"] == 1]["w"].values[:]
    )
    c1 = Circular(
        data=d[d["sample"] == 2]["θ"].values[:], w=d[d["sample"] == 2]["w"].values[:]
    )
    result = watson_u2_test([c0, c1])

    np.testing.assert_approx_equal(result.U2, 0.0612, significant=3)
    assert result.pval > 0.5

    expanded0 = np.repeat(c0.alpha, c0.w)
    expanded1 = np.repeat(c1.alpha, c1.w)
    array_result = watson_u2_test([expanded0, expanded1])
    np.testing.assert_allclose(array_result.U2, result.U2, rtol=1e-6)
    np.testing.assert_allclose(array_result.pval, result.pval, rtol=1e-6)


def test_wheeler_watson_test():
    d = load_data("D12", source="zar")
    c0 = Circular(data=d[d["sample"] == 1]["θ"].values[:])
    c1 = Circular(data=d[d["sample"] == 2]["θ"].values[:])

    result = wheeler_watson_test([c0, c1])
    np.testing.assert_approx_equal(result.W, 3.678, significant=3)
    assert 0.1 < result.pval < 0.25

    array_result = wheeler_watson_test([c0.alpha, c1.alpha])
    np.testing.assert_allclose(array_result.W, result.W, rtol=1e-6)
    np.testing.assert_allclose(array_result.pval, result.pval, rtol=1e-6)


def test_wallraff_test():
    d = load_data("D14", source="zar")
    c0 = Circular(data=d[d["sex"] == "male"]["θ"].values[:])
    c1 = Circular(data=d[d["sex"] == "female"]["θ"].values[:])
    result = wallraff_test(samples=[c0, c1], angle=np.deg2rad(135))
    np.testing.assert_approx_equal(result.U, 18.5, significant=3)
    assert result.pval > 0.20

    array_result = wallraff_test(samples=[c0.alpha, c1.alpha], angle=np.deg2rad(135))
    np.testing.assert_allclose(array_result.U, result.U, rtol=1e-6)
    np.testing.assert_allclose(array_result.pval, result.pval, rtol=1e-6)

    from pycircstat2.utils import time2float

    d = load_data("D15", source="zar")
    c0 = Circular(data=time2float(d[d["sex"] == "male"]["time"].values[:]))
    c1 = Circular(data=time2float(d[d["sex"] == "female"]["time"].values[:]))
    result = wallraff_test(
        angle=np.deg2rad(time2float(["7:55", "8:15"])),
        samples=[c0, c1],
        verbose=True,
    )
    np.testing.assert_equal(result.U, 13)
    assert result.pval > 0.05


def test_kuiper_test():
    d = load_data("B5", source="fisher")["θ"].values[:]
    c = Circular(data=d, unit="degree", full_cycle=180)
    result = kuiper_test(alpha=c.alpha)
    np.testing.assert_approx_equal(result.V, 1.5864, significant=3)
    assert result.pval > 0.05


def test_watson_test():
    pigeon = np.array([20, 135, 145, 165, 170, 200, 300, 325, 335, 350, 350, 350, 355])
    c_pigeon = Circular(data=pigeon)
    result = watson_test(alpha=c_pigeon.alpha, n_simulation=9999)
    np.testing.assert_approx_equal(result.U2, 0.137, significant=3)
    assert result.pval > 0.10


def test_angular_randomisation_test():
    np.random.seed(42)
    alpha1 = Circular(np.random.vonmises(mu=0, kappa=3, size=10), unit="radian")
    alpha2 = Circular(np.random.vonmises(mu=0, kappa=3, size=50), unit="radian")

    result = angular_randomisation_test([alpha1, alpha2])
    assert result.pval > 0.05, "Expected non-significant p-value"

    array_result = angular_randomisation_test([alpha1.alpha, alpha2.alpha])
    np.testing.assert_allclose(array_result.statistic, result.statistic, rtol=1e-6)


def test_rao_spacing_test():
    pigeon = np.array([20, 135, 145, 165, 170, 200, 300, 325, 335, 350, 350, 350, 355])
    c_pigeon = Circular(data=pigeon)
    result = rao_spacing_test(alpha=c_pigeon.alpha, n_simulation=9999)
    np.testing.assert_approx_equal(result.statistic, 161.92308, significant=3)
    assert 0.05 < result.pval < 0.10


def test_randomized_tests_seed_harmonization():
    alpha = np.linspace(0.0, 2 * np.pi, 12, endpoint=False)
    seed_value = 123

    def make_generator():
        return np.random.default_rng(seed_value)

    rayleigh_int = rayleigh_test(alpha=alpha, B=128, seed=seed_value)
    rayleigh_gen = rayleigh_test(alpha=alpha, B=128, seed=make_generator())
    assert rayleigh_int.bootstrap_pval == rayleigh_gen.bootstrap_pval

    samples = [alpha[:6], alpha[6:]]
    art_int = angular_randomisation_test(samples, n_simulation=128, seed=seed_value)
    art_gen = angular_randomisation_test(
        samples, n_simulation=128, seed=make_generator()
    )
    assert art_int.pval == art_gen.pval

    kuiper_int = kuiper_test(alpha=alpha, n_simulation=256, seed=seed_value)
    kuiper_gen = kuiper_test(alpha=alpha, n_simulation=256, seed=make_generator())
    assert kuiper_int.pval == kuiper_gen.pval

    watson_int = watson_test(alpha=alpha, n_simulation=256, seed=seed_value)
    watson_gen = watson_test(alpha=alpha, n_simulation=256, seed=make_generator())
    assert watson_int.pval == watson_gen.pval

    rao_int = rao_spacing_test(alpha=alpha, n_simulation=256, seed=seed_value)
    rao_gen = rao_spacing_test(alpha=alpha, n_simulation=256, seed=make_generator())
    assert rao_int.pval == rao_gen.pval


def test_circ_range_test():
    x_deg = np.array(
        [
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            3.6,
            36.0,
            36.0,
            36.0,
            36.0,
            36.0,
            36.0,
            72.0,
            108.0,
            108.0,
            169.2,
            324.0,
        ]
    )
    x_rad = np.deg2rad(x_deg)
    result = circ_range_test(x_rad)
    np.testing.assert_approx_equal(result.range_stat, 3.581416, significant=5)
    np.testing.assert_approx_equal(result.pval, 5.825496e-05, significant=5)


def test_circ_range_test_rejects_degree_input():
    x_deg = np.array([0.0, 10.0, 20.0])
    with pytest.raises(ValueError):
        circ_range_test(x_deg)


def test_binomial_test_uniform():
    """Test binomial_test with uniform circular data (should not reject H0)."""
    np.random.seed(42)
    alpha = np.random.uniform(0, 2 * np.pi, 100)  # Uniformly distributed angles
    md = np.pi  # Test median at π (should be non-significant)

    result = binomial_test(alpha, md)

    assert 0.05 < result.pval < 1.0, (
        f"Unexpected p-value for uniform data: {result.pval}"
    )


def test_binomial_test_skewed():
    """Test binomial_test with a skewed circular distribution (should reject H0)."""
    np.random.seed(42)
    alpha = np.random.vonmises(mu=np.pi / 4, kappa=3, size=100)  # Clustered around π/4
    md = np.pi  # Incorrect median hypothesis

    result = binomial_test(alpha, md)

    assert result.pval < 0.05, f"Expected significant p-value but got {result.pval}"


def test_binomial_test_symmetric():
    """Test binomial_test with symmetric distribution around π (should fail to reject H0)."""
    alpha = np.array([-np.pi / 4, np.pi / 4, np.pi / 2, -np.pi / 2, np.pi])
    md = np.pi  # This should be a valid median

    result = binomial_test(alpha, md)

    assert result.pval > 0.05, f"Unexpected p-value for symmetric data: {result.pval}"


def test_binomial_test_extreme_case():
    """Test binomial_test with all points clustered at π (extreme case)."""
    alpha = np.full(20, np.pi)  # All angles at π
    md = np.pi

    result = binomial_test(alpha, md)

    assert np.isclose(result.pval, 1.0), (
        f"Expected p-value of 1 for identical data but got {result.pval}"
    )


def test_concentration_identical():
    """Test concentration_test with identical von Mises distributions (should fail to reject H0)."""
    rng = np.random.default_rng(42)
    alpha1 = vonmises.rvs(mu=0, kappa=3, size=50, random_state=rng)
    alpha2 = vonmises.rvs(mu=0, kappa=3, size=50, random_state=rng)

    result = concentration_test(alpha1, alpha2)

    assert result.pval > 0.05, (
        f"Unexpectedly small p-value: {result.pval}, should not reject H0."
    )


def test_concentration_different():
    """Test concentration_test with different kappa values (should reject H0)."""
    rng = np.random.default_rng(123)
    alpha1 = vonmises.rvs(
        mu=0, kappa=3, size=50, random_state=rng
    )  # Higher concentration
    alpha2 = vonmises.rvs(
        mu=0, kappa=1, size=50, random_state=rng
    )  # Lower concentration

    result = concentration_test(alpha1, alpha2)

    assert result.pval < 0.05, f"Expected small p-value, but got {result.pval}"


def test_concentration_high_dispersion():
    """Test concentration_test with very dispersed data (should fail to reject H0)."""
    np.random.seed(42)
    alpha1 = np.random.uniform(0, 2 * np.pi, 50)  # Uniformly spread
    alpha2 = np.random.uniform(0, 2 * np.pi, 50)

    result = concentration_test(alpha1, alpha2)

    assert result.pval > 0.05, (
        f"Unexpectedly small p-value: {result.pval}, should not reject H0."
    )


def test_concentration_extreme_case():
    """Test concentration_test when both samples have extremely high concentration (should fail to reject H0)."""
    rng = np.random.default_rng(42)
    alpha1 = vonmises.rvs(mu=0, kappa=100, size=50, random_state=rng)
    alpha2 = vonmises.rvs(mu=0, kappa=100, size=50, random_state=rng)

    result = concentration_test(alpha1, alpha2)

    assert result.pval > 0.05, (
        f"Unexpectedly small p-value: {result.pval}, should not reject H0."
    )


def test_rao_homogeneity_identical():
    """Test with identical von Mises distributions (should fail to reject H0)."""
    seeds = [101, 102, 103]
    samples = [
        vonmises.rvs(mu=0, kappa=2, size=50, random_state=np.random.default_rng(seed))
        for seed in seeds
    ]

    results = rao_homogeneity_test(samples)

    assert results.pval_polar > 0.05, (
        f"Unexpectedly small p-value: {results.pval_polar}"
    )
    assert results.pval_disp > 0.05, f"Unexpectedly small p-value: {results.pval_disp}"


def test_rao_homogeneity_different_means():
    """Test with different mean directions (should reject H0 for mean equality)."""
    seeds = [201, 202, 203]
    mus = (0.0, np.pi / 4, np.pi / 2)
    samples = [
        vonmises.rvs(kappa=2, mu=mu, size=50, random_state=np.random.default_rng(seed))
        for seed, mu in zip(seeds, mus)
    ]
    results = rao_homogeneity_test(samples)

    assert results.pval_polar < 0.05, (
        f"Expected rejection but got p={results.pval_polar}"
    )


def test_rao_homogeneity_different_dispersion():
    """Test with different kappa values (should reject H0 for dispersion equality)."""
    seeds = [301, 302, 303]
    kappas = (5, 2, 1)
    samples = [
        vonmises.rvs(
            mu=0, kappa=kappa, size=50, random_state=np.random.default_rng(seed)
        )
        for seed, kappa in zip(seeds, kappas)
    ]

    results = rao_homogeneity_test(samples)

    assert results.pval_disp < 0.05, f"Expected rejection but got p={results.pval_disp}"


def test_rao_homogeneity_small_samples():
    """Test with very small sample sizes (should handle without error)."""
    seeds = [401, 402, 403]
    samples = [
        vonmises.rvs(mu=0, kappa=3, size=5, random_state=np.random.default_rng(seed))
        for seed in seeds
    ]

    results = rao_homogeneity_test(samples)

    assert isinstance(results.pval_polar, float)
    assert isinstance(results.pval_disp, float)


def test_rao_homogeneity_invalid_input():
    """Test invalid input (should raise ValueError)."""
    with pytest.raises(ValueError):
        rao_homogeneity_test([np.array([0, np.pi / 2]), "invalid_array"])


def test_change_point_basic():
    """Test change_point_test() on a simple dataset matching R."""
    alpha = np.array(
        [
            3.03,
            0.28,
            3.90,
            5.56,
            5.77,
            5.06,
            5.96,
            0.16,
            0.51,
            1.21,
            6.03,
            1.05,
            0.45,
            1.47,
            6.09,
        ]
    )

    result = change_point_test(alpha)

    # Expected values based on R output
    expected_rho = 0.52307
    expected_rmax = 2.237654
    expected_k_r = 6
    expected_rave = 1.066862
    expected_tmax = 0.602549
    expected_k_t = 6
    expected_tave = 0.460675

    assert np.isclose(result.rho, expected_rho, atol=1e-5)
    assert np.isclose(result.rmax, expected_rmax, atol=1e-5)
    assert result.k_r == expected_k_r
    assert np.isclose(result.rave, expected_rave, atol=1e-5)
    assert np.isclose(result.tmax, expected_tmax, atol=1e-5)
    assert result.k_t == expected_k_t
    assert np.isclose(result.tave, expected_tave, atol=1e-5)


def test_harrison_kanji_test():
    """Test Harrison-Kanji two-way ANOVA for circular data."""
    np.random.seed(42)
    alpha = np.random.vonmises(0, 2, 50)
    idp = np.random.choice([1, 2, 3], 50)
    idq = np.random.choice([1, 2], 50)

    result = harrison_kanji_test(alpha, idp, idq)

    assert len(result.p_values) == 3  # Should return three p-values
    assert result.anova_table.shape[0] >= 3  # At least 3 sources in ANOVA table
    assert all(0 <= p <= 1 for p in result.p_values if p is not None)  # Valid p-values


def test_harrison_kanji_vs_pycircstat():
    """Compare PyCircStat2 `harrison_kanji_test` with original PyCircStat `hktest`."""

    def hktest(alpha, idp, idq, inter=True, fn=None):
        """copied and fixed from pycircstat.hktest"""
        import pandas as pd
        from scipy import special, stats

        from pycircstat2.descriptive import circ_kappa, circ_mean, circ_r

        if fn is None:
            fn = ["A", "B"]
        p = len(np.unique(idp))
        q = len(np.unique(idq))
        df = pd.DataFrame({fn[0]: idp, fn[1]: idq, "dependent": alpha})
        n = len(df)
        tr = n * circ_r(np.asarray(df["dependent"].values))
        kk = circ_kappa(tr / n)

        # both factors
        gr = df.groupby(fn)
        cn = gr.count()
        cr = gr.agg(circ_r) * cn
        cn = cn.unstack(fn[1])
        cr = cr.unstack(fn[1])

        # factor A
        gr = df.groupby(fn[0])
        pn = gr.count()["dependent"]
        pr = gr.agg(circ_r)["dependent"] * pn
        pm = gr.agg(circ_mean)["dependent"]
        # factor B
        gr = df.groupby(fn[1])
        qn = gr.count()["dependent"]
        qr = gr.agg(circ_r)["dependent"] * qn
        qm = gr.agg(circ_mean)["dependent"]

        if kk > 2:  # large kappa
            # effect of factor 1
            eff_1 = sum(pr**2 / cn.sum(axis=1)) - tr**2 / n
            df_1 = p - 1
            ms_1 = eff_1 / df_1

            # effect of factor 2
            eff_2 = sum(qr**2.0 / cn.sum(axis=0)) - tr**2 / n
            df_2 = q - 1
            ms_2 = eff_2 / df_2

            # total effect
            eff_t = n - tr**2 / n
            df_t = n - 1
            m = cn.values[:].mean()

            if inter:
                # correction factor for improved F statistic
                beta = 1 / (1 - 1 / (5 * kk) - 1 / (10 * (kk**2)))
                # residual effects
                eff_r = n - (cr**2.0 / cn).values[:].sum()
                df_r = p * q * (m - 1)
                ms_r = eff_r / df_r

                # interaction effects
                eff_i = (
                    (cr**2.0 / cn).values[:].sum()
                    - sum(qr**2.0 / qn)
                    - sum(pr**2.0 / pn)
                    + tr**2 / n
                )
                df_i = (p - 1) * (q - 1)
                ms_i = eff_i / df_i
                # interaction test statistic
                FI = ms_i / ms_r
                pI = 1 - stats.f.cdf(FI, df_i, df_r)
            else:
                # residual effect
                eff_r = n - sum(qr**2.0 / qn) - sum(pr**2.0 / pn) + tr**2 / n
                df_r = (p - 1) * (q - 1)
                ms_r = eff_r / df_r

                # interaction effects
                eff_i = None
                df_i = None
                ms_i = None

                # interaction test statistic
                FI = None
                pI = np.nan
                beta = 1

            F1 = beta * ms_1 / ms_r
            p1 = 1 - stats.f.cdf(F1, df_1, df_r)

            F2 = beta * ms_2 / ms_r
            p2 = 1 - stats.f.cdf(F2, df_2, df_r)

        else:  # small kappa
            # correction factor
            # special.iv is Modified Bessel function of the first kind of real order
            rr = special.iv(1, kk) / special.iv(0, kk)
            f = 2 / (1 - rr**2)

            chi1 = f * (sum(pr**2.0 / pn) - tr**2 / n)
            df_1 = 2 * (p - 1)
            p1 = 1 - stats.chi2.cdf(chi1, df=df_1)

            chi2 = f * (sum(qr**2.0 / qn) - tr**2 / n)
            df_2 = 2 * (q - 1)
            p2 = 1 - stats.chi2.cdf(chi2, df=df_2)

            chiI = f * (
                (cr**2.0 / cn).values[:].sum()
                - sum(pr**2.0 / pn)
                - sum(qr**2.0 / qn)
                + tr**2 / n
            )
            df_i = (p - 1) * (q - 1)
            pI = stats.chi2.sf(chiI, df=df_i)

        pval = (p1.squeeze(), p2.squeeze(), pI.squeeze())

        if kk > 2:
            table = pd.DataFrame(
                {
                    "Source": fn + ["Interaction", "Residual", "Total"],
                    "DoF": [df_1, df_2, df_i, df_r, df_t],
                    "SS": [eff_1, eff_2, eff_i, eff_r, eff_t],
                    "MS": [ms_1, ms_2, ms_i, ms_r, np.nan],
                    "F": [F1.squeeze(), F2.squeeze(), FI, np.nan, np.nan],
                    "p": list(pval) + [np.nan, np.nan],
                }
            )
            table = table.set_index("Source")
        else:
            table = pd.DataFrame(
                {
                    "Source": fn + ["Interaction"],
                    "DoF": [df_1, df_2, df_i],
                    "chi2": [chi1.squeeze(), chi2.squeeze(), chiI.squeeze()],
                    "p": pval,
                }
            )
            table = table.set_index("Source")

        return pval, table

    alpha = np.random.vonmises(0, 2, 50)
    idp = np.random.choice([1, 2, 3], 50)
    idq = np.random.choice([1, 2], 50)

    # Run original PyCircStat test
    pval_orig, table_orig = hktest(alpha, idp, idq)

    # Run PyCircStat2 version
    result_new = harrison_kanji_test(alpha, idp, idq)
    pval_new = result_new.p_values
    table_new = result_new.anova_table

    # Compare p-values
    assert np.allclose(pval_orig, pval_new, atol=1e-6), (
        f"P-values mismatch:\nOriginal: {pval_orig}\nNew: {pval_new}"
    )

    # Compare ANOVA table values (ignoring index differences)
    table_orig_values = table_orig.to_numpy()
    table_new_values = table_new.to_numpy()

    assert np.allclose(
        table_orig_values, table_new_values, atol=1e-6, equal_nan=True
    ), f"ANOVA tables differ:\nOriginal:\n{table_orig}\nNew:\n{table_new}"


def test_circ_anova():
    """Test the Circular ANOVA (F-test & LRT) for multiple samples."""

    # Set seed for reproducibility
    np.random.seed(42)

    # Generate von Mises distributed samples with different mean directions
    group1 = np.random.vonmises(mu=0, kappa=5, size=50)
    group2 = np.random.vonmises(mu=np.pi / 4, kappa=5, size=50)
    group3 = np.random.vonmises(mu=np.pi / 2, kappa=5, size=50)

    samples = [group1, group2, group3]

    # Run F-test
    result_f = circ_anova(samples, method="F-test")
    assert result_f.method == "F-test"
    assert 0 <= result_f.pval <= 1, "F-test p-value out of range"
    assert result_f.df == (2, 147, 149), (
        f"F-test degrees of freedom mismatch: {result_f.df}"
    )
    assert result_f.SS is not None and result_f.MS is not None

    # Run Likelihood Ratio Test (LRT)
    result_lrt = circ_anova(samples, method="LRT")
    assert result_lrt.method == "LRT"
    assert 0 <= result_lrt.pval <= 1, "LRT p-value out of range"
    assert result_lrt.df == 2, f"LRT degrees of freedom mismatch: {result_lrt.df}"

    # Edge case: All groups have the same mean direction
    identical_group = np.random.vonmises(mu=0, kappa=5, size=50)
    result_identical = circ_anova([identical_group] * 3, method="F-test")
    assert result_identical.pval > 0.05, (
        "F-test should not reject H0 for identical groups"
    )

    # Edge case: Small sample sizes
    small_group1 = np.random.vonmises(mu=0, kappa=5, size=5)
    small_group2 = np.random.vonmises(mu=np.pi / 4, kappa=5, size=5)
    small_group3 = np.random.vonmises(mu=np.pi / 2, kappa=5, size=5)

    result_small = circ_anova(
        [small_group1, small_group2, small_group3], method="F-test"
    )
    assert 0 <= result_small.pval <= 1, "Small-sample p-value out of range"

    # Invalid method check
    with pytest.raises(ValueError, match="Invalid method. Choose 'F-test' or 'LRT'."):
        circ_anova(samples, method="INVALID")

    # Single group should raise error
    with pytest.raises(ValueError, match="At least two groups are required for ANOVA."):
        circ_anova([group1])


def test_equal_median_identical_samples():
    """Test if the test correctly fails to reject H₀ when all groups are identical."""
    alpha1 = np.array([0.1, 0.2, 0.3, 1.5, 1.6])
    alpha2 = np.array([0.1, 0.2, 0.3, 1.5, 1.6])
    alpha3 = np.array([0.1, 0.2, 0.3, 1.5, 1.6])

    result = common_median_test([alpha1, alpha2, alpha3])
    assert result.reject is False
    assert not np.isnan(result.common_median)


def test_equal_median_different_samples():
    """Test if the test correctly rejects H₀ when groups have different medians."""
    alpha1 = np.array([0.1, 0.2, 0.3, 1.5, 1.6])
    alpha2 = np.array([2.2, 2.3, 2.4, 3.1, 3.2])
    alpha3 = np.array([3.5, 3.6, 3.7, 4.2, 4.3])

    result = common_median_test([alpha1, alpha2, alpha3])
    assert result.reject is True
    assert np.isnan(result.common_median)


def test_equal_median_large_sample():
    """Test the function on large sample sizes with similar medians."""
    np.random.seed(42)
    alpha1 = np.random.vonmises(mu=0, kappa=2, size=500)
    alpha2 = np.random.vonmises(mu=0, kappa=2, size=500)
    alpha3 = np.random.vonmises(mu=0, kappa=2, size=500)

    result = common_median_test([alpha1, alpha2, alpha3])
    assert result.reject is False
    assert not np.isnan(result.common_median)


def test_equal_median_small_sample():
    """Test if the function handles small sample sizes correctly."""
    alpha1 = np.array([0.1, 0.2, 0.3])
    alpha2 = np.array([0.15, 0.25, 0.35])

    result = common_median_test([alpha1, alpha2])
    assert result.reject is False
    assert not np.isnan(result.common_median)
