from tokenize import Token
from typing import Iterable, Tuple, List
from itertools import product

from lmql.ops.token_set import *
from lmql.ops.follow_map import *

lmql_operation_registry = {}

# @LMQLOp('function_name') decorator
def LMQLOp(name):
    def class_transformer(cls):
        if type(name) is list:
            for n in name:
                lmql_operation_registry[n] = f"lmql.{cls.__name__}"
            return cls
        lmql_operation_registry[name] = f"lmql.{cls.__name__}"
        return cls
    return class_transformer
class Node:
    def __init__(self, predecessors):
        assert type(predecessors) is list, "Predecessors must be a list, not {}".format(type(predecessors))
        self.predecessors = predecessors
        self.depends_on_context = False
        
        self.follow_map = None
    
    def execute_predecessors(self, trace, context):
        return [execute_op(p, trace=trace, context=context) for p in self.predecessors]

    def forward(self, *args, **kwargs):
        raise NotImplementedError(type(self) + " does not implement forward()")

    def follow(self, *args, **kwargs):
        raise NotImplementedError(type(self) + " does not implement follow()")
    
    def final(self, args, **kwargs):
        if all([a == "fin" for a in args]):
            return "fin"
        return "var"

    def __nodelabel__(self):
        return str(type(self))

def DynamicTypeDispatch(name, type_map):
    def get_handler(args):
        for signature, op in type_map:
            # fallback implementation
            if signature == "*": return op
            
            # check for matching signature
            is_match = True
            if type(signature) is tuple:
                for arg, t in zip(args, signature):
                    is_match = is_match and isinstance(arg, t)
            else:
                is_match = isinstance(args, signature)
            if is_match: return op
        raise NotImplementedError("error: no matching implemntation of {} for arguments of type {}".format(name, [type(arg) for arg in args]))
    
    class TypeDispatchingNode(Node):
        def forward(self, *args, **kwargs):
            return get_handler(args).forward(self, *args, **kwargs)
        
        def follow(self, *args, **kwargs):
            return get_handler(args).follow(self, *args, **kwargs)
        
        def final(self, *args, **kwargs):
            return get_handler(args).final(self, *args, **kwargs)
        
        def __str__(self):
            return f"<{name}>"
        
        def __repr__(self):
            return f"<{name}>"
        
        def __nodelabel__(self):
            return name

    return TypeDispatchingNode

NextToken = "<lmql.next>"

def is_next_token(t): 
    return t == NextToken

def strip_next_token(x):
    if type(x) is list:
        return [i for i in x if not is_next_token(i)]
    elif type(x) is tuple:
        return tuple(i for i in x if not is_next_token(i))
    if x.endswith(NextToken):
        x = x[:-len(NextToken)]
    return x

@LMQLOp("SENTENCES")
class Sentences(Node):
    def forward(self, v):
        sentences = tuple(self.split(v, separator=["."]))
        return self.strip(sentences)
    
    def strip(self, sentences):
        if len(sentences) == 0:
            return sentences
        elif sentences[-1] == ():
            return tuple(sentences[:-1])
        else: 
            return tuple(sentences)

    def add_end(self, stc, end):
        if len(stc) == 0: return stc
        return stc + end

    def split(self, v, separator):
        result = ()
        
        current = ""
        for c in v:
            if c in separator:
                result += (current + c,)
                current = ""
            else:
                current += c
        if len(current) > 0:
            result += (current,)
        if len(result) == 0:
            return ("",)
        return result

    def follow(self, x, **kwargs):
        v = strip_next_token(x)
        has_next_token = v != x
        sentences = tuple(self.split(v, separator=["."]))

        if has_next_token: # continues with next token
            if len(sentences) > 0 and sentences[-1].endswith("."):
                return fmap(
                    ("eos", self.strip(sentences)),
                    ("*", sentences + (NextToken,))
                )
            else:
                return fmap(
                    ("eos", self.strip(sentences)),
                    ("*", tuple(sentences[:-1] + (sentences[-1] + NextToken,)))
                )
        else:
            return fmap(
                ("*", sentences)
            )

    def final(self, x, operands=None, result=None, **kwargs):
        return x[0]
    

