#!/usr/bin/env python3

from functools import wraps

# Support code

class Token:
    def __init__(self, name, value=None):
        self.name = name
        self.value = value

    def __call__(self, value):
        return Token(self.name, value)

    def __repr__(self):
        return f'<Token {self!s}>'

    def __str__(self):
        if self.value is None:
            return self.name
        else:
            return f'{self.name}({self.value})'

    def __eq__(self, x):
        if not isinstance(x, Token):
            return False

        return self.name == x.name

    def __hash__(self):
        return hash( Token, self.name )


PLUS    = Token('plus')
MINUS   = Token('minus')
STAR    = Token('star')
SLASH   = Token('slash')
ID      = Token('id')
NUM     = Token('num')
LPAREN  = Token('lparen')
RPAREN  = Token('rparen')
EOF     = Token('EOF')

# Quick sanity checks
assert PLUS == PLUS
assert PLUS != MINUS
assert PLUS in [ PLUS, MINUS, STAR ]
assert PLUS not in [ STAR, SLASH, ID ]
assert NUM(1) == NUM(2) # NB: This is as intended here


class Tokenizer:
    def __init__(self, source, tracer):
        self.tracer = tracer
        self.source = source + [ EOF ]

    def current(self):
        return self.source[0]

    def advance(self):
        self.source = self.source[1:]

    def eat(self, typ):
        cur = self.current()

        if cur == typ:
            print(f'{self.tracer.indent}{cur}')
            self.advance()
        else:
            raise Exception(f'Expected {typ}, got {self.source[0]}')

    def error(self, msg='Something went wrong'):
        raise Exception(f'Error: {msg}')

class Tracer:
    def __init__(self):
        self.stack = []
        self.depth = 0

    def __call__(self, f):
        name = f.__name__

        @wraps(f)
        def wrapped(*args, **kwargs):
            print(f'{self.indent}{name}:')
            self._push(name)

            r = f(*args, **kwargs)

            self._pop()

            return r

        return wrapped

    @property
    def indent(self):
        return '  ' * self.depth

    def _push(self, name):
        self.stack.append(name)
        self.depth += 1

    def _pop(self):
        self.stack.pop()
        self.depth -= 1


tracer = Tracer()


#
#       1 + (2 + 3) * (4 - 5) + 6
#

source = [ NUM(1), PLUS, LPAREN, NUM(2), PLUS, NUM(3), RPAREN, STAR, LPAREN, NUM(4), MINUS, NUM, RPAREN(5), PLUS, NUM(6) ]

t = Tokenizer(source, tracer)
cur = t.current
advance = t.advance
eat = t.eat
error = t.error


## Recursive descent parser

@tracer
def S():
    if cur() in [ ID, NUM, LPAREN ]:
        E()
    else: error()

@tracer
def E():
    if cur() in [ ID, NUM, LPAREN ]:
        T()
        E1()
    else: error()

@tracer
def T():
    if cur() in [ ID, NUM, LPAREN ]:
        F()
        T1()
    else: error()

@tracer
def F():
    if cur() == ID:
        eat(ID)
    elif cur() == NUM:
        eat(NUM)
    elif cur() == LPAREN:
        eat(LPAREN)
        E()
        eat(RPAREN)
    else: error()

@tracer
def E1():
    if cur() == PLUS:
        eat(PLUS)
        T()
        E1()
    elif cur() == MINUS:
        eat(MINUS)
        T()
        E1()
    elif cur() in [ RPAREN, EOF ]:
        pass
    else: error()

@tracer
def T1():
    if cur() in [ PLUS, MINUS, RPAREN, EOF ]:
        pass
    elif cur() == STAR:
        eat(STAR)
        F()
        T1()
    elif cur() == SLASH:
        eat(SLASH)
        F()
        T1()
    else: error()

S()
