"""Base class for gridworld environment."""

import copy
import warnings

import networkx as nx
import numpy as np

import neugym as ng
from ._agent import _Agent
from ._object import _Object

__all__ = [
    "GridWorld"
]


class GridWorld:
    r"""Base class for gridworld environment.

    ``Gridworld`` environment consists of a ``world``, ``objects``, and one ``agent``.
    The world consists of multiple connected rectangle areas and
    each area is represented by a two-dimensional gird graph,
    which has each node connected to its four nearest neighbors.
    Each node represents a state in the world, and has an attribute
    ``altitude``. When the agent moves from one state toward another state,
    it will get a reward generated by the altitude change (if there is),
    i.e.

    .. math::

        R_{move} = A_s - A_{s + 1}

    where $R_{move}$ is the movement reward and $A$
    represents the altitude of current state $s$ and next state $s + 1$.

    At each position(state), the agent can choose from 5 ``actions`` to move towards
    **UP**, **DOWN**, **LEFT**, **RIGHT**, and **STAY** in the same state.
    When the performed movement would make the agent get out of the world,
    the agent would be forced to stay in the same state.

    Objects where the agent can get reward are placed at different states
    and each state can only obtain one object. Each object has its own adjustable
    probability (``prob``) of getting a ``reward`` when the agent reaches
    the state with this object, if the agent fails to get a reward,
    it will get a punishment (``punish``).

    .. math::

        R_{object} =
            \left \{
                \begin{aligned}
                    & reward, P=p \\
                    & punish, P=1-p
                \end{aligned}
            \right.

    Under this situation, the total reward for this step will be
    the movement reward adding the object reward.

    .. math::

        R_{total} = R_{move} + R_{object}

    So long as the agent gets to an object
    (no matter it was reward or punish that it got),
    this trial is finished and then the agent will be
    sent back to the start state of each trial.

    Parameters
    ----------
    origin_shape : tuple of ints (optional, default: None)
        Shape of the world origin. If not provided, the origin will be
        initialized to be only one state (0, 0, 0), otherwise it will
        be a rectangular area of shape ``origin_shape``.

    See Also
    --------
    DelayedRewardGridWorld
    TimeLimitedGridWorld


    Examples
    --------
    Initialize a gridworld environment with only an origin state.

    >>> W = GridWorld()

    W can be grown in several aspects.

    **Areas:**

    Add one area of shape (2, 2).

    >>> W.add_area((2, 2))

    Manually specify start and end state of inter-area path and action to register.

    >>> W.add_area((2, 2),
    ...            access_from=(0, 0, 0), access_to=(1, 1),
    ...            register_action=(-1, 0))

    Remove areas.

    >>> W.remove_area(1)

    Set area altitude.

    >>> W.add_area((3, 3))
    >>> W.set_altitude(1, altitude_mat=np.random.randn(3, 3))

    **Paths:**

    Add additional inter-area paths.

    >>> W = GridWorld()
    >>> W.add_area((2, 2))
    >>> W.add_area((2, 2))
    >>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1))

    Remove paths.

    >>> W.remove_path((1, 1, 1), (2, 1, 0))

    **Objects:**

    Add objects.

    >>> W = GridWorld()
    >>> W.add_area((2, 2))
    >>> W.add_object((0, 0, 0), reward=1, prob=0.7)
    >>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10)

    Remove objects.

    >>> W.remove_object((1, 0, 0))

    Update object attributes.

    >>> W.update_object((0, 0, 0), reward=10)
    >>> W.update_object((0, 0, 0), reward=1, prob=0.8)

    **Agent:**

    >>> W.init_agent()

    One can also manually set the agent initial state.

    >>> W.add_area((2, 2))
    >>> W.init_agent(init_coord=(1, 1, 1))

    When the agent is initialized, the agent can move in the
    world and get rewards.

    >>> next_state, reward, done = W.step(action=(0, 1))

    **Reset:**

    To reset the environment, first set a reset checkpoint.

    >>> W.set_reset_checkpoint()

    Then the environment can be reset if needed.

    >>> W.reset()
    """

    def __init__(self, origin_shape=None):
        """Initialize a gridworld environment.

        Parameters
        ----------
        origin_shape : tuple of ints (optional, default: None)
            Shape of the world origin. If not provided, the origin will be
            initialized to be only one state ``(0, 0, 0)``, otherwise it will
            be a rectangular area of shape ``origin_shape``.

        Examples
        --------
        Initialize a gridworld environment by default.

        >>> W = GridWorld()

        Manually set origin shape.

        >>> W = GridWorld((3, 4))
        """
        self._world = nx.Graph()
        self._time = 0
        self._num_area = 0

        # Add origin.
        if origin_shape is None:
            origin_shape = (1, 1)
            self._world.add_node((0, 0, 0))
        else:
            m, n = origin_shape
            origin = nx.grid_2d_graph(m, n)
            mapping = {}
            for coord in origin.nodes:
                mapping[coord] = tuple([0] + list(coord))
            origin = nx.relabel_nodes(origin, mapping)
            self._world.update(origin)
        origin_altitude_mat = np.zeros(origin_shape)
        self.set_altitude(0, origin_altitude_mat)

        self._alias = {}
        self._objects = []
        self._actions = ((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1))

        # Agent.
        self._agent = None

        # Reset state.
        self._has_reset_checkpoint = False
        self._reset_state = {
            "world": None,
            "time": None,
            "num_area": None,
            "alias": None,
            "objects": None,
            "agent": None
        }

    def add_area(self, shape,
                 access_from=None, access_to=None,
                 register_action=None):
        """Add a new area to the world.

        .. note::
            When an inter-area path from ``access_from`` to ``access_to`` with
            action ``register_action`` is built, a reverse path is also build at
            the same time, i.e. the agent can also move from ``access_to`` to
            ``access_from`` with the reverse action of ``register_action`` (e.g.
            the reverse action of **UP(1, 0)** is **DOWN(-1, 0)**).

        Parameters
        ----------
        shape : tuple of ints
            Shape of the new area.

        access_from : tuple of ints (optional, default: None)
            Coordinate of one state in the world that the new area will
            connect to, i.e. the start state of the inter-area path.
            If not provided, it will be set to ``(0, 0, 0)`` by default.

        access_to : tuple of ints (optional, default: None)
            Coordinate of one state in the new area, i.e. the end state
            of the inter-area path. Tuple of length 2 is required and the
            area index will be automatically added. If not provided, it will be set
            to ``(0, 0)`` by default.

        register_action : tuple of ints (optional, default: None)
            Register an action to transport the agent from ``access_from`` to
            ``access_to``.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 3))
        >>> W.add_area((2, 2))
        >>> W.add_area((3, 3), access_from=(1, 1, 2))
        >>> W.add_area((4, 4), access_from=(0, 0, 0), access_to=(3, 3))
        >>> W.add_area((2, 2),
        ...            access_from=(4, 0, 0), access_to=(1, 1),
        ...            register_action=(0, -1))
        """
        if access_from is None:
            access_from = (0, 0, 0)
        if access_to is None:
            access_to = (0, 0)

        if not self._world.has_node(access_from):
            msg = "'access_from' coordinate " \
                  "{} out of world".format(access_from)
            raise ValueError(msg)

        if len(access_to) != 2:
            msg = "Tuple of length 2 expected for " \
                  "argument 'access_to', got {}".format(len(access_to))
            raise ValueError(msg)
        access_to = tuple([self._num_area + 1] + list(access_to))

        # Create checkpoint for rollback.
        world_backup = copy.deepcopy(self._world)

        # Create new area.
        m, n = shape
        new_area = nx.grid_2d_graph(m, n)
        mapping = {}
        for coord in new_area.nodes:
            mapping[coord] = tuple([self._num_area + 1] + list(coord))
        new_area = nx.relabel_nodes(new_area, mapping)

        self._world.update(new_area)
        self._num_area += 1

        # Add inter-area connections and altitude.
        altitude_mat = np.zeros(shape)
        try:
            self.add_path(access_from, access_to, register_action)
            self.set_altitude(self._num_area, altitude_mat)
        except Exception:
            self._world = world_backup
            self._num_area -= 1
            raise

    def remove_area(self, area_idx):
        """Remove an area from the world.

        Index for all other areas left will be automatically reset.
        (Minus one if their original index are larger than the index of removed
        area.)

        .. note::
            - Origin of the world is not allowed to be removed.
            - Removing action that will make the world no longer connected
              will also be prohibited.

        Parameters
        ----------
        area_idx : int
            Index of the area to be removed.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.remove_area(1)
        """
        new_world = copy.deepcopy(self._world)
        if area_idx == 0:
            raise ng.NeuGymPermissionError("Not allowed to remove origin area")

        # Remove area
        node_list = list(new_world.nodes)
        for node in node_list:
            if node[0] == area_idx:
                new_world.remove_node(node)
            elif node[0] > area_idx:
                new_label = tuple([node[0] - 1] + list(node[1:]))
                new_world = nx.relabel_nodes(new_world, {node: new_label})

        if not nx.is_connected(new_world):
            msg = "Not allowed to remove area {}, " \
                  "world would be no longer connected".format(area_idx)
            raise ng.NeuGymConnectivityError(msg)

        self._world = new_world
        self._num_area -= 1

        # Remove invalid alias.
        new_alias = {}
        for key, value in self._alias.items():
            if key[0] == area_idx:
                continue
            elif key[0] > area_idx:
                new_key = tuple([key[0] - 1] + list(key[1:]))
            else:
                new_key = key

            if value[0] == area_idx:
                continue
            elif value[0] > area_idx:
                new_value = tuple([value[0] - 1] + list(value[1:]))
            else:
                new_value = value
            new_alias[new_key] = new_value

        self._alias = new_alias

        # Remove objects in the area to be removed.
        new_objects = []
        for i, obj in enumerate(self._objects):
            if obj.coord[0] < area_idx:
                new_objects.append(obj)
            elif obj.coord[0] == area_idx:
                continue
            else:
                obj.coord = tuple([obj.coord[0] - 1] + list(obj.coord[1:]))
                new_objects.append(obj)
        self._objects = new_objects

    def add_path(self, coord_from, coord_to, register_action=None):
        """Add a new inter-area connection.

        .. note::
            - Creating a path within the same area is not allowed.
            - When an inter-area path from ``coord_from`` to ``coord_to`` with
              action ``register_action`` is built, a reverse path is also build at
              the same time, i.e. the agent can also move from ``coord_to`` to
              ``coord_from`` with the reverse action of ``register_action`` (e.g.
              the reverse action of **UP(1, 0)** is **DOWN(-1, 0)**).

        Parameters
        ----------
        coord_from : tuple of ints
            Coordinate of the path start state.

        coord_to : tuple of ints
            Coordinate of the path end state.

        register_action : tuple of ints (optional, default: None)
            Register an action to transport the agent from ``coord_from`` to
            ``coord_to``

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_area((3, 3))
        >>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1))
        """
        if coord_from[0] == coord_to[0]:
            msg = "Not allowed to add path within an area"
            raise ng.NeuGymPermissionError(msg)

        if len(coord_from) != 3:
            msg = "Tuple of length 3 expected for argument " \
                  "'coord_from', got {}".format(len(coord_from))
            raise ValueError(msg)
        if not self._world.has_node(coord_from):
            msg = "'coord_from' coordinate {} out of world".format(coord_from)
            raise ValueError(msg)
        if self._world.degree(coord_from) == 4:
            msg = "Maximum number of connections (4) for position " \
                  "{} reached, not allowed to access from it".format(coord_from)
            raise ng.NeuGymConnectivityError(msg)

        if len(coord_to) != 3:
            msg = "Tuple of length 3 expected for argument " \
                  "'coord_to', got {}".format(len(coord_to))
            raise ValueError(msg)
        if not self._world.has_node(coord_to):
            msg = "'coord_to' coordinate {} out of world".format(coord_to)
            raise ValueError(msg)
        elif self._world.degree(coord_to) == 4:
            msg = "Maximum number of connections (4) for position " \
                  "{} reached, not allowed to access to it".format(coord_to)
            raise ng.NeuGymConnectivityError(msg)

        if (coord_from, coord_to) in self._world.edges:
            msg = "Path already exists between {} and {}".format(coord_from, coord_to)
            raise ng.NeuGymOverwriteError(msg)

        # Search for free actions that can be registered.
        free_actions = []
        for action in self._actions:
            dx, dy = action
            alias_to = tuple([coord_from[0]] + [coord_from[1] + dx] + [coord_from[2] + dy])
            alias_from = tuple([coord_to[0]] + [coord_to[1] - dx] + [coord_to[2] - dy])
            if self._world.has_node(alias_to) or self._world.has_node(alias_from) or \
                    alias_to in self._alias.keys() or alias_from in self._alias.keys():
                continue
            free_actions.append(action)

        if len(free_actions) == 0:
            msg = "Unable to connect two areas from 'coord_from' {} to 'coord_to' {}, " \
                  "all allowed actions allocated".format(coord_from, coord_to[1:])
            raise ng.NeuGymConnectivityError(msg)

        if register_action is not None:
            if register_action not in self._actions:
                msg = "Illegal 'register_action' {}, " \
                      "expected one of {}".format(register_action, self._actions)
                raise ValueError(msg)
            if register_action not in free_actions:
                msg = "Unable to register action 'register_action' {}, " \
                      "already allocated".format(register_action)
                raise ng.NeuGymConnectivityError(msg)
            dx, dy = register_action
        else:
            dx, dy = free_actions[0]

        # Register action.
        self._alias[tuple([coord_from[0]] +
                          [coord_from[1] + dx] +
                          [coord_from[2] + dy])] = coord_to
        self._alias[tuple([coord_to[0]] +
                          [coord_to[1] - dx] +
                          [coord_to[2] - dy])] = coord_from
        self._world.add_edge(coord_from, coord_to)

    def remove_path(self, coord_from, coord_to):
        """Remove one inter-area connection from the world.

        .. note::
            Removing action that will cause the world no longer connected
            will be prohibited.

        Parameters
        ----------
        coord_from : tuple of ints
            Coordinate of the path start state.

        coord_to : tuple of ints
            Coordinate of the path end state.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_area((3, 3))
        >>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1))
        >>> W.remove_path((1, 1, 1), (2, 1, 0))
        """
        if coord_from[0] == coord_to[0]:
            msg = "Not allowed to remove path within an area"
            raise ng.NeuGymPermissionError(msg)

        if (coord_from, coord_to) in list(nx.bridges(self._world)):
            msg = "Not allowed to remove path ({}, {}), " \
                  "world would be no longer connected".format(coord_from, coord_to)
            raise ng.NeuGymConnectivityError(msg)

        if len(coord_from) != 3 or len(coord_to) != 3:
            msg = "Tuple of length 3 expected for position coordinate"
            raise ValueError(msg)

        # Find alias to be removed.
        remove_list = []
        for action in self._actions:
            dx, dy = action
            alias_to = tuple([coord_from[0]] +
                             [coord_from[1] + dx] +
                             [coord_from[2] + dy])
            alias_from = tuple([coord_to[0]] +
                               [coord_to[1] - dx] +
                               [coord_to[2] - dy])

            if self._alias.get(alias_to) == coord_to and \
                    self._alias.get(alias_from) == coord_from:
                remove_list.append(alias_to)
                remove_list.append(alias_from)

        if len(remove_list) == 0:
            msg = "Inter-area path not found between {} and {}, " \
                  "noting to do".format(coord_from, coord_to)
            warnings.warn(RuntimeWarning(msg))
        else:
            assert len(remove_list) == 2
            for key in remove_list:
                self._alias.pop(key)
            self._world.remove_edge(coord_from, coord_to)

    def add_object(self, coord, reward, prob, punish=0):
        """Add one object to the world.

        Each state can only have one object.

        Parameters
        ----------
        coord : tuple of ints
            Coordinate of the state to place the object.

        reward : int or float
            Reward that the object can generate.

        prob : float
            Probability for the object to generate a reward.

        punish : int or float (optional, default: 0)
            Punishment that the object will generate if failed
            to generate a reward.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_object((0, 0, 0), reward=1, prob=0.7)
        >>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10)
        """
        if coord in self._world.nodes:
            self._objects.append(_Object(reward, punish, prob, coord))
        else:
            msg = "Coordinate {} out of world".format(coord)
            raise ValueError(msg)

    def remove_object(self, coord):
        """Remove one object from the world.

        Parameters
        ----------
        coord : tuple of ints
            Coordinate of the state whose object will be removed.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_object((0, 0, 0), reward=1, prob=0.7)
        >>> W.remove_object((0, 0, 0))
        """
        pop_idx = None
        for i, obj in enumerate(self._objects):
            if coord == obj.coord:
                pop_idx = i
                break
        if pop_idx is not None:
            self._objects.pop(pop_idx)
        else:
            msg = "No object found at {}".format(coord)
            raise ValueError(msg)

    def update_object(self, coord, **attr):
        """Reset object attributes.

        Except the object coordinate ``coord``, all other
        three attributes could be updated (``reward``,
        ``prob``, ``punish``).

        Parameters
        ----------
        coord : tuple of ints
            Coordinate of the state whose object attribute will be updated.

        attr : keyword arguments \
               {'reward': int or float, 'prob': int or float, 'punish': int or float} \
               (optional, default: no attributes)
            Attribute and new value to reset. If not provided, no attribute will be reset.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_object((0, 0, 0), reward=1, prob=0.7)
        >>> W.update_object((0, 0, 0), reward=10, prob=0.8, punish=-1)
        """
        for obj in self._objects:
            if coord == obj.coord:
                for key, value in attr.items():
                    if hasattr(obj, key):
                        setattr(obj, key, value)
                    else:
                        msg = "'Object' object doesn't have attribute " \
                              "'{}', ignored".format(key)
                        warnings.warn(RuntimeWarning(msg))
                return

        msg = "No object found at {}".format(coord)
        raise ValueError(msg)

    def get_object_attribute(self, coord, attr):
        """Get the value of object attribute.

        Parameters
        ----------
        coord : tuple of ints
            Coordinate of the state whose object attribute will be looked for.

        attr : str {'reward', 'prob', 'punish'}
            Object attribute to look for.

        Returns
        -------
        attribute_value : int or float
            Value of object attribute ``attr``.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_object((0, 0, 0), reward=1, prob=0.7)
        >>> W.get_object_attribute((0, 0, 0), 'reward')
        1
        >>> W.get_object_attribute((0, 0, 0), 'prob')
        0.7
        """
        for obj in self._objects:
            if coord == obj.coord:
                if hasattr(obj, attr):
                    return getattr(obj, attr)
                else:
                    msg = "'Object' object doesn't have attribute '{}'".format(attr)
                    raise ValueError(msg)

        msg = "No object found at {}".format(coord)
        raise ValueError(msg)

    def set_altitude(self, area_idx, altitude_mat):
        """Set the altitude of each state for one area.

        Parameters
        ----------
        area_idx : int
            Index of the area to set altitude.

        altitude_mat : numpy.ndarray
            An matrix of the same shape as the area.
            Each element in the matrix corresponds to the altitude of one state
            in the area.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 3))
        >>> mat = np.random.randn(2, 3)
        >>> W.set_altitude(1, altitude_mat=mat)
        """
        if area_idx > self._num_area:
            msg = "Area {} not found".format(area_idx)
            raise ValueError(msg)

        area_shape = self.get_area_shape(area_idx)

        if altitude_mat.shape != area_shape:
            msg = "Mismatch shape between Area({}) {} and " \
                  "altitude matrix {}".format(area_idx,
                                              area_shape,
                                              altitude_mat.shape)
            raise ValueError(msg)

        altitude_mapping = {}

        for x in range(area_shape[0]):
            for y in range(area_shape[1]):
                coord = (area_idx, x, y)
                altitude_mapping[coord] = altitude_mat[x, y]
        nx.set_node_attributes(self._world, altitude_mapping, 'altitude')

    def get_area_shape(self, area_idx):
        """Get the shape of one area.

        Parameters
        ----------
        area_idx : int
            Index of the area to get its shape.

        Returns
        -------
        shape : tuple of ints
            Shape of the area with index ``area_idx``.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((3, 10))
        >>> W.get_area_shape(1)
        (3, 10)
        """
        if area_idx > self._num_area:
            msg = "Area {} not found".format(area_idx)
            raise ValueError(msg)

        max_x = 0
        max_y = 0
        for area, x, y in self._world.nodes:
            if area != area_idx:
                continue
            else:
                if x > max_x:
                    max_x = x
                if y > max_y:
                    max_y = y
        return max_x + 1, max_y + 1

    def get_area_altitude(self, area_idx):
        """Get the altitude of each state in one area.

        Parameters
        ----------
        area_idx : int
            Index of the area to get its state altitude.

        Returns
        -------
        altitude_matrix : numpy.ndarray
            Altitude matrix of the area with index ``area_idx``.
            Each element in the matrix corresponds to the altitude of one state
            in the area.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((3, 5))
        >>> mat = np.ones((3, 5))
        >>> W.set_altitude(1, altitude_mat=mat)
        >>> W.get_area_altitude(1)
        array([[1., 1., 1., 1., 1.],
               [1., 1., 1., 1., 1.],
               [1., 1., 1., 1., 1.]])
        """
        if area_idx > self._num_area:
            msg = "Area {} not found".format(area_idx)
            raise ValueError(msg)

        area_shape = self.get_area_shape(area_idx)

        altitude_mat = np.zeros(area_shape)

        for coord in self._world.nodes:
            if coord[0] != area_idx:
                continue
            else:
                altitude_mat[coord[1], coord[2]] = \
                    nx.get_node_attributes(self._world, 'altitude')[coord]

        return altitude_mat

    def init_agent(self, init_coord=None, overwrite=False):
        """Initialize an agent in the world.

        Parameters
        ----------
        init_coord : tuple of ints (optional, default: None)
            Coordinate of the agent initial state. If not
            provided, the agent will be initialized at ``(0, 0, 0)``
            by default.

        overwrite : bool (optional, default: False)
            Whether to overwrite the existing agent.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.init_agent()
        >>> W.add_area((2, 4))
        >>> W.init_agent((1, 1, 3), overwrite=True)
        """
        if init_coord is None:
            init_coord = (0, 0, 0)

        if not self._world.has_node(init_coord):
            msg = "Initial state coordinate {} out of world".format(init_coord)
            raise ValueError(msg)

        if self._agent is None or overwrite:
            self._agent = _Agent(init_coord)
        else:
            raise ng.NeuGymOverwriteError("Agent already exists, "
                                          "set 'overwrite=True' to overwrite")

    def get_agent_state(self, when="current"):
        """Get state of the agent.

        Parameters
        ----------
        when : str {"current", "init"} (optional, default: "current")
            Choose to get the initial ("init") or current ("current") state of the agent.

        Returns
        -------
        agent_current_state : tuple of ints
            Coordinate of the state where the agent stays.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.init_agent()
        >>> W.step((1, 0))
        ((1, 0, 0), 0.0, False)
        >>> W.get_agent_state()
        (1, 0, 0)
        >>> W.get_agent_state(when="init")
        (0, 0, 0)
        """
        if when == "current":
            return self._agent.current_state
        elif when == "init":
            return self._agent.init_state
        else:
            msg = "Unrecognized parameter '{}', 'current' or 'init' expected".format(when)
            raise ValueError(msg)

    @property
    def world(self):
        """A copy of ``world`` attribute of the gridworld environment.

        ``GridWorld.world`` is a NetworkX Graph object which represents
        here the areas, states and their connections in the gridworld
        environment. Each node in the graph is a state named by its
        global coordinate ``(area_idx, x, y)``, and it has an attribute
        ``altitude`` which represents the altitude of the state.
        Each edge in the graph denotes the connections between two
        states (including inter-area connections).

        .. note::
            More information about NetworkX Graph object can be found at
            `networkx.Graph \
            <https://networkx.org/documentation/stable/reference/classes/graph.html>`_

        Returns
        -------
        world : netwokx.Graph
            World attribute of the gridworld environment represented by
            a NetworkX Graph object.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> G = W.world
        >>> G.nodes
        NodeView(((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)))

        References
        ---------
        .. [#] NetworkX Documentation: https://networkx.org/
        """
        return self._world.copy()

    @property
    def time(self):
        """Gridworld environment time.

        ``time`` attribute of gridworld environment represents the
        number of steps that the agent has moved.

        Returns
        -------
        time : int
            Current time of gridworld environment.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 3))
        >>> W.init_agent()
        >>> W.step((1, 0))
        >>> W.step((0, 0))
        >>> W.time
        2
        """
        return self._time

    @property
    def num_area(self):
        """Number of areas in the ``world`` of gridworld environment.

        .. note::
            Origin is not included when counting the number.

        Returns
        -------
        num_area : int
            Number of areas in the world.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 2))
        >>> W.add_area((3, 3))
        >>> W.num_area
        2
        """
        return self._num_area

    @property
    def actions(self):
        """Action space of the gridworld environment.

        Each action in the action space is represented
        with a tuple ``(dx, dy)``.

        Returns
        -------
        actions : tuple
            Action space of the gridworld environment.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.actions
        ((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1))
        """
        return self._actions

    @property
    def has_reset_checkpoint(self):
        """Whether there is a reset checkpoint for the gridworld environment.

        Returns
        -------
        has_reset_checkpoint : bool
            Whether a reset checkpoint of the environment has
            been created.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.has_reset_checkpoint
        False
        >>> W.set_reset_checkpoint()
        >>> W.has_reset_checkpoint
        True
        """
        return self._has_reset_checkpoint

    def step(self, action):
        """Make the agent move toward direction given by ``action``.

        .. note::
            - If one movement will cause the agent get out of the world,
              the agent will be forced to stay in the same position (state) instead.
            - If the agent reaches a state with an object, no matter whether the agent
              gets a reward or punishment from the object, this trial will end and the
              agent will be transported back to its initial state.

        Parameters
        ----------
        action : tuple of ints \
                 {(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)}
            Direction of the agent movement.

        Returns
        -------
        next_state : tuple of ints
            Next state of the agent after movement.

        reward : int or float
            Reward that the agent gets at through this movement.

        done : bool
            Whether this trial ends.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((2, 3))
        >>> W.init_agent()
        >>> W.step((1, 0))
        ((1, 0, 0), 0.0, False)
        """
        if action not in self._actions:
            msg = "Illegal action {}, should be one of {}".format(action, self._actions)
            raise ValueError(msg)
        else:
            dx, dy = action

        done = False
        reward = 0
        current_state = self._agent.current_state
        next_state = (current_state[0], current_state[1] + dx, current_state[2] + dy)
        if not self._world.has_node(next_state):
            if next_state in self._alias.keys():
                next_state = self._alias[next_state]
            else:
                next_state = current_state

        altitude = nx.get_node_attributes(self._world, 'altitude')
        reward += altitude[current_state] - altitude[next_state]

        for obj in self._objects:
            if obj.coord == next_state:
                reward += obj.get_reward()
                done = True
                break

        self._time += 1
        if done:
            self._agent.reset()
        else:
            self._agent.update(current_state=next_state)

        return next_state, reward, done

    def set_reset_checkpoint(self, overwrite=False):
        """Set environment checkpoint for reset.

        Parameters
        ----------
        overwrite : bool (optional, default: False)
            Whether to overwrite existing checkpoint.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((3, 3))
        >>> W.add_object((1, 2, 1), reward=1, prob=0.7)
        >>> W.set_reset_checkpoint()
        >>> W.has_reset_checkpoint
        True
        """
        if not self._has_reset_checkpoint or overwrite:
            for key in self._reset_state.keys():
                self._reset_state[key] = copy.deepcopy(getattr(self, '_' + key))
                self._has_reset_checkpoint = True
        else:
            raise ng.NeuGymOverwriteError("Reset state already exists, "
                                          "set 'overwrite=True' to overwrite")

    def reset(self):
        """Reset the environment to the checkpoint state.

        Examples
        --------
        >>> W = GridWorld()
        >>> W.add_area((3, 3))
        >>> W.add_object((1, 2, 1), reward=1, prob=0.7)
        >>> W.set_reset_checkpoint()
        >>> W.init_agent()
        >>> W.step((1, 0))
        ((1, 0, 0), 0.0, False)
        >>> W.reset()
        """
        if not self._has_reset_checkpoint:
            raise ng.NeuGymCheckpointError(
                "Reset state not found, use 'set_reset_state()' "
                "to set the reset checkpoint first")

        for key, value in self._reset_state.items():
            setattr(self, '_' + key, copy.deepcopy(value))

    def __repr__(self):
        msg = "GridWorld(\n" \
              "\ttime={},\n" \
              "\torigin=Origin([0])(shape={}),\n".format(self._time, self.get_area_shape(0))

        if self._num_area == 0:
            msg += "\tareas=(),\n"
        else:
            msg += "\tareas=(\n"
            for i in range(1, self._num_area + 1):
                msg += "\t\t[{}] Area(shape={})".format(i, self.get_area_shape(i))
                if i != self._num_area:
                    msg += ",\n"
                else:
                    msg += "\n"
            msg += "\t),\n"

        if len(self._objects) == 0:
            msg += "\tobjects=(),\n"
        else:
            msg += "\tobjects=(\n"
            for i, obj in enumerate(self._objects):
                msg += "\t\t[{}] {}".format(i, str(obj))
                if i != len(self._objects) - 1:
                    msg += ",\n"
                else:
                    msg += "\n"
            msg += "\t),\n"

        msg += "\tactions={},\n".format(self._actions)
        msg += "\tagent={},\n".format(str(self._agent))
        msg += "\thas_reset_state={},\n".format(self._has_reset_checkpoint)
        msg += ")"

        return msg