@LMQLOp("INT")
class IntOp(Node):
    def forward(self, x):
        if x is None: return None
        if x == "": return None

        # check int contains digits only
        if x.startswith(" "):
            x = x[1:]
        if not all([c in "0123456789" for c in x]):
            return False
        else:
            return True

    def follow(self, v, **kwargs):
        if v is None: return None
        
        has_next_token = v != strip_next_token(v)
        v = strip_next_token(v)

        context = kwargs.get("context", None)
        if context.runtime.prefers_compact_mask:
            number_tokens = tset("1","2","3","4","5","6","7","8","9","Ġ2","Ġ3","Ġ4","Ġ5","Ġ0","Ġ6","Ġ7","Ġ8","Ġ9","10","12","50","19","11","20","30","15","14","16","13","25","18","17","24","80","40","22","60","23","29","27","26","28","99","33","70","45","35","64","75","21","38","44","36","32","39","34","37","48","66","55","47","49","65","68","31","67","59","77","58","69","88","46","57","43","42","78","79","90","95","41","56","54","98","76","52","53","51","86","74","89","72","73","96","71","63","62","85","61","97","84","87","94","92","83","93","91","82","81", exact=True)
            number_continuation_tokens = tset("0","1","2","3","4","5","6","7","8","9","00","01","10","12","50","19","11","20","30","15","14","16","13","25","18","17","24","80","40","22","60","23","29","27","26","28","99","33","70","45","35","64","75","21","38","44","36","32","39","34","05","37","48","66","55","47","08","49","09","65","07","02","04","03","68","31","67","59","06","77","58","69","88","46","57","43","42","78","79","90","95","41","56","54","98","76","52","53","51","86","74","89","72","73","96","71","63","62","85","61","97","84","87","94","92","83","93","91","82","81", exact=True)
        else:
            number_tokens = tset("[ 1-9][0-9]*$", regex=True)
            number_continuation_tokens = tset("[0-9]+$", regex=True)

        if not has_next_token:
            return fmap(
                ("eos", len(v.strip()) != 0),
                ("*", self.forward(v))
            )

        if len(v) == 0:
            return fmap(
                (number_tokens, True),
                ("*", False)
            )
        else:
            if len(v.strip()) == 0:
                # do not allow empty strings
                return fmap(
                    (number_continuation_tokens, True),
                    ("eos", False),
                    ("*", False)
                )

            return fmap(
                (number_continuation_tokens, True),
                ("eos", True),
                ("*", False)
            )
        
    def final(self, x, operands=None, result=None, **kwargs):
        if result == False and x[0] == "inc":
            return "fin"
        return super().final(x, operands=operands, result=result, **kwargs)

@LMQLOp("TOKENS")
class TokensOp(Node):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.depends_on_context = True

    def forward(self, x, context):
        import asyncio
        if x is None: return None
        if x == "": return []
        
        tokens = context.runtime.model.sync_tokenize(x)
        return tokens

    def follow(self, v, context=None, **kwargs):
        if v is None: return None
        contains_next_token = v != strip_next_token(v)
        words = self.forward(strip_next_token(v), context)

        # if len(words) > 0 and contains_next_token:
        #     # allow continuation sub-tokens
        #     continuation_tokens = tset("[^\u0120].*", regex=True)
        #     valid_continuations = union(tset("eos"), continuation_tokens)
        #     components = [((valid_continuations, (words[:-1] + [words[-1] + NextToken])))] + components
        
        return fmap(
            ("*", (words + [NextToken]) if contains_next_token else words)
        )

    def final(self, x, context, operands=None, result=None, **kwargs):
        return x[0]

@LMQLOp("WORDS")
class WordsOp(Node):
    def forward(self, x):
        if x is None: return None
        if x == "": return []
        return [w for w in x.split(" ") if w.strip() != ""]
    
    def follow(self, v, **kwargs):
        if v is None: return None
        contains_next_token = v != strip_next_token(v)
        words = self.forward(strip_next_token(v))

        components = [("*", (words + [NextToken]) if contains_next_token else words)]
        
        if len(words) > 0 and contains_next_token:
            # allow continuation sub-tokens
            continuation_tokens = tset("[^\u0120].*", regex=True)
            valid_continuations = union(tset("eos"), continuation_tokens)
            components = [((valid_continuations, (words[:-1] + [words[-1] + NextToken])))] + components
        
        return fmap(
            *components
        )

    def final(self, x, operands=None, result=None, **kwargs):
        return x[0]

