Beautified code and moved visitor methods to camelCase naming convention

This commit is contained in:
Nocturn9x 2022-03-01 14:55:27 +01:00
parent e92e40b6ea
commit 81e95feea4
3 changed files with 96 additions and 86 deletions

View File

@ -48,8 +48,8 @@ type
of NodeKind.Ident: of NodeKind.Ident:
name*: string name*: string
of NodeKind.Call: of NodeKind.Call:
arguments*: seq[AstNode] arguments*: seq[AstNode]
function*: AstNode function*: AstNode
NodeVisitor* = ref object NodeVisitor* = ref object
# A node visitor object # A node visitor object
@ -57,21 +57,20 @@ type
proc `$`*(self: AstNode): string = proc `$`*(self: AstNode): string =
## Stringifies an AST node ## Stringifies an AST node
case self.kind: case self.kind:
of NodeKind.Grouping: of NodeKind.Grouping:
result = &"Grouping({self.expr})" result = &"Grouping({self.expr})"
of NodeKind.Unary: of NodeKind.Unary:
result = &"Unary({$self.unOp.kind}, {$self.operand})" result = &"Unary({$self.unOp.kind}, {$self.operand})"
of NodeKind.Binary: of NodeKind.Binary:
result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})" result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})"
of NodeKind.Integer: of NodeKind.Integer:
result = &"Integer({$int(self.value)})" result = &"Integer({$int(self.value)})"
of NodeKind.Float: of NodeKind.Float:
result = &"Float({$self.value})" result = &"Float({$self.value})"
of NodeKind.Call: of NodeKind.Call:
result = &"Call({self.function.name}, {self.arguments.join(\", \")})" result = &"Call({self.function.name}, {self.arguments.join(\", \")})"
of NodeKind.Ident: of NodeKind.Ident:
result = &"Identifier({self.name})" result = &"Identifier({self.name})"
proc initNodeVisitor*(): NodeVisitor = proc initNodeVisitor*(): NodeVisitor =
@ -93,21 +92,23 @@ template handleBinary(left, right: AstNode, operator: untyped): AstNode =
template ensureNonZero(node: AstNode) = template ensureNonZero(node: AstNode) =
## Handy template to ensure that a given node's value is not 0 ## Handy template to ensure that a given node's value is not 0
if node.value == 0.0: if node.value == 0.0:
case node.kind: case node.kind:
of NodeKind.Float, NodeKind.Integer: of NodeKind.Float, NodeKind.Integer:
raise newException(MathError, "value can't be zero") raise newException(MathError, "value can't be zero")
else: else:
raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") raise newException(CatchableError,
&"invalid node kind '{node.kind}' for ensureNonZero")
template ensurePositive(node: AstNode) = template ensurePositive(node: AstNode) =
## Handy template to ensure that a given node's value is positive ## Handy template to ensure that a given node's value is positive
if node.value < 0.0: if node.value < 0.0:
case node.kind: case node.kind:
of NodeKind.Float, NodeKind.Integer: of NodeKind.Float, NodeKind.Integer:
raise newException(MathError, "value must be positive") raise newException(MathError, "value must be positive")
else: else:
raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") raise newException(CatchableError,
&"invalid node kind '{node.kind}' for ensureNonZero")
template ensureIntegers(left, right: AstNode) = template ensureIntegers(left, right: AstNode) =
@ -116,7 +117,7 @@ template ensureIntegers(left, right: AstNode) =
raise newException(MathError, "an integer is required") raise newException(MathError, "an integer is required")
template callFunction(fun: untyped, args: varargs[untyped]) = template callFunction(fun: untyped, args: varargs[untyped]) =
## Handy template to call functions ## Handy template to call functions
let r = fun(args) let r = fun(args)
if r is float: if r is float:
@ -126,11 +127,11 @@ template callFunction(fun: untyped, args: varargs[untyped]) =
# Forward declarations # Forward declarations
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode
proc visit_unary(self: NodeVisitor, node: AstNode): AstNode proc visitUnary(self: NodeVisitor, node: AstNode): AstNode
proc visit_binary(self: NodeVisitor, node: AstNode): AstNode proc visitBinary(self: NodeVisitor, node: AstNode): AstNode
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode
proc visit_call(self: NodeVisitor, node: AstNode): AstNode proc visitCall(self: NodeVisitor, node: AstNode): AstNode
proc accept(self: AstNode, visitor: NodeVisitor): AstNode = proc accept(self: AstNode, visitor: NodeVisitor): AstNode =
@ -138,15 +139,15 @@ proc accept(self: AstNode, visitor: NodeVisitor): AstNode =
## for our AST visitor ## for our AST visitor
case self.kind: case self.kind:
of NodeKind.Integer, NodeKind.Float, NodeKind.Ident: of NodeKind.Integer, NodeKind.Float, NodeKind.Ident:
result = visitor.visit_literal(self) result = visitor.visitLiteral(self)
of NodeKind.Binary: of NodeKind.Binary:
result = visitor.visit_binary(self) result = visitor.visitBinary(self)
of NodeKind.Unary: of NodeKind.Unary:
result = visitor.visit_unary(self) result = visitor.visitUnary(self)
of NodeKind.Grouping: of NodeKind.Grouping:
result = visitor.visit_grouping(self) result = visitor.visitGrouping(self)
of NodeKind.Call: of NodeKind.Call:
result = visitor.visit_call(self) result = visitor.visitCall(self)
proc eval*(self: NodeVisitor, node: AstNode): AstNode = proc eval*(self: NodeVisitor, node: AstNode): AstNode =
@ -154,12 +155,12 @@ proc eval*(self: NodeVisitor, node: AstNode): AstNode =
result = node.accept(self) result = node.accept(self)
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode = proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode =
## Visits a literal AST node (such as integers) ## Visits a literal AST node (such as integers)
result = node # Not that we can do anything else after all, lol result = node # Not that we can do anything else after all, lol
proc visit_call(self: NodeVisitor, node: AstNode): AstNode = proc visitCall(self: NodeVisitor, node: AstNode): AstNode =
## Visits function call expressions ## Visits function call expressions
case node.function.name: case node.function.name:
of "sin": of "sin":
@ -175,7 +176,8 @@ proc visit_call(self: NodeVisitor, node: AstNode): AstNode =
of "log": of "log":
let arg = self.eval(node.arguments[0]) let arg = self.eval(node.arguments[0])
ensureNonZero(arg) ensureNonZero(arg)
callFunction(log, self.eval(node.arguments[0]).value, self.eval(node.arguments[1]).value) callFunction(log, self.eval(node.arguments[0]).value, self.eval(
node.arguments[1]).value)
of "ln": of "ln":
let arg = self.eval(node.arguments[0]) let arg = self.eval(node.arguments[0])
ensureNonZero(arg) ensureNonZero(arg)
@ -209,10 +211,11 @@ proc visit_call(self: NodeVisitor, node: AstNode): AstNode =
of "arccosh": of "arccosh":
callFunction(arccosh, self.eval(node.arguments[0]).value) callFunction(arccosh, self.eval(node.arguments[0]).value)
of "hypot": of "hypot":
callFunction(hypot, self.eval(node.arguments[0]).value, self.eval(node.arguments[1]).value) callFunction(hypot, self.eval(node.arguments[0]).value, self.eval(
node.arguments[1]).value)
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode = proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode =
## Visits grouping (i.e. parenthesized) expressions. Parentheses ## Visits grouping (i.e. parenthesized) expressions. Parentheses
## have no other meaning than to allow a lower-precedence expression ## have no other meaning than to allow a lower-precedence expression
## where a higher-precedence one is expected so that 2 * (3 + 1) is ## where a higher-precedence one is expected so that 2 * (3 + 1) is
@ -220,7 +223,7 @@ proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode =
return self.eval(node.expr) return self.eval(node.expr)
proc visit_binary(self: NodeVisitor, node: AstNode): AstNode = proc visitBinary(self: NodeVisitor, node: AstNode): AstNode =
## Visits a binary AST node and evaluates it ## Visits a binary AST node and evaluates it
let right = self.eval(node.right) let right = self.eval(node.right)
let left = self.eval(node.left) let left = self.eval(node.left)
@ -236,16 +239,17 @@ proc visit_binary(self: NodeVisitor, node: AstNode): AstNode =
# Modulo is a bit special since we must have integers # Modulo is a bit special since we must have integers
ensureIntegers(left, right) ensureIntegers(left, right)
ensureNonZero(right) ensureNonZero(right)
result = AstNode(kind: NodeKind.Integer, value: float(int(left.value) mod int(right.value))) result = AstNode(kind: NodeKind.Integer, value: float(int(
left.value) mod int(right.value)))
of TokenType.Exp: of TokenType.Exp:
result = handleBinary(left, right, pow) result = handleBinary(left, right, pow)
of TokenType.Mul: of TokenType.Mul:
result = handleBinary(left, right, `*`) result = handleBinary(left, right, `*`)
else: else:
discard # Unreachable discard # Unreachable
proc visit_unary(self: NodeVisitor, node: AstNode): AstNode = proc visitUnary(self: NodeVisitor, node: AstNode): AstNode =
## Visits unary expressions and evaluates them ## Visits unary expressions and evaluates them
let expr = self.eval(node.operand) let expr = self.eval(node.operand)
case node.unOp.kind: case node.unOp.kind:
@ -256,7 +260,7 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode =
of NodeKind.Integer: of NodeKind.Integer:
result = AstNode(kind: NodeKind.Integer, value: -expr.value) result = AstNode(kind: NodeKind.Integer, value: -expr.value)
else: else:
discard # Unreachable discard # Unreachable
of TokenType.Plus: of TokenType.Plus:
case expr.kind: case expr.kind:
of NodeKind.Float: of NodeKind.Float:
@ -264,6 +268,6 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode =
of NodeKind.Integer: of NodeKind.Integer:
result = AstNode(kind: NodeKind.Integer, value: expr.value) result = AstNode(kind: NodeKind.Integer, value: expr.value)
else: else:
discard # Unreachable discard # Unreachable
else: else:
discard # Unreachable discard # Unreachable

