import pyproj
import pandas as pd
import numpy as np

from functools import lru_cache
from shapely import ops
from shapely.wkt import loads  # Load into geometry namespace
from shapely.geometry import (
    Point,
    Polygon,
    LineString,
    MultiPoint,
    MultiPolygon,
    MultiLineString,
)
from numpy.linalg import norm
from collections import deque
from scipy.spatial import KDTree


@lru_cache(maxsize=5)
def project(from_crs, to_crs):
    return pyproj.Transformer.from_crs(
        pyproj.CRS(f"epsg:{from_crs:d}"), pyproj.CRS(f"epsg:{to_crs:d}"), always_xy=True
    ).transform


def _transform_single(geometry, from_crs=4326, to_crs=2193):
    return ops.transform(project(from_crs, to_crs), geometry)


def transform(geometry, from_crs=4326, to_crs=2193):
    if isinstance(geometry, list):
        if isinstance(geometry[0], Polygon):
            geometry = MultiPolygon(geometry)
        elif isinstance(geometry[0], LineString):
            geometry = MultiLineString(geometry)
        else:
            geometry = MultiPoint(geometry)
    return _transform_single(geometry, from_crs, to_crs)


def _build_point_layer(df):
    geometry, idx = [], []
    offset = 0.1  # Start/end offset in metres.
    for i, row in df.iterrows():
        # Add additional points over the first/last 20 metres to avoid
        # incorrect snapping at gaps.
        coords = deque(list(row.geometry.coords)[1:-2])
        for rr in range(20, 0, -5):
            coords.appendleft(
                row.geometry.interpolate(
                    rr / row.geometry.length, normalized=True
                ).coords[0]
            )
            coords.append(
                row.geometry.interpolate(
                    1 - rr / row.geometry.length, normalized=True
                ).coords[0]
            )

        # Modify the lines so that consecutive lines don't start/end right
        # on top of one another:
        coords.appendleft(
            row.geometry.interpolate(
                offset / row.geometry.length, normalized=True
            ).coords[0]
        )
        coords.append(
            row.geometry.interpolate(
                1 - offset / row.geometry.length, normalized=True
            ).coords[0]
        )

        # Extract the points and id for each feature:
        geometry += [Point(xy) for xy in coords]
        idx += [row.name] * len(coords)

    # Store individual points with corresponding id in a GeoDataFrame:
    return pd.DataFrame({"id": idx, "geometry": geometry})


def _coords(df):
    return [gg.coords[0] for gg in df.geometry.tolist()]


def _x_coords(coords):
    return [cc[0] for cc in coords]


def _y_coords(coords):
    return [cc[1] for cc in coords]


def _build_kdtree(df):
    return KDTree(np.array(list(zip(_x_coords(_coords(df)), _y_coords(_coords(df))))))