@LMQLOp("len")
class LenOp(Node):
    def forward(self, x):
        if x is None: return None
        return len(x)
    
    def follow(self, v, **kwargs):
        if v is None: return None
        if type(v) is list or type(v) is tuple:
            return len(v)
        else:
            assert type(v) is str, "len() can only be applied to strings, lists, or tuples"
            if NextToken not in v:
                return len(v)
            v = strip_next_token(v)
            
            len_masks = []
            all = "∅"
            l = 1
            while True:
                tmask = tset(charlen=l)
                all = tmask.union(all)
                # if 'all' encompasses all possible tokens, then we have enumerated all possible lengths
                if len(all) == VocabularyMatcher.instance().vocab_size:
                    break
                if len(tmask) > 0:
                    len_masks.append((tmask, len(v) + l))
                l += 1
            
            return fmap(*len_masks)

    def final(self, x, operands=None, result=None, **kwargs):
        return x[0]

class NotOp(Node):
    def forward(self, op):
        return not op

    def follow(self, v, **kwargs):
        return not v

class Lt(Node):
    def forward(self, *args):
        if any([a is None for a in args]): return None
        return args[0] < args[1]
    
    def follow(self, *args, **kwargs):
        if any([a is None for a in args]): return None
        return args[0] < args[1]

    def final(self, ops, operands=None, result=None, **kwargs):
        final_transition_indices = {"inc": 0, "dec": 1, "fin": 2, "var": 3}
        
        op1 = final_transition_indices[ops[0]]
        op2 = final_transition_indices[ops[1]]

        transition_table = [ # a < b
            # a "inc", "dec", "fin", "var"    # b
            [   "var", "fin", "fin", "var" ], # inc
            [   "var", "var", "var", "var" ], # dec
            [   "var", "fin", "fin", "var" ], # fin
            [   "var", "var", "var", "var" ], # var
        ]

        if result: 
            r = transition_table[op2][op1]
        else: 
            r = transition_table[op1][op2]
        
        return r

def Gt(preds): return Lt(list(reversed(preds)))

class EqOp(Node):
    def __init__(self, predecessors):
        super().__init__(predecessors)

    def forward(self, *args):
        return all([a == args[0] for a in args])

    def follow(self, *args, **kwargs):
        op1 = args[0]
        op2 = args[1]
        
        if op1 is None or op2 is None:
            return None
        
        if is_next_token(op1):
            if is_next_token(op2): 
                return fmap(
                    ("*", True)
                )
            else:
                return fmap(
                    (op2, True),
                    ("*", False)
                )
        if is_next_token(op2):
            if is_next_token(op1): 
                return fmap(
                    ("*", True)
                )
            else:
                return fmap(
                    (op1, True),
                    ("*", False)
                )

        if type(op1) is str or type(op1) is str:
            op_shorter = op1 if len(strip_next_token(op1)) < len(strip_next_token(op2)) else op2
            op_longer = op1 if len(strip_next_token(op1)) > len(strip_next_token(op2)) else op2

            if strip_next_token(op_longer) == op_longer and strip_next_token(op_shorter) != op_shorter:
                return InOpStrInSet([]).follow(op_shorter, [op_longer])

        return all([a == args[0] for a in args])

    def final(self, operand_final, operands=None, result=None, **kwargs):
        if not all(type(o) is str for o in operands):
            return super().final(operand_final, operands=operands, result=result, **kwargs)

        if result: # if equal, then fin iff all operands are fin
            return super().final(operand_final, operands=operands, result=result, **kwargs)
        
        if all([o == "fin" for o in operands]):
            return "fin"

        # result is False
        
        # determine longest, fixed prefix of final value
        fixed_value = None
        for o, of in zip(operands, operand_final):
            if of == "fin":
                if fixed_value is None:
                    fixed_value = o
                else:
                    if len(fixed_value) < len(o):
                        fixed_value = o
            elif of == "inc":
                if fixed_value is None: 
                    fixed_value = o
                else:
                    if len(fixed_value) < len(o):
                        fixed_value = o
            elif of == "var":
                continue
        
        # check that each operand is a prefix of the fixed_value
        for o, of in zip(operands, operand_final):
            if of == "fin":
                if fixed_value != o: return "fin"
            elif of == "inc":
                if not fixed_value.startswith(o): return "fin"
            else: # of == "var":
                continue
        
        return super().final(operand_final, operands=operands, result=result, **kwargs)

