import geopandas as gpd
from geopandas import GeoDataFrame
import pandas as pd
from typing import Dict


def erase(input_gdf, erase_gdf=None):
    """
    Erase erase_gdf from input_gdf.

    Parameters
    ----------
    input_gdf : GeoDataFrame
        Input GeoDataFrame.
    erase_gdf : GeoDataFrame
        Erase GeoDataFrame containing the erase features.

    Returns
    -------
    output : GeoDataFrame
        The remaining features after erasure.
    """
    if erase_gdf is None:
        return input_gdf
    else:
        return gpd.overlay(
            input_gdf, gpd.GeoDataFrame({'geometry': erase_gdf.unary_union}),
            how='difference'
        )


def spatial_join(target_gdf, join_gdf, op="intersects",
                 cols_agg: Dict[str, set] = None,
                 join_type="one to one", keep_all=True):
    """
    Spatial join two GeoDataFrames.

    Parameters
    ----------
    target_gdf, join_gdf : GeoDataFrames
        The GeoDataFrame to join to the target GeoDataFrame.
    op : string, default 'intersects'
        Binary predicate, one of {'intersects', 'contains', 'within'}. See
        http://shapely.readthedocs.io/en/latest/manual.html#binary-predicates.
    cols_agg : dict, default None
        Dict of ``{column_name: set of statistics}``, where the set of
        statistics is a set of strings containing the names of desired
        statistics for each column. Names of the statistics include:
        {'first', 'last', 'sum', 'mean', 'median', 'max', 'min',
        'std', 'var', 'count', 'size'}.
    join_type : string, default 'one to one'
        Binary predicate, one of {'one to one', 'one to many'}. The option
        'one to one' only returns one row for each target feature, whereas
        option 'one to many' return multiple rows for each match between
        target feature and join feature.
    keep_all : bool, default True
        Whether to keep all features from the target GeoDataFrame.
    Returns
    -------
    GeoDataFrame
        A GeoDataFrame contains all columns in the target GeoDataFrame and the
        specified columns from the join GeoDataFrame.
    """
    how = 'left' if keep_all else 'inner'
    gpd_sjoin = gpd.sjoin(target_gdf, join_gdf, how=how, op=op)

    if join_type.lower() == "one to one":
        sjoin_by_index = gpd_sjoin.groupby(gpd_sjoin.index)

        if cols_agg is None:
            cols_agg = {col: ['first'] for col in join_gdf.columns
                        if col != join_gdf.geometry.name}
            join_df = sjoin_by_index.agg(cols_agg)
            join_df.columns = cols_agg.keys()
        else:
            join_df = sjoin_by_index.agg(cols_agg)
            join_df.columns = [f"{key}_{v}"
                               for key, value in cols_agg.items()
                               for v in value]

        # remove duplicated rows generated by geopandas spatial join
        target_df = gpd_sjoin[target_gdf.columns].drop_duplicates()
        return pd.concat([target_df, join_df], axis=1)
    elif join_type.lower() == "one to many":
        return gpd_sjoin
    else:
        raise ValueError("join_type must be either 'one to one' or "
                         "'one to many'")


def within_dist(input_gdf, input_id, distance,
                target_gdf=None, output_col=None):
    """
    For each object in the input, test if any object in the target set is
    within a specified distance.

    Parameters
    ----------
    input_gdf : GeoDataFrame
        The input set.
    input_id : str
        The name of the column containing the unique id of the input set.
    distance : int or float
        Distance (in the same unit as the input GeoDataFrame).
    target_gdf : GeoDataFrame
        The target set.
    output_col : str
        The name of the output column.

    Returns
    -------
    output : GeoDataFrame
        The output value is 1 if there exists any target object within the
        specified distance of the input object and 0 otherwise.
    """
    if output_col is None:
        output_col = "within_" + str(distance)
    input_gdf[output_col] = 0
    if len(target_gdf) > 0:
        sjoin_result = gpd.sjoin(
            input_gdf, target_gdf.assign(geom=target_gdf.buffer(distance)),
            how='inner', op='intersects'
        )[input_id]
        input_gdf.loc[input_gdf[input_id].isin(sjoin_result), output_col] = 1
    return input_gdf


def select_by_location(input_gdf, select_gdf,
                       op='intersects', within_dist=0):
    """
    Select part of the input GeoDataFrame based on its relationship with the
    selecting GeoDataFrame.

    Parameters
    ----------
    input_gdf : GeoDataFrame
        The input GeoDataFrame.
    select_gdf : GeoDataFrame
        The selecting GeoDataFrame.
    op : string, default 'intersection'
        Binary predicate, one of {'intersects', 'contains', 'within',
        'within a distance'}. See
        http://shapely.readthedocs.io/en/latest/manual.html#binary-predicates.
    within_dist : int, default 0
        Search distance around the select_gdf. This parameter is only
        useful when op is set to be "within a distance".
    Returns
    -------
    output : GeoDataFrame
        The selected features from the input GeoDataFrame.
    """
    ops = ['intersects', 'contains', 'within', 'within a distance']
    assert op in ops, 'invalid op parameter,'
    if op == 'within a distance' and within_dist:
        select_gdf[select_gdf.geometry.name] = select_gdf.buffer(within_dist)
        op = 'within'
    output_gdf = input_gdf.loc[
                 input_gdf.index.to_series().isin(
                     gpd.sjoin(
                         input_gdf, select_gdf,
                         how='inner', op=op
                     ).index.values
                 ), :
                 ]
    output_gdf = output_gdf.rename_axis(None, axis=1)
    return output_gdf.copy()
