import numpy as np
import pandas as pd

import asteroid_spinprops.ssolib.utils as utils

from fink_utils.sso.spins import (
    estimate_sso_params,
    func_hg1g2_with_spin,
)
from asteroid_spinprops.ssolib.periodest import get_period_estimate


def get_fit_params(
    data,
    flavor,
    shg1g2_constrained=True,
    blind_scan=True,
    p0=None,
    survey_filter=None,
    alt_spin=False,
    period_in=None,
    terminator=False,
):
    """
    Fit a small solar system object's photometric data using SHG1G2 or SOCCA models.

    This function can perform either a standard SHG1G2 fit or a spin- and
    shape-constrained SOCCA fit, optionally including blind scans over
    initial pole positions and periods. It supports filtering data by survey.

    Parameters
    ----------
    data : pandas.DataFrame
        Input dataset containing photometry and geometry with columns:
        - 'cmred': reduced magnitudes
        - 'csigmapsf': uncertainties
        - 'Phase': solar phase angles (deg)
        - 'cfid': filter IDs
        - 'ra', 'dec': coordinates (deg)
        - 'cjd': observation times
        Optional (for terminator fits):
        - 'ra_s', 'dec_s': sub-solar point coordinates (deg)
    flavor : str
        Model type to fit. Must be 'SHG1G2' or 'SSHG1G2'.
    shg1g2_constrained : bool, optional
        Whether to constrain the SSHG1G2 fit using a prior SHG1G2 solution. Default True.
    blind_scan : bool, optional
        If True, perform a small grid search over initial pole positions and periods. Default True.
    p0 : list, optional
        Initial guess parameters for the fit. Required if `shg1g2_constrained=False`.
    survey_filter : str or None, optional
        If 'ZTF' or 'ATLAS', only data from that survey are used. Default None uses all data.
    alt_spin : bool, optional
        For SSHG1G2 constrained fits, use the antipodal spin solution. Default False.
    period_in : float, optional
        Input synodic period (days) to override automatic estimation. Default None.
    terminator : bool, optional
        If True, include self-shading in the fit. Default False.

    Returns
    -------
    dict or tuple
        If `flavor='SHG1G2'`:
            dict
                Best-fit SHG1G2 parameters.
        If `flavor='SSHG1G2'`:
            dict
                Best-fit SOCCA parameters.

    Notes
    -----
    - For SOCCA fits with `shg1g2_constrained=True`, the function first performs
      a SHG1G2 fit to constrain H, G1, G2, and shape parameters.
    - Blind scans systematically vary initial pole positions and period to find
      the optimal fit when `blind_scan=True`.

    Raises
    ------
    ValueError
        If `flavor` is not 'SHG1G2' or 'SSHG1G2'.
    """

    if survey_filter is None:
        filter_mask = np.array(data["cfid"].values[0]) >= 0
    if survey_filter == "ZTF":
        filter_mask = (np.array(data["cfid"].values[0]) == 1) | (
            np.array(data["cfid"].values[0]) == 2
        )
    if survey_filter == "ATLAS":
        filter_mask = (np.array(data["cfid"].values[0]) == 3) | (
            np.array(data["cfid"].values[0]) == 4
        )
    if flavor == "SHG1G2":
        if p0 is None:
            Afit = estimate_sso_params(
                magpsf_red=data["cmred"].values[0][filter_mask],
                sigmapsf=data["csigmapsf"].values[0][filter_mask],
                phase=np.radians(data["Phase"].values[0][filter_mask]),
                filters=data["cfid"].values[0][filter_mask],
                ra=np.radians(data["ra"].values[0][filter_mask]),
                dec=np.radians(data["dec"].values[0][filter_mask]),
                model="SHG1G2",
            )

        if p0 is not None:
            Afit = estimate_sso_params(
                magpsf_red=data["cmred"].values[0][filter_mask],
                sigmapsf=data["csigmapsf"].values[0][filter_mask],
                phase=np.radians(data["Phase"].values[0][filter_mask]),
                filters=data["cfid"].values[0][filter_mask],
                ra=np.radians(data["ra"].values[0][filter_mask]),
                dec=np.radians(data["dec"].values[0][filter_mask]),
                model="SHG1G2",
                p0=p0,
            )

        return Afit
    if flavor == "SSHG1G2":
        if shg1g2_constrained is True:
            shg1g2_params = get_fit_params(
                data=data, flavor="SHG1G2", survey_filter=survey_filter
            )
            residuals_dataframe = make_residuals_df(
                data, model_parameters=shg1g2_params
            )
            if period_in is None:
                sg, _, _ = get_period_estimate(residuals_dataframe=residuals_dataframe)
                period_sy = 2 / sg[2][0]
            else:
                period_sy = period_in

            if blind_scan is True:
                rms = []
                model = []

                period_scan = np.linspace(
                    period_sy - 20 / (24 * 60 * 60), period_sy + 20 / (24 * 60 * 60), 20
                )

                ra0, dec0 = shg1g2_params["alpha0"], shg1g2_params["delta0"]

                ra_init, dec_init = utils.generate_initial_points(
                    ra0, dec0, dec_shift=45
                )

                H_key = next(
                    (f"H_{i}" for i in range(1, 7) if f"H_{i}" in shg1g2_params),
                    None,
                )

                for ra, dec in zip(ra_init, dec_init):
                    for period_in in period_scan:
                        p_in = [
                            shg1g2_params[H_key],
                            0.15,
                            0.15,  # G1,2
                            np.radians(ra),
                            np.radians(dec),
                            period_in,  # in days
                            shg1g2_params["a_b"],
                            shg1g2_params["a_c"],
                            0.1,
                        ]  # phi 0

                        sshg1g2 = get_fit_params(
                            data,
                            "SSHG1G2",
                            shg1g2_constrained=False,
                            p0=p_in,
                            terminator=terminator,
                        )
                        try:
                            rms.append(sshg1g2["rms"])
                            model.append(sshg1g2)
                        except Exception:
                            continue
                rms = np.array(rms)
                sshg1g2_opt = model[rms.argmin()]

                return sshg1g2_opt
            else:
                period_si_t, alt_period_si_t, _ = utils.estimate_sidereal_period(
                    data=data, model_parameters=shg1g2_params, synodic_period=period_sy
                )
                period_si = np.median(period_si_t)
                alt_period_si = np.median(alt_period_si_t)

                if alt_spin is True:
                    period = alt_period_si
                    ra0, de0 = utils.flip_spin(
                        shg1g2_params["alpha0"],
                        shg1g2_params["delta0"],
                    )
                    ra0, de0 = np.radians(ra0), np.radians(de0)
                else:
                    period = period_si
                    ra0, de0 = (
                        np.radians(shg1g2_params["alpha0"]),
                        np.radians(shg1g2_params["delta0"]),
                    )
                #
                H = next(
                    (
                        shg1g2_params.get(f"H_{i}")
                        for i in range(1, 5)
                        if f"H_{i}" in shg1g2_params
                    ),
                    None,
                )
                G1 = next(
                    (
                        shg1g2_params.get(f"G1_{i}")
                        for i in range(1, 5)
                        if f"G1_{i}" in shg1g2_params
                    ),
                    None,
                )
                G2 = next(
                    (
                        shg1g2_params.get(f"G2_{i}")
                        for i in range(1, 5)
                        if f"G2_{i}" in shg1g2_params
                    ),
                    None,
                )

                p0 = [
                    H,
                    G1,
                    G2,
                    ra0,
                    de0,
                    period,
                    shg1g2_params["a_b"],
                    shg1g2_params["a_c"],
                    0.1,
                ]

                # Constrained Fit
                Afit = estimate_sso_params(
                    data["cmred"].values[0][filter_mask],
                    data["csigmapsf"].values[0][filter_mask],
                    np.radians(data["Phase"].values[0][filter_mask]),
                    data["cfid"].values[0][filter_mask],
                    ra=np.radians(data["ra"].values[0][filter_mask]),
                    dec=np.radians(data["dec"].values[0][filter_mask]),
                    jd=data["cjd"].values[0][filter_mask],
                    model="SSHG1G2",
                    p0=p0,
                )
                return Afit

        if shg1g2_constrained is False:
            if p0 is None:
                print("Initialize SSHG1G2 first!")
            if p0 is not None:
                if terminator:
                    Afit = estimate_sso_params(
                        data["cmred"].values[0][filter_mask],
                        data["csigmapsf"].values[0][filter_mask],
                        np.radians(data["Phase"].values[0][filter_mask]),
                        data["cfid"].values[0][filter_mask],
                        ra=np.radians(data["ra"].values[0][filter_mask]),
                        dec=np.radians(data["dec"].values[0][filter_mask]),
                        jd=data["cjd"].values[0][filter_mask],
                        model="SSHG1G2",
                        p0=p0,
                        terminator=terminator,
                        ra_s=np.radians(data["ra_s"].values[0][filter_mask]),
                        dec_s=np.radians(data["dec_s"].values[0][filter_mask]),
                    )
                else:
                    Afit = estimate_sso_params(
                        data["cmred"].values[0][filter_mask],
                        data["csigmapsf"].values[0][filter_mask],
                        np.radians(data["Phase"].values[0][filter_mask]),
                        data["cfid"].values[0][filter_mask],
                        ra=np.radians(data["ra"].values[0][filter_mask]),
                        dec=np.radians(data["dec"].values[0][filter_mask]),
                        jd=data["cjd"].values[0][filter_mask],
                        model="SSHG1G2",
                        p0=p0,
                        terminator=terminator,
                    )
                return Afit
    if flavor not in ["SHG1G2", "SSHG1G2"]:
        print("Model must either be SHG1G2 or SSHG1G2, not {}".format(flavor))


