Fixed bugs in automatic types

This commit is contained in:
Mattia Giambirtone 2022-12-05 17:09:09 +01:00
parent 6caaf7e707
commit 7ebd13f739
3 changed files with 76 additions and 88 deletions

View File

@ -63,6 +63,7 @@ type
retJumps: seq[int]
forwarded: bool
location: int
compiled: bool
of CustomType:
fields: TableRef[string, Type]
of Reference, Pointer:
@ -279,11 +280,11 @@ proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tu
incremental: bool = false, isMainModule: bool = true, disabledWarnings: seq[WarningKind] = @[], showMismatches: bool = false,
mode: CompileMode = Debug): Chunk
proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.}
proc statement(self: Compiler, node: Statement, compile: bool = true)
proc declaration(self: Compiler, node: Declaration, compile: bool = true)
proc statement(self: Compiler, node: Statement)
proc declaration(self: Compiler, node: Declaration)
proc peek(self: Compiler, distance: int = 0): ASTNode
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.}
proc varDecl(self: Compiler, node: VarDecl, compile: bool = true)
proc varDecl(self: Compiler, node: VarDecl)
proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name
proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discardable.}
proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.}
@ -303,7 +304,7 @@ proc handlePurePragma(self: Compiler, pragma: Pragma, name: Name)
proc handleErrorPragma(self: Compiler, pragma: Pragma, name: Name)
proc dispatchPragmas(self: Compiler, name: Name)
proc dispatchDelayedPragmas(self: Compiler, name: Name)
proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true)
proc funDecl(self: Compiler, node: FunDecl, name: Name)
proc compileModule(self: Compiler, module: Name)
proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int)
proc prepareFunction(self: Compiler, fn: Name)
@ -1066,7 +1067,7 @@ proc findByName(self: Compiler, name: string): seq[Name] =
## with the given name. Returns all objects that apply.
for obj in reversed(self.names):
if obj.ident.token.lexeme == name:
if obj.owner != self.currentModule:
if obj.owner.path != self.currentModule.path:
if obj.isPrivate or self.currentModule notin obj.exportedTo:
continue
result.add(obj)
@ -1172,7 +1173,7 @@ proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowF
self.error(msg, node)
elif impl.len() > 1:
# Forward declarations don't count when looking for a function
impl = filterIt(impl, not it.valueType.forwarded)
impl = filterIt(impl, not it.valueType.forwarded and not it.valueType.isAuto)
if impl.len() > 1:
# If it's *still* more than one match, then it's an error
var msg = &"multiple matching implementations of '{name}' found"
@ -1313,7 +1314,7 @@ proc handleBuiltinFunction(self: Compiler, fn: Type, args: seq[Expression], line
self.patchJump(jump)
of "LogicalAnd":
self.expression(args[0])
var jump = self.emitJump(JumpIfFalseOrPop, line)
let jump = self.emitJump(JumpIfFalseOrPop, line)
self.expression(args[1])
self.patchJump(jump)
else:
@ -1989,7 +1990,7 @@ proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Type {.dis
elif r.isLet:
self.error(&"cannot reassign '{name.token.lexeme}' (value is immutable)", name)
self.check(node.value, r.valueType)
self.expression(node.value)
self.expression(node.value, compile)
var position = r.position
if r.depth < self.depth:
self.warning(WarningKind.MutateOuterScope, &"mutation of '{r.ident.token.lexeme}' declared in outer scope ({r.owner.file}.pn:{r.ident.token.line}:{r.ident.token.relPos.start})", nil, node)
@ -2024,43 +2025,35 @@ proc blockStmt(self: Compiler, node: BlockStmt, compile: bool = true) =
self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last)
else:
discard
self.declaration(decl, compile)
self.declaration(decl)
last = decl
self.endScope()
proc ifStmt(self: Compiler, node: IfStmt, compile: bool = true) =
proc ifStmt(self: Compiler, node: IfStmt) =
## Compiles if/else statements for conditional
## execution of code
self.check(node.condition, Type(kind: Bool))
self.expression(node.condition)
var jump: int
var jump2: int
if compile:
jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.statement(node.thenBranch, compile)
if compile:
jump2 = self.emitJump(JumpForwards, node.token.line)
self.patchJump(jump)
let jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.statement(node.thenBranch)
let jump2 = self.emitJump(JumpForwards, node.token.line)
self.patchJump(jump)
if not node.elseBranch.isNil():
self.statement(node.elseBranch, compile)
if compile:
self.patchJump(jump2)
self.statement(node.elseBranch)
self.patchJump(jump2)
proc whileStmt(self: Compiler, node: WhileStmt, compile: bool = true) =
proc whileStmt(self: Compiler, node: WhileStmt) =
## Compiles C-style while loops and
## desugared C-style for loops
self.check(node.condition, Type(kind: Bool))
let start = self.chunk.code.high()
self.expression(node.condition)
var jump: int
if compile:
jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.statement(node.body, compile)
if compile:
self.emitLoop(start, node.token.line)
self.patchJump(jump)
let jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.statement(node.body)
self.emitLoop(start, node.token.line)
self.patchJump(jump)
proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} =
@ -2161,6 +2154,8 @@ proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string,
var default: Expression
var node = FunDecl(fn.node)
var fn = deepCopy(fn)
fn.valueType.isAuto = false
fn.valueType.compiled = false
self.names.add(fn)
for (argument, val) in zip(node.arguments, args):
if self.names.high() > 16777215:
@ -2277,10 +2272,10 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda
if impl.valueType.isAuto:
impl = self.prepareAutoFunction(impl, args)
result = impl.valueType
self.funDecl(FunDecl(result.fun), impl, compile=compile)
if not impl.valueType.compiled:
self.funDecl(FunDecl(result.fun), impl)
result = result.returnType
if compile:
# Now we call it
self.generateCall(impl, argExpr, node.token.line)
of NodeKind.callExpr:
# Calling a call expression, like hello()()
@ -2307,8 +2302,7 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda
if impl.isGeneric:
result = self.specialize(result, argExpr)
result = result.returnType
if compile:
self.generateCall(impl, argExpr, node.token.line)
self.generateCall(impl, argExpr, node.token.line)
# TODO: Calling lambdas on-the-fly (i.e. on the same line)
else:
let typ = self.infer(node)
@ -2445,53 +2439,48 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.
# TODO
proc awaitStmt(self: Compiler, node: AwaitStmt, compile: bool = true) =
proc awaitStmt(self: Compiler, node: AwaitStmt) =
## Compiles await statements
# TODO
proc deferStmt(self: Compiler, node: DeferStmt, compile: bool = true) =
proc deferStmt(self: Compiler, node: DeferStmt) =
## Compiles defer statements
# TODO
proc yieldStmt(self: Compiler, node: YieldStmt, compile: bool = true) =
proc yieldStmt(self: Compiler, node: YieldStmt) =
## Compiles yield statements
# TODO
proc raiseStmt(self: Compiler, node: RaiseStmt, compile: bool = true) =
proc raiseStmt(self: Compiler, node: RaiseStmt) =
## Compiles raise statements
# TODO
proc assertStmt(self: Compiler, node: AssertStmt, compile: bool = true) =
proc assertStmt(self: Compiler, node: AssertStmt) =
## Compiles assert statements
# TODO
# TODO
proc forEachStmt(self: Compiler, node: ForEachStmt, compile: bool = true) =
proc forEachStmt(self: Compiler, node: ForEachStmt) =
## Compiles foreach loops
proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) =
proc returnStmt(self: Compiler, node: ReturnStmt) =
## Compiles return statements
if self.currentFunction.valueType.returnType.isNil() and not node.value.isNil():
self.error("cannot return a value from a void function", node.value)
elif not self.currentFunction.valueType.returnType.isNil() and node.value.isNil():
self.error("bare return statement is only allowed in void functions", node)
if not node.value.isNil():
if not self.currentFunction.valueType.isAuto:
self.check(node.value, self.currentFunction.valueType.returnType)
else:
if self.currentFunction.valueType.returnType.kind == Auto:
self.currentFunction.valueType.returnType = self.inferOrError(node.value)
else:
self.check(node.value, self.currentFunction.valueType.returnType)
self.expression(node.value, compile)
if compile:
self.emitByte(OpCode.SetResult, node.token.line)
if self.currentFunction.valueType.returnType.kind == Auto:
self.currentFunction.valueType.returnType = self.inferOrError(node.value)
self.check(node.value, self.currentFunction.valueType.returnType)
self.expression(node.value)
self.emitByte(OpCode.SetResult, node.token.line)
# Since the "set result" part and "exit the function" part
# of our return mechanism are already decoupled into two
# separate opcodes, we perform the former and then jump to
@ -2503,8 +2492,7 @@ proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) =
# the function has any local variables or not, this jump might be
# patched to jump to the function's PopN/PopC instruction(s) rather
# than straight to the return statement
if compile:
self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line))
self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line))
proc continueStmt(self: Compiler, node: ContinueStmt, compile: bool = true) =
@ -2570,12 +2558,10 @@ proc exportStmt(self: Compiler, node: ExportStmt, compile: bool = true) =
discard
proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) =
proc breakStmt(self: Compiler, node: BreakStmt) =
## Compiles break statements. A break statement
## jumps to the end of the loop
if node.label.isNil():
if not compile:
return
self.currentLoop.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line))
if self.currentLoop.depth > self.depth:
# Breaking out of a loop closes its scope
@ -2594,7 +2580,7 @@ proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) =
self.error(&"unknown block name '{node.label.token.lexeme}'", node.label)
proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) =
proc namedBlock(self: Compiler, node: NamedBlockStmt) =
## Compiles named blocks
self.beginScope()
var blk = self.namedBlocks[^1]
@ -2606,80 +2592,78 @@ proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) =
self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last)
else:
discard
if blk.broken and compile:
if blk.broken:
blk.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line))
self.declaration(decl, compile)
self.declaration(decl)
last = decl
if compile:
self.patchBreaks()
self.patchBreaks()
self.endScope()
proc statement(self: Compiler, node: Statement, compile: bool = true) =
proc statement(self: Compiler, node: Statement) =
## Compiles all statements
case node.kind:
of exprStmt:
let expression = ExprStmt(node).expression
let kind = self.infer(expression)
self.expression(expression, compile)
self.expression(expression)
if kind.isNil():
# The expression has no type and produces no value,
# so we don't have to pop anything
discard
elif self.replMode and compile:
elif self.replMode:
self.printRepl(kind, expression)
elif compile:
else:
self.emitByte(Pop, node.token.line)
of NodeKind.namedBlockStmt:
self.namedBlocks.add(NamedBlock(start: self.chunk.code.len(),
depth: self.depth,
breakJumps: @[],
name: NamedBlockStmt(node).name.token.lexeme))
self.namedBlock(NamedBlockStmt(node), compile)
self.namedBlock(NamedBlockStmt(node),)
discard self.namedBlocks.pop()
of NodeKind.ifStmt:
self.ifStmt(IfStmt(node), compile)
self.ifStmt(IfStmt(node))
of NodeKind.assertStmt:
self.assertStmt(AssertStmt(node), compile)
self.assertStmt(AssertStmt(node))
of NodeKind.raiseStmt:
self.raiseStmt(RaiseStmt(node), compile)
self.raiseStmt(RaiseStmt(node))
of NodeKind.breakStmt:
self.breakStmt(BreakStmt(node), compile)
self.breakStmt(BreakStmt(node))
of NodeKind.continueStmt:
self.continueStmt(ContinueStmt(node), compile)
self.continueStmt(ContinueStmt(node))
of NodeKind.returnStmt:
self.returnStmt(ReturnStmt(node), compile)
self.returnStmt(ReturnStmt(node))
of NodeKind.importStmt:
self.importStmt(ImportStmt(node), compile)
self.importStmt(ImportStmt(node))
of NodeKind.exportStmt:
self.exportStmt(ExportStmt(node), compile)
self.exportStmt(ExportStmt(node))
of NodeKind.whileStmt:
# Note: Our parser already desugars
# for loops to while loops
let loop = self.currentLoop
self.currentLoop = Loop(start: self.chunk.code.len(),
depth: self.depth, breakJumps: @[])
self.whileStmt(WhileStmt(node), compile)
if compile:
self.patchBreaks()
self.whileStmt(WhileStmt(node))
self.patchBreaks()
self.currentLoop = loop
of NodeKind.forEachStmt:
self.forEachStmt(ForEachStmt(node), compile)
self.forEachStmt(ForEachStmt(node))
of NodeKind.blockStmt:
self.blockStmt(BlockStmt(node), compile)
self.blockStmt(BlockStmt(node))
of NodeKind.yieldStmt:
self.yieldStmt(YieldStmt(node), compile)
self.yieldStmt(YieldStmt(node))
of NodeKind.awaitStmt:
self.awaitStmt(AwaitStmt(node), compile)
self.awaitStmt(AwaitStmt(node))
of NodeKind.deferStmt:
self.deferStmt(DeferStmt(node), compile)
self.deferStmt(DeferStmt(node))
of NodeKind.tryStmt:
discard
else:
self.expression(Expression(node))
proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) =
proc varDecl(self: Compiler, node: VarDecl) =
## Compiles variable declarations
# Our parser guarantees that the variable declaration
@ -2715,7 +2699,7 @@ proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) =
name.valueType = typ
proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) =
proc funDecl(self: Compiler, node: FunDecl, name: Name) =
## Compiles function declarations
if node.token.kind == Operator and node.name.token.lexeme in [".", "="]:
self.error(&"Due to compiler limitations, the '{node.name.token.lexeme}' operator cannot be currently overridden", node.name)
@ -2737,6 +2721,7 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) =
self.stackIndex = name.position
# A function's code is just compiled linearly
# and then jumped over
name.valueType.compiled = true
jmp = self.emitJump(JumpForwards, node.token.line)
name.codePos = self.chunk.code.len()
name.valueType.location = name.codePos
@ -2818,14 +2803,14 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) =
self.stackIndex = stackIdx
proc declaration(self: Compiler, node: Declaration, compile: bool = true) =
proc declaration(self: Compiler, node: Declaration) =
## Compiles declarations, statements and expressions
## recursively
case node.kind:
of NodeKind.funDecl:
var name = self.declare(node)
if not name.valueType.isAuto:
self.funDecl(FunDecl(node), name, compile=compile)
self.funDecl(FunDecl(node), name)
if name.isGeneric:
# After we're done compiling a generic
# function, we pull a magic trick: since,
@ -2851,9 +2836,9 @@ proc declaration(self: Compiler, node: Declaration, compile: bool = true) =
# We compile this immediately because we
# need to keep the stack in the right state
# at runtime
self.varDecl(VarDecl(node), compile)
self.varDecl(VarDecl(node))
else:
self.statement(Statement(node), compile)
self.statement(Statement(node))
proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tuple[start, stop: int]], source: string, chunk: Chunk = nil,

View File

@ -702,6 +702,9 @@ proc `$`*(self: ASTNode): string =
of blockStmt:
var self = BlockStmt(self)
result &= &"""Block([{self.code.join(", ")}])"""
of namedBlockStmt:
var self = NamedBlockStmt(self)
result &= &"""Block(name={self.name}, [{self.code.join(", ")}])"""
of whileStmt:
var self = WhileStmt(self)
result &= &"While(condition={self.condition}, body={self.body})"

View File

@ -11,4 +11,4 @@ print(x == 1);
print(sum(1, 2) == 3);
print(sum(1'i32, 2'i32) == 3'i32);
print(sum(1.0, 2.0) == 3.0);
#print(sum(1'i32, 2'i16)); # Will fail to compile
#print(sum(1'i32, 2'i16)); # Will fail to compile