class SelectOp(Node):
    def forward(self, *args):
        if len(args[0]) <= args[1]:
            return None
        return args[0][args[1]]

    def follow(self, *args, **kwargs):
        l = args[0]
        idx = args[1]

        if l is None and idx is None:
            return None
        else:
            if l is not None and len(l) == idx + 1:
                return fmap(
                    ("eos", None),
                    ("*", l[idx])
                )
            else:
                return None

    def final(self, ops, operands, result, **kwargs):
        l = ops[0]
        idx = ops[1]

        if idx != "fin": return "var"

        if l == "fin": 
            return "fin"
        if result is not None and (l == "fin" or l == "inc"):
            return "fin"
        else: return "var"

class Var(Node):
    def __init__(self, name):
        super().__init__([])
        self.name = name

        self.depends_on_context = True
        
        # indicates whether the downstream node requires text diff information
        self.diff_aware_read = False

    async def json(self):
        return self.name

    def forward(self, context):
        if self.diff_aware_read:
            return (context.get(self.name, None), context.get_diff(self.name, None))
        return context.get(self.name, None)
    
    def follow(self, context, **kwargs):
        value = context.get(self.name, None)
        if value is None: return None
        
        # also return the text diff if required
        if self.diff_aware_read:
            value = (value, context.get_diff(self.name, None))

        # strip_next_token but also supports tuples
        def strip_nt(v):
            if type(v) is tuple: return (strip_next_token(v[0]), v[1])
            else: return strip_next_token(v)

        return fmap(
            ("eos", PredeterminedFinal(strip_nt(value), "fin")),
            ("*", value),
        )

    def final(self, x, context, operands=None, result=None, **kwargs):
        return context.final(self.name)

    def __repr__(self) -> str:
        return f"<Var {self.name}>"

class RawValueOp(Node):
    def __init__(self, args):
        super().__init__([])
        
        value, final = args
        self.value = value
        self.final_value = final

    def forward(self):
        return self.value

    def follow(self, **kwargs):
        return fmap(
            ("*", self.value)
        )

    def final(self, args, operands=None, result=None, **kwargs):
        return self.final_value

def matching_phrases_suffixes(x, allowed_phrases, allow_full_matches=False):
    x = strip_next_token(x)

    for phrase in allowed_phrases:
        if not phrase.startswith(x):
            continue
        if len(phrase) > len(x):
            yield phrase[len(x):]
        else:
            if allow_full_matches: 
                yield ""

class InOpStrInStr(Node):
    def forward(self, *args):
        if any([a is None for a in args]): return None

        return args[0] in args[1]

    def follow(self, *args, **kwargs):
        op1 = strip_next_token(args[0])
        op1_generating = args[0] != op1
        op2 = strip_next_token(args[1])
        op2_generating = args[1] != op2

        assert not op1_generating, "InOpStrInStr: left-hand side operand must not be generating"

        # if op2 is finished, then the result is fixed
        if not op2_generating: return op1 in op2
        
        # if op1 already contained
        if op1 in op2: 
            return True

        # op1 is not contained in op2, so check for partial overlap with suffixes of op2
        overlap = 0
        for i in range(len(op1)):
            if op2[-i:] == op1[:i]:
                overlap = i
        suffix = op1[overlap:]
        suffix = suffix.replace(".", r"\.").replace("*", r"\*")

        allowed_subtokens = tset(f".*{suffix}.*", regex=True)

        return fmap(
            (allowed_subtokens, True),
            (setminus("*", allowed_subtokens), False)
        )

    def final(self, op_final, result=None, **kwargs):
        if not result:
            return super().final(op_final, result=result, **kwargs)
        if op_final[1] == "inc" and op_final[0] == "fin":
            return "fin"
        return super().final(op_final, result=result, **kwargs)

