# AUTOGENERATED! DO NOT EDIT! File to edit: notebooks/01_data_processing.ipynb (unless otherwise specified).

__all__ = ['AL_teams', 'add_pitcher_team', 'add_postouts', 'outs_per_inning', 'batters_faced',
           'get_games_pitchers_years', 'preliminary_clean', 'last', 'aggregate_at_bats', 'create_indicator',
           'accumulate', 'feature_engineering']

# Internal Cell
import pandas as pd
import sqlite3
import numpy as np
from typing import List

# Cell

AL_teams = [
    "MIN",
    "CLE",
    "DET",
    "HOU",
    "BOS",
    "TOR",
    "LAA",
    "BAL",
    "KC",
    "NYY",
    "CWS",
    "TEX",
    "TB",
    "OAK",
    "SEA",
]

# Cell

def add_pitcher_team(row):
    """
    ```python
    df["pitcher_team"] = df.apply(lambda row: add_pitcher_team(row), axis=1)
    ```
    """
    if row.iloc[row.index.get_loc("inning_topbot")] == "Bot":
        return row.iloc[row.index.get_loc("away_team")]
    return row.iloc[row.index.get_loc("home_team")]

# Cell

# utility functions for identifying openers



def add_postouts(game_team_df: pd.DataFrame):
    """
    Appends a `"postouts"` column to DataFrame, which is the number of outs at the end of the at-bat.

    * input:
        - `game_team_df`: `pd.DataFrame`, df of pitches thrown in single game by a single team, sorted by `at_bat_number`

    * output:
        - `game_df`: `pd.DataFrame`, same as input, with the added `"postouts"` column.
    """
    # put assert here to ensure that the df is sorted

    # getting postouts for entire game for a single team
    game_team_df.loc[:, "postouts"] = game_team_df["outs_when_up"].shift(-1).fillna(method="ffill")

    # if the inning changed, then the postouts is 3
    # previously, forgot to forward fill, this could have been huge problem
    game_team_df.loc[(game_team_df["inning"] != game_team_df["inning"].shift(-1).fillna(method="ffill")), "postouts"] = 3

    # checking for complete games
    # if last inning and last at-bat, and 9th inning or later, then postouts == 3
    last_ab = game_team_df["at_bat_number"].max()
    last_inning = game_team_df["inning"].max()
    game_team_df.loc[((game_team_df["at_bat_number"]==last_ab) &
                      (game_team_df["inning"]==last_inning)), "postouts"] = 3

    return game_team_df


def outs_per_inning(x: pd.Series):
    """
    An aggregation function that takes the sum of a one-time step difference in a `pd.Series`.
    Intended to be used in a groupby aggregation to calculate the number of outs recorded in an inning.

    * **usage**:

    ```python
    df.groupby(["inning"]).agg({"postouts": outs_per_inning})
    ```

    * input:
        - `x`: `pd.Series`

    * output:
        - sum of one time-step differences in `x`
    """
    return (x - x.shift(1).fillna(0)).sum()


def batters_faced(at_bats: pd.Series):
    """
    For a series of at bat numbers (presumably, many repeated), this function
    returns the number of unique at bats.

    * input:
        - `at_bats`: `pd.Series`, pitches in an individual

    * output:
        - `int`, number of at bats
    """
    return at_bats.unique().shape[0]

# Cell

