Added a rough functions implementation

This commit is contained in:
nocturn9x 2021-03-11 20:02:51 +01:00
parent cd80d3babf
commit 219c5c9ac1
6 changed files with 177 additions and 74 deletions

View File

@ -1,14 +1,27 @@
# NimKalc - A math parsing library # NimKalc - A math parsing library
NimKalc is a simple implementation of a recursive-descent top-down parser that can evaluate NimKalc is a simple implementation of a recursive-descent top-down parser that can evaluate
mathematical expressions. Notable mentions are support for common mathematical constants (pi, tau, euler's number, etc), mathematical expressions.
functions (`sin`, `cos`, `tan`...), equation-solving algos using newton's method and scientific notation numbers (such as `2e5`)
__Disclaimer__: This library is __in beta__ and is not fully tested yet. It will be soon, though
Features:
- Support for mathematical constants (`pi`, `tau` and `e` right now)
- Supported functions:
- `sin`
- `cos`
- `tan`
- `sqrt`
- `root` (for generic roots, takes the base and the argument)
- `log` (logarithm in base `e`)
- `logN` (logarithm in a given base, second argument)
- Parentheses can be used to enforce different precedence levels
- Easy API for tokenization, parsing and evaluation of AST nodes
## Current limitations ## Current limitations
- No functions (coming soon)
- No equation-solving (coming soon) - No equation-solving (coming soon)
- The parsing is a bit weird because `2 2` will parse the first 2 and just stop instead of erroring out (FIXME) - The parsing is a bit weird because something like `2 2` will parse the first 2 and just stop instead of erroring out (FIXME)
## How to use it ## How to use it
@ -18,9 +31,8 @@ NimKalc parses mathematical expressions following this process:
- Generate an AST - Generate an AST
- Visit the nodes - Visit the nodes
Each of these steps can be run separately, but for convenience a wrapper Each of these steps can be run separately, but for convenience a wrapper `eval` procedure has been defined which takes in a string
`eval` procedure has been defined which takes in a string and returns a and returns a single AST node containing the result of the given expression.
single AST node containing the result of the given expression.
## Supported operators ## Supported operators
@ -28,30 +40,39 @@ Beyond the classical 4 operators (`+`, `-`, `/` and `*`), NimKalc supports:
- `%` for modulo division - `%` for modulo division
- `^` for exponentiation - `^` for exponentiation
- unary `-` for negation - unary `-` for negation
- Arbitrarily nested parentheses (__not__ empty ones!) to enforce precedence
## Exceptions ## Exceptions
NimKalc defines 2 exceptions: NimKalc defines various exceptions:
- `ParseError` is used when the expression is invalid - `NimKalcException` is a generic superclass for all errors
- `ParseError` is used when the expression is syntactically invalid
- `MathError` is used when there is an arithmetical error such as division by 0 or domain errors (e.g. `log(0)`) - `MathError` is used when there is an arithmetical error such as division by 0 or domain errors (e.g. `log(0)`)
- `EvaluationError` is used when the runtime evaluation of an expression fails (e.g. trying to call something that isn't a function)
## Design ## Design
NimKalc treats all numerical values as `float` to simplify the implementation of the underlying operators. To tell integers NimKalc treats all numerical values as `float` to simplify the implementation of the underlying operators. To tell integers
from floating point numbers the `AstNode` object has a `kind` discriminant which will be equal to `NodeKind.Integer` for ints from floating point numbers the `AstNode` object has a `kind` discriminant which will be equal to `NodeKind.Integer` for ints
and `NodeKind.Float` for decimals. It is advised that you take this into account when using the library and `NodeKind.Float` for decimals. It is advised that you take this into account when using the library, since integers might
start losing precision when converted from their float counterpart due to the difference of the two types. Everything should
be fine as long as the value doesn't exceed 2 ^ 53, though
__Note__: The string representation of integer nodes won't show the decimal part for clarity __Note__: The string representation of integer nodes won't show the decimal part for clarity
Some other notable design choices (due to the underlying simplicity of the language we parse) are as follows:
- Identifiers are checked when tokenizing, since they're all constant
- Mathematical constants are immediately mapped to their real values when tokenizing with no intermediate steps or tokens
- Type errors (such as trying to call an integer) are detected statically at parse time
## String representations ## String representations
All of NimKalc's objects implement the `$` operator and are therefore printable. Integer nodes will look like `Integer(x)`, while All of NimKalc's objects implement the `$` operator and are therefore printable. Integer nodes will look like `Integer(x)`, while
floats are represented with `Float(x.x)`. Unary operators print as `Unary(operator, right)`, while binary operators print as `Binary(left, operator, right)`. floats are represented with `Float(x.x)`. Unary operators print as `Unary(operator, right)`, while binary operators print as `Binary(left, operator, right)`.
Parenthesized expressions print as `Grouping(expr)`, where `expr` is the expression enclosed in parentheses (as an AST node, obviously). Parenthesized expressions print as `Grouping(expr)`, where `expr` is the expression enclosed in parentheses (as an AST node, obviously).
Token objects will print as `Token(kind, lexeme)`: an example for the number 2 would be `Token(Integer, '2')` Token objects will print as `Token(kind, lexeme)`: an example for the number 2 would be `Token(Integer, '2')`. Function calls print like `Call(name, args)`
where `name` is the function name and `args` is a `seq[AstNode]` representing the function's arguments
## Example ## Example
@ -115,14 +136,13 @@ when isMainModule:
``` ```
__Note__: If you don't need the intermediate representations shown here (tokens, AST) you can just `import nimkalc` and use __Note__: If you don't need the intermediate representations shown here (tokens/AST) you can just `import nimkalc` and use
the `eval` procedure, which takes in a string and returns the evaluated result as a primary AST node like so: the `eval` procedure, which takes in a string and returns the evaluated result as a primary AST node like so:
```nim ```nim
import nimkalc import nimkalc
echo eval("2+2") # Prints Integer(4) echo eval("2+2") # Prints Integer(4)
``` ```
## Installing ## Installing

View File

@ -19,13 +19,16 @@ import error
import strformat import strformat
import tables import tables
import math import math
import strutils
type type
NodeKind* {.pure.} = enum NodeKind* {.pure.} = enum
# An enum for all kinds of AST nodes
Grouping, Unary, Binary, Integer, Grouping, Unary, Binary, Integer,
Float Float, Call, Ident
AstNode* = ref object AstNode* = ref object
# An AST node object
case kind*: NodeKind case kind*: NodeKind
of NodeKind.Grouping: of NodeKind.Grouping:
expr*: AstNode expr*: AstNode
@ -42,6 +45,11 @@ type
# using a double precision float for everything # using a double precision float for everything
# is just easier # is just easier
value*: float64 value*: float64
of NodeKind.Ident:
name*: string
of NodeKind.Call:
arguments*: seq[AstNode]
function*: AstNode
NodeVisitor* = ref object NodeVisitor* = ref object
# A node visitor object # A node visitor object
@ -64,35 +72,10 @@ proc `$`*(self: AstNode): string =
result = &"Integer({$int(self.value)})" result = &"Integer({$int(self.value)})"
of NodeKind.Float: of NodeKind.Float:
result = &"Float({$self.value})" result = &"Float({$self.value})"
of NodeKind.Call:
result = &"Call({self.function.name}, {self.arguments})"
# Forward declarations of NodeKind.Ident:
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode result = &"Identifier({self.name})"
proc visit_unary(self: NodeVisitor, node: AstNode): AstNode
proc visit_binary(self: NodeVisitor, node: AstNode): AstNode
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode
proc accept(self: AstNode, visitor: NodeVisitor): AstNode =
case self.kind:
of NodeKind.Integer, NodeKind.Float:
result = visitor.visit_literal(self)
of NodeKind.Binary:
result = visitor.visit_binary(self)
of NodeKind.Unary:
result = visitor.visit_unary(self)
of NodeKind.Grouping:
result = visitor.visit_grouping(self)
proc eval*(self: NodeVisitor, node: AstNode): AstNode =
## Evaluates an AST node
result = node.accept(self)
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode =
## Visits a literal AST node (such as integers)
result = node # Not that we can do anything else after all, lol
template handleBinary(left, right: AstNode, operator: untyped): AstNode = template handleBinary(left, right: AstNode, operator: untyped): AstNode =
@ -106,18 +89,14 @@ template handleBinary(left, right: AstNode, operator: untyped): AstNode =
AstNode(kind: NodeKind.Float, value: r) AstNode(kind: NodeKind.Float, value: r)
template ensureNonZero(node: AstNode) =
template rightOpNonZero(node: AstNode, opType: string) = ## Handy template to ensure that a given node's value is not 0
## Handy template to make sure that the given AST node matches
## a condition from
if node.value == 0.0: if node.value == 0.0:
case node.kind: case node.kind:
of NodeKind.Float: of NodeKind.Float, NodeKind.Integer:
raise newException(MathError, "float " & opType & " by 0") raise newException(MathError, &"{($node.kind).toLowerAscii()} can't be zero")
of NodeKind.Integer:
raise newException(MathError, "integer " & opType & " by 0")
else: else:
raise newException(CatchableError, &"invalid node kind '{node.kind}' for rightOpNonZero") raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero")
template ensureIntegers(left, right: AstNode) = template ensureIntegers(left, right: AstNode) =
@ -126,6 +105,73 @@ template ensureIntegers(left, right: AstNode) =
raise newException(MathError, "an integer is required") raise newException(MathError, "an integer is required")
# Forward declarations
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode
proc visit_unary(self: NodeVisitor, node: AstNode): AstNode
proc visit_binary(self: NodeVisitor, node: AstNode): AstNode
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode
proc visit_call(self: NodeVisitor, node: AstNode): AstNode
proc accept(self: AstNode, visitor: NodeVisitor): AstNode =
## Implements the accept part of the visitor pattern
## for our AST visitor
case self.kind:
of NodeKind.Integer, NodeKind.Float, NodeKind.Ident:
result = visitor.visit_literal(self)
of NodeKind.Binary:
result = visitor.visit_binary(self)
of NodeKind.Unary:
result = visitor.visit_unary(self)
of NodeKind.Grouping:
result = visitor.visit_grouping(self)
of NodeKind.Call:
result = visitor.visit_call(self)
proc eval*(self: NodeVisitor, node: AstNode): AstNode =
## Evaluates an AST node
result = node.accept(self)
proc visit_literal(self: NodeVisitor, node: AstNode): AstNode =
## Visits a literal AST node (such as integers)
result = node # Not that we can do anything else after all, lol
proc visit_call(self: NodeVisitor, node: AstNode): AstNode =
## Visits function call expressions
var args: seq[AstNode] = @[]
for arg in node.arguments:
args.add(self.eval(arg))
if node.function.name == "sin":
let r = sin(args[0].value)
if r is float:
result = AstNode(kind: NodeKind.Float, value: r)
else:
result = AstNode(kind: NodeKind.Integer, value: float(r))
if node.function.name == "cos":
let r = cos(args[0].value)
if r is float:
result = AstNode(kind: NodeKind.Float, value: r)
else:
result = AstNode(kind: NodeKind.Integer, value: float(r))
if node.function.name == "tan":
let r = tan(args[0].value)
if r is float:
result = AstNode(kind: NodeKind.Float, value: r)
else:
result = AstNode(kind: NodeKind.Integer, value: float(r))
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode =
## Visits grouping (i.e. parenthesized) expressions. Parentheses
## have no other meaning than to allow a lower-precedence expression
## where a higher-precedence one is expected so that 2 * (3 + 1) is
## different from 2 * 3 + 1
return self.eval(node.expr)
proc visit_binary(self: NodeVisitor, node: AstNode): AstNode = proc visit_binary(self: NodeVisitor, node: AstNode): AstNode =
## Visits a binary AST node and evaluates it ## Visits a binary AST node and evaluates it
let right = self.eval(node.right) let right = self.eval(node.right)
@ -136,11 +182,11 @@ proc visit_binary(self: NodeVisitor, node: AstNode): AstNode =
of TokenType.Minus: of TokenType.Minus:
result = handleBinary(left, right, `-`) result = handleBinary(left, right, `-`)
of TokenType.Div: of TokenType.Div:
rightOpNonZero(right, "division") ensureNonZero(right)
result = handleBinary(left, right, `/`) result = handleBinary(left, right, `/`)
of TokenType.Modulo: of TokenType.Modulo:
# Modulo is a bit special since we must have integers # Modulo is a bit special since we must have integers
rightOpNonZero(right, "modulo") ensureNonZero(right)
ensureIntegers(left, right) ensureIntegers(left, right)
result = AstNode(kind: NodeKind.Integer, value: float(int(left.value) mod int(right.value))) result = AstNode(kind: NodeKind.Integer, value: float(int(left.value) mod int(right.value)))
of TokenType.Exp: of TokenType.Exp:
@ -165,11 +211,3 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode =
discard # Unreachable discard # Unreachable
else: else:
discard # Unreachable discard # Unreachable
proc visit_grouping(self: NodeVisitor, node: AstNode): AstNode =
## Visits grouping (i.e. parenthesized) expressions. Parentheses
## have no other meaning than to allow a lower-precedence expression
## where a higher-precedence one is expected so that 2 * (3 + 1) is
## different from 2 * 3 + 1
return self.eval(node.expr)

View File

@ -15,6 +15,9 @@
type type
ParseError* = object of CatchableError NimKalcException* = object of CatchableError
ParseError* = object of NimKalcException
## A parsing exception ## A parsing exception
MathError* = object of ArithmeticDefect MathError* = object of NimKalcException
## An arithmetic error
EvaluationError* = object of NimKalcException

View File

@ -22,8 +22,10 @@ type
# Operators # Operators
Plus, Minus, Div, Exp, Modulo, Plus, Minus, Div, Exp, Modulo,
Mul, RightParen, LeftParen, Mul, RightParen, LeftParen,
# Identifiers
Ident,
# Other # Other
Eof Eof, Comma
Token* = object Token* = object
# A token object # A token object
lexeme*: string lexeme*: string

View File

@ -27,15 +27,18 @@ const tokens = to_table({
'(': TokenType.LeftParen, ')': TokenType.RightParen, '(': TokenType.LeftParen, ')': TokenType.RightParen,
'-': TokenType.Minus, '+': TokenType.Plus, '-': TokenType.Minus, '+': TokenType.Plus,
'*': TokenType.Mul, '/': TokenType.Div, '*': TokenType.Mul, '/': TokenType.Div,
'%': TokenType.Modulo, '^': TokenType.Exp}) '%': TokenType.Modulo, '^': TokenType.Exp,
',': TokenType.Comma})
# All the identifiers and constants (such as PI) # All the identifiers and constants (such as PI)
# Since they're constant we don't even need to bother adding another # Since they're constant we don't even need to bother adding another
# AST node kind, we can just map the name to a float literal ;) # AST node kind, we can just map the name to a float literal ;)
const identifiers = to_table({ const constants = to_table({
"pi": Token(kind: TokenType.Float, lexeme: "3.141592653589793"), "pi": Token(kind: TokenType.Float, lexeme: "3.141592653589793"),
"e": Token(kind: TokenType.Float, lexeme: "2.718281828459045"), "e": Token(kind: TokenType.Float, lexeme: "2.718281828459045"),
"tau": Token(kind: TokenType.Float, lexeme: "6.283185307179586") "tau": Token(kind: TokenType.Float, lexeme: "6.283185307179586")
}) })
# Since also math functions are hardcoded, we can use an array
const functions = ["sin", "cos", "tan"]
type type
@ -88,6 +91,8 @@ func createToken(self: Lexer, tokenType: TokenType): Token =
proc parseNumber(self: Lexer) = proc parseNumber(self: Lexer) =
## Parses numeric literals ## Parses numeric literals
var kind = TokenType.Int var kind = TokenType.Int
var scientific: bool = false
var sign: bool = false
while true: while true:
if self.peek().isDigit(): if self.peek().isDigit():
discard self.step() discard self.step()
@ -99,6 +104,11 @@ proc parseNumber(self: Lexer) =
# Scientific notation # Scientific notation
kind = TokenType.Float kind = TokenType.Float
discard self.step() discard self.step()
scientific = true
elif self.peek().toLowerAscii() in {'-', '+'} and scientific and not sign:
# So we can parse stuff like 2e-5
sign = true
discard self.step()
else: else:
break break
self.tokens.add(self.createToken(kind)) self.tokens.add(self.createToken(kind))
@ -111,8 +121,10 @@ proc parseIdentifier(self: Lexer) =
while self.peek().isAlphaNumeric() or self.peek() in {'_', }: while self.peek().isAlphaNumeric() or self.peek() in {'_', }:
discard self.step() discard self.step()
var text: string = self.source[self.start..<self.current] var text: string = self.source[self.start..<self.current]
if text.toLowerAscii() in identifiers: if text.toLowerAscii() in constants:
self.tokens.add(identifiers[text]) self.tokens.add(constants[text])
elif text.toLowerAscii() in functions:
self.tokens.add(self.createToken(TokenType.Ident))
else: else:
raise newException(ParseError, &"Unknown identifier '{text}'") raise newException(ParseError, &"Unknown identifier '{text}'")
@ -138,6 +150,8 @@ proc lex*(self: Lexer, source: string): seq[Token] =
## Lexes a source string, converting a stream ## Lexes a source string, converting a stream
## of characters into a series of tokens ## of characters into a series of tokens
self.source = source self.source = source
self.tokens = @[]
self.current = 0
while not self.done(): while not self.done():
self.start = self.current self.start = self.current
self.scanToken() self.scanToken()

