#
# This file is automatically generated by gen_ast.py
#
# Do not edit by hand!
#
              
# First, we have a complete textual copy of ast_base.py (which should never
# be imported directly in normal use)
              
# ----- Start of ast_base.py copy -----

from enum import Enum, auto
from .tokenizer import *
import re

class Walk(Enum):
    ENTERING = auto()
    VISITING = auto()
    LEAVING = auto()
    SKIP = auto()

# The ZX Spectrum BASIC Grammar is found in spectrum_basic.tx

# Operator precedence table (higher number = tighter binding)
BINARY_PRECEDENCE = {
    'OR': 2,
    'AND': 3,
    '=': 5, '<': 5, '>': 5, '<=': 5, '>=': 5, '<>': 5,
    '+': 6, '-': 6,
    '*': 8, '/': 8,
    '^': 10,
}

UNARY_PRECEDENCE = {
    '-': 9,
    'NOT': 4,
}

def precedence(expr):
    """Get the precedence of an operator"""
    if isinstance(expr, BinaryOp):
        return BINARY_PRECEDENCE[expr.op]
    if isinstance(expr, UnaryOp):
        return UNARY_PRECEDENCE[expr.op]
    return 0

def is_complex(expr):
    """Determine if an expression needs parentheses in function context"""
    if isinstance(expr, BinaryOp):
        return True
    # Could add other cases here
    return False

def needs_parens(expr, parent=None, is_rhs=False):
    """Determine if expression needs parentheses based on context"""
    if not isinstance(expr, BinaryOp) and not isinstance(expr, UnaryOp):
        return False

    if parent is None:
        return False

    expr_prec = precedence(expr)
    parent_prec = precedence(parent)
    
    # Different cases where we need parens:
    
    # Lower precedence always needs parens
    if expr_prec < parent_prec:
        return True
        
    # Equal precedence depends on operator and position
    if expr_prec == parent_prec:
        # For subtraction and division, right side always needs parens
        if parent.op in {'-', '/'} and is_rhs:
            return True
        # For power, both sides need parens if same precedence
        if parent.op == '^':
            return True
    
    return False

# Rather than a visitor patter, we use a generator-based approach with
# a walk function that yields “visit events” for each node in the tree

def walk(obj):
    """Handles walking over the AST, but particularly non-AST nodes"""
    if obj is None:
        return
    if isinstance(obj, (list, tuple)):
        for item in obj:
            yield from walk(item)
    elif isinstance(obj, dict):
        for key, value in obj.items():
            yield from walk(value)
    elif isinstance(obj, (str, int, float)):
        yield (Walk.VISITING, obj)
    elif hasattr(obj, "walk"):
        yield from obj.walk()
    # raw AST nodes have a _tx_attrs attribute whose keys are the names of the attributes
    elif hasattr(obj, "_tx_attrs"):
        yield (Walk.VISITING, obj)
        for attr in obj._tx_attrs:
            yield from walk(getattr(obj, attr))
        yield (Walk.LEAVING, obj)
    else:
        yield (Walk.VISITING, obj)

# Classes for the BASIC language

def sane_bytes(s):
    """Like bytes, but works on strings without needing encoding"""
    if isinstance(s, str):
        return s.encode('ascii')
    elif s is None:
        return b""
    if hasattr(s, '__iter__'):
        return bjoin(s)
    if isinstance(s, int):
        return s.to_bytes(1, 'big')
    return bytes(s)

def bjoin(items, sep=b""):
    """Join a list of byte sequences (or convertibles) with a separator"""
    return sep.join(sane_bytes(item) for item in items)

def sjoin(items, sep=""):
    """Join a list of strings (or convertibles) with a separator"""
    return sep.join(str(item) for item in items)

class ASTNode:
    """Base class for all (non-textx) AST nodes"""
    def __repr__(self):
        return str(self)
    
    def walk(self):
        """Base walk method for all expressions"""
        yield (Walk.VISITING, self)

