
"""
Context for evaluation engine -- carries all of the global variables (schema, graph, etc.)

We might fold the various routines inside context and replace "cntxt: Context" with "self", but we will have to see.

"""
from typing import Dict, Any, Callable, Optional, List, Tuple, Union

from ShExJSG import ShExJ
from ShExJSG.ShExJ import Schema
from pyjsg.jsglib.jsg import isinstance_
from rdflib import Graph, BNode, Namespace

from pyshex.shapemap_structure_and_language.p3_shapemap_structure import nodeSelector, START


class DebugContext:
    def __init__(self):
        self.trace_indent = 0
        self.trace_satisfies = False
        self.trace_nodeSatisfies = False
        self.trace_matches = False


class _VisitorCenter:
    """ A visitor context -- couldn't resist calling it Visitor Center, however... it is python, you know """
    def __init__(self, f: Callable[[Any, ShExJ.shapeExpr, "Context"], None], arg_cntxt: Any) \
            -> None:
        self.f = f
        self.arg_cntxt = arg_cntxt
        self._seen_shapes = []
        self._visiting_shapes = []
        self._seen_tes = []
        self._visiting_tes = []

    def start_visiting_shape(self, id_: str) -> None:
        self._visiting_shapes.append(id_)

    def actively_visiting_shape(self, id_: str) -> bool:
        return id_ in self._visiting_shapes

    def done_visiting_shape(self, id_: str) -> None:
        self._visiting_shapes.remove(id_)
        self._seen_shapes.append(id_)

    def already_seen_shape(self, id_: str) -> bool:
        return id_ in self._seen_shapes

    def start_visiting_te(self, id_: str) -> None:
        self._visiting_tes.append(id_)

    def actively_visiting_te(self, id_: str) -> bool:
        return id_ in self._visiting_tes

    def done_visiting_te(self, id_: str) -> None:
        self._visiting_tes.remove(id_)
        self._seen_tes.append(id_)

    def already_seen_te(self, id_: str) -> bool:
        return id_ in self._seen_tes


def default_external_shape_resolver(_: ShExJ.IRIREF) -> Optional[ShExJ.Shape]:
    """ Default external shape resolution function """
    return None


