# Copyright 2021 Mattia Giambirtone # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # An Abstract Syntax Tree and node visitor implementation import token import error import strformat import tables import math import strutils type NodeKind* {.pure.} = enum # An enum for all kinds of AST nodes Grouping, Unary, Binary, Integer, Float, Call, Ident AstNode* = ref object # An AST node object case kind*: NodeKind of NodeKind.Grouping: expr*: AstNode of NodeKind.Unary: unOp*: Token operand*: AstNode of NodeKind.Binary: binOp*: Token left*: AstNode right*: AstNode of NodeKind.Integer, NodeKind.Float: # The kind makes us differentiate between # floats and integers, but for our purposes # using a double precision float for everything # is just easier value*: float64 of NodeKind.Ident: name*: string of NodeKind.Call: arguments*: seq[AstNode] function*: AstNode NodeVisitor* = ref object # A node visitor object proc `$`*(self: AstNode): string = ## Stringifies an AST node case self.kind: of NodeKind.Grouping: result = &"Grouping({self.expr})" of NodeKind.Unary: result = &"Unary({$self.unOp.kind}, {$self.operand})" of NodeKind.Binary: result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})" of NodeKind.Integer: result = &"Integer({$int(self.value)})" of NodeKind.Float: result = &"Float({$self.value})" of NodeKind.Call: result = &"Call({self.function.name}, {self.arguments.join(\", \")})" of NodeKind.Ident: result = &"Identifier({self.name})" proc initNodeVisitor*(): NodeVisitor = ## Initializes a node visitor new(result) template handleBinary(left, right: AstNode, operator: untyped): AstNode = ## Handy template that avoids us the hassle of copy-pasting ## the same checks over and over again in the visitor let r = operator(left.value, right.value) if float(int(r)) == r: ## It's a whole number! AstNode(kind: NodeKind.Integer, value: r) else: AstNode(kind: NodeKind.Float, value: r) template ensureNonZero(node: AstNode) = ## Handy template to ensure that a given node's value is not 0 if node.value == 0.0: case node.kind: of NodeKind.Float, NodeKind.Integer: raise newException(MathError, "value can't be zero") else: raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") template ensurePositive(node: AstNode) = ## Handy template to ensure that a given node's value is positive if node.value < 0.0: case node.kind: of NodeKind.Float, NodeKind.Integer: raise newException(MathError, "value must be positive") else: raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") template ensureIntegers(left, right: AstNode) = ## Ensures both operands are integers if left.kind != NodeKind.Integer or right.kind != NodeKind.Integer: raise newException(MathError, "an integer is required") template callFunction(fun: untyped, args: varargs[untyped]) = ## Handy template to call functions let r = fun(args) if r is float: result = AstNode(kind: NodeKind.Float, value: r) else: result = AstNode(kind: NodeKind.Integer, value: float(r)) # Forward declarations proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode proc visitUnary(self: NodeVisitor, node: AstNode): AstNode proc visitBinary(self: NodeVisitor, node: AstNode): AstNode proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode proc visitCall(self: NodeVisitor, node: AstNode): AstNode proc accept(self: AstNode, visitor: NodeVisitor): AstNode = ## Implements the accept part of the visitor pattern ## for our AST visitor case self.kind: of NodeKind.Integer, NodeKind.Float, NodeKind.Ident: result = visitor.visitLiteral(self) of NodeKind.Binary: result = visitor.visitBinary(self) of NodeKind.Unary: result = visitor.visitUnary(self) of NodeKind.Grouping: result = visitor.visitGrouping(self) of NodeKind.Call: result = visitor.visitCall(self) proc eval*(self: NodeVisitor, node: AstNode): AstNode = ## Evaluates an AST node result = node.accept(self) proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode = ## Visits a literal AST node (such as integers) result = node # Not that we can do anything else after all, lol proc visitCall(self: NodeVisitor, node: AstNode): AstNode = ## Visits function call expressions case node.function.name: of "sin": callFunction(sin, self.eval(node.arguments[0]).value) of "cos": callFunction(cos, self.eval(node.arguments[0]).value) of "tan": callFunction(tan, self.eval(node.arguments[0]).value) of "sqrt": let arg = self.eval(node.arguments[0]) ensurePositive(arg) callFunction(sqrt, self.eval(node.arguments[0]).value) of "log": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) callFunction(log, self.eval(node.arguments[0]).value, self.eval( node.arguments[1]).value) of "ln": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) callFunction(ln, self.eval(node.arguments[0]).value) of "log2": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) callFunction(log2, self.eval(node.arguments[0]).value) of "log10": let arg = self.eval(node.arguments[0]) ensureNonZero(arg) callFunction(log10, self.eval(node.arguments[0]).value) of "cbrt": callFunction(cbrt, self.eval(node.arguments[0]).value) of "tanh": callFunction(sinh, self.eval(node.arguments[0]).value) of "sinh": callFunction(tanh, self.eval(node.arguments[0]).value) of "cosh": callFunction(cosh, self.eval(node.arguments[0]).value) of "arcsin": callFunction(arcsin, self.eval(node.arguments[0]).value) of "arccos": callFunction(arccos, self.eval(node.arguments[0]).value) of "arctan": callFunction(arctan, self.eval(node.arguments[0]).value) of "arctanh": callFunction(arctanh, self.eval(node.arguments[0]).value) of "arcsinh": callFunction(arcsinh, self.eval(node.arguments[0]).value) of "arccosh": callFunction(arccosh, self.eval(node.arguments[0]).value) of "hypot": callFunction(hypot, self.eval(node.arguments[0]).value, self.eval( node.arguments[1]).value) proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode = ## Visits grouping (i.e. parenthesized) expressions. Parentheses ## have no other meaning than to allow a lower-precedence expression ## where a higher-precedence one is expected so that 2 * (3 + 1) is ## different from 2 * 3 + 1 return self.eval(node.expr) proc visitBinary(self: NodeVisitor, node: AstNode): AstNode = ## Visits a binary AST node and evaluates it let right = self.eval(node.right) let left = self.eval(node.left) case node.binOp.kind: of TokenType.Plus: result = handleBinary(left, right, `+`) of TokenType.Minus: result = handleBinary(left, right, `-`) of TokenType.Div: ensureNonZero(right) result = handleBinary(left, right, `/`) of TokenType.Modulo: # Modulo is a bit special since we must have integers ensureIntegers(left, right) ensureNonZero(right) result = AstNode(kind: NodeKind.Integer, value: float(int( left.value) mod int(right.value))) of TokenType.Exp: result = handleBinary(left, right, pow) of TokenType.Mul: result = handleBinary(left, right, `*`) else: discard # Unreachable proc visitUnary(self: NodeVisitor, node: AstNode): AstNode = ## Visits unary expressions and evaluates them let expr = self.eval(node.operand) case node.unOp.kind: of TokenType.Minus: case expr.kind: of NodeKind.Float: result = AstNode(kind: NodeKind.Float, value: -expr.value) of NodeKind.Integer: result = AstNode(kind: NodeKind.Integer, value: -expr.value) else: discard # Unreachable of TokenType.Plus: case expr.kind: of NodeKind.Float: result = AstNode(kind: NodeKind.Float, value: expr.value) of NodeKind.Integer: result = AstNode(kind: NodeKind.Integer, value: expr.value) else: discard # Unreachable else: discard # Unreachable