class Statement(ASTNode):
    """Base class for all BASIC statements"""
    pass

class BuiltIn(Statement):
    """Represents simple built-in commands with fixed argument patterns"""
    def __init__(self, parent, action, *args, sep=", "):
        self.parent = parent
        self.action = action.upper()
        self.args = args
        self.is_expr = False
        self.sep = sep
    
    def __str__(self):
        if not self.args:
            return self.action

        present_args = [str(arg) for arg in self.args if arg is not None]
        if self.is_expr:
            if len(present_args) == 1:
                # For single argument function-like expressions, only add parens if needed
                arg_str = present_args[0]
                if is_complex(self.args[0]):
                    return f"{self.action} ({arg_str})"
                return f"{self.action} {arg_str}"
            elif len(present_args) == 0:
                return f"{self.action}"
            else:
                return f"{self.action}({self.sep.join(present_args)})"
        else:
            return f"{self.action} {self.sep.join(present_args)}"
        
    def walk(self):
        """Walk method for built-in commands"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.args)
        yield (Walk.LEAVING, self)

    def __bytes__(self):
        """Return the in-memory representation of the command"""
        btoken = token_to_byte(self.action)
        bsep = self.sep.strip().encode('ascii')
        present_args = [arg for arg in self.args if arg is not None]
        if self.is_expr:
            if len(self.args) == 1:
                the_arg = self.args[0]
                if is_complex(the_arg):
                    return bjoin([btoken, b'(', the_arg, b')'])
                return bjoin([btoken, the_arg])
            elif len(self.args) == 0:
                return btoken
            else:
                return bjoin([btoken, b'(', bjoin(present_args, sep=bsep), b')'])
        else:
            return bjoin([btoken, bjoin(present_args, sep=bsep)])

class ColouredBuiltin(BuiltIn):
    """Special case for commands that can have colour parameters"""
    def __init__(self, parent, action, colours, *args):
        super().__init__(parent, action, *args)
        self.colours = colours or []
    
    def __str__(self):
        parts = [self.action]
        if self.colours:
            colour_strs = [str(c) for c in self.colours]
            parts.append(" ")
            parts.append("; ".join(colour_strs))
            parts.append(";")
        if self.args:
            if self.colours:
                parts.append(" ")
            parts.append(self.sep.join(map(str, self.args)))
        return "".join(parts)

    def walk(self):
        """Walk method for coloured built-in commands"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.colours)
        yield from walk(self.args)
        yield (Walk.LEAVING, self)

    def __bytes__(self):
        """Return the in-memory representation of the command"""
        bparts = [token_to_byte(self.action)]
        if self.colours:
            bparts.append(bjoin(self.colours, sep=b";"))
            bparts.append(b";")
        if self.args:
            bparts.append(bjoin(self.args, sep=self.sep.strip().encode('ascii')))
        return bjoin(bparts)
        

def nstr(obj, sep="", none=""):
    "Like str, but returns an empty string for None"
    if obj is None:
        return none
    return f"{obj}{sep}"

def speccy_quote(s):
    """Quote a string in ZX Spectrum BASIC format"""
    doubled = s.replace('"', '""')
    unescaped = escapes_to_unicode(doubled)
    return f'"{unescaped}"'


# Expression classes

class Expression(ASTNode):
    pass

def is_expression(obj):
    """Determine if an object is an expression"""
    return isinstance(obj, Expression) or (isinstance(obj, BuiltIn) and obj.is_expr)


# ----- End of ast_base.py copy -----
              
# Automagically generated code for the AST classes

