From 81e95feea4519fbaf382d749d60c2946fb9685bb Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Tue, 1 Mar 2022 14:55:27 +0100 Subject: [PATCH] Beautified code and moved visitor methods to camelCase naming convention --- src/nimkalc/objects/ast.nim | 106 +++++++++++++++++---------------- src/nimkalc/parsing/lexer.nim | 12 ++-- src/nimkalc/parsing/parser.nim | 64 +++++++++++--------- 3 files changed, 96 insertions(+), 86 deletions(-) diff --git a/src/nimkalc/objects/ast.nim b/src/nimkalc/objects/ast.nim index 4fbf904..232687c 100644 --- a/src/nimkalc/objects/ast.nim +++ b/src/nimkalc/objects/ast.nim @@ -48,8 +48,8 @@ type of NodeKind.Ident: name*: string of NodeKind.Call: - arguments*: seq[AstNode] - function*: AstNode + arguments*: seq[AstNode] + function*: AstNode NodeVisitor* = ref object # A node visitor object @@ -57,21 +57,20 @@ type proc `$`*(self: AstNode): string = ## Stringifies an AST node case self.kind: - of NodeKind.Grouping: - result = &"Grouping({self.expr})" - of NodeKind.Unary: - result = &"Unary({$self.unOp.kind}, {$self.operand})" - of NodeKind.Binary: - result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})" - of NodeKind.Integer: - result = &"Integer({$int(self.value)})" - of NodeKind.Float: - result = &"Float({$self.value})" - of NodeKind.Call: - result = &"Call({self.function.name}, {self.arguments.join(\", \")})" - of NodeKind.Ident: - result = &"Identifier({self.name})" - + of NodeKind.Grouping: + result = &"Grouping({self.expr})" + of NodeKind.Unary: + result = &"Unary({$self.unOp.kind}, {$self.operand})" + of NodeKind.Binary: + result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})" + of NodeKind.Integer: + result = &"Integer({$int(self.value)})" + of NodeKind.Float: + result = &"Float({$self.value})" + of NodeKind.Call: + result = &"Call({self.function.name}, {self.arguments.join(\", \")})" + of NodeKind.Ident: + result = &"Identifier({self.name})" proc initNodeVisitor*(): NodeVisitor = @@ -93,21 +92,23 @@ template handleBinary(left, right: AstNode, operator: untyped): AstNode = template ensureNonZero(node: AstNode) = ## Handy template to ensure that a given node's value is not 0 if node.value == 0.0: - case node.kind: - of NodeKind.Float, NodeKind.Integer: - raise newException(MathError, "value can't be zero") - else: - raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") + case node.kind: + of NodeKind.Float, NodeKind.Integer: + raise newException(MathError, "value can't be zero") + else: + raise newException(CatchableError, + &"invalid node kind '{node.kind}' for ensureNonZero") template ensurePositive(node: AstNode) = ## Handy template to ensure that a given node's value is positive if node.value < 0.0: - case node.kind: - of NodeKind.Float, NodeKind.Integer: - raise newException(MathError, "value must be positive") - else: - raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") + case node.kind: + of NodeKind.Float, NodeKind.Integer: + raise newException(MathError, "value must be positive") + else: + raise newException(CatchableError, + &"invalid node kind '{node.kind}' for ensureNonZero") template ensureIntegers(left, right: AstNode) = @@ -116,7 +117,7 @@ template ensureIntegers(left, right: AstNode) = raise newException(MathError, "an integer is required") -template callFunction(fun: untyped, args: varargs[untyped]) = +template callFunction(fun: untyped, args: varargs[untyped]) = ## Handy template to call functions let r = fun(args) if r is float: @@ -126,11 +127,11 @@ template callFunction(fun: untyped, args: varargs[untyped]) = # Forward declarations -proc visit_literal(self: NodeVisitor, node: AstNode): AstNode -proc visit_unary(self: NodeVisitor, node: AstNode): AstNode -proc visit_binary(self: NodeVisitor, node: AstNode): AstNode -proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode -proc visit_call(self: NodeVisitor, node: AstNode): AstNode +proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode +proc visitUnary(self: NodeVisitor, node: AstNode): AstNode +proc visitBinary(self: NodeVisitor, node: AstNode): AstNode +proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode +proc visitCall(self: NodeVisitor, node: AstNode): AstNode proc accept(self: AstNode, visitor: NodeVisitor): AstNode = @@ -138,15 +139,15 @@ proc accept(self: AstNode, visitor: NodeVisitor): AstNode = ## for our AST visitor case self.kind: of NodeKind.Integer, NodeKind.Float, NodeKind.Ident: - result = visitor.visit_literal(self) + result = visitor.visitLiteral(self) of NodeKind.Binary: - result = visitor.visit_binary(self) + result = visitor.visitBinary(self) of NodeKind.Unary: - result = visitor.visit_unary(self) + result = visitor.visitUnary(self) of NodeKind.Grouping: - result = visitor.visit_grouping(self) + result = visitor.visitGrouping(self) of NodeKind.Call: - result = visitor.visit_call(self) + result = visitor.visitCall(self) proc eval*(self: NodeVisitor, node: AstNode): AstNode = @@ -154,12 +155,12 @@ proc eval*(self: NodeVisitor, node: AstNode): AstNode = result = node.accept(self) -proc visit_literal(self: NodeVisitor, node: AstNode): AstNode = +proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode = ## Visits a literal AST node (such as integers) - result = node # Not that we can do anything else after all, lol + result = node # Not that we can do anything else after all, lol -proc visit_call(self: NodeVisitor, node: AstNode): AstNode = +proc visitCall(self: NodeVisitor, node: AstNode): AstNode = ## Visits function call expressions case node.function.name: of "sin": @@ -175,7 +176,8 @@ proc visit_call(self: NodeVisitor, node: AstNode): AstNode = of "log": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) - callFunction(log, self.eval(node.arguments[0]).value, self.eval(node.arguments[1]).value) + callFunction(log, self.eval(node.arguments[0]).value, self.eval( + node.arguments[1]).value) of "ln": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) @@ -209,10 +211,11 @@ proc visit_call(self: NodeVisitor, node: AstNode): AstNode = of "arccosh": callFunction(arccosh, self.eval(node.arguments[0]).value) of "hypot": - callFunction(hypot, self.eval(node.arguments[0]).value, self.eval(node.arguments[1]).value) + callFunction(hypot, self.eval(node.arguments[0]).value, self.eval( + node.arguments[1]).value) -proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode = +proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode = ## Visits grouping (i.e. parenthesized) expressions. Parentheses ## have no other meaning than to allow a lower-precedence expression ## where a higher-precedence one is expected so that 2 * (3 + 1) is @@ -220,7 +223,7 @@ proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode = return self.eval(node.expr) -proc visit_binary(self: NodeVisitor, node: AstNode): AstNode = +proc visitBinary(self: NodeVisitor, node: AstNode): AstNode = ## Visits a binary AST node and evaluates it let right = self.eval(node.right) let left = self.eval(node.left) @@ -236,16 +239,17 @@ proc visit_binary(self: NodeVisitor, node: AstNode): AstNode = # Modulo is a bit special since we must have integers ensureIntegers(left, right) ensureNonZero(right) - result = AstNode(kind: NodeKind.Integer, value: float(int(left.value) mod int(right.value))) + result = AstNode(kind: NodeKind.Integer, value: float(int( + left.value) mod int(right.value))) of TokenType.Exp: result = handleBinary(left, right, pow) of TokenType.Mul: result = handleBinary(left, right, `*`) else: - discard # Unreachable + discard # Unreachable -proc visit_unary(self: NodeVisitor, node: AstNode): AstNode = +proc visitUnary(self: NodeVisitor, node: AstNode): AstNode = ## Visits unary expressions and evaluates them let expr = self.eval(node.operand) case node.unOp.kind: @@ -256,7 +260,7 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode = of NodeKind.Integer: result = AstNode(kind: NodeKind.Integer, value: -expr.value) else: - discard # Unreachable + discard # Unreachable of TokenType.Plus: case expr.kind: of NodeKind.Float: @@ -264,6 +268,6 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode = of NodeKind.Integer: result = AstNode(kind: NodeKind.Integer, value: expr.value) else: - discard # Unreachable + discard # Unreachable else: - discard # Unreachable + discard # Unreachable diff --git a/src/nimkalc/parsing/lexer.nim b/src/nimkalc/parsing/lexer.nim index e24aec5..62b0a3e 100644 --- a/src/nimkalc/parsing/lexer.nim +++ b/src/nimkalc/parsing/lexer.nim @@ -29,9 +29,9 @@ const tokens = to_table({ '*': TokenType.Mul, '/': TokenType.Div, '%': TokenType.Modulo, '^': TokenType.Exp, ',': TokenType.Comma}) -# All the identifiers and constants (such as PI) -# Since they're constant we don't even need to bother adding another -# AST node kind, we can just map the name to a float literal ;) + # All the identifiers and constants (such as PI) + # Since they're constant we don't even need to bother adding another + # AST node kind, we can just map the name to a float literal ;) const constants = to_table({ "pi": Token(kind: TokenType.Float, lexeme: "3.141592653589793"), "e": Token(kind: TokenType.Float, lexeme: "2.718281828459045"), @@ -91,7 +91,7 @@ func createToken(self: Lexer, tokenType: TokenType): Token = ## Creates a token object for later use in the parser result = Token(kind: tokenType, lexeme: self.source[self.start..= self.tokens.high() -proc peek(self: Parser): Token = +proc peek(self: Parser): Token = ## Peeks into the tokens list or ## returns an EOF token if we're at ## the end of the input @@ -70,9 +71,9 @@ proc peek(self: Parser): Token = result = endOfFile -proc step(self: Parser): Token = +proc step(self: Parser): Token = ## Consumes a token from the input and - ## steps forward or returns an EOF token + ## steps forward or returns an EOF token ## if we're at the end of the input if not self.done(): result = self.peek() @@ -81,19 +82,19 @@ proc step(self: Parser): Token = result = endOfFile -proc previous(self: Parser): Token = +proc previous(self: Parser): Token = ## Returns the previously consumed ## token result = self.tokens[self.current - 1] -proc check(self: Parser, kind: TokenType): bool = +proc check(self: Parser, kind: TokenType): bool = ## Returns true if the current token matches ## the given type result = self.peek().kind == kind -proc match(self: Parser, kind: TokenType): bool = +proc match(self: Parser, kind: TokenType): bool = ## Checks if the current token matches the ## given type and consumes it if it does, returns ## false otherwise. True is returned if the @@ -105,7 +106,7 @@ proc match(self: Parser, kind: TokenType): bool = result = false -proc match(self: Parser, kinds: varargs[TokenType]): bool = +proc match(self: Parser, kinds: varargs[TokenType]): bool = ## Checks if the current token matches any of the ## given type(s) and consumes it if it does, returns ## false otherwise. True is returned at @@ -116,20 +117,20 @@ proc match(self: Parser, kinds: varargs[TokenType]): bool = result = false -proc error(self: Parser, message: string) = +proc error(self: Parser, message: string) = ## Raises a parsing error with the given message raise newException(ParseError, message) -proc expect(self: Parser, kind: TokenType, message: string) = +proc expect(self: Parser, kind: TokenType, message: string) = ## Checks if the current token matches the given type - ## and consumes it if it does, raises an error + ## and consumes it if it does, raises an error ## with the given message otherwise. if not self.match(kind): self.error(message) -proc primary(self: Parser): AstNode = +proc primary(self: Parser): AstNode = ## Parses primary expressions let value = self.previous() case value.kind: @@ -151,7 +152,7 @@ proc primary(self: Parser): AstNode = self.error(&"invalid token of kind '{value.kind}' in primary expression") -proc call(self: Parser): AstNode = +proc call(self: Parser): AstNode = ## Parses function calls such as sin(2) var expression = self.primary() if self.match(TokenType.LeftParen): @@ -160,7 +161,8 @@ proc call(self: Parser): AstNode = arguments.add(self.binary()) while self.match(TokenType.Comma): arguments.add(self.binary()) - result = AstNode(kind: NodeKind.Call, arguments: arguments, function: expression) + result = AstNode(kind: NodeKind.Call, arguments: arguments, + function: expression) if expression.kind != NodeKind.Ident: self.error(&"can't call object of type {expression.kind}") if len(arguments) != arities[expression.name]: @@ -171,50 +173,54 @@ proc call(self: Parser): AstNode = -proc unary(self: Parser): AstNode = +proc unary(self: Parser): AstNode = ## Parses unary expressions such as -1 case self.step().kind: of TokenType.Minus, TokenType.Plus: - result = AstNode(kind: NodeKind.Unary, unOp: self.previous(), operand: self.unary()) + result = AstNode(kind: NodeKind.Unary, unOp: self.previous(), + operand: self.unary()) else: result = self.call() - -proc pow(self: Parser): AstNode = + +proc pow(self: Parser): AstNode = ## Parses exponentiation result = self.unary() var operator: Token while self.match(TokenType.Exp): operator = self.previous() - result = AstNode(kind: NodeKind.Binary, left: result, right: self.unary(), binOp: operator) + result = AstNode(kind: NodeKind.Binary, left: result, right: self.unary(), + binOp: operator) -proc mul(self: Parser): AstNode = +proc mul(self: Parser): AstNode = ## Parses divisions (including modulo) and ## multiplications result = self.pow() var operator: Token while self.match(TokenType.Div, TokenType.Modulo, TokenType.Mul): operator = self.previous() - result = AstNode(kind: NodeKind.Binary, left: result, right: self.pow(), binOp: operator) + result = AstNode(kind: NodeKind.Binary, left: result, right: self.pow(), + binOp: operator) -proc addition(self: Parser): AstNode = +proc addition(self: Parser): AstNode = ## Parses additions and subtractions result = self.mul() var operator: Token while self.match(TokenType.Plus, TokenType.Minus): operator = self.previous() - result = AstNode(kind: NodeKind.Binary, left: result, right: self.mul(), binOp: operator) + result = AstNode(kind: NodeKind.Binary, left: result, right: self.mul(), + binOp: operator) -proc binary(self: Parser): AstNode = +proc binary(self: Parser): AstNode = ## Parses binary expressions, the highest ## level of expression result = self.addition() -proc parse*(self: Parser, tokens: seq[Token]): AstNode = +proc parse*(self: Parser, tokens: seq[Token]): AstNode = ## Parses a list of tokens into an AST tree self.tokens = tokens self.current = 0