# AUTOGENERATED! DO NOT EDIT! File to edit: ../../notebooks/03_mag_plasma.ipynb.

# %% auto 0
__all__ = ['interpolate', 'interpolate2', 'combine_features', 'calc_plasma_parameter_change', 'calc_combined_features',
           'update_events_with_plasma_data', 'update_events_with_temp_data', 'update_events']

# %% ../../notebooks/03_mag_plasma.ipynb 1
import polars as pl
import polars.selectors as cs
from beforerr.polars import decompose_vector, format_time
from space_analysis.plasma.formulary.polars import (
    df_Alfven_speed,
    df_Alfven_current,
    df_inertial_length,
    df_gradient_current,
)
from space_analysis.meta import PlasmaDataset
from .utils.ops import vector_project_pl
from .core.propeties import df_rotation_angle
from .naming import DENSITY_COL, FIT_AMPL_COL
from .utils.naming import standardize_plasma_data
from loguru import logger

# %% ../../notebooks/03_mag_plasma.ipynb 2
def _interpolate(
    df: pl.DataFrame, on="time", method="index", limit=1, limit_direction="both"
):
    # Note: limit is set to 1 to improve the confidence of the interpolation
    # Related: https://github.com/pola-rs/polars/issues/9616
    return pl.from_pandas(
        df.to_pandas()
        .set_index(on)
        .sort_index()
        .interpolate(
            method=method,
            limit=limit,
            limit_direction=limit_direction,
        )
        .reset_index()
    )


def interpolate(df: pl.DataFrame, on="time"):
    return df.sort(on).with_columns(cs.numeric().interpolate_by(on))


def interpolate2(df1: pl.DataFrame, df2, **kwargs):
    return pl.concat([df1, df2], how="diagonal_relaxed").pipe(interpolate, **kwargs)

# %% ../../notebooks/03_mag_plasma.ipynb 3
from fastcore.all import concat  # noqa: E402


def combine_features(
    events: pl.DataFrame,
    states_data: pl.DataFrame,
    plasma_meta: PlasmaDataset = PlasmaDataset(),
    method: str = "interpolate",
    left_on="t.d_time",
    right_on="time",
    subset=False,
):
    if subset:
        m = plasma_meta
        subset_cols = concat([m.density_col, m.velocity_cols, m.temperature_col])
        subset_cols = [item for item in subset_cols if item is not None]  # remove None
        subset_cols = subset_cols + [right_on]
        states_data = states_data.select(subset_cols)

    # change time format: see issue: https://github.com/pola-rs/polars/issues/12023
    states_data = states_data.pipe(format_time).sort(right_on)
    events = events.pipe(format_time).sort(left_on)

    df = events.join_asof(
        states_data, left_on=left_on, right_on=right_on, strategy="nearest"
    ).drop(right_on + "_right")

    if method == "interpolate":
        before_df = interpolate2(df.select(time=pl.col("t.d_start")), states_data)
        after_df = interpolate2(df.select(time=pl.col("t.d_end")), states_data)
        return (
            df.sort("t.d_start")
            .join(
                before_df,
                left_on="t.d_start",
                right_on=right_on,
                suffix=".before",
            )
            .sort("t.d_end")
            .join(
                after_df,
                left_on="t.d_end",
                right_on=right_on,
                suffix=".after",
            )
        )

    elif method == "nearest":
        return (
            df.sort("t.d_start")
            .join_asof(
                states_data,
                left_on="t.d_start",
                right_on=right_on,
                strategy="backward",
                suffix=".before",
            )
            .sort("t.d_end")
            .join_asof(
                states_data,
                left_on="t.d_end",
                right_on=right_on,
                strategy="forward",
                suffix=".after",
            )
        )
    else:
        return df