def get_model_points(data, params):
    """
    Compute modeled magnitudes for a dataset using SHG1G2.

    For each unique filter in the data, this function applies the SHG1G2 model
    to the corresponding subset of observations.

    Parameters
    ----------
    data : pandas.DataFrame
        Dataset containing at least the following columns:
        - 'Phase' : solar phase angles (deg)
        - 'ra' : right ascension (deg)
        - 'dec' : declination (deg)
        - 'cfid' : filter IDs (int)
    params : dict
        Model parameters containing keys:
        - 'H_i', 'G1_i', 'G2_i' for each filter i
        - 'R' : oblateness
        - 'alpha0', 'delta0' : pole coordinates in degrees

    Returns
    -------
    tuple of lists
        - model_points_stack : list of numpy.ndarray
            Modeled magnitudes for each filter.
        - index_points_stack : list of numpy.ndarray
            Indices of the original data points corresponding to each modeled subset.
    """

    model_points_stack = []
    index_points_stack = []
    index = np.array([ind for ind in range(len(data["cfid"].values[0]))])

    for i, f in enumerate(np.unique(data["cfid"].values[0])):
        filter_mask = data["cfid"].values[0] == f

        model_params = [
            params["H_{}".format(f)],
            params["G1_{}".format(f)],
            params["G2_{}".format(f)],
            params["R"],
            np.radians(params["alpha0"]),
            np.radians(params["delta0"]),
        ]

        model_points = func_hg1g2_with_spin(
            [
                np.radians(data["Phase"].values[0][filter_mask]),
                np.radians(data["ra"].values[0][filter_mask]),
                np.radians(data["dec"].values[0][filter_mask]),
            ],
            *model_params,
        )
        index_points_stack.append(index[filter_mask])
        model_points_stack.append(model_points)

    return model_points_stack, index_points_stack