class InOpStrInSet(Node):
    def forward(self, *args):
        if any([a is None for a in args]): return None
        
        x = args[0]
        allowed_phrases = args[1]
        
        if x is None: 
            return None

        for _ in matching_phrases_suffixes(x, allowed_phrases, allow_full_matches=True):
            # any match is enough
            return True
        
        return False

    def follow(self, *args, **kwargs):
        if any([a is None for a in args]): return None

        x = args[0]
        allowed_phrases = args[1]

        suffixes = list(matching_phrases_suffixes(x, allowed_phrases, allow_full_matches=True))
        num_full_matches = len([s for s in suffixes if s == ""])
        suffixes = [s for s in suffixes if s != ""]
        
        if len(suffixes) == 0:
            if num_full_matches > 0 and not x.endswith(NextToken):
                return True
            return False
        else:
            full_remainders = [s + "$" for s in suffixes]
            if num_full_matches > 0: full_remainders.append("eos")

            return fmap(
                (tset(*full_remainders), True),
                (tset(*suffixes, prefix=True), PredeterminedFinal(False, "var")),
                ("*", PredeterminedFinal(False, "fin"))
            )

    def final(self, args, operands=None, result=None, **kwargs):
        x_final = args[0]

        if result is None:
            return "var"
        elif result == False:
            if x_final == "inc" or x_final == "fin":
                return "fin"
            return "var"
        else: # result == True
            if x_final == "fin":
                return "fin"
            return "var"

InOp = DynamicTypeDispatch("InOp", (
    ((str, str), InOpStrInStr),
    ("*", InOpStrInSet),
))

class OrOp(Node):
    def forward(self, *args):
        if any([a == True for a in args]):
            return True
        elif all([a == False for a in args]):
            return False
        else:
            return None

    def follow(self, *args, **kwargs):
        return fmap(
            ("*", self.forward(*args))
        )

    def final(self, args, operands=None, result=None, **kwargs):
        if result:
            if any(a == "fin" and v == True for a,v in zip(args, operands)):
                return "fin"
            return "var"
        else: # not result
            if any(a == "var" for a in args):
                return "var"
            return "fin"

class AndOp(Node):
    def forward(self, *args):
        if type(args[0]) is tuple and len(args) == 1:
            args = args[0]

        if any([a == False for a in args]):
            return False
        elif any([a is None for a in args]):
            return None
        else:
            return all([a for a in args])

    def follow(self, *v, **kwargs):
        return fmap(
            ("*", self.forward(*v))
        )

    def final(self, args, operands=None, result=None, **kwargs):
        if result:
            if all([a == "fin" for a in args]):
                return "fin"
            return "var"
        else: # not result
            if any([a == "fin" and v == False for a,v in zip(args, operands)]):
                return "fin"
            return "var"

def seq_starts_with(seq1, seq2):
    num_matching = sum([1 if i1 == i2 else 0 for i1,i2 in zip(seq1, seq2)])
    return num_matching == len(seq2)

def remainder(seq: str, phrase: str):
    overlap = 0
    for i in range(len(phrase)):
        if seq[-i:] == phrase[:i]:
            overlap = i

    if overlap == 0: return None
    else: return phrase[i:]

