# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../notebooks/properties/00_mva.ipynb.

# %% auto 0
__all__ = ['minvar', 'calc_mva_features', 'calc_maxiumum_variance_direction', 'fit_maxiumum_variance_direction',
           'calc_candidate_mva_features']

# %% ../../../notebooks/properties/00_mva.ipynb 1
import xarray as xr
import numpy as np
import pandas as pd

from lmfit.models import StepModel, ConstantModel
from lmfit import Parameters

from typing import Literal

# %% ../../../notebooks/properties/00_mva.ipynb 3
def minvar(data: np.ndarray):
    """
    see `pyspedas.cotrans.minvar`

    This program computes the principal variance directions and variances of a
    vector quantity as well as the associated eigenvalues.

    Parameters
    -----------
    data:
        Vxyz, an (npoints, ndim) array of data(ie Nx3)

    Returns
    -------
    vrot:
        an array of (npoints, ndim) containing the rotated data in the new coordinate system, ijk.
        Vi(maximum direction)=vrot[0,:]
        Vj(intermediate direction)=vrot[1,:]
        Vk(minimum variance direction)=Vrot[2,:]
    v:
        an (ndim,ndim) array containing the principal axes vectors
        Maximum variance direction eigenvector, Vi=v[*,0]
        Intermediate variance direction, Vj=v[*,1] (descending order)
    w:
        the eigenvalues of the computation
    """

    #  Min var starts here
    # data must be Nx3
    vecavg = np.nanmean(np.nan_to_num(data, nan=0.0), axis=0)

    mvamat = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            mvamat[i, j] = (
                np.nanmean(np.nan_to_num(data[:, i] * data[:, j], nan=0.0))
                - vecavg[i] * vecavg[j]
            )

    # Calculate eigenvalues and eigenvectors
    w, v = np.linalg.eigh(mvamat, UPLO="U")

    # Sorting to ensure descending order
    w = np.abs(w)
    idx = np.flip(np.argsort(w))

    # IDL compatability
    if True:
        if np.sum(w) == 0.0:
            idx = [0, 2, 1]

    w = w[idx]
    v = v[:, idx]

    # Rotate intermediate var direction if system is not Right Handed
    YcrossZdotX = v[0, 0] * (v[1, 1] * v[2, 2] - v[2, 1] * v[1, 2])
    if YcrossZdotX < 0:
        v[:, 1] = -v[:, 1]
        # v[:, 2] = -v[:, 2] # Should not it is being flipped at Z-axis?

    # Ensure minvar direction is along +Z (for FAC system)
    if v[2, 2] < 0:
        v[:, 2] = -v[:, 2]
        v[:, 1] = -v[:, 1]

    vrot = np.array([np.dot(row, v) for row in data])

    return vrot, v, w

# %% ../../../notebooks/properties/00_mva.ipynb 5
def calc_mva_features(data: np.ndarray):
    """
    Compute MVA features based on the given data array.

    Parameters:
    - data (np.ndarray): Input data

    Returns:
    - List: Computed features
    """

    # Compute variance properties
    vrot, v, eigs = minvar(data)

    # Maximum variance direction eigenvector
    Vl = v[:, 0]
    Vn = v[:, 2]

    vec_mag = np.linalg.norm(vrot, axis=1)

    # Compute changes in each component of B_rot
    dvec = vrot[0] - vrot[-1]

    # Compute mean values
    vec_mag_mean = np.mean(vec_mag)
    vec_n_mean = np.mean(vrot[:, 2])
    VnOverVmag = vec_n_mean / vec_mag_mean

    # Compute relative changes in magnitude
    dvec_mag = vec_mag[-1] - vec_mag[0]
    dBOverB = np.abs(dvec_mag / vec_mag_mean)
    dBOverB_max = (np.max(vec_mag) - np.min(vec_mag)) / vec_mag_mean

    result = {
        "Vl": Vl,
        "Vn": Vn,
        "b_mag": vec_mag_mean,
        "b_n": vec_n_mean,
        "B.vec.before": vrot[0],
        "B.vec.after": vrot[-1],
        "B.before": vec_mag[0],
        "B.after": vec_mag[-1],
        "db_mag": dvec_mag,
        "bn_over_b": VnOverVmag,
        "db_over_b": dBOverB,
        "db_over_b_max": dBOverB_max,
        "dB_lmn": dvec,
    }

    result = pd.Series(result)
    return result, vrot