# %% ../../notebooks/03_mag_plasma.ipynb 6
def calc_plasma_parameter_change(
    df: pl.DataFrame,
    plasma_meta: PlasmaDataset = PlasmaDataset(),
):
    n_col = plasma_meta.density_col or DENSITY_COL
    n_before_col = f"{n_col}.before"
    n_after_col = f"{n_col}.after"

    if plasma_meta.temperature_col:
        col = plasma_meta.temperature_col
        df = df.with_columns(
            (pl.col(f"{col}.after") - pl.col(f"{col}.before")).alias(f"{col}.change")
        )

    if plasma_meta.speed_col:
        col = plasma_meta.speed_col
        df = df.with_columns(
            (pl.col("v.ion.after") - pl.col("v.ion.before")).alias("v.ion.change")
        )

    return (
        df.pipe(
            df_Alfven_speed,
            density=n_before_col,
            B="B.vec.before.l",
            col_name="v.Alfven.before.l",
            sign=True,
        )
        .pipe(
            df_Alfven_speed,
            density=n_after_col,
            B="B.vec.after.l",
            col_name="v.Alfven.after.l",
            sign=True,
        )
        .pipe(
            df_Alfven_speed,
            B=FIT_AMPL_COL,
            density=n_col,
            col_name="v.Alfven.change.l.fit",
            sign=False,
        )
        .with_columns(
            (pl.col(n_after_col) - pl.col(n_before_col)).alias("n.change"),
            (pl.col("v.ion.after.l") - pl.col("v.ion.before.l")).alias(
                "v.ion.change.l"
            ),
            (pl.col("B.after") - pl.col("B.before")).alias("B.change"),
            (pl.col("v.Alfven.after.l") - pl.col("v.Alfven.before.l")).alias(
                "v.Alfven.change.l"
            ),
        )
    )

# %% ../../notebooks/03_mag_plasma.ipynb 7
def calc_mag_features(
    df: pl.DataFrame,
    b_cols: list[str],
    normal_cols: list[str] = ["k_x", "k_y", "k_z"],
):
    b_norm = pl.col("b_mag")

    return df.pipe(
        df_rotation_angle, b_cols, normal_cols, name="theta_n_b"
    ).with_columns(
        (cs.by_name(b_cols) / b_norm).name.suffix("_norm"),
    )