class Program(ASTNode):
    """Program AST node"""
    def __init__(self, lines):
        self.lines = lines
    def __str__(self):
        """Return a string representation of a Program node"""
        return f"{chr(10).join(str(line) for line in self.lines)}"
    def __bytes__(self):
        """Return the in-memory representation of a Program node"""
        return bjoin([bjoin(self.lines)])
    def walk(self):
        """Walk method for Program nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.lines)
        yield (Walk.LEAVING, self)

class SourceLine(ASTNode):
    """SourceLine AST node"""
    def __init__(self, parent, line_number, label, statements, after):
        self.parent = parent
        self.line_number = line_number
        self.label = label
        self.statements = statements
        self.after = after
    def __bytes__(self):
        """Return the in-memory representation of a SourceLine node"""
        return bjoin([line_to_bytes(self.line_number, bjoin([bjoin(self.statements, b':'),self.after]))])
    def walk(self):
        """Walk method for SourceLine nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.line_number)
        yield from walk(self.label)
        yield from walk(self.statements)
        yield from walk(self.after)
        yield (Walk.LEAVING, self)

    def __str__(self):
        str_statements = ": ".join(str(stmt) for stmt in self.statements)
        after = sjoin(self.after)
        if self.line_number and self.label:
            return f"{self.line_number} {self.label}: {str_statements}{after}"
        elif self.line_number:
            return f"{self.line_number}	{str_statements}{after}"
        elif self.label:
            return f"{self.label}:{'	' if len(self.label.name) < 6 else ' '}{str_statements}{after}"
        return f"	{str_statements}"

class CommentLine(ASTNode):
    """CommentLine AST node"""
    def __init__(self, parent, char, comment):
        self.parent = parent
        self.char = char
        self.comment = comment
    def __str__(self):
        """Return a string representation of a CommentLine node"""
        return f"{self.char}{self.comment}"
    def __bytes__(self):
        """Return the in-memory representation of a CommentLine node"""
        return bjoin([])
    def walk(self):
        """Walk method for CommentLine nodes"""
        yield (Walk.VISITING, self)

class JankyStatement(Statement):
    """JankyStatement AST node"""
    def __init__(self, parent, before, actual, after):
        self.parent = parent
        self.before = before
        self.actual = actual
        self.after = after
    def __str__(self):
        """Return a string representation of a JankyStatement node"""
        return f"{sjoin(junk)}{nstr(self.actual)}{sjoin(self.after)}"
    def __bytes__(self):
        """Return the in-memory representation of a JankyStatement node"""
        return bjoin([echars_to_bytes(j) for j in self.before] + [self.actual] + [echars_to_bytes(j) for j in self.after])
    def walk(self):
        """Walk method for JankyStatement nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.before)
        yield from walk(self.actual)
        yield from walk(self.after)
        yield (Walk.LEAVING, self)

class JankyFunctionExpr(Statement):
    """JankyFunctionExpr AST node"""
    def __init__(self, parent, before, actual, after):
        self.parent = parent
        self.before = before
        self.actual = actual
        self.after = after
    def __str__(self):
        """Return a string representation of a JankyFunctionExpr node"""
        return f"{sjoin(junk)}{nstr(self.actual)}{sjoin(self.after)}"
    def __bytes__(self):
        """Return the in-memory representation of a JankyFunctionExpr node"""
        return bjoin([echars_to_bytes(j) for j in self.before] + [self.actual] + [echars_to_bytes(j) for j in self.after])
    def walk(self):
        """Walk method for JankyFunctionExpr nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.before)
        yield from walk(self.actual)
        yield from walk(self.after)
        yield (Walk.LEAVING, self)

class Let(Statement):
    """Let AST node"""
    def __init__(self, parent, var, expr):
        self.parent = parent
        self.var = var
        self.expr = expr
    def __str__(self):
        """Return a string representation of a Let node"""
        return f"LET {self.var} = {self.expr}"
    def __bytes__(self):
        """Return the in-memory representation of a Let node"""
        return token_to_byte('LET') + bjoin([self.var, b'=', self.expr])
    def walk(self):
        """Walk method for Let nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.var)
        yield from walk(self.expr)
        yield (Walk.LEAVING, self)

class For(Statement):
    """For AST node"""
    def __init__(self, parent, var, start, end, step):
        self.parent = parent
        self.var = var
        self.start = start
        self.end = end
        self.step = step
    def __str__(self):
        """Return a string representation of a For node"""
        return f"FOR {self.var} = {self.start} TO {self.end}{f' STEP {self.step}' if self.step else ''}"
    def __bytes__(self):
        """Return the in-memory representation of a For node"""
        return token_to_byte('FOR') + bjoin([self.var, b'=', self.start, token_to_byte('TO'), self.end] + ([token_to_byte('STEP'), self.step] if self.step else []))
    def walk(self):
        """Walk method for For nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.var)
        yield from walk(self.start)
        yield from walk(self.end)
        yield from walk(self.step)
        yield (Walk.LEAVING, self)

