nimkalc/src/nimkalc/objects/ast.nim

274 lines
8.9 KiB
Nim

# Copyright 2021 Mattia Giambirtone
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# An Abstract Syntax Tree and node visitor implementation
import token
import error
import strformat
import tables
import math
import strutils
type
NodeKind* {.pure.} = enum
# An enum for all kinds of AST nodes
Grouping, Unary, Binary, Integer,
Float, Call, Ident
AstNode* = ref object
# An AST node object
case kind*: NodeKind
of NodeKind.Grouping:
expr*: AstNode
of NodeKind.Unary:
unOp*: Token
operand*: AstNode
of NodeKind.Binary:
binOp*: Token
left*: AstNode
right*: AstNode
of NodeKind.Integer, NodeKind.Float:
# The kind makes us differentiate between
# floats and integers, but for our purposes
# using a double precision float for everything
# is just easier
value*: float64
of NodeKind.Ident:
name*: string
of NodeKind.Call:
arguments*: seq[AstNode]
function*: AstNode
NodeVisitor* = ref object
# A node visitor object
proc `$`*(self: AstNode): string =
## Stringifies an AST node
case self.kind:
of NodeKind.Grouping:
result = &"Grouping({self.expr})"
of NodeKind.Unary:
result = &"Unary({$self.unOp.kind}, {$self.operand})"
of NodeKind.Binary:
result = &"Binary({$self.left}, {$self.binOp.kind}, {$self.right})"
of NodeKind.Integer:
result = &"Integer({$int(self.value)})"
of NodeKind.Float:
result = &"Float({$self.value})"
of NodeKind.Call:
result = &"Call({self.function.name}, {self.arguments.join(\", \")})"
of NodeKind.Ident:
result = &"Identifier({self.name})"
proc initNodeVisitor*(): NodeVisitor =
## Initializes a node visitor
new(result)
template handleBinary(left, right: AstNode, operator: untyped): AstNode =
## Handy template that avoids us the hassle of copy-pasting
## the same checks over and over again in the visitor
let r = operator(left.value, right.value)
if float(int(r)) == r:
## It's a whole number!
AstNode(kind: NodeKind.Integer, value: r)
else:
AstNode(kind: NodeKind.Float, value: r)
template ensureNonZero(node: AstNode) =
## Handy template to ensure that a given node's value is not 0
if node.value == 0.0:
case node.kind:
of NodeKind.Float, NodeKind.Integer:
raise newException(MathError, "value can't be zero")
else:
raise newException(CatchableError,
&"invalid node kind '{node.kind}' for ensureNonZero")
template ensurePositive(node: AstNode) =
## Handy template to ensure that a given node's value is positive
if node.value < 0.0:
case node.kind:
of NodeKind.Float, NodeKind.Integer:
raise newException(MathError, "value must be positive")
else:
raise newException(CatchableError,
&"invalid node kind '{node.kind}' for ensureNonZero")
template ensureIntegers(left, right: AstNode) =
## Ensures both operands are integers
if left.kind != NodeKind.Integer or right.kind != NodeKind.Integer:
raise newException(MathError, "an integer is required")
template callFunction(fun: untyped, args: varargs[untyped]) =
## Handy template to call functions
let r = fun(args)
if r is float:
result = AstNode(kind: NodeKind.Float, value: r)
else:
result = AstNode(kind: NodeKind.Integer, value: float(r))
# Forward declarations
proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode
proc visitUnary(self: NodeVisitor, node: AstNode): AstNode
proc visitBinary(self: NodeVisitor, node: AstNode): AstNode
proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode
proc visitCall(self: NodeVisitor, node: AstNode): AstNode
proc accept(self: AstNode, visitor: NodeVisitor): AstNode =
## Implements the accept part of the visitor pattern
## for our AST visitor
case self.kind:
of NodeKind.Integer, NodeKind.Float, NodeKind.Ident:
result = visitor.visitLiteral(self)
of NodeKind.Binary:
result = visitor.visitBinary(self)
of NodeKind.Unary:
result = visitor.visitUnary(self)
of NodeKind.Grouping:
result = visitor.visitGrouping(self)
of NodeKind.Call:
result = visitor.visitCall(self)
proc eval*(self: NodeVisitor, node: AstNode): AstNode =
## Evaluates an AST node
result = node.accept(self)
proc visitLiteral(self: NodeVisitor, node: AstNode): AstNode =
## Visits a literal AST node (such as integers)
result = node # Not that we can do anything else after all, lol
proc visitCall(self: NodeVisitor, node: AstNode): AstNode =
## Visits function call expressions
case node.function.name:
of "sin":
callFunction(sin, self.eval(node.arguments[0]).value)
of "cos":
callFunction(cos, self.eval(node.arguments[0]).value)
of "tan":
callFunction(tan, self.eval(node.arguments[0]).value)
of "sqrt":
let arg = self.eval(node.arguments[0])
ensurePositive(arg)
callFunction(sqrt, self.eval(node.arguments[0]).value)
of "log":
let arg = self.eval(node.arguments[0])
ensureNonZero(arg)
callFunction(log, self.eval(node.arguments[0]).value, self.eval(
node.arguments[1]).value)
of "ln":
let arg = self.eval(node.arguments[0])
ensureNonZero(arg)
callFunction(ln, self.eval(node.arguments[0]).value)
of "log2":
let arg = self.eval(node.arguments[0])
ensureNonZero(arg)
callFunction(log2, self.eval(node.arguments[0]).value)
of "log10":
let arg = self.eval(node.arguments[0])
ensureNonZero(arg)
callFunction(log10, self.eval(node.arguments[0]).value)
of "cbrt":
callFunction(cbrt, self.eval(node.arguments[0]).value)
of "tanh":
callFunction(sinh, self.eval(node.arguments[0]).value)
of "sinh":
callFunction(tanh, self.eval(node.arguments[0]).value)
of "cosh":
callFunction(cosh, self.eval(node.arguments[0]).value)
of "arcsin":
callFunction(arcsin, self.eval(node.arguments[0]).value)
of "arccos":
callFunction(arccos, self.eval(node.arguments[0]).value)
of "arctan":
callFunction(arctan, self.eval(node.arguments[0]).value)
of "arctanh":
callFunction(arctanh, self.eval(node.arguments[0]).value)
of "arcsinh":
callFunction(arcsinh, self.eval(node.arguments[0]).value)
of "arccosh":
callFunction(arccosh, self.eval(node.arguments[0]).value)
of "hypot":
callFunction(hypot, self.eval(node.arguments[0]).value, self.eval(
node.arguments[1]).value)
proc visitGrouping(self: NodeVisitor, node: AstNode): AstNode =
## Visits grouping (i.e. parenthesized) expressions. Parentheses
## have no other meaning than to allow a lower-precedence expression
## where a higher-precedence one is expected so that 2 * (3 + 1) is
## different from 2 * 3 + 1
return self.eval(node.expr)
proc visitBinary(self: NodeVisitor, node: AstNode): AstNode =
## Visits a binary AST node and evaluates it
let right = self.eval(node.right)
let left = self.eval(node.left)
case node.binOp.kind:
of TokenType.Plus:
result = handleBinary(left, right, `+`)
of TokenType.Minus:
result = handleBinary(left, right, `-`)
of TokenType.Div:
ensureNonZero(right)
result = handleBinary(left, right, `/`)
of TokenType.Modulo:
# Modulo is a bit special since we must have integers
ensureIntegers(left, right)
ensureNonZero(right)
result = AstNode(kind: NodeKind.Integer, value: float(int(
left.value) mod int(right.value)))
of TokenType.Exp:
result = handleBinary(left, right, pow)
of TokenType.Mul:
result = handleBinary(left, right, `*`)
else:
discard # Unreachable
proc visitUnary(self: NodeVisitor, node: AstNode): AstNode =
## Visits unary expressions and evaluates them
let expr = self.eval(node.operand)
case node.unOp.kind:
of TokenType.Minus:
case expr.kind:
of NodeKind.Float:
result = AstNode(kind: NodeKind.Float, value: -expr.value)
of NodeKind.Integer:
result = AstNode(kind: NodeKind.Integer, value: -expr.value)
else:
discard # Unreachable
of TokenType.Plus:
case expr.kind:
of NodeKind.Float:
result = AstNode(kind: NodeKind.Float, value: expr.value)
of NodeKind.Integer:
result = AstNode(kind: NodeKind.Integer, value: expr.value)
else:
discard # Unreachable
else:
discard # Unreachable