from .metadata import main, method, data, methods, comments
from .parser import parse
from .dumper import dump
from .normalizer import normalize
import pandas as pd
import numpy as np
import logging

logger = logging.getLogger(__name__)

def merge_dicts(*dicts):
    res = {}
    for d in dicts:
        res.update(d)
    return res

def sections_to_geotech_set(sections, merge=False, id_col="investigation_point"):
    unique_ids = set([borehole["main"][0][id_col] for borehole in sections])
    assert len(unique_ids) == len(sections), "%s is not unique for each borehole" % id_col

    if merge:
        main = pd.DataFrame([merge_dicts(*borehole["main"]) for borehole in sections])
    else:
        main = pd.concat([
            pd.DataFrame(borehole["main"]).assign(**{id_col: borehole["main"][0][id_col]})
            for borehole in sections
            if "main" in borehole])
    
    data = pd.concat([
        borehole["data"].assign(investigation_point=borehole["main"][0][id_col])
        if "data" in borehole else pd.DataFrame([], columns=["investigation_point"])
        for borehole in sections])
    method = pd.concat([
        pd.DataFrame(borehole["method"]).assign(investigation_point=borehole["main"][0][id_col])
        if "method" in borehole else pd.DataFrame([], columns=["investigation_point"])
        for borehole in sections])
    return {"main": main, "data": data, "method": method}

def geotech_set_to_sections(geotech, id_col="investigation_point"):
    return [{"main": [row.to_dict() for idx, row
                      in geotech["main"][geotech["main"][id_col] == section_id].iterrows()]
                     if "main" in geotech else [],
             "data": geotech["data"][geotech["data"][id_col] == section_id]
                     if "data" in geotech else pd.DataFrame(),
             "method": [method_row for method_idx, method_row
                        in geotech["method"][geotech["method"][id_col] == section_id].iterrows()]
                     if "method" in geotech else pd.DataFrame()
            } for section_id in geotech["main"][id_col].unique()]

_dump_function = dump
_normalize_function = normalize

class SGFData(object):
    def __new__(cls, *arg, **kw):
        self = object.__new__(cls)
        self.model_dict = {}
        self.id_col = "investigation_point"
        if arg or kw:
            self.id_col = kw.pop("id_col", "investigation_point")
            if arg and isinstance(arg[0], dict):
                self.model_dict = arg[0]
            elif arg and isinstance(arg[0], list):
                self.model_dict = sections_to_geotech_set(arg[0], id_col=self.id_col)
            else:
                self.model_dict = sections_to_geotech_set(parse(*arg, **kw), id_col=self.id_col)
        return self

    def dump(self, *arg, **kw):
        _dump_function(self.sections, *arg, **kw)

    def normalize(self):
        sections = self.sections
        for section in sections:
            if "data" in section:
                section["data"] = section["data"].copy()
        _normalize_function(sections)
        return type(self)(sections)
        
    @property
    def sections(self):
        return geotech_set_to_sections(self.model_dict, id_col=self.id_col)
    
    @sections.setter
    def sections(self, sections):
        self.model_dict = sections_to_geotech_set(sections, id_col=self.id_col)
            
    @property
    def main(self):
        return self.model_dict.get("main", None)

    @main.setter
    def main(self, a):
        self.model_dict["main"] = a
    
    @property
    def data(self):
        return self.model_dict.get("data", None)

    @data.setter
    def data(self, a):
        self.model_dict["data"] = a

    @property
    def method(self):
        return self.model_dict.get("method", None)

    @method.setter
    def method(self, a):
        self.model_dict["method"] = a
        
    def __repr__(self):
        res = [
            "Geotechnical data",
            "===================",
            "Soundings: %s" % (len(self.main) if self.main is not None else 0,),
            "Depths: %s" % (len(self.data) if self.data is not None else 0,),
            "===================",
            repr(self.main[["x_coordinate", "y_coordinate"]].describe().loc[["min", "max"]],) if self.main is not None else ""]

        if self.data is not None:
            for col in ("depth", "feed_thrust_force"):
                if col in self.data.columns:
                    res.append(repr(pd.DataFrame(self.data[col].describe())))

        return "\n".join(res)

    def sample_dtm(self, raster, overwrite=True):
        from . import dtm
        self.sections = dtm.sample_z_coordinate_from_dtm(self.sections, self.projection, raster=raster, overwrite=overwrite)

    def sample_terrainy_dtm(self, raster_name, overwrite=True):
        import terrainy
        import geopandas as gpd

        assert self.main

        conn = terrainy.connect(raster_name)

        # Buffer 1m, or we can get problems with boreholes right at the tile boundary...
        tiles = gpd.GeoDataFrame(
            geometry=[polygon for x_idx, y_idx, polygon in conn.get_tile_bounds(self.area.buffer(1), 1)],
            crs=conn.get_crs())

        positions = self.positions.to_crs(tiles.crs)
        positions["tile"] = -1
        for idx, tile in enumerate(tiles.geometry):
            positions.loc[positions.within(tile), "tile"] = idx

        tile_idxs = positions.tile.unique()
        for idx, tile_idx in enumerate(tile_idxs):
            logger.info("Working on tile %s of %s" % (idx, len(tile_idxs)))
            filt = positions.tile == tile_idx
            xy = np.column_stack((positions.geometry.x, positions.geometry.y))
            with conn.open_tile(tiles.loc[tile_idx].geometry.bounds, 1) as dataset:
                positions.loc[filt, "topo"] = [v[0] for v in dataset.sample(xy[filt,:])]
                
        self.main["z_coordinate"] = positions.topo
        
    @property
    def projection(self):
        if "projection" not in self.main.columns:
            return None
        projections = self.main.projection.unique()
        if len(projections) != 1:
            return None
        return int(projections[0])
    
    @property
    def positions(self):
        import geopandas as gpd

        assert self.main
        
        projection = self.projection
        if projection is None: raise ValueError("SGF file has boreholes in multiple projections, or projection not specified.")
        
        positions = gpd.GeoDataFrame(
            geometry=gpd.points_from_xy(self.main.x_coordinate, self.main.y_coordinate),
            index=self.main.index)
        return positions.set_crs(projection)

    @property
    def area(self):
        """Returns the convex hull of all borehole positions"""
        import geopandas as gpd

        positions = self.positions
        return gpd.GeoDataFrame(geometry=[positions.unary_union.convex_hull]).set_crs(self.positions.crs)
    
    @property
    def bounds(self):
        return self.area.bounds