class Next(Statement):
    """Next AST node"""
    def __init__(self, parent, var):
        self.parent = parent
        self.var = var
    def __str__(self):
        """Return a string representation of a Next node"""
        return f"NEXT {self.var}"
    def __bytes__(self):
        """Return the in-memory representation of a Next node"""
        return token_to_byte('NEXT') + bjoin([self.var])
    def walk(self):
        """Walk method for Next nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.var)
        yield (Walk.LEAVING, self)

class If(Statement):
    """If AST node"""
    def __init__(self, parent, condition, statements, after):
        self.parent = parent
        self.condition = condition
        self.statements = statements
        self.after = after
    def __str__(self):
        """Return a string representation of a If node"""
        return f"IF {self.condition} THEN {': '.join(str(stmt) for stmt in self.statements)}{sjoin(self.after)}"
    def __bytes__(self):
        """Return the in-memory representation of a If node"""
        return token_to_byte('IF') + bjoin([self.condition, token_to_byte('THEN'), bjoin(self.statements, sep=b':'), bjoin(self.after)])
    def walk(self):
        """Walk method for If nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.condition)
        yield from walk(self.statements)
        yield from walk(self.after)
        yield (Walk.LEAVING, self)

class LongIf(Statement):
    """LongIf AST node"""
    def __init__(self, parent, condition):
        self.parent = parent
        self.condition = condition
    def __str__(self):
        """Return a string representation of a LongIf node"""
        return f"IF {self.condition}"
    def __bytes__(self):
        """Return the in-memory representation of a LongIf node"""
        return token_to_byte('IF') + bjoin([self.condition])
    def walk(self):
        """Walk method for LongIf nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.condition)
        yield (Walk.LEAVING, self)

class ElseIf(Statement):
    """ElseIf AST node"""
    def __init__(self, parent, condition):
        self.parent = parent
        self.condition = condition
    def __str__(self):
        """Return a string representation of a ElseIf node"""
        return f"ELSE IF {self.condition}"
    def __bytes__(self):
        """Return the in-memory representation of a ElseIf node"""
        return token_to_byte('ELSE') + bjoin([token_to_byte('IF'), self.condition])
    def walk(self):
        """Walk method for ElseIf nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.condition)
        yield (Walk.LEAVING, self)

class Else(Statement):
    """Else AST node"""
    def __init__(self, parent, statements, after):
        self.parent = parent
        self.statements = statements
        self.after = after
    def __str__(self):
        """Return a string representation of a Else node"""
        return f"ELSE {': '.join(str(stmt) for stmt in self.statements)}{sjoin(self.after)}"
    def __bytes__(self):
        """Return the in-memory representation of a Else node"""
        return token_to_byte('ELSE') + bjoin([bjoin(self.statements, sep=b':'), bjoin(self.after)])
    def walk(self):
        """Walk method for Else nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.statements)
        yield from walk(self.after)
        yield (Walk.LEAVING, self)

class EndIf(Statement):
    """EndIf AST node"""
    def __init__(self, parent, keyword):
        self.parent = parent
        self.keyword = keyword
    def __str__(self):
        """Return a string representation of a EndIf node"""
        return f"ENDIF"
    def __bytes__(self):
        """Return the in-memory representation of a EndIf node"""
        return token_to_byte('ENDIF') + bjoin([])
    def walk(self):
        """Walk method for EndIf nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.keyword)
        yield (Walk.LEAVING, self)