def get_residuals(data, params):
    """
    Compute residuals between observed and modeled magnitudes for a dataset.

    Parameters
    ----------
    data : pandas.DataFrame
        Dataset containing at least the following columns:
        - 'cmred' : observed reduced magnitudes
        - 'Phase' : solar phase angles (deg)
        - 'ra' : right ascension (deg)
        - 'dec' : declination (deg)
        - 'cfid' : filter IDs (int)
    params : dict
        Model parameters including H, G1, G2 for each filter, pole coordinates,
        and oblateness. Keys should match those expected by `get_model_points`.

    Returns
    -------
    numpy.ndarray
        Residuals (observed - modeled magnitudes) for all data points,
        ordered according to the original dataset.
    """

    pstack, istack = get_model_points(data, params)
    fpstack, fistack = utils.flatten_list(pstack), utils.flatten_list(istack)
    df_to_sort = pd.DataFrame({"mpoints": fpstack}, index=fistack)
    df_to_sort = df_to_sort.sort_index()
    df_to_sort["observation"] = data["cmred"].values[0]
    return (df_to_sort["observation"] - df_to_sort["mpoints"]).values


def make_residuals_df(data, model_parameters):
    """
    Create a DataFrame of residuals between observed and modeled magnitudes.

    Parameters
    ----------
    data : pandas.DataFrame
        Dataset containing at least the following columns:
        - 'cmred' : observed reduced magnitudes
        - 'csigmapsf' : photometric uncertainties
        - 'Phase' : solar phase angles (deg)
        - 'ra' : right ascension (deg)
        - 'dec' : declination (deg)
        - 'cfid' : filter IDs (int)
        - 'cjd' : observation times
    model_parameters : dict
        Model parameters including H, G1, G2 for each filter, pole coordinates,
        and oblateness. Keys should match those expected by `get_model_points`.

    Returns
    -------
    pandas.DataFrame
        DataFrame indexed by observation index, with columns:
        - 'mpoints' : modeled magnitudes
        - 'mred' : observed reduced magnitudes
        - 'sigma' : observational uncertainties
        - 'filters' : filter IDs
        - 'jd' : observation times
        - 'residuals' : difference between observed and modeled magnitudes
          (mred - mpoints)
    """
    mpoints, indices = get_model_points(data=data, params=model_parameters)
    flat_mpoints, flat_index = utils.flatten_list(mpoints), utils.flatten_list(indices)

    residual_df = pd.DataFrame({"mpoints": flat_mpoints}, index=flat_index)
    residual_df = residual_df.sort_index()
    residual_df["mred"] = data["cmred"].values[0]
    residual_df["sigma"] = data["csigmapsf"].values[0]
    residual_df["filters"] = data["cfid"].values[0]
    residual_df["jd"] = data["cjd"].values[0]
    residual_df["residuals"] = residual_df["mred"] - residual_df["mpoints"]

    return residual_df