@LMQLOp("STARTS_WITH")
class StartsWithOp(Node):
    def forward(self, *args):
        if any([a is None for a in args]): return None
        
        x = args[0]
        allowed_phrases = args[1]

        for phrase in allowed_phrases:
            if x.startswith(phrase):
                return True

        return False

    def follow(self, *args, **kwargs):
        if any([a is None for a in args]): return None

        x = args[0]
        allowed_phrases = args[1]

        # if there is any full match, then the result is True
        if any(strip_next_token(x).startswith(phrase) for phrase in allowed_phrases):
            return True

        # otherwise check for partial matches with the allowed phrases
        suffixes = list(matching_phrases_suffixes(x, allowed_phrases))

        if len(suffixes) == 0:
            return False
        else:
            return fmap(
                (tset(*suffixes), True),
                (ntset(*suffixes), False)
            )

    def final(self, ops_final, result=None, operands=None, **kwargs):
        op1 = ops_final[0]
        op2 = ops_final[1]

        if op1 == "inc" and op2 == "fin":
            if len(operands[0]) == 0 and not result:
                return "var"
            return "fin"
        
        return super().final(ops_final, **kwargs)

    # def final(self, args, operands=None, result=None, pattern: TokenSet=None, **kwargs):
    #     x_final = args[0]

    #     if result is None:
    #         return "var"
    #     elif result == False:
    #         x = operands[0]
    #         allowed_phrases = operands[1]

    #         if x_final == "inc":
    #             # check whether there are suffixes which could be matched
    #             suffixes = list(matching_phrases_suffixes(x, allowed_phrases))
    #             suffixes = [s for s in suffixes if pattern is None or pattern.starts_with(s)]
    #             # print("final() with x_final inc", x, allowed_phrases)
    #             if len(suffixes) > 0: return "var"
    #             else: return "fin"
    #         elif x_final == "fin":
    #             return "fin"
    #         else:
    #             return "var"
    #     else: # result == True
    #         if x_final == "inc" or x_final == "fin":
    #             return "fin"
    #         return "var"

@LMQLOp(["STOPS_AT", "stops_at"])
class StopAtOp(Node):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._tokenized_stopping_phrase_cache = {}

    def execute_predecessors(self, trace, context):
        var_op: Var = self.predecessors[0]
        assert type(var_op) is Var, "The first argument of STOPS_AT must be a direct reference to a template variable."
        assert type(self.predecessors[1]) is str, "The second argument of STOPS_AT must be a string literal."
        var_op.diff_aware_read = True
        return super().execute_predecessors(trace, context)
    
    @property
    def variable(self):
        return self.predecessors[0]

    @property
    def stopping_phrase(self):
        return self.predecessors[1]

    async def stopping_phrase_tokenized(self, tokenizer):
        if tokenizer in self._tokenized_stopping_phrase_cache:
            return self._tokenized_stopping_phrase_cache[tokenizer]
        else:
            result = (await tokenizer(self.stopping_phrase))
            self._tokenized_stopping_phrase_cache[tokenizer] = result
            return result

    def forward(self, *args):
        if any([a is None for a in args]): return None

        op1, op1_diff = args[0]
        op2 = args[1]

        if op1 is None: return
        if op1_diff is None: op1_diff = ""

        matched_phrase_index = op1.rfind(op2)
        op2_in_op1 = matched_phrase_index != -1 and matched_phrase_index + len(op2) > len(op1) - len(op1_diff)

        return not op2_in_op1 or op1.endswith(op2)

    def follow(self, *args, previous_result=None, **kwargs):
        if any([a is None for a in args]): 
            return None

        op1, op1_diff = args[0]
        if op1 is None:return None
        if op1_diff is None: op1_diff = ""

        op1 = strip_next_token(op1)
        op2 = args[1]


        matched_phrase_index = op1.rfind(op2)
        op2_in_op1 = matched_phrase_index != -1 and matched_phrase_index + len(op2) > len(op1) - len(op1_diff)

        if not op2_in_op1: return fmap(("*", True))

        ends_with_stopping_phrase = op1.endswith(op2)

        if op1 != args[0][0] and ends_with_stopping_phrase:
            # print("StopAtOp.follow()", [op1], [op2], valid)
            ends_with_stopping_phrase = False
        if len(op1) == 0:
            ends_with_stopping_phrase = True
        
        return fmap(("*", ends_with_stopping_phrase))

    def final(self, ops_final, operands, result, **kwargs):
        if result: 
            if ops_final[0] == "var":
                return "var"
            return "fin"
        else: # not result
            if ops_final[0] == "var": 
                r = "var"
            elif ops_final[0] == "dec": 
                r = "var"
            else: 
                r = "fin"
            return r

class OpaqueLambdaOp(Node):
    def forward(self, *args):
        if any([a is None for a in args]): return None
        fct, *args = args
        return fct(*args)
    
    def follow(self, *v, **kwargs):
        if any([a is None for a in v]): return None

        fct, *args = v
        return fmap(
            ("*", fct(*args))
        )