class Repeat(Statement):
    """Repeat AST node"""
    def __init__(self, parent, keyword):
        self.parent = parent
        self.keyword = keyword
    def __str__(self):
        """Return a string representation of a Repeat node"""
        return f"REPEAT"
    def __bytes__(self):
        """Return the in-memory representation of a Repeat node"""
        return token_to_byte('REPEAT') + bjoin([])
    def walk(self):
        """Walk method for Repeat nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.keyword)
        yield (Walk.LEAVING, self)

class Until(Statement):
    """Until AST node"""
    def __init__(self, parent, condition):
        self.parent = parent
        self.condition = condition
    def __str__(self):
        """Return a string representation of a Until node"""
        return f"REPEAT UNTIL {self.condition}"
    def __bytes__(self):
        """Return the in-memory representation of a Until node"""
        return token_to_byte('REPEAT') + bjoin([token_to_byte('UNTIL'), self.condition])
    def walk(self):
        """Walk method for Until nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.condition)
        yield (Walk.LEAVING, self)

class While(Statement):
    """While AST node"""
    def __init__(self, parent, condition):
        self.parent = parent
        self.condition = condition
    def __str__(self):
        """Return a string representation of a While node"""
        return f"WHILE {self.condition}"
    def __bytes__(self):
        """Return the in-memory representation of a While node"""
        return token_to_byte('WHILE') + bjoin([self.condition])
    def walk(self):
        """Walk method for While nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.condition)
        yield (Walk.LEAVING, self)

class Exit(Statement):
    """Exit AST node"""
    def __init__(self, parent, exits, line):
        self.parent = parent
        self.exits = exits
        self.line = line
    def __str__(self):
        """Return a string representation of a Exit node"""
        return f"{':'.join(self.exits)}{' '+str(self.line) if self.line else ''}"
    def __bytes__(self):
        """Return the in-memory representation of a Exit node"""
        return token_to_byte('EXIT') + bjoin([b':',token_to_byte('EXIT')] * (len(self.exits) - 1) + ([self.line] if self.line else []))
    def walk(self):
        """Walk method for Exit nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.exits)
        yield from walk(self.line)
        yield (Walk.LEAVING, self)

class ContinueLoop(Statement):
    """ContinueLoop AST node"""
    def __init__(self, parent, nexts):
        self.parent = parent
        self.nexts = nexts
    def __str__(self):
        """Return a string representation of a ContinueLoop node"""
        return f"GOTO {' '.join(self.nexts)}"
    def __bytes__(self):
        """Return the in-memory representation of a ContinueLoop node"""
        return token_to_byte('CONTINUELOOP') + bjoin([token_to_byte(n) for n in self.nexts])
    def walk(self):
        """Walk method for ContinueLoop nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.nexts)
        yield (Walk.LEAVING, self)

class Dim(Statement):
    """Dim AST node"""
    def __init__(self, parent, name, dims):
        self.parent = parent
        self.name = name
        self.dims = dims
    def __str__(self):
        """Return a string representation of a Dim node"""
        return f"DIM {self.name}({', '.join(str(d) for d in self.dims)})"
    def __bytes__(self):
        """Return the in-memory representation of a Dim node"""
        return token_to_byte('DIM') + bjoin([self.name, b'(', bjoin(self.dims, sep=b','), b')'])
    def walk(self):
        """Walk method for Dim nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.name)
        yield from walk(self.dims)
        yield (Walk.LEAVING, self)