View File

@ -29,9 +29,9 @@ const tokens = to_table({
'*': TokenType.Mul, '/': TokenType.Div, '*': TokenType.Mul, '/': TokenType.Div,
'%': TokenType.Modulo, '^': TokenType.Exp, '%': TokenType.Modulo, '^': TokenType.Exp,
',': TokenType.Comma}) ',': TokenType.Comma})
# All the identifiers and constants (such as PI) # All the identifiers and constants (such as PI)
# Since they're constant we don't even need to bother adding another # Since they're constant we don't even need to bother adding another
# AST node kind, we can just map the name to a float literal ;) # AST node kind, we can just map the name to a float literal ;)
const constants = to_table({ const constants = to_table({
"pi": Token(kind: TokenType.Float, lexeme: "3.141592653589793"), "pi": Token(kind: TokenType.Float, lexeme: "3.141592653589793"),
"e": Token(kind: TokenType.Float, lexeme: "2.718281828459045"), "e": Token(kind: TokenType.Float, lexeme: "2.718281828459045"),
@ -91,7 +91,7 @@ func createToken(self: Lexer, tokenType: TokenType): Token =
## Creates a token object for later use in the parser ## Creates a token object for later use in the parser
result = Token(kind: tokenType, result = Token(kind: tokenType,
lexeme: self.source[self.start..<self.current], lexeme: self.source[self.start..<self.current],
) )
proc parseNumber(self: Lexer) = proc parseNumber(self: Lexer) =
@ -140,12 +140,12 @@ proc scanToken(self: Lexer) =
## called iteratively until the source ## called iteratively until the source
## string reaches EOF ## string reaches EOF
var single = self.step() var single = self.step()
if single in [' ', '\t', '\r', '\n']: # We skip whitespaces, tabs and other stuff if single in [' ', '\t', '\r', '\n']: # We skip whitespaces, tabs and other stuff
return return
elif single.isDigit(): elif single.isDigit():
self.parseNumber() self.parseNumber()
elif single in tokens: elif single in tokens:
self.tokens.add(self.createToken(tokens[single])) self.tokens.add(self.createToken(tokens[single]))
elif single.isAlphanumeric() or single == '_': elif single.isAlphanumeric() or single == '_':
self.parseIdentifier() self.parseIdentifier()
else: else:

View File

@ -35,12 +35,13 @@ type
const arities = to_table({"sin": 1, "cos": 1, "tan": 1, "cosh": 1, const arities = to_table({"sin": 1, "cos": 1, "tan": 1, "cosh": 1,
"tanh": 1, "sinh": 1, "arccos": 1, "arcsin": 1, "tanh": 1, "sinh": 1, "arccos": 1, "arcsin": 1,
"arctan": 1, "log": 2, "log10": 1, "ln": 1, "log2": 1, "arctan": 1, "log": 2, "log10": 1, "ln": 1, "log2": 1,
"hypot": 2, "sqrt": 1, "cbrt": 2, "arctanh": 1, "arcsinh": 1, "hypot": 2, "sqrt": 1, "cbrt": 2, "arctanh": 1,
"arcsinh": 1,
"arccosh": 1 "arccosh": 1
}) })
proc initParser*(): Parser = proc initParser*(): Parser =
new(result) new(result)
result.current = 0 result.current = 0
result.tokens = @[] result.tokens = @[]
@ -51,16 +52,16 @@ proc initParser*(): Parser =
proc binary(self: Parser): AstNode proc binary(self: Parser): AstNode
template endOfFile: Token = template endOfFile: Token =
## Creates an EOF token -- utility template ## Creates an EOF token -- utility template
Token(lexeme: "", kind: TokenType.Eof) Token(lexeme: "", kind: TokenType.Eof)
func done(self: Parser): bool = func done(self: Parser): bool =
result = self.current >= self.tokens.high() result = self.current >= self.tokens.high()
proc peek(self: Parser): Token = proc peek(self: Parser): Token =
## Peeks into the tokens list or ## Peeks into the tokens list or
## returns an EOF token if we're at ## returns an EOF token if we're at
## the end of the input ## the end of the input
@ -70,9 +71,9 @@ proc peek(self: Parser): Token =
result = endOfFile result = endOfFile
proc step(self: Parser): Token = proc step(self: Parser): Token =
## Consumes a token from the input and ## Consumes a token from the input and
## steps forward or returns an EOF token ## steps forward or returns an EOF token
## if we're at the end of the input ## if we're at the end of the input
if not self.done(): if not self.done():
result = self.peek() result = self.peek()
@ -81,19 +82,19 @@ proc step(self: Parser): Token =
result = endOfFile result = endOfFile
proc previous(self: Parser): Token = proc previous(self: Parser): Token =
## Returns the previously consumed ## Returns the previously consumed
## token ## token
result = self.tokens[self.current - 1] result = self.tokens[self.current - 1]
proc check(self: Parser, kind: TokenType): bool = proc check(self: Parser, kind: TokenType): bool =
## Returns true if the current token matches ## Returns true if the current token matches
## the given type ## the given type
result = self.peek().kind == kind result = self.peek().kind == kind
proc match(self: Parser, kind: TokenType): bool = proc match(self: Parser, kind: TokenType): bool =
## Checks if the current token matches the ## Checks if the current token matches the
## given type and consumes it if it does, returns ## given type and consumes it if it does, returns
## false otherwise. True is returned if the ## false otherwise. True is returned if the
@ -105,7 +106,7 @@ proc match(self: Parser, kind: TokenType): bool =
result = false result = false
proc match(self: Parser, kinds: varargs[TokenType]): bool = proc match(self: Parser, kinds: varargs[TokenType]): bool =
## Checks if the current token matches any of the ## Checks if the current token matches any of the
## given type(s) and consumes it if it does, returns ## given type(s) and consumes it if it does, returns
## false otherwise. True is returned at ## false otherwise. True is returned at
@ -116,20 +117,20 @@ proc match(self: Parser, kinds: varargs[TokenType]): bool =
result = false result = false
proc error(self: Parser, message: string) = proc error(self: Parser, message: string) =
## Raises a parsing error with the given message ## Raises a parsing error with the given message
raise newException(ParseError, message) raise newException(ParseError, message)
proc expect(self: Parser, kind: TokenType, message: string) = proc expect(self: Parser, kind: TokenType, message: string) =
## Checks if the current token matches the given type ## Checks if the current token matches the given type
## and consumes it if it does, raises an error ## and consumes it if it does, raises an error
## with the given message otherwise. ## with the given message otherwise.
if not self.match(kind): if not self.match(kind):
self.error(message) self.error(message)
proc primary(self: Parser): AstNode = proc primary(self: Parser): AstNode =
## Parses primary expressions ## Parses primary expressions
let value = self.previous() let value = self.previous()
case value.kind: case value.kind:
@ -151,7 +152,7 @@ proc primary(self: Parser): AstNode =
self.error(&"invalid token of kind '{value.kind}' in primary expression") self.error(&"invalid token of kind '{value.kind}' in primary expression")
proc call(self: Parser): AstNode = proc call(self: Parser): AstNode =
## Parses function calls such as sin(2) ## Parses function calls such as sin(2)
var expression = self.primary() var expression = self.primary()
if self.match(TokenType.LeftParen): if self.match(TokenType.LeftParen):
@ -160,7 +161,8 @@ proc call(self: Parser): AstNode =
arguments.add(self.binary()) arguments.add(self.binary())
while self.match(TokenType.Comma): while self.match(TokenType.Comma):
arguments.add(self.binary()) arguments.add(self.binary())
result = AstNode(kind: NodeKind.Call, arguments: arguments, function: expression) result = AstNode(kind: NodeKind.Call, arguments: arguments,
function: expression)
if expression.kind != NodeKind.Ident: if expression.kind != NodeKind.Ident:
self.error(&"can't call object of type {expression.kind}") self.error(&"can't call object of type {expression.kind}")
if len(arguments) != arities[expression.name]: if len(arguments) != arities[expression.name]:
@ -171,50 +173,54 @@ proc call(self: Parser): AstNode =
proc unary(self: Parser): AstNode = proc unary(self: Parser): AstNode =
## Parses unary expressions such as -1 ## Parses unary expressions such as -1
case self.step().kind: case self.step().kind:
of TokenType.Minus, TokenType.Plus: of TokenType.Minus, TokenType.Plus:
result = AstNode(kind: NodeKind.Unary, unOp: self.previous(), operand: self.unary()) result = AstNode(kind: NodeKind.Unary, unOp: self.previous(),
operand: self.unary())
else: else:
result = self.call() result = self.call()
proc pow(self: Parser): AstNode =
proc pow(self: Parser): AstNode =
## Parses exponentiation ## Parses exponentiation
result = self.unary() result = self.unary()
var operator: Token var operator: Token
while self.match(TokenType.Exp): while self.match(TokenType.Exp):
operator = self.previous() operator = self.previous()
result = AstNode(kind: NodeKind.Binary, left: result, right: self.unary(), binOp: operator) result = AstNode(kind: NodeKind.Binary, left: result, right: self.unary(),
binOp: operator)
proc mul(self: Parser): AstNode = proc mul(self: Parser): AstNode =
## Parses divisions (including modulo) and ## Parses divisions (including modulo) and
## multiplications ## multiplications
result = self.pow() result = self.pow()
var operator: Token var operator: Token
while self.match(TokenType.Div, TokenType.Modulo, TokenType.Mul): while self.match(TokenType.Div, TokenType.Modulo, TokenType.Mul):
operator = self.previous() operator = self.previous()
result = AstNode(kind: NodeKind.Binary, left: result, right: self.pow(), binOp: operator) result = AstNode(kind: NodeKind.Binary, left: result, right: self.pow(),
binOp: operator)
proc addition(self: Parser): AstNode = proc addition(self: Parser): AstNode =
## Parses additions and subtractions ## Parses additions and subtractions
result = self.mul() result = self.mul()
var operator: Token var operator: Token
while self.match(TokenType.Plus, TokenType.Minus): while self.match(TokenType.Plus, TokenType.Minus):
operator = self.previous() operator = self.previous()
result = AstNode(kind: NodeKind.Binary, left: result, right: self.mul(), binOp: operator) result = AstNode(kind: NodeKind.Binary, left: result, right: self.mul(),
binOp: operator)
proc binary(self: Parser): AstNode = proc binary(self: Parser): AstNode =
## Parses binary expressions, the highest ## Parses binary expressions, the highest
## level of expression ## level of expression
result = self.addition() result = self.addition()
proc parse*(self: Parser, tokens: seq[Token]): AstNode = proc parse*(self: Parser, tokens: seq[Token]): AstNode =
## Parses a list of tokens into an AST tree ## Parses a list of tokens into an AST tree
self.tokens = tokens self.tokens = tokens
self.current = 0 self.current = 0