def get_games_pitchers_years(df: pd.DataFrame, verbose: bool=True):
    """
    Filter out openers to get all game-pitcher combinations that qualify
    """
    # get unique game ids from regular season games
    games = np.sort(df.loc[(df["game_type"]=="R"), "game_pk"].unique())
    if verbose:
        print(f"In this dataset, there are {len(games)} total games.")

    # This will be list of tuples for each game and pitcher to analyze
    games_pitchers_years = []

    # identifying "opener" candidates
    for game in games:
        # getting df of game data and saving year
        game_df = df.loc[(df["game_pk"]==game)]
        year = int(game_df["game_year"].iloc[0])

        # getting sorted (by at bat) df for a specific game
        game_df = game_df.sort_values("at_bat_number", ascending=True)

        # first pitcher for each team is throwing at min(at_bat_number)
        home_pitcher_first_ab = game_df.loc[(game_df["inning_topbot"]=="Top"), "at_bat_number"].min()
        home_team = game_df["home_team"].head(1).item()
        home_pitcher = game_df.loc[(game_df["at_bat_number"]==home_pitcher_first_ab), "pitcher"].head(1).item()

        away_pitcher_first_ab = game_df.loc[(game_df["inning_topbot"]=="Bot", "at_bat_number")].min()
        away_team = game_df["away_team"].head(1).item()
        away_pitcher = game_df.loc[(game_df["at_bat_number"]==away_pitcher_first_ab), "pitcher"].head(1).item()

        # adding pitcher_team
        game_df.loc[:, "pitcher_team"] = game_df.apply(lambda row: add_pitcher_team(row), axis=1)

        # check if either are "openers"
        for pitcher, team in ((home_pitcher, home_team), (away_pitcher, away_team)):

            # adding postouts for entire game for a single team
            game_team_df = game_df.loc[(game_df["pitcher_team"]==team)]
            game_team_df = add_postouts(game_team_df)

            # subsetting to get pitches thrown by the starter
            game_team_pitcher_df = game_team_df.loc[(game_team_df["pitcher"]==pitcher)]

            # getting criteria to check if opener
            outs = game_team_pitcher_df.groupby(["inning"]).agg({"postouts": outs_per_inning}).sum().item()
            n_batters = batters_faced(game_team_pitcher_df["at_bat_number"])
            opener = outs < 7 or n_batters < 10

            # must not be opener, be from an AL team, and be playing in an AL stadium
            if not opener and (team in AL_teams) and (home_team in AL_teams):
                games_pitchers_years.append((game, pitcher, year))

    if verbose:
        print(f"There are {(len(games)*2) - len(games_pitchers_years)} ineligible starts in the dataset (either 'openers' or an NL team).")
        print(f"There are {len(games_pitchers_years)} total eligible game-pitcher combinations in this dataset.")

    return games_pitchers_years

# Cell

def preliminary_clean(df: pd.DataFrame, g: int, p: int):
    """
    Before aggregating, perform a preliminary cleaning of dataset

    * inputs:
        - `df`: `pd.DataFrame`, DataFrame of pitch-level data from eligible game-pitcher combos
        - `g`: `int`, unique game id
        - `p`: `int`, unique pitcher id

    * output:
        - `df`: `pd.DataFrame`, cleaned DataFrame, of pitch-level data from single pitcher in single game
    """
    # subsetting to get individual game
    game_df = df.loc[(df["game_pk"]==g)] # & (df["pitcher"]==p)].sort_values("at_bat_number", ascending=True)

    # adding pitcher_team
    game_df["pitcher_team"] = game_df.apply(lambda row: add_pitcher_team(row), axis=1)

    # finding team that pitcher is on
#     print("teams that pitcher is on", game_df.loc[(game_df["pitcher"]==p), "pitcher_team"].unique().shape[0]) # -> sanity check
    pitcher_team = game_df.loc[(game_df["pitcher"]==p), "pitcher_team"].iloc[0]

    # adding postouts for entire game for a single team
    game_team_df = game_df.loc[(game_df["pitcher_team"]==pitcher_team)]
    game_team_df = add_postouts(game_team_df)

    # subsetting to get pitches thrown by the starter
    game_team_pitcher_df = game_team_df.loc[(game_team_df["pitcher"]==p)]

    # filling missing events with empty string so can aggregate easily
    game_team_pitcher_df["events"] = game_team_pitcher_df["events"].fillna("")

    # post_bat_score is not actually score after at-bat, needs to be lagged
    game_team_pitcher_df["post_bat_score"] = game_team_pitcher_df["post_bat_score"].shift(-1).fillna(method="ffill")

    # post runners on (need to lag -> this info is known in between at-bats)
    for base in (1, 2, 3):
        game_team_pitcher_df[f"post_on_{base}b"] = game_team_pitcher_df[f"on_{base}b"].fillna(0).apply(lambda x: 1 if x>0 else 0).shift(-1).fillna(method="ffill")

    # if next batter opposite handed
    game_team_pitcher_df["post_opposite_hand"] = (game_team_pitcher_df["stand"]!=game_team_pitcher_df["p_throws"]).astype(int).shift(-1).fillna(method="ffill")

    return game_team_pitcher_df

# Cell


def last(x: pd.Series):
    """
    Utility for grabbing last value in a pd.Series.
    Especially helpful when used in tandem with in pd.DataFrame.groupby.agg

    * input:
        - `x`: `pd.Series`

    * output:
        - last value in `x`
    """
    return x.iloc[-1]