class Data(Statement):
    """Data AST node"""
    def __init__(self, parent, items):
        self.parent = parent
        self.items = items
    def __str__(self):
        """Return a string representation of a Data node"""
        return f"DATA {', '.join(str(v) for v in self.items)}"
    def __bytes__(self):
        """Return the in-memory representation of a Data node"""
        return token_to_byte('DATA') + bjoin([bjoin(self.items,sep=b',')])
    def walk(self):
        """Walk method for Data nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.items)
        yield (Walk.LEAVING, self)

class Read(Statement):
    """Read AST node"""
    def __init__(self, parent, vars):
        self.parent = parent
        self.vars = vars
    def __str__(self):
        """Return a string representation of a Read node"""
        return f"READ {', '.join(str(v) for v in self.vars)}"
    def __bytes__(self):
        """Return the in-memory representation of a Read node"""
        return token_to_byte('READ') + bjoin([bjoin(self.vars, sep=b',')])
    def walk(self):
        """Walk method for Read nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.vars)
        yield (Walk.LEAVING, self)

class DefFn(Statement):
    """DefFn AST node"""
    def __init__(self, parent, name, params, expr):
        self.parent = parent
        self.name = name
        self.params = params
        self.expr = expr
    def __str__(self):
        """Return a string representation of a DefFn node"""
        return f"DEF FN {self.name}({', '.join(str(p) for p in self.params)}) = {self.expr}"
    def __bytes__(self):
        """Return the in-memory representation of a DefFn node"""
        return token_to_byte('DEF FN') + bjoin([self.name, b'(', bjoin([sane_bytes(p) + bytes((14,0,0,0,0,0)) for p in self.params], sep=b','), b')=', self.expr])
    def walk(self):
        """Walk method for DefFn nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.name)
        yield from walk(self.params)
        yield from walk(self.expr)
        yield (Walk.LEAVING, self)

class PrintItem(ASTNode):
    """PrintItem AST node"""
    def __init__(self, value, sep):
        self.value = value
        self.sep = sep
    def __str__(self):
        """Return a string representation of a PrintItem node"""
        return f"{nstr(self.value)}{nstr(self.sep)}"
    def __bytes__(self):
        """Return the in-memory representation of a PrintItem node"""
        return bjoin([self.value, self.sep])
    def walk(self):
        """Walk method for PrintItem nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.value)
        yield from walk(self.sep)
        yield (Walk.LEAVING, self)

class Rem(Statement):
    """Rem AST node"""
    def __init__(self, parent, comment):
        self.parent = parent
        self.comment = comment
    def __str__(self):
        """Return a string representation of a Rem node"""
        return f"REM {self.comment}"
    def __bytes__(self):
        """Return the in-memory representation of a Rem node"""
        return token_to_byte('REM') + bjoin([echars_to_bytes(self.comment)])
    def walk(self):
        """Walk method for Rem nodes"""
        yield (Walk.VISITING, self)

class Label(ASTNode):
    """Label AST node"""
    def __init__(self, parent, name):
        self.parent = parent
        self.name = name[1:]
    def __str__(self):
        """Return a string representation of a Label node"""
        return f"@{self.name}"
    def __bytes__(self):
        """Return the in-memory representation of a Label node"""
        return token_to_byte('LABEL') + bjoin([self.name])
    def walk(self):
        """Walk method for Label nodes"""
        yield (Walk.VISITING, self)

class Variable(Expression):
    """Variable AST node"""
    def __init__(self, parent, name):
        self.parent = parent
        self.name = name.replace(' ', '').replace('\t', '')
    def __str__(self):
        """Return a string representation of a Variable node"""
        return f"{self.name}"
    def __bytes__(self):
        """Return the in-memory representation of a Variable node"""
        return bjoin([self.name])
    def walk(self):
        """Walk method for Variable nodes"""
        yield (Walk.VISITING, self)

class Number(Expression):
    """Number AST node"""
    def __init__(self, parent, value):
        self.parent = parent
        self.value = value
    def __str__(self):
        """Return a string representation of a Number node"""
        return f"{self.value}"
    def __bytes__(self):
        """Return the in-memory representation of a Number node"""
        return bjoin([num_to_bytes(self.value)])
    def walk(self):
        """Walk method for Number nodes"""
        yield (Walk.VISITING, self)