class Centreline(object):
    def __init__(self, df):
        # Used to find the displacement value of a point projected onto the
        # nearest road feature. The length of the target line is specified and
        # does not need to match the length of the geometry feature (this is
        # usually the case with RAMM features).
        #
        # df can be a GeoDataFrame, or a standard DataFrame containing the RAMM
        # carr_way table with the 'sh_direction' and 'sh_element_type' columns
        # from the RAMM roadname table.

        # The reference crs is used when projecting the point onto a line.
        self.ref_crs = 2193
        self._df_features = df
        self._df_points = _build_point_layer(df)
        self._kdtree = _build_kdtree(self._df_points)

    def nearest_feature(self, point, point_crs=4326):
        # Find the id of the feature nearest to a specified point.

        if point_crs != self.ref_crs:
            point = transform(point, point_crs, self.ref_crs)

        _, ii = self._kdtree.query(point.coords[0], 2)
        idx = self._df_points.loc[ii[0], "id"]

        # Calculate offset distance:
        p1 = np.array(self._df_points.loc[ii[0], "geometry"].coords)
        p2 = np.array(self._df_points.loc[ii[1], "geometry"].coords)
        p3 = np.array(point.coords)
        distance = (np.abs(np.cross(p2 - p1, p1 - p3)) / norm(p2 - p1))[0]

        return idx, distance

    def displacement(self, point, point_crs=4326):
        # Find the position along the line that is closest to the specified
        # point. Also returns the road_id.

        if point_crs != self.ref_crs:
            point = self.transform(point, point_crs, self.ref_crs)

        # Find the nearest line feature to the specified point:
        carr_way_no, offset_m = self.nearest_feature(point, self.ref_crs)

        start_m = self._df_features.loc[carr_way_no, "carrway_start_m"]
        length_m = self._df_features.loc[carr_way_no, "length_m"]

        position = self._df_features.geometry[carr_way_no].project(point, True)

        return (
            start_m + position * length_m,
            self._df_features.loc[carr_way_no, "road_id"],
            carr_way_no,
            offset_m,
        )

    def append_geometry(self, df, geometry_type="wkt"):
        """
        Append geometry to dataframe. Dataframe must contain road_id, start_m and end_m.

        Parameters
        ----------
        df : pd.DataFrame

        geometry_type : str
            Use "wkt" to return a well-known-text string. Use "coord" to return the
            coordinates of the start_m position.

        """
        if geometry_type not in ["wkt", "coord"]:
            raise AttributeError

        geometry = []
        for _, row in df.iterrows():
            geometry.append(
                self.extract_geometry(row["road_id"], row["start_m"], row["end_m"])
            )
        if geometry_type == "wkt":
            df["wkt"] = [gg.wkt for gg in geometry]
        elif geometry_type == "coord":
            coords = [gg.coords[0] for gg in geometry]
            df["easting"] = [cc[0] for cc in coords]
            df["northing"] = [cc[1] for cc in coords]
        return df

    def extract_geometry(self, road_id, start_m, end_m):
        """Extract the part of the centreline that corresponds to the section
        of interest.

        Parameters
        ----------
        road_id : int

        start_m : float

        end_m : float

        Returns
        -------
        sp.geometry.LineString
            Geometry object.

        """

        centreline = self._df_features

        selected_cways = centreline.loc[
            (centreline["road_id"] == road_id)
            & (centreline["carrway_end_m"] > start_m)
            & (centreline["carrway_start_m"] < end_m)
        ].sort_values("carrway_start_m")

        cway_no = selected_cways.index.tolist()
        if len(cway_no) == 0:
            return None

        # Find start:
        current_cway = centreline.loc[cway_no[0]]
        ref_pos = (start_m - current_cway["carrway_start_m"]) / current_cway["length_m"]
        extracted_coords = [
            current_cway["geometry"].interpolate(ref_pos, normalized=True).coords[0]
        ]

        for ii, cc in enumerate(current_cway["geometry"].coords):
            pos = current_cway["geometry"].project(Point(cc), normalized=True)
            if pos > ref_pos:
                break

        if len(selected_cways) == 1:
            # There both start_m and end_m lie on the same carr_way elements.

            # Find the end point:
            ref_pos = (end_m - current_cway["carrway_start_m"]) / current_cway[
                "length_m"
            ]
            for jj, cc in enumerate(current_cway["geometry"].coords):
                pos = current_cway["geometry"].project(Point(cc), normalized=True)
                if pos > ref_pos:
                    break

            # Build geometry:
            extracted_coords += current_cway["geometry"].coords[ii:jj]
            extracted_coords.append(
                current_cway["geometry"].interpolate(ref_pos, normalized=True).coords[0]
            )

        else:
            # Find the end point.
            # Select the last carr_way element
            current_cway = centreline.loc[cway_no[-1]]
            ref_pos = (end_m - current_cway["carrway_start_m"]) / current_cway[
                "length_m"
            ]
            for jj, cc in enumerate(current_cway["geometry"].coords):
                pos = current_cway["geometry"].project(Point(cc), normalized=True)
                if pos > ref_pos:
                    break

            # Build geometry.
            # Coords from first carr_way element:
            extracted_coords += centreline.loc[cway_no[0]]["geometry"].coords[ii:]

            if len(selected_cways) > 2:
                # There are additional (complete) carr_way elements between the
                # first and last carr_way elements:
                for kk in cway_no[1:-1]:
                    # The first point is shared with the previous carr_way
                    # element, so don't append it.
                    extracted_coords += centreline.loc[kk]["geometry"].coords[1:]

            # Add the points from the last carr_way element. Ignore the first
            # point (shared with previous carr_way element).
            extracted_coords += centreline.loc[cway_no[-1]]["geometry"].coords[1:jj]
            # Finally add the last point (interpolated).
            extracted_coords.append(
                centreline.loc[cway_no[-1]]["geometry"]
                .interpolate(ref_pos, normalized=True)
                .coords[0]
            )

        return LineString(extracted_coords)


def build_chainage_layer(centreline, road_id, length_m=1000, width_m=300):
    selected = _extract_centreline(centreline, road_id)
    chainage_base = centreline.append_geometry(
        _build_chainage_base_table(selected, length_m)
    )
    return build_label_layer(chainage_base, width_m)


def build_label_layer(df, width_m=300):
    df = df.copy()
    for ii, row in df.iterrows():
        df.loc[ii, "wkt"] = _generate_perpendicular_geometry(
            loads(row["wkt"]), row["direction"], width_m
        )
    return df


def _generate_perpendicular_geometry(linestring, direction, width_m):
    pt1, pt2 = (np.array(pp) for pp in zip(*linestring.xy))
    m = -1 / ((pt2[1] - pt1[1]) / (pt2[0] - pt1[0]))
    theta = np.arctan(m)
    dx = width_m * np.cos(theta)
    dy = width_m * np.sin(theta)
    if direction == "D":
        pt3 = pt1 - np.array([dx, dy])
    else:
        pt3 = pt1 + np.array([dx, dy])
    return LineString([pt1, pt3])


def _build_chainage_base_table(selected, length_m):
    groupby = ["road_id", "sh_state_hway", "sh_ref_station_no", "sh_direction"]
    df = pd.DataFrame()
    for (road_id_, sh, rs, direction), gg in selected.groupby(groupby):
        df_ = pd.DataFrame(
            {
                "road_id": road_id_,
                "start_m": np.arange(*_carrway_start_end_m(gg), length_m),
            }
        )
        df_["end_m"] = df_["start_m"] + 1
        df_["direction"] = direction
        df_["label"] = _generate_rsrp_labels(
            sh, rs, direction, df_["start_m"].to_list()
        )
        df = pd.concat([df, df_], ignore_index=True)
    return df


def _generate_rsrp_labels(sh, rs, direction, rps):
    rs = float(rs)
    rps = [float(rp) for rp in rps]
    return [f"{sh}-{rs:04.0f}/{rp/1000:05.2f}-{direction}" for rp in rps]


def _carrway_start_end_m(df):
    return df.carrway_start_m.min(), df.carrway_end_m.max()


def _extract_centreline(centreline_obj, road_id):
    if isinstance(road_id, int):
        road_id = [road_id]
    return centreline_obj._df_features.loc[
        centreline_obj._df_features["road_id"].isin(road_id)
    ]