# %% ../../../notebooks/properties/00_mva.ipynb 6
def calc_maxiumum_variance_direction(data: xr.DataArray, datetime_unit="s", **kwargs):
    d_data = data.differentiate("time", datetime_unit=datetime_unit)
    return pd.Series({"d_star": abs(d_data).max(dim="time").item()})

# %% ../../../notebooks/properties/00_mva.ipynb 9
def fit_maxiumum_variance_direction(
    data: xr.DataArray, datetime_unit="s", return_best_fit: bool = False, **kwargs
):
    """
    Fit maximum variance direction data by model

    Note:
        - see `datetime_to_numeric` in `xarray.core.duck_array_ops` for more details about converting datetime to numeric
        - Xarray uses the numpy dtypes datetime64[ns] and timedelta64[ns] to represent datetime data.
    """
    time = data["time"].values
    x = (time - min(time)) / np.timedelta64(1, datetime_unit)
    y = data.values

    x_min, x_max = min(x), max(x)
    x_width = x_max - x_min

    # Create a model
    step_mod = StepModel(form="logistic")
    const_mod = ConstantModel()
    mod = step_mod + const_mod

    # Create parameters
    params = Parameters()

    init_amplitude = y[-1] - y[0]

    params.add(
        "center",
        value=(x_max + x_min) / 2.0,
        # min=x_min + x_width / 7.0,
        # max=x_max - x_width / 7.0,
    )
    params.add(
        "amplitude",
        value=init_amplitude,
        max=abs(init_amplitude) * 2.0,
        min=-abs(init_amplitude) * 2.0,
    )
    params.add("sigma", value=x_width / 7.0, min=0)

    if True:
        int_center_y = (y[-1] + y[0]) / 2.0
        c1 = 1 / 8
        params.add(
            "center_y",
            min=int_center_y - c1 * np.abs(init_amplitude),
            max=int_center_y + c1 * np.abs(init_amplitude),
        )
        params.add("c", expr="center_y - amplitude / 2.0")

    # Ensure there are enough data points to fit the model
    if len(y) < 4:
        # Not enough data points, return None or an empty result instead of raising an error
        amplitude = sigma = center = rsquared = chisqr = c = best_fit = np.nan
    else:
        out = mod.fit(y, params, x=x)
        amplitude = out.params["amplitude"].value
        sigma = out.params["sigma"].value
        center = out.params["center"].value
        c = out.params["c"].value
        rsquared = out.rsquared
        chisqr = out.chisqr
        best_fit = out.best_fit

    max_df = amplitude / (4 * sigma)

    d_time = min(time) + center * np.timedelta64(1, datetime_unit).astype(
        "timedelta64[ns]"
    )

    result = pd.Series(
        {
            "fit.vars.amplitude": amplitude,
            "fit.vars.sigma": sigma,
            "t.d_time": d_time,
            "d_star": max_df,
            "fit.vars.c": c,
            "fit.stat.rsquared": rsquared,
            "fit.stat.chisqr": chisqr,
        }
    )
    if return_best_fit:
        result["fit.best_fit"] = best_fit
        result["fit.time"] = time
    return result

# %% ../../../notebooks/properties/00_mva.ipynb 10
# def calc_ts_mva_features(data, method=Literal["fit", "derivative"], **kwargs):
#     mva_features, vrot = calc_mva_features(data.to_numpy())
#     event_data_l = xr.DataArray(vrot[:, 0], dims=["time"], coords={"time": data.time})

#     if method == "fit":
#         result = fit_maxiumum_variance_direction(event_data_l, **kwargs)
#     elif method == "derivative":
#         result = calc_maxiumum_variance_direction(event_data_l, **kwargs)
#     return pd.concat([mva_features, result])


def calc_candidate_mva_features(
    event, data: xr.DataArray, method=Literal["fit", "derivative"], **kwargs
):
    event_data = data.sel(time=slice(event["t.d_start"], event["t.d_end"]))

    mva_features, vrot = calc_mva_features(event_data.to_numpy())

    event_data_l = xr.DataArray(
        vrot[:, 0], dims=["time"], coords={"time": event_data.time}
    )

    if method == "fit":
        result = fit_maxiumum_variance_direction(event_data_l, **kwargs)
    elif method == "derivative":
        result = calc_maxiumum_variance_direction(event_data_l, **kwargs)
    return pd.concat([mva_features, result])
