Fixes to function calls and attempt to fix closures inside variables

This commit is contained in:
Mattia Giambirtone 2022-12-02 13:35:54 +01:00
parent 8fb90bb0ef
commit 142e575497
4 changed files with 172 additions and 87 deletions

View File

@ -767,6 +767,7 @@ proc dispatch*(self: var PeonVM) =
of LoadUInt8:
self.push(uint64(self.constReadUInt8(int(self.readLong()))))
of LoadString:
# Loads the string's pointer onto the stack
self.push(cast[uint64](self.constReadString(int(self.readLong()), int(self.readLong()))))
# We cast instead of converting because, unlike with integers,
# we don't want nim to touch any of the bits of the underlying

View File

@ -50,7 +50,6 @@ type
Type = ref object
## A wrapper around
## compile-time types
isBuiltin: bool
case kind: TypeKind:
of Function:
isLambda: bool
@ -66,6 +65,7 @@ type
parent: Type
retJumps: seq[int]
forwarded: bool
location: int
of CustomType:
fields: TableRef[string, Type]
of Reference, Pointer:
@ -149,6 +149,8 @@ type
# Has the compiler generates this name internally or
# does it come from user code?
isReal: bool
# Is this name a builtin?
isBuiltin: bool
Loop = object
## A "loop object" used
@ -263,13 +265,13 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.
proc statement(self: Compiler, node: Statement)
proc declaration(self: Compiler, node: Declaration)
proc peek(self: Compiler, distance: int = 0): ASTNode
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Name {.discardable.}
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.}
proc varDecl(self: Compiler, node: VarDecl)
proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name
proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.}
proc getItemExpr(self: Compiler, node: GetItemExpr, compile: bool = true): Type {.discardable.}
proc unary(self: Compiler, node: UnaryExpr, compile: bool = true): Name {.discardable.}
proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Name {.discardable.}
proc unary(self: Compiler, node: UnaryExpr, compile: bool = true): Type {.discardable.}
proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Type {.discardable.}
proc infer(self: Compiler, node: LiteralExpr): Type
proc infer(self: Compiler, node: Expression): Type
proc inferOrError(self: Compiler, node: Expression): Type
@ -287,6 +289,7 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name)
proc compileModule(self: Compiler, module: Name)
proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int)
proc prepareFunction(self: Compiler, fn: Name)
proc lambdaExpr(self: Compiler, node: LambdaExpr, compile: bool = true): Type {.discardable.}
# End of forward declarations
@ -328,7 +331,7 @@ proc getSource*(self: Compiler): string = self.source
## Utility functions
proc `$`*(self: Name): string = $self[]
proc `$`(self: Type): string = $self[]
#proc `$`(self: Type): string = $self[]
proc hash(self: Name): Hash = self.ident.token.lexeme.hash()
@ -612,6 +615,8 @@ proc fixNames(self: Compiler, where, oldLen: int) =
for name in self.names:
if name.codePos > where:
name.codePos += offset
if name.valueType.kind == Function:
name.valueType.location += offset
proc insertAt(self: Compiler, where: int, opcode: OpCode, data: openarray[uint8]): int =
@ -730,7 +735,7 @@ proc getStackPos(self: Compiler, name: Name): int =
# temporaries. There is no stack frame for builtins, so we skip
# these names too
elif variable.kind == Argument:
if variable.belongsTo.valueType.isBuiltin:
if variable.belongsTo.isBuiltin:
continue
elif not variable.belongsTo.resolved:
continue
@ -989,26 +994,28 @@ proc infer(self: Compiler, node: Expression): Type =
if node.isNil():
return nil
case node.kind:
of identExpr:
result = self.identifier(IdentExpr(node), compile=false).valueType
of unaryExpr:
result = self.unary(UnaryExpr(node), compile=false).valueType.returnType
of binaryExpr:
result = self.binary(BinaryExpr(node), compile=false).valueType.returnType
of {intExpr, hexExpr, binExpr, octExpr,
strExpr, falseExpr, trueExpr, floatExpr
of NodeKind.identExpr:
result = self.identifier(IdentExpr(node), compile=false)
of NodeKind.unaryExpr:
result = self.unary(UnaryExpr(node), compile=false)
of NodeKind.binaryExpr:
result = self.binary(BinaryExpr(node), compile=false)
of {NodeKind.intExpr, NodeKind.hexExpr, NodeKind.binExpr, NodeKind.octExpr,
NodeKind.strExpr, NodeKind.falseExpr, NodeKind.trueExpr, NodeKind.floatExpr
}:
result = self.infer(LiteralExpr(node))
of NodeKind.callExpr:
result = self.call(CallExpr(node), compile=false).returnType
of refExpr:
result = self.call(CallExpr(node), compile=false)
of NodeKind.refExpr:
result = Type(kind: Reference, value: self.infer(Ref(node).value))
of ptrExpr:
of NodeKind.ptrExpr:
result = Type(kind: Pointer, value: self.infer(Ptr(node).value))
of groupingExpr:
of NodeKind.groupingExpr:
result = self.infer(GroupingExpr(node).expression)
of NodeKind.getItemExpr:
result = self.getItemExpr(GetItemExpr(node), compile=false)
of NodeKind.lambdaExpr:
result = self.lambdaExpr(LambdaExpr(node), compile=false)
else:
discard # TODO
@ -1260,8 +1267,6 @@ proc handleBuiltinFunction(self: Compiler, fn: Type, args: seq[Expression], line
}.to_table()
if fn.builtinOp == "print":
var typ = self.expression(args[0], compile=false)
if typ.kind == Function:
typ = typ.returnType
case typ.kind:
of Int64:
self.emitByte(PrintInt64, line)
@ -1291,6 +1296,8 @@ proc handleBuiltinFunction(self: Compiler, fn: Type, args: seq[Expression], line
self.emitByte(PrintNan, line)
of Inf:
self.emitByte(PrintInf, line)
of Function:
self.emitByte(PrintHex, line)
else:
self.error("invalid type for built-in 'print'", args[0])
return
@ -1380,7 +1387,7 @@ proc endScope(self: Compiler) =
if name.kind notin [NameKind.Var, NameKind.Argument]:
continue
elif name.kind == NameKind.Argument:
if name.belongsTo.valueType.isBuiltin:
if name.belongsTo.isBuiltin:
# Arguments to builtin functions become temporaries on the
# stack and are popped automatically
continue
@ -1397,7 +1404,7 @@ proc endScope(self: Compiler) =
self.warning(UnusedName, &"'{name.ident.token.lexeme}' is declared but not used (add '_' prefix to silence warning)", name)
of NameKind.Argument:
if not name.ident.token.lexeme.startsWith("_") and name.isPrivate:
if not name.belongsTo.valueType.isBuiltin and name.belongsTo.isReal:
if not name.belongsTo.isBuiltin and name.belongsTo.isReal:
# Builtin functions never use their arguments. We also don't emit this
# warning if the function was generated internally by the compiler (for
# example as a result of generic specialization) because such objects do
@ -1595,7 +1602,7 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} =
for name in self.findByName(declaredName):
if name == n:
continue
# We don't check for name clashes with functions because match does that
# We don't check for name clashes with functions because self.match() does that
elif name.kind in [NameKind.Var, NameKind.Module, NameKind.CustomType, NameKind.Enum]:
if name.owner != self.currentModule:
if name.isPrivate:
@ -1640,7 +1647,7 @@ proc handleMagicPragma(self: Compiler, pragma: Pragma, name: Name) =
elif pragma.args[0].kind != strExpr:
self.error("'magic' pragma: wrong type of argument (constant string expected)")
elif name.node.kind == NodeKind.funDecl:
name.valueType.isBuiltin = true
name.isBuiltin = true
name.valueType.builtinOp = pragma.args[0].token.lexeme[1..^2]
elif name.node.kind == NodeKind.typeDecl:
name.valueType = pragma.args[0].token.lexeme[1..^2].toIntrinsic()
@ -1648,7 +1655,7 @@ proc handleMagicPragma(self: Compiler, pragma: Pragma, name: Name) =
self.error("'magic' pragma: wrong argument value", pragma.args[0])
if name.valueType.kind == All:
self.error("don't even think about it (compiler-chan is angry at you)", pragma)
name.valueType.isBuiltin = true
name.isBuiltin = true
else:
self.error("'magic' pragma is not valid in this context")
@ -1669,7 +1676,7 @@ proc handlePurePragma(self: Compiler, pragma: Pragma, name: Name) =
case name.node.kind:
of NodeKind.funDecl:
FunDecl(name.node).isPure = true
of lambdaExpr:
of NodeKind.lambdaExpr:
LambdaExpr(name.node).isPure = true
else:
self.error("'pure' pragma is not valid in this context")
@ -1683,7 +1690,7 @@ proc dispatchPragmas(self: Compiler, name: Name) =
case name.node.kind:
of NodeKind.funDecl, NodeKind.typeDecl, NodeKind.varDecl:
pragmas = Declaration(name.node).pragmas
of lambdaExpr:
of NodeKind.lambdaExpr:
pragmas = LambdaExpr(name.node).pragmas
else:
discard # Unreachable
@ -1795,19 +1802,26 @@ proc beginProgram(self: Compiler): int =
## End of utility functions
proc literal(self: Compiler, node: ASTNode) =
proc literal(self: Compiler, node: ASTNode, compile: bool = true): Type {.discardable.} =
## Emits instructions for literals such
## as singletons, strings and numbers
case node.kind:
of trueExpr:
self.emitByte(LoadTrue, node.token.line)
result = Type(kind: Bool)
if compile:
self.emitByte(LoadTrue, node.token.line)
of falseExpr:
self.emitByte(LoadFalse, node.token.line)
result = Type(kind: Bool)
if compile:
self.emitByte(LoadFalse, node.token.line)
of strExpr:
self.emitConstant(LiteralExpr(node), Type(kind: String))
result = Type(kind: String)
if compile:
self.emitConstant(LiteralExpr(node), Type(kind: String))
of intExpr:
let y = IntExpr(node)
let kind = self.infer(y)
result = kind
if kind.kind in [Int64, Int32, Int16, Int8]:
var x: int
try:
@ -1819,11 +1833,13 @@ proc literal(self: Compiler, node: ASTNode) =
try:
discard parseBiggestUInt(y.literal.lexeme, x)
except ValueError:
self.error("integer value out of range")
self.emitConstant(y, kind)
self.error("integer value out of range")
if compile:
self.emitConstant(y, kind)
of hexExpr:
var x: int
var y = HexExpr(node)
result = self.infer(y)
try:
discard parseHex(y.literal.lexeme, x)
except ValueError:
@ -1834,10 +1850,12 @@ proc literal(self: Compiler, node: ASTNode) =
relPos: (start: y.token.relPos.start, stop: y.token.relPos.start + len($x))
)
)
self.emitConstant(node, self.infer(y))
if compile:
self.emitConstant(node, result)
of binExpr:
var x: int
var y = BinExpr(node)
result = self.infer(y)
try:
discard parseBin(y.literal.lexeme, x)
except ValueError:
@ -1848,10 +1866,12 @@ proc literal(self: Compiler, node: ASTNode) =
relPos: (start: y.token.relPos.start, stop: y.token.relPos.start + len($x))
)
)
self.emitConstant(node, self.infer(y))
if compile:
self.emitConstant(node, result)
of octExpr:
var x: int
var y = OctExpr(node)
result = self.infer(y)
try:
discard parseOct(y.literal.lexeme, x)
except ValueError:
@ -1862,51 +1882,54 @@ proc literal(self: Compiler, node: ASTNode) =
relPos: (start: y.token.relPos.start, stop: y.token.relPos.start + len($x))
)
)
self.emitConstant(node, self.infer(y))
if compile:
self.emitConstant(node, result)
of floatExpr:
var x: float
var y = FloatExpr(node)
result = self.infer(y)
try:
discard parseFloat(y.literal.lexeme, x)
except ValueError:
self.error("floating point value out of range")
self.emitConstant(y, self.infer(y))
if compile:
self.emitConstant(y, result)
of awaitExpr:
var y = AwaitExpr(node)
self.expression(y.expression)
self.emitByte(OpCode.Await, node.token.line)
discard # TODO
else:
self.error(&"invalid AST node of kind {node.kind} at literal(): {node} (This is an internal error and most likely a bug!)")
proc unary(self: Compiler, node: UnaryExpr, compile: bool = true): Name {.discardable.} =
proc unary(self: Compiler, node: UnaryExpr, compile: bool = true): Type {.discardable.} =
## Compiles all unary expressions
var default: Expression
let fn = Type(kind: Function,
returnType: Type(kind: Any),
args: @[("", self.inferOrError(node.a), default)])
result = self.match(node.token.lexeme, fn, node)
let impl = self.match(node.token.lexeme, fn, node)
result = impl.valueType.returnType
if compile:
self.generateCall(result, @[node.a], result.line)
self.generateCall(impl, @[node.a], impl.line)
proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Name {.discardable.} =
proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Type {.discardable.} =
## Compiles all binary expressions
var default: Expression
let fn = Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)])
result = self.match(node.token.lexeme, fn, node)
let impl = self.match(node.token.lexeme, fn, node)
result = impl.valueType.returnType
if compile:
self.generateCall(result, @[node.a, node.b], result.line)
self.generateCall(impl, @[node.a, node.b], impl.line)
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Name {.discardable.} =
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.} =
## Compiles access to identifiers
var s = name
if s.isNil():
s = self.resolveOrError(node)
var t = self.findByType(s.ident.token.lexeme, Type(kind: All))
s = t[0] # Shadowing!
result = s
result = s.valueType
if not compile:
return
if s.isConst:
@ -1919,8 +1942,8 @@ proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool
# resolve them to their address into our bytecode when
# they're referenced
self.emitByte(LoadUInt64, node.token.line)
self.emitBytes(self.chunk.writeConstant(s.codePos.toLong()), node.token.line)
elif s.valueType.isBuiltin:
self.emitBytes(self.chunk.writeConstant(s.valueType.location.toLong()), node.token.line)
elif s.isBuiltin:
case s.ident.token.lexeme:
of "nil":
self.emitByte(LoadNil, node.token.line)
@ -1990,7 +2013,15 @@ proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Name {.dis
self.emitByte(StoreClosure, node.token.line)
self.emitBytes(self.getClosurePos(r).toTriple(), node.token.line)
of setItemExpr:
discard # TODO
let node = SetItemExpr(node)
let name = IdentExpr(node.name)
var r = self.resolveOrError(name)
if r.isConst:
self.error(&"cannot assign to '{name.token.lexeme}' (value is a constant)", name)
elif r.isLet:
self.error(&"cannot reassign '{name.token.lexeme}' (value is immutable)", name)
if r.valueType.kind != CustomType:
self.error("only types have fields", node)
else:
self.error(&"invalid AST node of kind {node.kind} at assignment(): {node} (This is an internal error and most likely a bug)")
@ -2038,11 +2069,11 @@ proc whileStmt(self: Compiler, node: WhileStmt) =
self.patchJump(jump)
# TODO: This will be needed for lambdas
proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} =
## Version of generateCall that takes Type objects
## instead of Name objects. The function is assumed
## to be on the stack
## instead of Name objects (used for lambdas and
## consequent calls). The function's address is
## assumed to be on the stack
self.emitByte(LoadUInt64, line)
self.emitBytes(self.chunk.writeConstant(0.toLong()), line)
let pos = self.chunk.consts.len() - 8
@ -2124,7 +2155,7 @@ proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) =
## Small wrapper that abstracts emitting a call instruction
## for a given function
self.dispatchDelayedPragmas(fn)
if fn.valueType.isBuiltin:
if fn.isBuiltin:
self.handleBuiltinFunction(fn.valueType, args, line)
return
case fn.kind:
@ -2212,6 +2243,7 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda
result = impl.valueType
if impl.isGeneric:
result = self.specialize(result, argExpr)
result = result.returnType
if compile:
# Now we call it
self.generateCall(impl, argExpr, node.token.line)
@ -2226,18 +2258,21 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda
node = CallExpr(node).callee
# Now that we know how many call expressions we
# need to compile, we start from the outermost
# one (which is at the end because we went all
# the way back to the first one earlier) and work
# our way to the innermost call
for exp in reversed(all):
self.call(exp, compile)
# one and work our way to the innermost call
for exp in all:
result = self.call(exp, compile)
#echo result
#result = result.returnType
if compile and result.kind == Function:
self.generateCall(result, argExpr, node.token.line)
result = result.returnType
# TODO: Calling lambdas on-the-fly (i.e. on the same line)
else:
let typ = self.infer(node)
if typ.isNil():
self.error(&"expression has no type")
self.error(&"expression has no type", node)
else:
self.error(&"object of type '{self.stringify(typ)}' is not callable")
self.error(&"object of type '{self.stringify(typ)}' is not callable", node)
proc getItemExpr(self: Compiler, node: GetItemExpr, compile: bool = true): Type {.discardable.} =
@ -2263,6 +2298,59 @@ proc getItemExpr(self: Compiler, node: GetItemExpr, compile: bool = true): Type
self.error("invalid syntax", node)
proc lambdaExpr(self: Compiler, node: LambdaExpr, compile: bool = true): Type {.discardable.} =
## Compiles lambda functions as expressions
result = Type(kind: Function, isLambda: true, fun: node)
self.beginScope()
var constraints: seq[tuple[match: bool, kind: Type]] = @[]
for gen in node.generics:
self.unpackGenerics(gen.cond, constraints)
self.names.add(Name(depth: self.depth,
isPrivate: true,
valueType: Type(kind: Generic, name: gen.name.token.lexeme, cond: constraints),
codePos: 0,
isLet: false,
line: node.token.line,
belongsTo: nil, # TODO
ident: gen.name,
owner: self.currentModule,
file: self.file))
constraints = @[]
var default: Expression
var i = 0
for argument in node.arguments:
if self.names.high() > 16777215:
self.error("cannot declare more than 16777215 variables at a time")
self.names.add(Name(depth: self.depth,
isPrivate: true,
owner: self.currentModule,
file: self.currentModule.file,
isConst: false,
ident: argument.name,
valueType: self.inferOrError(argument.valueType),
codePos: 0,
isLet: false,
line: argument.name.token.line,
belongsTo: nil, # TODO
kind: NameKind.Argument,
node: argument.name
))
if node.arguments.high() - node.defaults.high() <= node.arguments.high():
# There's a default argument!
result.args.add((self.names[^1].ident.token.lexeme, self.names[^1].valueType, node.defaults[i]))
inc(i)
else:
# This argument has no default
result.args.add((self.names[^1].ident.token.lexeme, self.names[^1].valueType, default))
# The function needs a return type too!
if not node.returnType.isNil():
result.returnType = self.inferOrError(node.returnType)
if not compile:
return
# TODO
self.endScope()
proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.} =
## Compiles all expressions
case node.kind:
@ -2276,27 +2364,29 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.
# the node to its true type because that type information
# would be lost in the call anyway. The differentiation
# happens in self.assignment()
of setItemExpr, assignExpr:
of NodeKind.setItemExpr, NodeKind.assignExpr:
return self.assignment(node, compile).valueType
of identExpr:
return self.identifier(IdentExpr(node), compile=compile).valueType
of unaryExpr:
of NodeKind.identExpr:
return self.identifier(IdentExpr(node), compile=compile)
of NodeKind.unaryExpr:
# Unary expressions such as ~5 and -3
return self.unary(UnaryExpr(node), compile).valueType
of groupingExpr:
return self.unary(UnaryExpr(node), compile)
of NodeKind.groupingExpr:
# Grouping expressions like (2 + 1)
return self.expression(GroupingExpr(node).expression, compile)
of binaryExpr:
of NodeKind.binaryExpr:
# Binary expressions such as 2 ^ 5 and 0.66 * 3.14
return self.binary(BinaryExpr(node)).valueType
of intExpr, hexExpr, binExpr, octExpr, strExpr, falseExpr, trueExpr,
floatExpr:
return self.binary(BinaryExpr(node))
of NodeKind.intExpr, NodeKind.hexExpr, NodeKind.binExpr, NodeKind.octExpr,
NodeKind.strExpr, NodeKind.falseExpr, NodeKind.trueExpr, NodeKind.floatExpr:
# Since all of these AST nodes share the
# same overall structure and the kind
# field is enough to tell one from the
# other, why bother with specialized
# cases when one is enough?
self.literal(node)
return self.literal(node, compile)
of NodeKind.lambdaExpr:
return self.lambdaExpr(LambdaExpr(node), compile)
else:
self.error(&"invalid AST node of kind {node.kind} at expression(): {node} (This is an internal error and most likely a bug)")
@ -2555,13 +2645,14 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name) =
function.valueType.children.add(name.valueType)
name.valueType.parent = function.valueType
self.currentFunction = name
if self.currentFunction.valueType.isBuiltin:
if self.currentFunction.isBuiltin:
self.currentFunction = function
return
# A function's code is just compiled linearly
# and then jumped over
jmp = self.emitJump(JumpForwards, node.token.line)
name.codePos = self.chunk.code.len()
name.valueType.location = name.codePos
# We let our debugger know this function's boundaries
self.chunk.functions.add(self.chunk.code.high().toTriple())
self.functions.add((start: self.chunk.code.high(), stop: 0, pos: self.chunk.functions.len() - 3, fn: name))

View File

@ -9,6 +9,6 @@ fn clock*: float {
}
fn print*[T: Number | string | bool | nan | inf](x: T) {
fn print*[T: any](x: T) {
#pragma[magic: "print"]
}

View File

@ -1,5 +1,5 @@
# Tests closures
# import std;
import std;
fn makeClosure(x: int): fn: int {
@ -10,14 +10,7 @@ fn makeClosure(x: int): fn: int {
}
fn makeClosureTwo(y: int): fn: int {
fn inner: int {
return y;
}
return inner;
}
makeClosureTwo(38)();
print(makeClosure(38)() == 38); # true;
var closure = makeClosure(42);
print(closure);
#closure(); # TODO: Fix