# %% ../../notebooks/03_mag_plasma.ipynb 8
def calc_combined_features(
    df: pl.DataFrame,
    detail: bool = True,
    b_norm_col="b_mag",
    normal_cols: list[str] = ["k_x", "k_y", "k_z"],
    Vl_cols=["Vl_x", "Vl_y", "Vl_z"],
    Vn_cols=["Vn_x", "Vn_y", "Vn_z"],
    thickness_cols=["L_k"],
    current_cols=["j0_k"],
    plasma_meta: PlasmaDataset = None,
):
    """Calculate the combined features of the discontinuity

    Args:
        df (pl.DataFrame): _description_
        normal_cols (list[str], optional): normal vector of the discontinuity plane. Defaults to [ "k_x", "k_y", "k_z", ].
        detail (bool, optional): _description_. Defaults to True.
        Vl_cols (list, optional): maxium variance direction vector of the magnetic field. Defaults to [ "Vl_x", "Vl_y", "Vl_z", ].
        Vn_cols (list, optional): minimum variance direction vector of the magnetic field. Defaults to [ "Vn_x", "Vn_y", "Vn_z", ].
        current_cols (list, optional): _description_. Defaults to ["j0_mn", "j0_k"].
    """

    length_norm = pl.col("ion_inertial_length")
    current_norm = pl.col("j_Alfven")

    vec_cols = plasma_meta.velocity_cols
    density_col = plasma_meta.density_col

    result = (
        df.pipe(vector_project_pl, vec_cols, Vl_cols, name="v_l")
        .pipe(vector_project_pl, vec_cols, Vn_cols, name="v_n")
        .pipe(vector_project_pl, vec_cols, normal_cols, name="v_k")
        .with_columns(
            pl.col("v_n").abs(),
            pl.col("v_k").abs(),
            # v_mn=(pl.col("plasma_speed") ** 2 - pl.col("v_l") ** 2).sqrt(),
        )
        .with_columns(
            L_k=pl.col("v_k") * pl.col("duration"),
            # NOTE: n direction is not properly determined for MVA analysis
            # j0_mn=pl.col("d_star") / pl.col("v_mn"),
            # L_n=pl.col("v_n") * pl.col("duration"),
            # L_mn=pl.col("v_mn") * pl.col("duration"),
            # NOTE: the duration is not properly determined for `max distance` method
            # L_k=pl.col("v_k") * pl.col("duration"),
        )
        .pipe(
            df_gradient_current, B_gradient="d_star", speed="v_k", col_name="j0_k"
        )  # TODO: d_star corresponding to dB/dt, which direction is not exactly perpendicular to the k direction
        .pipe(df_inertial_length, density=density_col)
        .pipe(df_Alfven_speed, B=b_norm_col, density=density_col)
        .pipe(df_Alfven_current, density=density_col)
        .with_columns(
            (cs.by_name(thickness_cols) / length_norm).name.suffix("_norm"),
            (cs.by_name(current_cols) / current_norm).name.suffix("_norm"),
        )
    )

    if detail:
        result = (
            result.pipe(
                vector_project_pl,
                [_ + ".before" for _ in vec_cols],
                Vl_cols,
                name="v.ion.before.l",
            )
            .pipe(
                vector_project_pl,
                [_ + ".after" for _ in vec_cols],
                Vl_cols,
                name="v.ion.after.l",
            )
            .pipe(decompose_vector, "B.vec.before", suffixes=[".l", ".m", ".n"])
            .pipe(decompose_vector, "B.vec.after", suffixes=[".l", ".m", ".n"])
            .pipe(calc_plasma_parameter_change, plasma_meta=plasma_meta)
        )

    return result

# %% ../../notebooks/03_mag_plasma.ipynb 9
def update_events_with_plasma_data(
    events: pl.DataFrame,
    plasma_data: pl.LazyFrame | None,
    plasma_meta: PlasmaDataset,
    **kwargs,
):
    if plasma_data is not None:
        events = combine_features(
            events,
            plasma_data.collect(),
            plasma_meta=plasma_meta,
            **kwargs,
        )

        events = calc_combined_features(
            events,
            plasma_meta=plasma_meta,
            **kwargs,
        )
    else:
        logger.info("Plasma data is not available.")

    return events

# %% ../../notebooks/03_mag_plasma.ipynb 10
def update_events_with_temp_data(
    events: pl.DataFrame,
    ion_temp_data: pl.LazyFrame | None,
    e_temp_data: pl.LazyFrame | None,
):
    left_on = "t.d_time"
    right_on = "time"

    events = events.pipe(format_time).sort(left_on)

    if ion_temp_data is not None:
        ion_temp_data = ion_temp_data.pipe(format_time).sort(right_on)
        events = events.join_asof(
            ion_temp_data.collect(), left_on=left_on, right_on=right_on
        ).drop(right_on + "_right")
    else:
        logger.info("Ion temperature data is not available.")

    if e_temp_data is not None:
        e_temp_data = e_temp_data.pipe(format_time).sort(right_on)
        events = events.join_asof(
            e_temp_data.collect(), left_on=left_on, right_on=right_on
        ).drop(right_on + "_right")
    else:
        logger.info("Electron temperature data is not available.")
    return events

# %% ../../notebooks/03_mag_plasma.ipynb 11
def update_events(
    events, plasma_data, plasma_meta, ion_temp_data, e_temp_data, **kwargs
):
    plasma_data = standardize_plasma_data(plasma_data, meta=plasma_meta)
    events = update_events_with_plasma_data(events, plasma_data, plasma_meta, **kwargs)
    events = update_events_with_temp_data(events, ion_temp_data, e_temp_data)
    return events
