diff --git a/src/nimkalc/objects/ast.nim b/src/nimkalc/objects/ast.nim index 86ec55f..e1f2760 100644 --- a/src/nimkalc/objects/ast.nim +++ b/src/nimkalc/objects/ast.nim @@ -100,6 +100,16 @@ template ensureNonZero(node: AstNode) = raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") +template ensurePositive(node: AstNode) = + ## Handy template to ensure that a given node's value is positive + if node.value < 0.0: + case node.kind: + of NodeKind.Float, NodeKind.Integer: + raise newException(MathError, "value must be positive") + else: + raise newException(CatchableError, &"invalid node kind '{node.kind}' for ensureNonZero") + + template ensureIntegers(left, right: AstNode) = ## Ensures both operands are integers if left.kind != NodeKind.Integer or right.kind != NodeKind.Integer: @@ -159,6 +169,7 @@ proc visit_call(self: NodeVisitor, node: AstNode): AstNode = of "tan": callFunction(tan, self.eval(node.arguments[0]).value) of "sqrt": + ensurePositive(arg) callFunction(sqrt, self.eval(node.arguments[0]).value) of "log": let arg = self.eval(node.arguments[0]) @@ -222,8 +233,8 @@ proc visit_binary(self: NodeVisitor, node: AstNode): AstNode = result = handleBinary(left, right, `/`) of TokenType.Modulo: # Modulo is a bit special since we must have integers - ensureNonZero(right) ensureIntegers(left, right) + ensureNonZero(right) result = AstNode(kind: NodeKind.Integer, value: float(int(left.value) mod int(right.value))) of TokenType.Exp: result = handleBinary(left, right, pow) @@ -245,5 +256,7 @@ proc visit_unary(self: NodeVisitor, node: AstNode): AstNode = result = AstNode(kind: NodeKind.Integer, value: -expr.value) else: discard # Unreachable + of TokenType.Plus: + result = node # Unary + does nothing else: discard # Unreachable