View File

@ -20,6 +20,7 @@ import ../objects/error
import parseutils import parseutils
import strformat import strformat
import tables
{.experimental: "implicitDeref".} {.experimental: "implicitDeref".}
@ -31,6 +32,9 @@ type
current: int current: int
const arities = to_table({"sin": 1, "cos": 1, "tan": 1})
proc initParser*(): Parser = proc initParser*(): Parser =
new(result) new(result)
result.current = 0 result.current = 0
@ -134,17 +138,39 @@ proc primary(self: Parser): AstNode =
let expression = self.binary() let expression = self.binary()
self.expect(TokenType.RightParen, "unexpected EOL") self.expect(TokenType.RightParen, "unexpected EOL")
result = AstNode(kind: NodeKind.Grouping, expr: expression) result = AstNode(kind: NodeKind.Grouping, expr: expression)
of TokenType.Ident:
result = AstNode(kind: NodeKind.Ident, name: value.lexeme)
else: else:
self.error(&"invalid token of kind '{value.kind}' in primary expression") self.error(&"invalid token of kind '{value.kind}' in primary expression")
proc call(self: Parser): AstNode =
## Parses function calls such as sin(2)
var expression = self.primary()
if self.match(TokenType.LeftParen):
if expression.kind != NodeKind.Ident:
self.error(&"object of type '{expression.kind}' is not callable")
var arguments: seq[AstNode] = @[]
if not self.check(TokenType.RightParen):
arguments.add(self.binary())
while self.match(TokenType.Comma):
arguments.add(self.binary())
result = AstNode(kind: NodeKind.Call, arguments: arguments, function: expression)
if len(arguments) != arities[expression.name]:
self.error(&"Wrong number of arguments supplied to function '{expression.name}': expected {arities[expression.name]}, got {len(arguments)}")
self.expect(TokenType.RightParen, "unclosed function call")
else:
result = expression
proc unary(self: Parser): AstNode = proc unary(self: Parser): AstNode =
## Parses unary expressions such as -1 ## Parses unary expressions such as -1
case self.step().kind: case self.step().kind:
of TokenType.Minus: of TokenType.Minus:
result = AstNode(kind: NodeKind.Unary, unOp: self.previous(), operand: self.unary()) result = AstNode(kind: NodeKind.Unary, unOp: self.previous(), operand: self.unary())
else: else:
result = self.primary() result = self.call()
proc pow(self: Parser): AstNode = proc pow(self: Parser): AstNode =
@ -181,10 +207,10 @@ proc binary(self: Parser): AstNode =
result = self.addition() result = self.addition()
proc parse*(self: Parser, tokens: seq[Token]): AstNode = proc parse*(self: Parser, tokens: seq[Token]): AstNode =
## Parses a list of tokens into an AST tree ## Parses a list of tokens into an AST tree
self.tokens = tokens self.tokens = tokens
self.current = 0
result = self.binary() result = self.binary()