def create_mask(follow_map, valid, final):
    if follow_map is None:
        return "*"
    
    allowed_tokens = tset()
    otherwise_result = None

    for pattern, result in follow_map:
        if pattern == "*":
            otherwise_result = result

        if result is not None:
            value, final = result
        else:
            value = None
            final = "var"

        if value == True or value is None:
            allowed_tokens = union(allowed_tokens,pattern)
        elif value == False and final == ("var",):
            allowed_tokens = union(allowed_tokens, pattern)
        elif value is None and len(follow_map.components) == 1:
            allowed_tokens = "*"
        elif result == (False, ('fin',)):
            if pattern != "*":
                allowed_tokens = setminus(allowed_tokens, pattern)

    if allowed_tokens == "∅":
        return tset("eos")

    if len(allowed_tokens) == 0:
        if otherwise_result is not None:
            othw_value, othw_final = otherwise_result
        else:
            othw_value, othw_final = None, "var"
        if not othw_value and othw_final == ("fin",):
            return tset("eos")
        else:
            return "*"

    return allowed_tokens

def is_node(op):
    return issubclass(type(op), Node)

def derive_final(op, trace, context, result):
    def get_final(v):
        # for nodes, get final value from trace
        if is_node(v): return trace[v][1]
        # for constants, final value is always "fin"
        return "fin"

    predecessor_final = [get_final(p) for p in op.predecessors]

    def get_predecessor_result(v):
        if is_node(v): return trace[v][0]
        return v
    
    predecessor_values = [get_predecessor_result(p) for p in op.predecessors]

    context_arg = ()
    if op.depends_on_context: 
        context_arg += (context,)
    
    return op.final(predecessor_final, *context_arg, operands=predecessor_values, result=result)

def execute_op_stops_at_only(op: Node, result=None):
    """
    Evaluates a Node and returns the list of defined StopAtOps for the query.
    """
    if result is None: result = []

    if type(op) is StopAtOp:
        result.append(op)
    elif type(op) is AndOp:
        for p in op.predecessors:
            execute_op_stops_at_only(p, result=result)
    elif type(op) is OrOp:
        # TODO: actually STOPS_AT in OR is not really supported yet
        for p in op.predecessors:
            execute_op_stops_at_only(p, result=result)
    else:
        # other ops are no-ops from a STOPS_AT perspective (cannot contain additional STOPS_AT ops)
        # TODO: what about not
        return []
    return result

def execute_op(op: Node, trace=None, context=None, return_final=False):
    # for constant dependencies, just return their value
    if not is_node(op): 
        return op
    
    # only evaluate each operation once
    if op in trace.keys(): 
        return trace[op][0]
    
    # compute predecessor values
    inputs = op.execute_predecessors(trace, context)
    
    if op.depends_on_context: 
        inputs += (context,)
    
    result = op.forward(*inputs)
    is_final = derive_final(op, trace, context, result)
    
    if trace is not None: 
        trace[op] = (result, is_final)

    if return_final:
        return result, is_final

    return result

def digest(expr, context, follow_context, no_follow=False):
    if expr is None: return True, "fin", {}

    trace = {}
    expr_value, is_final = execute_op(expr, trace=trace, context=context, return_final=True)

    if no_follow:
        return expr_value, is_final, trace

    for op, value in trace.items():
        # determine follow map of predecessors
        if len(op.predecessors) == 0: 
            # empty argtuple translates to no follow input
            intm = all_fmap((ArgTuple(), ["fin"])) 
        else:
            # use * -> value, for constant value predecessor nodes
            def follow_map(p):
                if is_node(p): return p.follow_map
                else: return fmap(("*", (p, ("fin",))))
            intm = fmap_product(*[follow_map(p) for p in op.predecessors])
        
        # apply follow map
        op_follow_map = follow_apply(intm, op, value, context=follow_context)

        # name = op.__class__.__name__
        # print(name, value)
        # print("follow({}) = {}".format(name, op_follow_map))

        setattr(op, "follow_map", op_follow_map)
    
    return expr_value, is_final, trace