def aggregate_at_bats(df: pd.DataFrame, at_bat_aggs: dict):
    """
    Aggregates statcast data from the pitch to at-bat level.
    Assumes the df has come straight out of preliminary clean.

    * input:
        - `df`: `pd.DataFrame`, Statcast pitch-level data (just went through preliminary clean)

    * output:
        - `agged_df`: `pd.DataFrame`, at-bat level aggregated DataFrame
    """
    agged_df = (
        df.groupby(by=["game_pk", "pitcher", "batter", "at_bat_number"])
        .agg(at_bat_aggs)
        .sort_values(by="at_bat_number")
        .reset_index()
    )
    return agged_df


# Cell

# helper feature engineering funcs


def create_indicator(
    df, col: str = "events", indicators: List = [], indicator_col_names: List = []
):
    """
    In the [statcast data](https://baseballsavant.mlb.com/statcast_search) the
    ["events"](https://baseballsavant.mlb.com/csv-docs#events) column
    is a textual recording of the event that occurred in the at bat.

    * inputs:
        - `df`:, `pd.DataFrame`, at-bat level statcast data
        - `col`: `str`, the column populated with indicators
        - `indicators`: `list`, the categorical variables to turn to indicators
        - `indicator_col_names`: `list`, alternative (optional) names for the new indicators

    * outputs:
        - `df`, `pd.DataFrame`, mutated DataFrame containing the new indicators
    """
    if not indicator_col_names:
        indicator_col_names = indicators
    for indicator, indicator_col_name in zip(indicators, indicator_col_names):
        df[indicator_col_name] = 0
        df.loc[(df[col] == indicator), indicator_col_name] = 1
    return df


def accumulate(df: pd.DataFrame, col: str, agg_func:str="cumsum"):
    """
    Utility to perform a cumulative accumulation of a single column

    * input:
        - `df`: `pd.DataFrame`, DataFrame of at-bat leve statcast data
        - `col`: `str`, column to accumulate
        - `agg_func`: `str`, string recognized by Pandas as a function

    * output:
        - `df`: `pd.DataFrame`, mutated DataFrame with additional accumulated columns
    """
    if not agg_func.startswith("cum"):
        raise Warning(
            "Are you sure you want to accumulate with a non-cumulative aggregation function?"
        )

    df[f"{col}_{agg_func}"] = df[col].agg([agg_func])
    return df


# Cell


def feature_engineering(df: pd.DataFrame):
    """
    Performs feature engineering on at-bat level statcast data.
    df should come directly out of `aggregate_at_bats`.

    * input:
        - `df`: `pd.DataFrame`, at-bat level statcast data

    * output:
        - `df`: `pd.DataFrame`, mutated to have many new features.
    """
    # strike-ball ratio
    df["cum_balls"] = df["balls"].cumsum()
    df["cum_strikes"] = df["strikes"].cumsum()
    df["cum_sb_ratio"] = df["cum_strikes"] / (df["cum_balls"] + 1)

    # end of inning
    df["end_inning"] = df["postouts"].apply(lambda x: 1 if (x == 3) else 0)

    # times through order
    df["times_thru_order"] = [1 / 9 * i for i in range(1, len(df) + 1)]

    # score diff
    df["score_diff"] = df["post_fld_score"] - df["post_bat_score"]

    # post total runners
    df["post_total_runners"] = df[["post_on_1b", "post_on_2b", "post_on_3b"]].sum(
        axis=1
    )

    # tying run or leading run on base
    df["tying_run_on"] = (
        (df["score_diff"].isin((0, 1))) & (df["post_total_runners"] >= 1)
    ).astype(int)

    # pitch total
    df["pitch_total"] = df["pitch_number"].cumsum()

    # adding unique category for each team-year combo (for embeddings later)
    df["pitcher_team_year"] = (
        df["pitcher_team"] + "_" + df["game_year"].astype(int).astype(str)
    )

    # creating indicator cols for different events
    df = create_indicator(
        df,
        col="events",
        indicators=["strikeout", "walk", "single", "double", "triple", "home_run"],
    )

    # cumulative stats
    cum_cols = [
        ("strikeout", "cumsum"),
        ("walk", "cumsum"),
        ("single", "cumsum"),
        ("double", "cumsum"),
        ("triple", "cumsum"),
        ("home_run", "cumsum"),
    ]
    for col, agg_func in cum_cols:
        df = accumulate(df, col=col, agg_func=agg_func)

    # cumulative bases
    df["bases_cumsum"] = (
        df["walk_cumsum"]
        + df["single_cumsum"]
        + (2 * df["double_cumsum"])
        + (3 * df["triple_cumsum"])
        + (4 * df["home_run_cumsum"])
    )

    return df