class String(Expression):
    """String AST node"""
    def __init__(self, parent, value):
        self.parent = parent
        self.value = value[1:-1]
    def __str__(self):
        """Return a string representation of a String node"""
        return f"{speccy_quote(self.value)}"
    def __bytes__(self):
        """Return the in-memory representation of a String node"""
        return bjoin([strlit_to_bytes(self.value)])
    def walk(self):
        """Walk method for String nodes"""
        yield (Walk.VISITING, self)

class BinValue(ASTNode):
    """BinValue AST node"""
    def __init__(self, parent, digits):
        self.parent = parent
        self.digits = digits
    def __str__(self):
        """Return a string representation of a BinValue node"""
        return f"BIN {self.digits}"
    def __bytes__(self):
        """Return the in-memory representation of a BinValue node"""
        return token_to_byte('BIN') + bjoin([self.digits])
    def walk(self):
        """Walk method for BinValue nodes"""
        yield (Walk.VISITING, self)

class ArrayRef(Expression):
    """ArrayRef AST node"""
    def __init__(self, parent, name, subscripts):
        self.parent = parent
        self.name = name
        self.subscripts = subscripts
    def __str__(self):
        """Return a string representation of a ArrayRef node"""
        return f"{self.name}({', '.join(str(s) for s in self.subscripts)})"
    def __bytes__(self):
        """Return the in-memory representation of a ArrayRef node"""
        return bjoin([self.name, b'(', bjoin(self.subscripts, sep=b','), b')'])
    def walk(self):
        """Walk method for ArrayRef nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.name)
        yield from walk(self.subscripts)
        yield (Walk.LEAVING, self)

class Fn(Expression):
    """Fn AST node"""
    def __init__(self, parent, name, args):
        self.parent = parent
        self.name = name
        self.args = args
    def __str__(self):
        """Return a string representation of a Fn node"""
        return f"FN {self.name}({', '.join(str(arg) for arg in self.args)})"
    def __bytes__(self):
        """Return the in-memory representation of a Fn node"""
        return token_to_byte('FN') + bjoin([self.name, b'(', bjoin(self.args, sep=b','), b')'])
    def walk(self):
        """Walk method for Fn nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.name)
        yield from walk(self.args)
        yield (Walk.LEAVING, self)

class InputExpr(ASTNode):
    """InputExpr AST node"""
    def __init__(self, parent, expr):
        self.parent = parent
        self.expr = expr
    def walk(self):
        """Walk method for InputExpr nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.expr)
        yield (Walk.LEAVING, self)

    def needs_parens(self):
        return not (isinstance(self.expr, String) or isinstance(self.expr, Number))
    def __str__(self):
        return f"({self.expr})" if self.needs_parens() else str(self.expr)
    def __bytes__(self):
        bexpr = bytes(self.expr)
        if self.needs_parens():
            bexpr = b'(' + bexpr + b')'
        return bexpr

class Slice(ASTNode):
    """Slice AST node"""
    def __init__(self, parent, min, max):
        self.parent = parent
        self.min = min
        self.max = max
    def walk(self):
        """Walk method for Slice nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.min)
        yield from walk(self.max)
        yield (Walk.LEAVING, self)

    def __str__(self):
        if self.min is None:
            return f"TO {self.max}"
        if self.max is None:
            return f"{self.min} TO"
        return f"{self.min} TO {self.max}"
    def __bytes__(self):
        bto = token_to_byte('TO')
        if self.min is None:
            return bjoin([bto, self.max])
        if self.max is None:
            return bjoin([self.min, bto])
        return bjoin([self.min, bto, self.max])

