"""Abstract base class for a surface water network for MODFLOW."""

import pickle

import geopandas
import numpy as np
import pandas as pd
from shapely import wkt
from shapely.geometry import LineString, Point, Polygon, box
from shapely.ops import linemerge

from swn.compat import ignore_shapely_warnings_for_object_array
from swn.core import SurfaceWaterNetwork
from swn.modflow._misc import (
    tile_series_as_frame, transform_data_to_series_or_frame
)
from swn.spatial import compare_crs, get_sindex, visible_wkt


class SwnModflowBase:
    """Abstract class for a surface water network adaptor for MODFLOW.

    Attributes
    ----------
    swn : swn.SurfaceWaterNetwork
        Instance of a SurfaceWaterNetwork.
    model : flopy.modflow.Modflow or flopy.mf6.ModflowGwf
        Reference to flopy model object.
    reaches : geopandas.GeoDataFrame
        Spatial data generated by this class.
    segments : geopandas.GeoDataFrame
        Copied from swn.segments, but with additional columns added.
    diversions : geopandas.GeoDataFrame, pd.DataFrame or None
        Copied from swn.diversions, if set/defined.
    logger : logging.Logger
        Logger to show messages.

    """

    def __init__(self, logger=None):
        """Initialise SwnModflow.

        Parameters
        ----------
        logger : logging.Logger, optional
            Logger to show messages.
        """
        from importlib.util import find_spec
        if not find_spec("flopy"):
            raise ImportError(f"{self.__class__.__name__} requires flopy")
        from swn.logger import get_logger, logging
        if logger is None:
            self.logger = get_logger(self.__class__.__name__)
        elif isinstance(logger, logging.Logger):
            self.logger = logger
        else:
            raise ValueError(
                f"expected 'logger' to be Logger; found {type(logger)!r}")
        self.logger.info("creating new %s object", self.__class__.__name__)
        self.segments = None
        self.diversions = None
        self.reaches = None

    def __iter__(self):
        """Return object datasets with an iterator."""
        yield "class", self.__class__.__name__
        yield "swn", self.swn
        yield "model", self.model
        yield "segments", self.segments
        yield "diversions", self.diversions
        yield "reaches", self.reaches

    def __setstate__(self, state):
        """Set object attributes from pickle loads."""
        self.__init__()
        if not isinstance(state, dict):
            raise ValueError(f"expected 'dict'; found {type(state)!r}")
        state_class = state.get("class")
        if state_class != self.__class__.__name__:
            raise ValueError("expected state class {!r}; found {!r}"
                             .format(self.__class__.__name__, state_class))
        # Note: swn and model must be set outside of this method
        self.segments = state.pop("segments")
        self.diversions = state.pop("diversions")
        self.reaches = state.pop("reaches")

    def __getstate__(self):
        """Serialize object attributes for pickle dumps."""
        obj = {}
        for k, v in self:
            if k in ("swn", "model"):
                continue
            obj[k] = v
        return obj

    def to_pickle(self, path, protocol=4):
        """Pickle (serialize) non-flopy object data to file.

        Parameters
        ----------
        path : str
            File path where the pickled object will be stored.
        protocol : int
            Default is 4, added in Python 3.4.

        """
        with open(path, "wb") as f:
            pickle.dump(self, f, protocol=protocol)

    @classmethod
    def from_pickle(cls, path, swn=None, model=None):
        """Read a pickled format from a file.

        Parameters
        ----------
        path : str
            File path where the pickled object will be stored.
        swn : swn.SurfaceWaterNetwork, optional
            Instance of a SurfaceWaterNetwork.
        model : flopy.modflow.Modflow or flopy.mf6.ModflowGwf, optional
            Instance of a flopy MODFLOW model.

        """
        with open(path, "rb") as f:
            obj = pickle.load(f)
        if swn is not None:
            obj.swn = swn
        if model is not None:
            obj.model = model
        return obj

    @property
    def swn(self):
        """Surface water network object.

        This propery can only be set once.

        See Also
        --------
        from_pickle : Read object from file.
        """
        try:
            return getattr(self, "_swn", None)
        except AttributeError:
            self.logger.error("swn property not set")

    @swn.setter
    def swn(self, swn):
        if not isinstance(swn, SurfaceWaterNetwork):
            raise TypeError(
                "swn property must be an instance of SurfaceWaterNetwork; "
                f"found {type(swn)!r}")
        elif getattr(self, "_swn", None) is None:
            self._swn = swn
        else:
            raise AttributeError("swn property can only be set once")

    @property
    def model(self):
        """Flopy model object.

        This propery can be set more than once, but time and most grid
        properties must match. Setting this method also generates
        ``time_index`` and ``grid_cells`` attributes from the model.

        See Also
        --------
        from_pickle : Read object from file.
        """
        try:
            return getattr(self, "_model", None)
        except AttributeError:
            self.logger.error("model property not set")

    @model.setter
    def model(self, model):
        import flopy

        if model is None:
            self.logger.info("unsetting model properry")
            self._model = None
            return

        this_class = self.__class__.__name__
        if this_class == "SwnModflow":
            if not isinstance(model, flopy.modflow.Modflow):
                raise ValueError(
                    "model must be a flopy Modflow object; "
                    f"found {type(model)!r}")
            elif not model.has_package("DIS"):
                raise ValueError("DIS package required")
            elif not model.has_package("BAS6"):
                raise ValueError("BAS6 package required")
        elif this_class == "SwnMf6":
            if not (isinstance(model, flopy.mf6.mfmodel.MFModel)):
                raise ValueError(
                    "model must be a flopy.mf6.MFModel object; found " +
                    str(type(model)))
            sim = model.simulation
            if "tdis" not in sim.package_key_dict.keys():
                raise ValueError("TDIS package required")
            if "dis" not in model.package_type_dict.keys():
                raise ValueError("DIS package required")
            if not model.dis.idomain.has_data():
                raise ValueError("DIS idomain has no data")
        else:
            raise NotImplementedError(this_class)
        _model = getattr(self, "_model", None)
        if _model is model:
            self.logger.info("model is same object, checking other metadata")
        elif _model is not None:
            self.logger.info("swapping model object, checking other metadata")
        self._model = model

        # Use a model cache to determine if the rest needs to be evaulated
        dis = model.dis
        modelgrid = model.modelgrid
        modeltime = model.modeltime
        if this_class == "SwnModflow":
            domain_label = "ibound"
            domain = model.bas6.ibound[0].array.copy()
            perlen = pd.Series(model.dis.perlen.array)
        elif this_class == "SwnMf6":
            domain_label = "idomain"
            domain = dis.idomain.array[0].copy()
            nper = sim.tdis.nper.data
            perlen = pd.Series(sim.tdis.perioddata.array.perlen)
            if len(perlen) == 1 and nper > 1:
                perlen = perlen.repeat(nper).reset_index(drop=True)
        else:
            raise NotImplementedError(this_class)
        modelcache = {
            "perlen": np.array(perlen).tobytes(),
            "time_units": str(modeltime.time_units),
            "domain": domain.tobytes(),
            "modelgrid": str(modelgrid),
            "delr": modelgrid.delr.tobytes(),
            "delc": modelgrid.delc.tobytes(),
        }
        prev_modelcache = getattr(self, "_modelcache", None)
        if prev_modelcache is not None:
            is_same = True
            for key in modelcache.keys():
                if prev_modelcache[key] != modelcache[key]:
                    is_same = False
                    self.logger.debug("model setter: %s is different", key)
            if is_same:
                self.logger.info(
                    "model properties are the same, no update required")
                return
            raise AttributeError(
                "model spatial and/or temporal properties are too different")

        # Build stress period DataFrame from modflow model
        stress_df = pd.DataFrame({"perlen": perlen})
        stress_df["duration"] = pd.TimedeltaIndex(perlen, modeltime.time_units)
        end_time = stress_df["duration"].cumsum()
        stress_df["start_time"] = end_time - stress_df["duration"]
        stress_df["end_time"] = end_time
        model_start_date = pd.to_datetime(modeltime.start_datetime)
        stress_df["start_date"] = model_start_date + stress_df["start_time"]
        stress_df["end_date"] = model_start_date + end_time
        self._stress_df = stress_df  # keep this for debugging
        self.time_index = pd.DatetimeIndex(
            stress_df["start_date"], freq="infer")
        self.time_index.name = None

        # Determine which CRS to use
        self_crs = getattr(self, "crs", None)
        modelgrid_crs = None
        epsg = modelgrid.epsg
        proj4_str = modelgrid.proj4
        if epsg is not None:
            segments_crs, modelgrid_crs, same = compare_crs(self_crs, epsg)
        else:
            segments_crs, modelgrid_crs, same = compare_crs(self_crs,
                                                            proj4_str)
        if (segments_crs is not None and modelgrid_crs is not None and
                not same):
            self.logger.warning(
                "CRS for modelgrid is different: %s vs. %s",
                segments_crs, modelgrid_crs)
        crs = segments_crs or modelgrid_crs
        if getattr(self, "segments", None) is not None:
            # Make sure their extents overlap
            minx, maxx, miny, maxy = modelgrid.extent
            model_bbox = box(minx, miny, maxx, maxy)
            rstats = self.segments.bounds.describe()
            segments_bbox = box(
                    rstats.loc["min", "minx"], rstats.loc["min", "miny"],
                    rstats.loc["max", "maxx"], rstats.loc["max", "maxy"])
            if model_bbox.disjoint(segments_bbox):
                raise ValueError(
                    "modelgrid extent does not cover segments extent")
        # More careful check of overlap of lines with grid polygons
        self.logger.debug("building model grid cell geometries")
        nrow, ncol = domain.shape
        cols, rows = np.meshgrid(np.arange(ncol), np.arange(nrow))
        grid_df = pd.DataFrame({"i": rows.flatten(), "j": cols.flatten()})
        grid_df.set_index(["i", "j"], inplace=True)
        grid_df[domain_label] = domain.flatten()
        # Note: modelgrid.get_cell_vertices(i, j) is slow!
        xv = modelgrid.xvertices
        yv = modelgrid.yvertices
        i, j = (np.array(s[1])
                for s in grid_df.reset_index()[["i", "j"]].iteritems())
        cell_verts = zip(
            zip(xv[i, j], yv[i, j]),
            zip(xv[i, j + 1], yv[i, j + 1]),
            zip(xv[i + 1, j + 1], yv[i + 1, j + 1]),
            zip(xv[i + 1, j], yv[i + 1, j])
        )
        # Add dataframe of model grid cells to object
        self.grid_cells = geopandas.GeoDataFrame(
            grid_df, geometry=[Polygon(r) for r in cell_verts], crs=crs)

        # Keep the for next time
        self._modelcache = modelcache

    @classmethod
    def from_swn_flopy(
            cls, swn, model, domain_action="freeze",
            reach_include_fraction=0.2):
        """Create a MODFLOW structure from a surface water network.

        Parameters
        ----------
        swn : swn.SurfaceWaterNetwork
            Instance of a SurfaceWaterNetwork.
        model : flopy.modflow.Modflow or flopy.mf6.ModflowGwf
            Instance of a flopy MODFLOW groundwater flow model.
        domain_action : str, optional
            Action to handle IBOUND or IDOMAIN:
                - ``freeze`` : Freeze domain, but clip streams to fit bounds.
                - ``modify`` : Modify domain to fit streams, where possible.
        reach_include_fraction : float or pandas.Series, optional
            Fraction of cell size used as a threshold distance to determine if
            reaches outside the active grid should be included to a cell.
            Based on the furthest distance of the line and cell geometries.
            Default 0.2 (e.g. for a 100 m grid cell, this is 20 m).

        Returns
        -------
        obj
        """
        this_class = cls.__name__
        if this_class == "SwnModflow":
            domain_label = "ibound"
        elif this_class == "SwnMf6":
            domain_label = "idomain"
        else:
            raise TypeError(f"unsupported subclass {cls!r}")
        if not isinstance(swn, SurfaceWaterNetwork):
            raise ValueError("swn must be a SurfaceWaterNetwork object")
        elif domain_action not in ("freeze", "modify"):
            raise ValueError("domain_action must be one of freeze or modify")
        obj = cls()

        # Assume CRS from swn.segments
        obj.crs = getattr(swn.segments.geometry, "crs", None)
        # Attach a few things to the fresh object
        obj.segments = swn.segments.copy()
        obj.model = model
        obj._swn = swn
        # Copy grid_cells generated from 'model' setter
        dis = model.dis
        grid_cells = obj.grid_cells.copy()
        if domain_action == "freeze":
            sel = grid_cells[domain_label] != 0
            if sel.any():
                # Remove any inactive grid cells from analysis
                grid_cells = grid_cells.loc[sel]
        num_domain_modified = 0
        if this_class == "SwnModflow":
            domain_label = "ibound"
            domain = model.bas6.ibound[0].array.copy()
        elif this_class == "SwnMf6":
            domain_label = "idomain"
            domain = dis.idomain.array[0].copy()
        else:
            raise TypeError(f"unsupported subclass {cls!r}")

        # Determine grid cell size
        col_size = np.median(dis.delr.array)
        if dis.delr.array.min() != dis.delr.array.max():
            obj.logger.warning(
                "assuming constant column spacing %s", col_size)
        row_size = np.median(dis.delc.array)
        if dis.delc.array.min() != dis.delc.array.max():
            obj.logger.warning(
                "assuming constant row spacing %s", row_size)
        cell_size = (row_size + col_size) / 2.0

        # Break up source segments according to the model grid definition
        obj.logger.debug("evaluating reach data on model grid")
        grid_sindex = get_sindex(grid_cells)
        reach_include = swn.segments_series(reach_include_fraction) * cell_size
        # Make an empty DataFrame for reaches
        obj.reaches = pd.DataFrame(columns=["geometry"])
        obj.reaches.insert(1, column="i", value=pd.Series(dtype=int))
        obj.reaches.insert(2, column="j", value=pd.Series(dtype=int))
        empty_reach_df = obj.reaches.copy()  # take this before more added
        obj.reaches.insert(
            1, column="segnum",
            value=pd.Series(dtype=obj.segments.index.dtype))
        obj.reaches.insert(2, column="segndist", value=pd.Series(dtype=float))
        empty_reach_df.insert(3, column="length", value=pd.Series(dtype=float))
        empty_reach_df.insert(4, column="moved", value=pd.Series(dtype=bool))

        # recursive helper function
        def append_reach_df(df, i, j, reach_geom, moved=False):
            if reach_geom.geom_type == "LineString":
                reach_d = {
                    "geometry": reach_geom,
                    "i": i,
                    "j": j,
                    "length": reach_geom.length,
                    "moved": moved,
                }
                with ignore_shapely_warnings_for_object_array():
                    df.loc[len(df.index)] = reach_d
            elif reach_geom.geom_type.startswith("Multi"):
                for sub_reach_geom in reach_geom.geoms:  # recurse
                    append_reach_df(df, i, j, sub_reach_geom, moved)
            else:
                raise NotImplementedError(reach_geom.geom_type)

        # helper function that returns early, if necessary
        def assign_short_reach(reach_df, idx, segnum):
            reach = reach_df.loc[idx]
            reach_geom = reach["geometry"]
            threshold = reach_include[segnum]
            if reach_geom.length > threshold:
                return
            cell_lengths = reach_df.groupby(["i", "j"])["length"].sum()
            this_ij = reach["i"], reach["j"]
            this_cell_length = cell_lengths[this_ij]
            if this_cell_length > threshold:
                return
            grid_geom = grid_cells.at[this_ij, "geometry"]
            # determine if it is crossing the grid once or twice
            grid_points = reach_geom.intersection(grid_geom.exterior)
            split_short = (
                grid_points.geom_type == "Point" or
                (grid_points.geom_type == "MultiPoint" and
                 len(grid_points.geoms) == 2))
            if not split_short:
                return
            matches = []
            # sequence scan on reach_df
            for item in reach_df.itertuples():
                if item[0] == idx or item.moved:
                    continue
                other_cell_length = cell_lengths[item.i, item.j]
                if (item.geometry.distance(reach_geom) < 1e-6 and
                        this_cell_length < other_cell_length):
                    matches.append((item[0], item.geometry))
            if len(matches) == 0:
                # don't merge, e.g. reach does not connect to adjacent cell
                pass
            elif len(matches) == 1:
                # short segment is in one other cell only
                # update new i and j values, keep geometry as it is
                ij1 = tuple(reach_df.loc[matches[0][0], ["i", "j"]])
                reach_df.loc[idx, ["i", "j", "moved"]] = ij1 + (True,)
                # self.logger.debug(
                #    "moved short segment of %s from %s to %s",
                #    segnum, this_ij, ij1)
            elif len(matches) == 2:
                assert grid_points.geom_type == "MultiPoint", grid_points.wkt
                if len(grid_points.geoms) != 2:
                    obj.logger.critical(
                        "expected 2 points, found %s", len(grid_points.geoms))
                # Build a tiny DataFrame of coordinates for this reach
                pts = [Point(c) for c in reach_geom.coords[:]]
                with ignore_shapely_warnings_for_object_array():
                    reach_c = pd.DataFrame({"pt": pts}, dtype=object)
                if len(reach_c) == 2:
                    # If this is a simple line with two coords, split it
                    reach_c.index = [0, 2]
                    ipt = reach_geom.interpolate(0.5, normalized=True)
                    with ignore_shapely_warnings_for_object_array():
                        reach_c.loc[1] = pd.Series({"pt": ipt}, dtype=object)
                    reach_c.sort_index(inplace=True)
                    reach_geom = LineString(list(reach_c["pt"]))  # rebuild
                # first match assumed to be touching the start of the line
                if reach_c.at[0, "pt"].distance(matches[1][1]) < 1e-6:
                    matches.reverse()
                reach_c["d1"] = reach_c["pt"].apply(
                                lambda p: p.distance(matches[0][1]))
                reach_c["d2"] = reach_c["pt"].apply(
                                lambda p: p.distance(matches[1][1]))
                reach_c["dm"] = reach_c[["d1", "d2"]].min(1)
                # try a simple split where distances switch
                ds = reach_c["d1"] < reach_c["d2"]
                cidx = ds[ds].index[-1]
                # ensure it's not the index of either end
                if cidx == 0:
                    cidx = 1
                elif cidx == len(reach_c) - 1:
                    cidx = len(reach_c) - 2
                i1, j1 = list(reach_df.loc[matches[0][0], ["i", "j"]])
                reach_geom1 = LineString(reach_geom.coords[:(cidx + 1)])
                i2, j2 = list(reach_df.loc[matches[1][0], ["i", "j"]])
                reach_geom2 = LineString(reach_geom.coords[cidx:])
                # update the first, append the second
                reach_df.loc[idx, ["i", "j", "length", "moved"]] = \
                    (i1, j1, reach_geom1.length, True)
                reach_df.at[idx, "geometry"] = reach_geom1
                append_reach_df(reach_df, i2, j2, reach_geom2, moved=True)
                # self.logger.debug(
                #   "split and moved short segment of %s from %s to %s and %s",
                #   segnum, this_ij, (i1, j1), (i2, j2))
            else:
                obj.logger.critical(
                    "unhandled assign_short_reach case with %d matches: %s\n"
                    "%s\n%s", len(matches), matches, reach, grid_points.wkt)

        def assign_remaining_reach(reach_df, segnum, rem):
            if rem.geom_type == "LineString":
                threshold = cell_size * 2.0
                if rem.length > threshold:
                    obj.logger.debug(
                        "remaining line segment from %s too long to merge "
                        "(%.1f > %.1f)", segnum, rem.length, threshold)
                    return
                # search full grid for other cells that could match
                if grid_sindex:
                    bbox_match = sorted(grid_sindex.intersection(rem.bounds))
                    sub = grid_cells.geometry.iloc[bbox_match]
                else:  # slow scan of all cells
                    sub = grid_cells.geometry
                assert len(sub) > 0, len(sub)
                matches = []
                for (i, j), grid_geom in sub.iteritems():
                    if grid_geom.touches(rem):
                        matches.append((i, j, grid_geom))
                if len(matches) == 0:
                    return
                threshold = reach_include[segnum]
                # Build a tiny DataFrame for just the remaining coordinates
                pts = [Point(c) for c in rem.coords[:]]
                with ignore_shapely_warnings_for_object_array():
                    rem_c = pd.DataFrame({"pt": pts}, dtype=object)
                if len(matches) == 1:  # merge it with adjacent cell
                    i, j, grid_geom = matches[0]
                    mdist = rem_c["pt"].apply(
                                    lambda p: grid_geom.distance(p)).max()
                    if mdist > threshold:
                        obj.logger.debug(
                            "remaining line segment from %s too far away to "
                            "merge (%.1f > %.1f)", segnum, mdist, threshold)
                        return
                    append_reach_df(reach_df, i, j, rem, moved=True)
                elif len(matches) == 2:  # complex: need to split it
                    if len(rem_c) == 2:
                        # If this is a simple line with two coords, split it
                        rem_c.index = [0, 2]
                        rem_c.loc[1] = pd.Series({
                            "pt": rem.interpolate(0.5, normalized=True)})
                        rem_c.sort_index(inplace=True)
                        rem = LineString(list(rem_c["pt"]))  # rebuild
                    # first match assumed to be touching the start of the line
                    if rem_c.at[0, "pt"].touches(matches[1][2]):
                        matches.reverse()
                    rem_c["d1"] = rem_c["pt"].apply(
                                    lambda p: p.distance(matches[0][2]))
                    rem_c["d2"] = rem_c["pt"].apply(
                                    lambda p: p.distance(matches[1][2]))
                    rem_c["dm"] = rem_c[["d1", "d2"]].min(1)
                    mdist = rem_c["dm"].max()
                    if mdist > threshold:
                        obj.logger.debug(
                            "remaining line segment from %s too far away to "
                            "merge (%.1f > %.1f)", segnum, mdist, threshold)
                        return
                    # try a simple split where distances switch
                    ds = rem_c["d1"] < rem_c["d2"]
                    cidx = ds[ds].index[-1]
                    # ensure it's not the index of either end
                    if cidx == 0:
                        cidx = 1
                    elif cidx == len(rem_c) - 1:
                        cidx = len(rem_c) - 2
                    i1, j1 = matches[0][0:2]
                    rem1 = LineString(rem.coords[:(cidx + 1)])
                    append_reach_df(reach_df, i1, j1, rem1, moved=True)
                    i2, j2 = matches[1][0:2]
                    rem2 = LineString(rem.coords[cidx:])
                    append_reach_df(reach_df, i2, j2, rem2, moved=True)
                else:
                    obj.logger.critical(
                        "how does this happen? Segments from %d touching %d "
                        "grid cells", segnum, len(matches))
            elif rem.geom_type.startswith("Multi"):
                for sub_rem_geom in rem.geoms:  # recurse
                    assign_remaining_reach(reach_df, segnum, sub_rem_geom)
            else:
                raise NotImplementedError(rem.geom_type)

        def do_linemerge(ij, df, drop_reach_ids):
            geom = linemerge(df["geometry"])
            if geom.geom_type == "MultiLineString":
                # workaround for odd floating point issue
                geom = linemerge(
                    [visible_wkt(g) for g in df["geometry"]])
            if geom.geom_type == "LineString":
                drop_reach_ids += list(df.index)
                obj.logger.debug(
                    "merging %d reaches for segnum %s at %s",
                    len(df), segnum, ij)
                i, j = ij
                append_reach_df(reach_df, i, j, geom)
            elif geom.geom_type == "MultiLineString":
                for part in geom.geoms:
                    part_covers = df.geometry.apply(part.covers)
                    if part_covers.sum() > 1:  # recurse
                        do_linemerge(ij, df[part_covers], drop_reach_ids)
                    elif part_covers.sum() == 0:
                        obj.logger.warning(
                            "part %s does not cover any segnum %s at %s",
                            part, segnum, ij)
            else:
                obj.logger.warning(
                    "failed to merge segnum %s at %s: %s", segnum, ij, geom)

        # Looping over each segment breaking down into reaches
        for segnum, line in obj.segments.geometry.iteritems():
            remaining_line = line
            if grid_sindex:
                bbox_match = sorted(grid_sindex.intersection(line.bounds))
                if not bbox_match:
                    continue
                sub = grid_cells.geometry.iloc[bbox_match]
            else:  # slow scan of all cells
                sub = grid_cells.geometry
            # Find all intersections between segment and grid cells
            reach_df = empty_reach_df.copy()
            for (i, j), grid_geom in sub.iteritems():
                reach_geom = grid_geom.intersection(line)
                if reach_geom.is_empty or reach_geom.geom_type == "Point":
                    continue
                # erase some odd floating point issues
                reach_geom = visible_wkt(reach_geom)
                remaining_line = remaining_line.difference(grid_geom)
                append_reach_df(reach_df, i, j, reach_geom)
            # Determine if any remaining portions of the line can be used
            if line is not remaining_line and remaining_line.length > 0:
                assign_remaining_reach(reach_df, segnum, remaining_line)
            # Reassign short reaches to two or more adjacent grid cells
            # starting with the shortest reach
            reach_lengths = reach_df["length"].loc[
                reach_df["length"] < reach_include[segnum]]
            for idx in list(reach_lengths.sort_values().index):
                assign_short_reach(reach_df, idx, segnum)
            # Potentially merge a few reaches for each i,j of this segnum
            drop_reach_ids = []
            for ij, gb in reach_df.copy().groupby(["i", "j"]):
                if len(gb) > 1:
                    gb["geometry"] = gb["geometry"].apply(visible_wkt)
                    do_linemerge(ij, gb, drop_reach_ids)
            if drop_reach_ids:
                reach_df.drop(drop_reach_ids, axis=0, inplace=True)
            # TODO: Some reaches match multiple cells if they share a border
            # Add all reaches for this segment
            for reach in reach_df.itertuples():
                i = reach.i
                j = reach.j
                reach_geom = reach.geometry
                if line.has_z:
                    # intersection(line) does not preserve Z coords,
                    # but line.interpolate(d) works as expected
                    reach_geom = LineString(line.interpolate(
                        line.project(Point(c))) for c in reach_geom.coords)
                # Get a point from the middle of the reach_geom
                reach_mid_pt = reach_geom.interpolate(0.5, normalized=True)
                reach_record = {
                    "geometry": reach_geom,
                    "segnum": segnum,
                    "segndist": line.project(reach_mid_pt, normalized=True),
                    "i": i,
                    "j": j,
                }
                with ignore_shapely_warnings_for_object_array():
                    obj.reaches.loc[len(obj.reaches.index)] = reach_record
                if domain_action == "modify" and domain[i, j] == 0:
                    num_domain_modified += 1
                    domain[i, j] = 1

        if domain_action == "modify":
            if num_domain_modified:
                obj.logger.debug(
                    "updating %d cells from %s array for top layer",
                    num_domain_modified, domain_label.upper())
                if domain_label == "ibound":
                    obj.model.bas6.ibound[0] = domain
                elif domain_label == "idomain":
                    obj.model.dis.idomain.set_data(domain, layer=0)
                obj.reaches = obj.reaches.merge(
                    grid_cells[[domain_label]],
                    left_on=["i", "j"], right_index=True)
                obj.reaches.rename(
                    columns={domain_label: f"prev_{domain_label}"},
                    inplace=True)
            else:
                obj.reaches[f"prev_{domain_label}"] = 1

        # Mark segments that are not used
        obj.segments["in_model"] = True
        outside_model = \
            set(swn.segments.index).difference(obj.reaches["segnum"])
        obj.segments.loc[list(outside_model), "in_model"] = False

        # Evaluate inflow segments that potentially receive flow from outside
        segnums_outside = set(obj.segments[~obj.segments["in_model"]].index)
        if segnums_outside:
            obj.logger.debug(
                "evaluating inflow connections from outside network")
            obj.segments["inflow_segnums"] = obj.segments.from_segnums.apply(
                lambda x: x.intersection(segnums_outside))

        # Consider diversions or SW takes, add more reaches
        has_diversions = swn.diversions is not None
        if has_diversions:
            obj.diversions = swn.diversions.copy()
            obj.reaches["diversion"] = False
            obj.reaches["divid"] = obj.diversions.index.dtype.type()
            # Mark diversions that are not used / outside model
            obj.diversions["in_model"] = True
            outside_model = []
            segnum_s = set(obj.reaches.segnum)
            for divid, from_segnum in obj.diversions.from_segnum.iteritems():
                if from_segnum not in segnum_s:
                    # segnum does not exist -- segment is outside model
                    outside_model.append(divid)
            if outside_model:
                obj.diversions.loc[list(outside_model), "in_model"] = False
                obj.logger.debug(
                    "added %d diversions, ignoring %d that did not connect to "
                    "existing segments",
                    obj.diversions["in_model"].sum(), len(outside_model))
            else:
                obj.logger.debug(
                    "added all %d diversions", len(obj.diversions))
            if swn.has_z:
                empty_geom = wkt.loads("linestring z empty")
            else:
                empty_geom = wkt.loads("linestring empty")
            diversions_in_model = obj.diversions[obj.diversions.in_model]
            is_spatial = (
                isinstance(obj.diversions, geopandas.GeoDataFrame) and
                "geometry" in obj.diversions.columns and
                (~diversions_in_model.is_empty).all())
            for divn in diversions_in_model.itertuples():
                # Use the last upstream reach as a template for a new reach
                reach_d = dict(obj.reaches.loc[
                    obj.reaches.segnum == divn.from_segnum].iloc[-1])
                reach_d.update({
                    "segnum": swn.END_SEGNUM,
                    "segndist": 0.0,
                    "diversion": True,
                    "divid": divn.Index,
                    "geometry": empty_geom,
                })
                # Assign one reach at grid cell
                if is_spatial:
                    # Find grid cell nearest to diversion
                    grid_cells = obj.grid_cells
                    grid_sindex = get_sindex(grid_cells)
                    if grid_sindex:
                        bbox_match = sorted(
                            grid_sindex.nearest(divn.geometry.bounds))
                        # more than one nearest can exist! just take one...
                        num_found = len(bbox_match)
                        grid_cell = grid_cells.iloc[bbox_match[0]]
                    else:  # slow scan of all cells
                        sel = grid_cells.intersects(divn.geometry)
                        num_found = sel.sum()
                        grid_cell = grid_cells.loc[sel].iloc[0]
                    if num_found > 1:
                        obj.logger.warning(
                            "%d grid cells are nearest to diversion %r, "
                            "but only taking the first %s",
                            num_found, divn.Index, grid_cell)
                    i, j = grid_cell.name
                    reach_d.update({"i": i, "j": j})
                    if not divn.geometry.is_empty:
                        with ignore_shapely_warnings_for_object_array():
                            reach_d["geometry"] = divn.geometry
                with ignore_shapely_warnings_for_object_array():
                    obj.reaches.loc[len(obj.reaches) + 1] = reach_d
        else:
            obj.diversions = None

        # Insert k=0, as it is assumed all reaches are on the top layer
        obj.reaches.insert(
            list(obj.reaches.columns).index("i"), column="k", value=0)

        # Now convert from DataFrame to GeoDataFrame
        obj.reaches = geopandas.GeoDataFrame(
                obj.reaches, geometry="geometry", crs=obj.crs)

        # Add information to reaches from segments
        obj.reaches = obj.reaches.merge(
            obj.segments[["sequence"]], "left",
            left_on="segnum", right_index=True)
        # TODO: how to sequence diversions (divid)?
        obj.reaches.sort_values(["sequence", "segndist"], inplace=True)
        del obj.reaches["sequence"]  # segment sequence not used anymore
        # keep "segndist" for interpolation from segment data

        # Add classic ISEG and IREACH, counting from 1
        obj.reaches["iseg"] = 0
        obj.reaches["ireach"] = 0
        iseg = ireach = 0
        prev_segnum = None
        for idx, segnum in obj.reaches["segnum"].iteritems():
            if has_diversions and obj.reaches.at[idx, "diversion"]:
                # Each diversion gets a new segment/reach
                iseg += 1
                ireach = 0
            elif segnum != prev_segnum:
                # Start of a regular segment/reach
                iseg += 1
                ireach = 0
            ireach += 1
            obj.reaches.at[idx, "iseg"] = iseg
            obj.reaches.at[idx, "ireach"] = ireach
            prev_segnum = segnum

        obj.reaches.reset_index(inplace=True, drop=True)
        obj.reaches.index += 1  # flopy series starts at one

        if not hasattr(obj.reaches.geometry, "geom_type"):
            # workaround needed for reaches.to_file()
            obj.reaches.geometry.geom_type = obj.reaches.geom_type

        # each subclass should do more processing with returned object
        return obj

    def set_reach_data_from_segments(
            self, name, value, value_out=None, method=None, log=False):
        """Set reach data based on segment series (or scalar).

        Parameters
        ----------
        name : str
            Name for reach dataset, added to the reaches data frame.
        value : scalar, list, dict or pandas.Series
            Value to assign to the upstream or top end of each segment.
            See :py:meth:`SurfaceWaterNetwork.pair_segments_frame` for details.
        value_out : None (default), scalar, dict or pandas.Series
            If None (default), the value used for the bottom is determined
            using ``method``. This option is normally specified for outlets.
            See :py:meth:`SurfaceWaterNetwork.pair_segments_frame` for details.
        method : str, default None
            This option determines how ``value_out`` values should be
            determined, if not specified. Choose one of:
              - None (default): automatically determine method. If value is
                float-like or ``log=True``, choose ``continuous`` with
                interpolation (if necessary) along reaches, otherwise
                ``constant`` without any interpolation.
              - ``continuous`` : downstream value is evaluated to be
                the same as the upstream value it connects to. This allows a
                continuous network of values along the networks, such as
                elevation. Values for each reach are linearly interpolated.
              - ``constant`` : ``value`` from each segment is used for all
                reaches. No interpolation is performed along reaches.
              - ``additive`` : downstream value is evaluated to be a fraction
                of tributaries that add to the upstream value it connects to.
                Proportions of values for each tributary are preserved, but
                lengths of segments are ignored. Segments with
                different ``value`` and ``value_out`` use interpolation
                along reaches.
        log : bool, default False
            If True and ``method`` is not "constant", apply a log
            transformation applied to interpolation.
        """
        if not isinstance(name, str):
            raise ValueError("name must be a str type")
        if np.isscalar(value) and value_out is None:
            self.logger.debug(
                "set_reach_data_from_segments: setting scalar %s = %s",
                name, value)
            self.reaches[name] = value
            return
        if method is None:
            if log:
                method = "continuous"
            else:
                value = self._swn.segments_series(value)
                if np.issubdtype(value, np.floating):
                    method = "continuous"
                else:
                    method = "constant"
            self.logger.debug(
                "set_reach_data_from_segments: choosing method=%r", method)
        segdat = self._swn.pair_segments_frame(value, value_out, method=method)
        c1, c2 = segdat.columns
        res = self.reaches[["segnum"]].join(segdat[c1], on="segnum")
        self.reaches[name] = res[c1]
        if method == "constant":
            return
        # Determine which segments need to be interpolated
        interp = segdat[c1] != segdat[c2]
        self.logger.debug(
            "set_reach_data_from_segments: interpolating %d segments",
            interp.sum())
        if log:
            segdat = np.log10(segdat)
        for item in segdat[interp].itertuples():
            sel = self.reaches["segnum"] == item[0]
            # interpolate to mid points of each reach from segment data
            segndist = self.reaches.loc[sel, "segndist"]
            value = (item[2] - item[1]) * segndist + item[1]
            if log:
                value = 10 ** value
            self.reaches.loc[sel, name] = value

    def set_reach_data_from_array(self, name, array):
        """Set reach data from an array that matches the model (nrow, ncol).

        Parameters
        ----------
        name : str
            Name for reach dataset.
        array : array_like
            2D array with dimensions (nrow, ncol).
        """
        if not isinstance(name, str):
            raise ValueError("'name' must be a str type")
        elif not hasattr(array, "ndim"):
            raise ValueError("'array' must be array-like")
        elif array.ndim != 2:
            raise ValueError("'array' must have two dimensions")
        dis = self.model.dis
        if self.__class__.__name__ == "SwnModflow":
            expected_shape = dis.nrow, dis.ncol
        elif self.__class__.__name__ == "SwnMf6":
            expected_shape = dis.nrow.data, dis.ncol.data
        else:
            raise TypeError(
                f"unsupported subclass {self.__class__.__name__!r}")
        if expected_shape != array.shape:
            raise ValueError("'array' must have shape (nrow, ncol)")
        self.reaches.loc[:, name] = array[self.reaches["i"], self.reaches["j"]]

    def set_reach_slope(self, method: str = "auto", min_slope=1./1000):
        """Set slope for reaches.

        This method also adds/updates several attributes for reaches.
        The actual data is stored in "slope" for SwnModflow or
        "rgrd" for SwnMf6 classes.

        Parameters
        ----------
        method: str, default "auto"
            Method used to evaluate reach slope.
            - "auto": automatically determine method.
            - "zcoord_ab": if surface water network has Z information,
              use the start/end elevations to determine elevation drop.
            - "grid_top": evaluate the slope from the top grid of the model.
            - "rch_len": calc dz from slope of top model grid
              calc rgrd as dz/rlen
        min_slope : float or pandas.Series, optional
            Minimum downwards slope imposed on segments. If float, then this is
            a global value, otherwise it is per-segment with a Series.
            Default 1./1000 (or 0.001). Diversions (if present) will use the
            minimum of series.
        """
        has_z = self._swn.has_z
        supported_methods = ["auto", "zcoord_ab", "grid_top", "rch_len"]
        if method not in supported_methods:
            raise ValueError(f"{method} not in {supported_methods}")
        if method == "auto":
            if has_z:
                method = "zcoord_ab"
            else:
                method = "grid_top"
        if method == "zcoord_ab":
            if not has_z:
                raise ValueError(
                    f"method {method} requested, but surface water network "
                    "does not contain Z coordinates")
        if self.__class__.__name__ == "SwnModflow":
            grid_name = "slope"
            lentag = "rchlen"
        elif self.__class__.__name__ == "SwnMf6":
            grid_name = "rgrd"
            lentag = "rlen"
        else:
            raise TypeError(
                f"unsupported subclass {self.__class__.__name__!r}")

        self.logger.debug(
            "setting reaches['%s'] with %s method", grid_name, method)
        rchs = self.reaches
        rchs["min_slope"] = np.nan
        self.set_reach_data_from_segments("min_slope", min_slope)
        # with diversions, these reaches will be NaN, so set to min
        sel = rchs.min_slope.isna()
        if sel.any():
            rchs.loc[sel, "min_slope"] = rchs.min_slope[~sel].min()
        rchs[grid_name] = 0.0
        if method == "zcoord_ab":
            def get_zcoords(g):
                if g.is_empty or not g.has_z:
                    return []
                elif g.geom_type == "LineString":
                    return [c[2] for c in g.coords[:]]
                elif g.geom_type == "Point":
                    return [g.z]
                elif g.geom_type.startswith("Multi"):
                    # recurse and flatten
                    t = [get_zcoords(sg) for sg in g.geoms]
                    return [item for slist in t for item in slist]
                else:
                    return []

            zcoords = rchs.geometry.apply(get_zcoords)
            rchs["zcoord_count"] = zcoords.apply(len)
            sel = rchs["zcoord_count"] > 0
            if not sel.any():
                self.logger.error(
                    "no reaches selected to determine slope, either because "
                    "they are not LineString or are EMPTY")
            rchs.loc[sel, "zcoord_min"] = zcoords[sel].apply(min)
            rchs.loc[sel, "zcoord_avg"] = \
                zcoords[sel].apply(sum) / rchs.loc[sel, "zcoord_count"]
            rchs.loc[sel, "zcoord_max"] = zcoords[sel].apply(max)
            rchs.loc[sel, "zcoord_first"] = zcoords[sel].apply(lambda z: z[0])
            rchs.loc[sel, "zcoord_last"] = zcoords[sel].apply(lambda z: z[-1])
            # Calculate gradient based on first/last coordinate
            rchs.loc[sel, grid_name] = (
                (rchs.loc[sel, "zcoord_first"] -
                 rchs.loc[sel, "zcoord_last"]) /
                rchs.loc[sel, "geometry"].length
            )
        elif method in ("grid_top", "rch_len"):
            # Estimate slope from top and grid spacing
            dis = self.model.dis
            col_size = np.median(dis.delr.array)
            row_size = np.median(dis.delc.array)
            px, py = np.gradient(dis.top.array, col_size, row_size)
            if method == "grid_top":
                grid_slope = np.sqrt(px ** 2 + py ** 2)
                self.set_reach_data_from_array(grid_name, grid_slope)
            elif method == "rch_len":
                grid_dz = np.sqrt((px * col_size) ** 2 + (py * row_size) ** 2)
                self.reaches.loc[:, grid_name] = (
                    grid_dz[self.reaches["i"], self.reaches["j"]] /
                    self.reaches[lentag])
        # Enforce min_slope when less than min_slop or is NaN
        sel = (rchs[grid_name] < rchs["min_slope"]) | rchs[grid_name].isna()
        if sel.any():
            num = sel.sum()
            self.logger.warning(
                "enforcing min_slope for %d reache%s (%.2f%%)",
                num, "" if num == 1 else "s", 100.0 * num / len(sel))
            rchs.loc[sel, grid_name] = rchs.loc[sel, "min_slope"]

    def _get_segments_inflow(self, data):
        """Get inflow data by gathering external flow upstream of the model.

        This method also adds "inflow_segnums" to segments.

        Parameters
        ----------
        data: dict, pandas.Series or pandas.DataFrame
            Time series of flow for segnums either inside or outside model,
            indexed by segnum.

        Returns
        -------
        pandas.DataFrame
        """
        time_index = self.time_index
        data = transform_data_to_series_or_frame(data, float, time_index)
        data_was_series = isinstance(data, pd.Series)
        if data_was_series:
            data = tile_series_as_frame(data, time_index)
        inflow = pd.DataFrame(dtype=float, index=time_index)

        def return_inflow():
            if data_was_series:
                series = inflow.iloc[0]
                series.name = None
                return series
            else:
                return inflow

        if len(data.columns) == 0:
            self.logger.debug("no data used to determine inflow")
            return return_inflow()
        inflow_segnums_series = pd.Series(
            [set() for _ in range(len(self.segments))], dtype=object,
            index=self.segments.index)
        segnum_s = set(self.segments[self.segments.in_model].index)
        for segd in self.segments.loc[self.segments.in_model].itertuples():
            if not segd.from_segnums:
                continue
            outside_segnums = segd.from_segnums.difference(segnum_s)
            if not outside_segnums:
                continue
            inflow_series = pd.Series(0.0, index=time_index)
            inflow_segnums = set()
            for from_segnum in outside_segnums:
                try:
                    inflow_series += data[from_segnum]
                    inflow_segnums.add(from_segnum)
                except KeyError:
                    self.logger.warning(
                        "flow from segment %s not provided by inflow data "
                        "(needed for segnum %s)", from_segnum, segd.Index)
            if inflow_segnums:
                inflow[segd.Index] = inflow_series
                inflow_segnums_series.at[segd.Index] = inflow_segnums
        num_found = len(inflow.columns)
        if num_found > 0:
            self.logger.info(
                "inflow found for %d segnum%s",
                num_found, "" if num_found == 1 else "s")
            self.segments["inflow_segnums"] = inflow_segnums_series
        else:
            self.logger.info("inflow not found for any segnums")
        return return_inflow()

    def plot(self, column=None, cmap="viridis_r", colorbar=False):
        """
        Show map of reaches with inflow segments in royalblue.

        Parameters
        ----------
        column : str, optional
            Column from reaches to use with "cmap"; default None will actually
            select the reaches.index. See also "legend" to help interpret
            values.
        cmap : str
            Matplotlib color map; default "viridis_r",
        colorbar : bool
            Show colorbar for "column"; default False.

        Returns
        -------
        AxesSubplot

        """
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots()
        ax.set_aspect("equal")

        if column is None:
            reaches = self.reaches[~self.reaches.is_empty].reset_index()
            column = reaches.columns[0]
        else:
            reaches = self.reaches[~self.reaches.is_empty]

        reaches[reaches.geom_type == "LineString"].plot(
            column=column, label="reaches", legend=colorbar, ax=ax, cmap=cmap)

        grid_cells = getattr(self, "grid_cells", None)
        if grid_cells is not None:
            domain_label = {
                "SwnModflow": "ibound",
                "SwnMf6": "idomain",
            }[self.__class__.__name__]
            sel = self.grid_cells[domain_label] != 0
            if sel.any():
                self.grid_cells.loc[sel].plot(
                    ax=ax, color="whitesmoke", edgecolor="gainsboro")

        def getpt(g, idx):
            if g.geom_type == "LineString":
                return Point(g.coords[idx])
            elif g.geom_type == "Point":
                return g
            else:
                return Point()

        def lastpt(g):
            return getpt(g, -1)

        def firstpt(g):
            return getpt(g, 0)

        swn = getattr(self, "swn", None)
        if swn is not None:
            reaches_idx = reaches[reaches.segnum.isin(self._swn.outlets)]\
                .groupby(["segnum"]).ireach.idxmax().values
            outlet_pt = reaches.loc[reaches_idx, "geometry"].apply(lastpt)
            outlet_pt.plot(ax=ax, label="outlet", marker="o", color="navy")

        if "inflow_segnums" in self.segments.columns:
            segnums = self.segments.index[
                self.segments.inflow_segnums.map(len) > 0]
            reaches_idx = (
                reaches.segnum.isin(segnums) & (reaches.ireach == 1))
            inflow_pt = reaches.loc[reaches_idx, "geometry"].apply(firstpt)
            inflow_pt.plot(
                ax=ax, label="inflow", marker="P", color="green")

        if self.diversions is not None:
            div_pt = reaches.loc[reaches.diversion, "geometry"].apply(firstpt)
            div_pt.plot(
                ax=ax, label="diversion", marker="D", color="red")

        return ax

    # __________________ SOME ELEVATION METHODS_________________________
    def add_model_topbot_to_reaches(self):
        """
        Get top and bottom elevation of the model cell containing each reach.

        Returns
        -------
        pandas.DataFrame
            with reach cell top and bottom elevations
        """
        dis = self.model.dis
        self.set_reach_data_from_array('top', dis.top.array)
        self.set_reach_data_from_array('bot', dis.botm.array[0])
        return self.reaches[['top', 'bot']]

    def plot_reaches_vs_model(
            self, seg, dem=None, plot_bottom=False, draw_lines=True):
        """Plot map of stream elevations relative to model surfaces.

        The elevation of the MODFLOW model projected streams relative to model
        top and layer 1 bottom.

        Parameters
        ----------
        seg : int or str, default "all"
            Specific segment number to plot (sfr iseg/nseg)
        dem : array_like, default None
            For using as plot background -- assumes same (nrow, ncol)
            dimensions as model layer
        plot_bottom : bool, default False
            Also plot stream bed elevation relative to the bottom of layer 1
        draw_lines: bool, default True
            Draw lines around SFR cells.

        Returns
        -------
        vtop, vbot : ModelPlot objects containing matplotlib fig and axes

        """
        from swn.modflow._modelplot import sfr_plot
        model = self.model  # inherit model from class object
        if self.__class__.__name__ == "SwnModflow":
            ib = model.bas6.ibound.array[0]
            strtoptag = "strtop"
        elif self.__class__.__name__ == "SwnMf6":
            ib = model.dis.idomain.array[0]
            strtoptag = "rtp"

        # Ensure reach elevations are up-to-date
        self.add_model_topbot_to_reaches()  # TODO check required first
        # Plot model top (or dem on background)
        dis = model.dis
        if dem is None:
            dem = np.ma.array(dis.top.array, mask=ib == 0)
        # Build sfr raster array from reaches data
        sfrar = np.ma.zeros(dis.top.array.shape, "f")
        sfrar.mask = np.ones(sfrar.shape)
        if seg == "all":
            segsel = self.reaches.geom_type == "LineString"
        else:
            # TODO multiple segs?
            segsel = self.reaches["segnum"] == seg
        # Reach elevation relative to model top
        self.reaches['tmp_tdif'] = (self.reaches["top"] -
                                    self.reaches[strtoptag])
        # TODO group by ij first?
        sfrar[
            tuple(self.reaches[segsel][["i", "j"]].values.T.tolist())
        ] = self.reaches.loc[segsel, 'tmp_tdif'].tolist()
        # .mask = np.ones(sfrar.shape)
        # Plot reach elevation relative to model top
        if draw_lines:
            lines = self.reaches.loc[segsel, ["geometry", "tmp_tdif"]]
        else:
            lines = None
        vtop = sfr_plot(
            model, sfrar, dem,
            label="str below\ntop (m)",
            lines=lines,
        )
        # If just single segment can plot profile quickly
        if seg != "all":
            self.plot_profile(seg, upstream=True, downstream=True)

        # Same for bottom
        if plot_bottom:
            dembot = np.ma.array(dis.botm.array[0], mask=ib == 0)
            sfrarbot = np.ma.zeros(dis.botm.array[0].shape, "f")
            sfrarbot.mask = np.ones(sfrarbot.shape)
            self.reaches['tmp_bdif'] = (self.reaches[strtoptag] -
                                        self.reaches["bot"])
            sfrarbot[
                tuple(self.reaches.loc[segsel, ["i", "j"]].values.T.tolist())
            ] = self.reaches.loc[segsel, 'tmp_bdif'].tolist()
            # .mask = np.ones(sfrar.shape)
            if draw_lines:
                lines = self.reaches.loc[segsel, ["geometry", "tmp_bdif"]]
            else:
                lines = None
            vbot = sfr_plot(
                model, sfrarbot, dembot,
                label="str above\nbottom (m)",
                lines=lines,
            )
        else:
            vbot = None
        return vtop, vbot

    def plot_profile(self, segnum, upstream=False, downstream=False):
        """Plot stream top profiles vs model grid top and bottom.

        Parameters
        ----------
        segnum : int
            Identifying segment number for plots.
        upstream : bool, default False
            Flag for continuing trace upstream from segnum = `seg`
        downstream : bool, default False
            Flag for continuing trace downstream of segnum = `seg`

        Returns
        -------
        None
        """
        from swn.modflow._modelplot import _profile_plot
        if self.__class__.__name__ == "SwnModflow":
            strtoptag = 'strtop'
            lentag = "rchlen"
        elif self.__class__.__name__ == "SwnMf6":
            strtoptag = 'rtp'
            lentag = "rlen"
        usegs = [segnum]
        dsegs = []
        if upstream:
            usegs = self._swn.query(upstream=segnum)
        if downstream:
            dsegs = self._swn.query(downstream=segnum)
        segs = usegs + dsegs
        if segnum not in segs:
            self.logger.error(
                f"something has changed in the code, {segnum} not in {segs}")
        reaches = self.reaches.loc[self.reaches.segnum.isin(segs)].sort_index()
        reaches['mid_dist'] = reaches[lentag].cumsum() - reaches[lentag] / 2.0
        _profile_plot(reaches, lentag=lentag, x='mid_dist',
                      cols=[strtoptag, 'top', 'bot'])