class Context:
    """ Environment for ShExJ evaluation """
    def __init__(self, g: Optional[Graph], s: Schema,
                 external_shape_resolver: Optional[Callable[[ShExJ.IRIREF], Optional[ShExJ.Shape]]]=None,
                 base_namespace: Optional[Namespace]=None) -> None:
        """
        Create a context consisting of an RDF Graph and a ShEx Schema and generate a identifier to
        item map.

        :param g: RDF graph
        :param s: ShExJ Schema instance
        :param external_shape_resolver: External resolution function
        :param base_namespace:
        """
        self.graph: Graph = g
        self.schema: ShExJ.Schema = s
        self.schema_id_map: Dict[ShExJ.shapeExprLabel, ShExJ.shapeExpr] = {}
        self.te_id_map: Dict[ShExJ.tripleExprLabel, ShExJ.tripleExpr] = {}
        self.external_shape_for = external_shape_resolver if external_shape_resolver \
            else default_external_shape_resolver
        self.base_namespace = base_namespace

        # A list of node selectors/shape expressions that are being evaluated.  If we attempt to evaluate
        # an entry for a second time, we, instead, put the entry into the assumptions table.  We start with 'true'
        # and, if the result is 'true' then we count it as success.  If not, we switch to false and try again
        self.evaluating: List[Tuple[nodeSelector, ShExJ.shapeExpr]] = []
        self.assumptions: Dict[Tuple[nodeSelector, ShExJ.shapeExpr], bool] = {}

        # Debugging options
        self.debug_context = DebugContext()

        if self.schema.start is not None:
            self._gen_schema_xref(self.schema.start)
        if self.schema.shapes is not None:
            for e in self.schema.shapes:
                self._gen_schema_xref(e)

    def _gen_schema_xref(self, expr: Optional[Union[ShExJ.shapeExprLabel, ShExJ.shapeExpr]]) -> None:
        """
        Generate the schema_id_map

        :param expr: root shape expression
        """
        if expr is not None and not isinstance_(expr, ShExJ.shapeExprLabel) and 'id' in expr and expr.id is not None:
            if expr.id in self.schema_id_map:
                return
            else:
                self.schema_id_map[self._resolve_relative_uri(expr.id)] = expr
        if isinstance(expr, (ShExJ.ShapeOr, ShExJ.ShapeAnd)):
            for expr2 in expr.shapeExprs:
                self._gen_schema_xref(expr2)
        elif isinstance(expr, ShExJ.ShapeNot):
            self._gen_schema_xref(expr.shapeExpr)
        elif isinstance(expr, ShExJ.Shape):
            if expr.expression is not None:
                self._gen_te_xref(expr.expression)

    def _resolve_relative_uri(self, ref: ShExJ.shapeExprLabel) -> ShExJ.shapeExprLabel:
        return ShExJ.IRIREF(str(self.base_namespace[str(ref)])) if ':' not in str(ref) and self.base_namespace else ref

    def _gen_te_xref(self, expr: Union[ShExJ.tripleExpr, ShExJ.tripleExprLabel]) -> None:
        """
        Generate the triple expression map (te_id_map)

        :param expr: root triple expression

        """
        if expr is not None and not isinstance_(expr, ShExJ.tripleExprLabel) and 'id' in expr and expr.id is not None:
            if expr.id in self.te_id_map:
                return
            else:
                self.te_id_map[self._resolve_relative_uri(expr.id)] = expr
        if isinstance(expr, (ShExJ.OneOf, ShExJ.EachOf)):
            for expr2 in expr.expressions:
                self._gen_te_xref(expr2)
        elif isinstance(expr, ShExJ.TripleConstraint):
            if expr.valueExpr is not None:
                self._gen_schema_xref(expr.valueExpr)

    def tripleExprFor(self, id_: ShExJ.tripleExprLabel) -> ShExJ.tripleExpr:
        """ Return the triple expression that corresponds to id """
        return self.te_id_map[id_]

    def shapeExprFor(self, id_: Union[ShExJ.shapeExprLabel, START]) -> Optional[ShExJ.shapeExpr]:
        """ Return the shape expression that corresponds to id """
        return self.schema.start if id_ is START else self.schema_id_map.get(id_)

    def visit_shapes(self, expr: ShExJ.shapeExpr, f: Callable[[Any, ShExJ.shapeExpr, "Context"], None], arg_cntxt: Any,
                     visit_center: _VisitorCenter = None, follow_inner_shapes: bool=True) -> None:
        """
        Visit expr and all of its "descendant" shapes.

        :param expr: root shape expression
        :param f: visitor function
        :param arg_cntxt: accompanying context for the visitor function
        :param visit_center: Recursive visit context.  (Not normally supplied on an external call)
        :param follow_inner_shapes: Follow nested shapes or just visit on outer level
        """
        if visit_center is None:
            visit_center = _VisitorCenter(f, arg_cntxt)
        has_id = getattr(expr, 'id', None) is not None
        if not has_id or not (visit_center.already_seen_shape(expr.id)
                              or visit_center.actively_visiting_shape(expr.id)):

            # Visit the root expression
            if has_id:
                visit_center.start_visiting_shape(expr.id)
            f(arg_cntxt, expr, self)

            # Traverse the expression and visit its components
            if isinstance(expr, (ShExJ.ShapeOr, ShExJ.ShapeAnd)):
                for expr2 in expr.shapeExprs:
                    self.visit_shapes(expr2, f, arg_cntxt, visit_center, follow_inner_shapes=follow_inner_shapes)
            elif isinstance(expr, ShExJ.ShapeNot):
                self.visit_shapes(expr.shapeExpr, f, arg_cntxt, visit_center, follow_inner_shapes=follow_inner_shapes)
            elif isinstance(expr, ShExJ.Shape):
                if expr.expression is not None and follow_inner_shapes:
                    self.visit_triple_expressions(expr.expression,
                                                  lambda ac, te, cntxt: self._visit_shape_te(te, visit_center),
                                                  arg_cntxt,
                                                  visit_center)
            elif ShExJ.isinstance_(expr, ShExJ.shapeExprLabel):
                if not visit_center.actively_visiting_shape(str(expr)) and follow_inner_shapes:
                    visit_center.start_visiting_shape(str(expr))
                    self.visit_shapes(self.shapeExprFor(expr), f, arg_cntxt, visit_center)
                    visit_center.done_visiting_shape(str(expr))
            if has_id:
                visit_center.done_visiting_shape(expr.id)

    def visit_triple_expressions(self, expr: ShExJ.tripleExpr, f: Callable[[Any, ShExJ.tripleExpr, "Context"], None],
                                 arg_cntxt: Any, visit_center: _VisitorCenter=None) -> None:
        if visit_center is None:
            visit_center = _VisitorCenter(f, arg_cntxt)
        has_id = 'id' in expr and expr.id is not None
        if not has_id or not visit_center.already_seen_te(expr.id):

            # Visit the root expression
            if has_id:
                visit_center.start_visiting_te(expr.id)
            f(arg_cntxt, expr, self)

            # Visit all of the references
            if isinstance(expr, (ShExJ.EachOf, ShExJ.OneOf)):
                for expr2 in expr.expressions:
                    self.visit_triple_expressions(expr2, f, arg_cntxt, visit_center)
            elif isinstance(expr, ShExJ.TripleConstraint):
                if expr.valueExpr is not None:
                    self.visit_shapes(expr.valueExpr,
                                      lambda ac, te, cntxt: self._visit_shape_te(te, visit_center),
                                      arg_cntxt,
                                      visit_center)
            elif ShExJ.isinstance_(expr, ShExJ.tripleExprLabel):
                if not visit_center.actively_visiting_te(str(expr)):
                    visit_center.start_visiting_te(str(expr))
                    self.visit_triple_expressions(self.tripleExprFor(expr), f, arg_cntxt, visit_center)
                    visit_center.done_visiting_te(str(expr))
            if has_id:
                visit_center.done_visiting_te(expr.id)

    def _visit_shape_te(self, te: ShExJ.tripleExpr, visit_center: _VisitorCenter) -> None:
        """
        Visit a triple expression that was reached through a shape. This, in turn, is used to visit additional shapes
        that are referenced by a TripleConstraint
        :param te: Triple expression reached through a Shape.expression
        :param visit_center: context used in shape visitor
        """
        if isinstance(te, ShExJ.TripleConstraint) and te.valueExpr is not None:
            visit_center.f(visit_center.arg_cntxt, te.valueExpr, self)

    def _visit_te_shape(self, shape: ShExJ.shapeExpr, visit_center: _VisitorCenter) -> None:
        """
        Visit a shape expression that was reached through a triple expression.  This, in turn, is used to visit
        additional triple expressions that are referenced by the Shape

        :param shape: Shape reached through triple expression traverse
        :param visit_center: context used in shape visitor
        """
        if isinstance(shape, ShExJ.Shape) and shape.expression is not None:
            visit_center.f(visit_center.arg_cntxt, shape.expression, self)

    def start_evaluating(self, n: nodeSelector, s: ShExJ.shapeExpr) -> Optional[bool]:
        """
        Indicate that we are beginning to evaluate n in terms of s.
        :param n: nodeSelector to be evaluated
        :param s: expression for node evaluation
        :return: Assumed evaluation result.  If None, evaluation must be performed
        """
        if not s.id:
            s.id = str(BNode())
        key = (n, s.id)
        if key not in self.evaluating:
            self.evaluating.append(key)
            return None
        elif key not in self.assumptions:
            self.assumptions[key] = True
            return True
        else:
            return self.assumptions[key]

    def done_evaluating(self, n: nodeSelector, s: ShExJ.shapeExpr, result: bool) -> bool:
        """
        Indicate that we have completed evaluating n in terms of s.

        :param n: nodeselector that was evaluated
        :param s: expression for node evaluation
        :param result: result of evaluation
        :return: True means that evaluation was successful, False means try the evaluation again
        """
        key = (n, s.id)
        self.evaluating.remove(key)
        if key not in self.assumptions:
            return True
        elif self.assumptions[key] == result:
            del self.assumptions[key]
            return True
        else:
            self.assumptions[key] = False
            return False