class StringSubscript(Expression):
    """StringSubscript AST node"""
    def __init__(self, expr, index):
        self.expr = expr
        self.index = index
    def __str__(self):
        """Return a string representation of a StringSubscript node"""
        return f"{self.expr if isinstance(self.expr, String) else '(' + str(self.expr) + ')'}({self.index})"
    def __bytes__(self):
        """Return the in-memory representation of a StringSubscript node"""
        return bjoin(([self.expr] if isinstance(self.expr, String) else [b'(', self.expr, b')']) + [b'(', self.index, b')'])
    def walk(self):
        """Walk method for StringSubscript nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.expr)
        yield from walk(self.index)
        yield (Walk.LEAVING, self)

class BinaryOp(Expression):
    """BinaryOp AST node"""
    def __init__(self, op, lhs, rhs):
        self.op = op
        self.lhs = lhs
        self.rhs = rhs
    def walk(self):
        """Walk method for BinaryOp nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.op)
        yield from walk(self.lhs)
        yield from walk(self.rhs)
        yield (Walk.LEAVING, self)

    def __str__(self):
        # Format left side
        lhs_str = str(self.lhs)
        if (isinstance(self.lhs, BinaryOp) or isinstance(self.lhs, UnaryOp)) and needs_parens(self.lhs, self, False):
            lhs_str = f"({lhs_str})"
            
        # Format right side
        rhs_str = str(self.rhs)
        if isinstance(self.rhs, BinaryOp) and needs_parens(self.rhs, self, True):
            rhs_str = f"({rhs_str})"
            
        return f"{lhs_str} {self.op} {rhs_str}"
    def __bytes__(self):
        bop = token_to_byte(self.op)
        # Format left side
        blhs = bytes(self.lhs)
        if isinstance(self.lhs, BinaryOp) and needs_parens(self.lhs, self, False):
            blhs = b'(' + blhs + b')'
        
        # Format right side
        brhs = bytes(self.rhs)
        if isinstance(self.rhs, BinaryOp) and needs_parens(self.rhs, self, True):
            brhs = b'(' + brhs + b')'
        
        return blhs + bop + brhs

class UnaryOp(Expression):
    """UnaryOp AST node"""
    def __init__(self, parent, op, expr):
        self.parent = parent
        self.op = op
        self.expr = expr
    def walk(self):
        """Walk method for UnaryOp nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.op)
        yield from walk(self.expr)
        yield (Walk.LEAVING, self)

    def __str__(self):
        expr_str = str(self.expr)
        if isinstance(self.expr, BinaryOp) and needs_parens(self.expr, self, False):
            expr_str = f"({expr_str})"
        # whether to add a space after the operator depends on whehter it is a symbol
        # like - or a keyword like NOT
        spacer = ' ' if self.op.isalpha() else ''
        return f"{self.op}{spacer}{expr_str}"
    def __bytes__(self):
        bop = token_to_byte(self.op)
        bexpr = bytes(self.expr)
        if isinstance(self.expr, BinaryOp) and needs_parens(self.expr, self, False):
            bexpr = b'(' + bexpr + b')'
        return bop + bexpr


class Not(UnaryOp):
    pass

class Neg(UnaryOp):
    pass

class ChanSpec(ASTNode):
    """ChanSpec AST node"""
    def __init__(self, parent, chan):
        self.parent = parent
        self.chan = chan
    def __str__(self):
        """Return a string representation of a ChanSpec node"""
        return f"#{self.chan}"
    def __bytes__(self):
        """Return the in-memory representation of a ChanSpec node"""
        return bjoin([b'#', self.chan])
    def walk(self):
        """Walk method for ChanSpec nodes"""
        if (yield (Walk.ENTERING, self)) == Walk.SKIP: return
        yield from walk(self.chan)
        yield (Walk.LEAVING, self)

class Colons(Statement):
    """Colons AST node"""
    def __init__(self, parent, colons):
        self.parent = parent
        self.colons = colons
    def __str__(self):
        """Return a string representation of a Colons node"""
        return f"{self.colons}"
    def __bytes__(self):
        """Return the in-memory representation of a Colons node"""
        return bjoin([safe_bytes(self.colons)])
    def walk(self):
        """Walk method for Colons nodes"""
        yield (Walk.VISITING, self)

