From 7ebd13f739698503a0234dfb8ac42fac8f10b9f1 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Mon, 5 Dec 2022 17:09:09 +0100 Subject: [PATCH] Fixed bugs in automatic types --- src/frontend/compiler.nim | 159 +++++++++++++++++--------------------- src/frontend/meta/ast.nim | 3 + tests/auto.pn | 2 +- 3 files changed, 76 insertions(+), 88 deletions(-) diff --git a/src/frontend/compiler.nim b/src/frontend/compiler.nim index 070ca77..938678f 100644 --- a/src/frontend/compiler.nim +++ b/src/frontend/compiler.nim @@ -63,6 +63,7 @@ type retJumps: seq[int] forwarded: bool location: int + compiled: bool of CustomType: fields: TableRef[string, Type] of Reference, Pointer: @@ -279,11 +280,11 @@ proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tu incremental: bool = false, isMainModule: bool = true, disabledWarnings: seq[WarningKind] = @[], showMismatches: bool = false, mode: CompileMode = Debug): Chunk proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.} -proc statement(self: Compiler, node: Statement, compile: bool = true) -proc declaration(self: Compiler, node: Declaration, compile: bool = true) +proc statement(self: Compiler, node: Statement) +proc declaration(self: Compiler, node: Declaration) proc peek(self: Compiler, distance: int = 0): ASTNode proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.} -proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) +proc varDecl(self: Compiler, node: VarDecl) proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discardable.} proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.} @@ -303,7 +304,7 @@ proc handlePurePragma(self: Compiler, pragma: Pragma, name: Name) proc handleErrorPragma(self: Compiler, pragma: Pragma, name: Name) proc dispatchPragmas(self: Compiler, name: Name) proc dispatchDelayedPragmas(self: Compiler, name: Name) -proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) +proc funDecl(self: Compiler, node: FunDecl, name: Name) proc compileModule(self: Compiler, module: Name) proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) proc prepareFunction(self: Compiler, fn: Name) @@ -1066,7 +1067,7 @@ proc findByName(self: Compiler, name: string): seq[Name] = ## with the given name. Returns all objects that apply. for obj in reversed(self.names): if obj.ident.token.lexeme == name: - if obj.owner != self.currentModule: + if obj.owner.path != self.currentModule.path: if obj.isPrivate or self.currentModule notin obj.exportedTo: continue result.add(obj) @@ -1172,7 +1173,7 @@ proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowF self.error(msg, node) elif impl.len() > 1: # Forward declarations don't count when looking for a function - impl = filterIt(impl, not it.valueType.forwarded) + impl = filterIt(impl, not it.valueType.forwarded and not it.valueType.isAuto) if impl.len() > 1: # If it's *still* more than one match, then it's an error var msg = &"multiple matching implementations of '{name}' found" @@ -1313,7 +1314,7 @@ proc handleBuiltinFunction(self: Compiler, fn: Type, args: seq[Expression], line self.patchJump(jump) of "LogicalAnd": self.expression(args[0]) - var jump = self.emitJump(JumpIfFalseOrPop, line) + let jump = self.emitJump(JumpIfFalseOrPop, line) self.expression(args[1]) self.patchJump(jump) else: @@ -1989,7 +1990,7 @@ proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Type {.dis elif r.isLet: self.error(&"cannot reassign '{name.token.lexeme}' (value is immutable)", name) self.check(node.value, r.valueType) - self.expression(node.value) + self.expression(node.value, compile) var position = r.position if r.depth < self.depth: self.warning(WarningKind.MutateOuterScope, &"mutation of '{r.ident.token.lexeme}' declared in outer scope ({r.owner.file}.pn:{r.ident.token.line}:{r.ident.token.relPos.start})", nil, node) @@ -2024,43 +2025,35 @@ proc blockStmt(self: Compiler, node: BlockStmt, compile: bool = true) = self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) else: discard - self.declaration(decl, compile) + self.declaration(decl) last = decl self.endScope() -proc ifStmt(self: Compiler, node: IfStmt, compile: bool = true) = +proc ifStmt(self: Compiler, node: IfStmt) = ## Compiles if/else statements for conditional ## execution of code self.check(node.condition, Type(kind: Bool)) self.expression(node.condition) - var jump: int - var jump2: int - if compile: - jump = self.emitJump(JumpIfFalsePop, node.token.line) - self.statement(node.thenBranch, compile) - if compile: - jump2 = self.emitJump(JumpForwards, node.token.line) - self.patchJump(jump) + let jump = self.emitJump(JumpIfFalsePop, node.token.line) + self.statement(node.thenBranch) + let jump2 = self.emitJump(JumpForwards, node.token.line) + self.patchJump(jump) if not node.elseBranch.isNil(): - self.statement(node.elseBranch, compile) - if compile: - self.patchJump(jump2) + self.statement(node.elseBranch) + self.patchJump(jump2) -proc whileStmt(self: Compiler, node: WhileStmt, compile: bool = true) = +proc whileStmt(self: Compiler, node: WhileStmt) = ## Compiles C-style while loops and ## desugared C-style for loops self.check(node.condition, Type(kind: Bool)) let start = self.chunk.code.high() self.expression(node.condition) - var jump: int - if compile: - jump = self.emitJump(JumpIfFalsePop, node.token.line) - self.statement(node.body, compile) - if compile: - self.emitLoop(start, node.token.line) - self.patchJump(jump) + let jump = self.emitJump(JumpIfFalsePop, node.token.line) + self.statement(node.body) + self.emitLoop(start, node.token.line) + self.patchJump(jump) proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} = @@ -2161,6 +2154,8 @@ proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string, var default: Expression var node = FunDecl(fn.node) var fn = deepCopy(fn) + fn.valueType.isAuto = false + fn.valueType.compiled = false self.names.add(fn) for (argument, val) in zip(node.arguments, args): if self.names.high() > 16777215: @@ -2277,10 +2272,10 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda if impl.valueType.isAuto: impl = self.prepareAutoFunction(impl, args) result = impl.valueType - self.funDecl(FunDecl(result.fun), impl, compile=compile) + if not impl.valueType.compiled: + self.funDecl(FunDecl(result.fun), impl) result = result.returnType if compile: - # Now we call it self.generateCall(impl, argExpr, node.token.line) of NodeKind.callExpr: # Calling a call expression, like hello()() @@ -2307,8 +2302,7 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda if impl.isGeneric: result = self.specialize(result, argExpr) result = result.returnType - if compile: - self.generateCall(impl, argExpr, node.token.line) + self.generateCall(impl, argExpr, node.token.line) # TODO: Calling lambdas on-the-fly (i.e. on the same line) else: let typ = self.infer(node) @@ -2445,53 +2439,48 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {. # TODO -proc awaitStmt(self: Compiler, node: AwaitStmt, compile: bool = true) = +proc awaitStmt(self: Compiler, node: AwaitStmt) = ## Compiles await statements # TODO -proc deferStmt(self: Compiler, node: DeferStmt, compile: bool = true) = +proc deferStmt(self: Compiler, node: DeferStmt) = ## Compiles defer statements # TODO -proc yieldStmt(self: Compiler, node: YieldStmt, compile: bool = true) = +proc yieldStmt(self: Compiler, node: YieldStmt) = ## Compiles yield statements # TODO -proc raiseStmt(self: Compiler, node: RaiseStmt, compile: bool = true) = +proc raiseStmt(self: Compiler, node: RaiseStmt) = ## Compiles raise statements # TODO -proc assertStmt(self: Compiler, node: AssertStmt, compile: bool = true) = +proc assertStmt(self: Compiler, node: AssertStmt) = ## Compiles assert statements # TODO # TODO -proc forEachStmt(self: Compiler, node: ForEachStmt, compile: bool = true) = +proc forEachStmt(self: Compiler, node: ForEachStmt) = ## Compiles foreach loops -proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) = +proc returnStmt(self: Compiler, node: ReturnStmt) = ## Compiles return statements if self.currentFunction.valueType.returnType.isNil() and not node.value.isNil(): self.error("cannot return a value from a void function", node.value) elif not self.currentFunction.valueType.returnType.isNil() and node.value.isNil(): self.error("bare return statement is only allowed in void functions", node) if not node.value.isNil(): - if not self.currentFunction.valueType.isAuto: - self.check(node.value, self.currentFunction.valueType.returnType) - else: - if self.currentFunction.valueType.returnType.kind == Auto: - self.currentFunction.valueType.returnType = self.inferOrError(node.value) - else: - self.check(node.value, self.currentFunction.valueType.returnType) - self.expression(node.value, compile) - if compile: - self.emitByte(OpCode.SetResult, node.token.line) + if self.currentFunction.valueType.returnType.kind == Auto: + self.currentFunction.valueType.returnType = self.inferOrError(node.value) + self.check(node.value, self.currentFunction.valueType.returnType) + self.expression(node.value) + self.emitByte(OpCode.SetResult, node.token.line) # Since the "set result" part and "exit the function" part # of our return mechanism are already decoupled into two # separate opcodes, we perform the former and then jump to @@ -2503,8 +2492,7 @@ proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) = # the function has any local variables or not, this jump might be # patched to jump to the function's PopN/PopC instruction(s) rather # than straight to the return statement - if compile: - self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) + self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) proc continueStmt(self: Compiler, node: ContinueStmt, compile: bool = true) = @@ -2570,12 +2558,10 @@ proc exportStmt(self: Compiler, node: ExportStmt, compile: bool = true) = discard -proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) = +proc breakStmt(self: Compiler, node: BreakStmt) = ## Compiles break statements. A break statement ## jumps to the end of the loop if node.label.isNil(): - if not compile: - return self.currentLoop.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) if self.currentLoop.depth > self.depth: # Breaking out of a loop closes its scope @@ -2594,7 +2580,7 @@ proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) = self.error(&"unknown block name '{node.label.token.lexeme}'", node.label) -proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) = +proc namedBlock(self: Compiler, node: NamedBlockStmt) = ## Compiles named blocks self.beginScope() var blk = self.namedBlocks[^1] @@ -2606,80 +2592,78 @@ proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) = self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) else: discard - if blk.broken and compile: + if blk.broken: blk.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) - self.declaration(decl, compile) + self.declaration(decl) last = decl - if compile: - self.patchBreaks() + self.patchBreaks() self.endScope() -proc statement(self: Compiler, node: Statement, compile: bool = true) = +proc statement(self: Compiler, node: Statement) = ## Compiles all statements case node.kind: of exprStmt: let expression = ExprStmt(node).expression let kind = self.infer(expression) - self.expression(expression, compile) + self.expression(expression) if kind.isNil(): # The expression has no type and produces no value, # so we don't have to pop anything discard - elif self.replMode and compile: + elif self.replMode: self.printRepl(kind, expression) - elif compile: + else: self.emitByte(Pop, node.token.line) of NodeKind.namedBlockStmt: self.namedBlocks.add(NamedBlock(start: self.chunk.code.len(), depth: self.depth, breakJumps: @[], name: NamedBlockStmt(node).name.token.lexeme)) - self.namedBlock(NamedBlockStmt(node), compile) + self.namedBlock(NamedBlockStmt(node),) discard self.namedBlocks.pop() of NodeKind.ifStmt: - self.ifStmt(IfStmt(node), compile) + self.ifStmt(IfStmt(node)) of NodeKind.assertStmt: - self.assertStmt(AssertStmt(node), compile) + self.assertStmt(AssertStmt(node)) of NodeKind.raiseStmt: - self.raiseStmt(RaiseStmt(node), compile) + self.raiseStmt(RaiseStmt(node)) of NodeKind.breakStmt: - self.breakStmt(BreakStmt(node), compile) + self.breakStmt(BreakStmt(node)) of NodeKind.continueStmt: - self.continueStmt(ContinueStmt(node), compile) + self.continueStmt(ContinueStmt(node)) of NodeKind.returnStmt: - self.returnStmt(ReturnStmt(node), compile) + self.returnStmt(ReturnStmt(node)) of NodeKind.importStmt: - self.importStmt(ImportStmt(node), compile) + self.importStmt(ImportStmt(node)) of NodeKind.exportStmt: - self.exportStmt(ExportStmt(node), compile) + self.exportStmt(ExportStmt(node)) of NodeKind.whileStmt: # Note: Our parser already desugars # for loops to while loops let loop = self.currentLoop self.currentLoop = Loop(start: self.chunk.code.len(), depth: self.depth, breakJumps: @[]) - self.whileStmt(WhileStmt(node), compile) - if compile: - self.patchBreaks() + self.whileStmt(WhileStmt(node)) + self.patchBreaks() self.currentLoop = loop of NodeKind.forEachStmt: - self.forEachStmt(ForEachStmt(node), compile) + self.forEachStmt(ForEachStmt(node)) of NodeKind.blockStmt: - self.blockStmt(BlockStmt(node), compile) + self.blockStmt(BlockStmt(node)) of NodeKind.yieldStmt: - self.yieldStmt(YieldStmt(node), compile) + self.yieldStmt(YieldStmt(node)) of NodeKind.awaitStmt: - self.awaitStmt(AwaitStmt(node), compile) + self.awaitStmt(AwaitStmt(node)) of NodeKind.deferStmt: - self.deferStmt(DeferStmt(node), compile) + self.deferStmt(DeferStmt(node)) of NodeKind.tryStmt: discard else: self.expression(Expression(node)) -proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) = +proc varDecl(self: Compiler, node: VarDecl) = ## Compiles variable declarations # Our parser guarantees that the variable declaration @@ -2715,7 +2699,7 @@ proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) = name.valueType = typ -proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) = +proc funDecl(self: Compiler, node: FunDecl, name: Name) = ## Compiles function declarations if node.token.kind == Operator and node.name.token.lexeme in [".", "="]: self.error(&"Due to compiler limitations, the '{node.name.token.lexeme}' operator cannot be currently overridden", node.name) @@ -2737,6 +2721,7 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) = self.stackIndex = name.position # A function's code is just compiled linearly # and then jumped over + name.valueType.compiled = true jmp = self.emitJump(JumpForwards, node.token.line) name.codePos = self.chunk.code.len() name.valueType.location = name.codePos @@ -2818,14 +2803,14 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) = self.stackIndex = stackIdx -proc declaration(self: Compiler, node: Declaration, compile: bool = true) = +proc declaration(self: Compiler, node: Declaration) = ## Compiles declarations, statements and expressions ## recursively case node.kind: of NodeKind.funDecl: var name = self.declare(node) if not name.valueType.isAuto: - self.funDecl(FunDecl(node), name, compile=compile) + self.funDecl(FunDecl(node), name) if name.isGeneric: # After we're done compiling a generic # function, we pull a magic trick: since, @@ -2851,9 +2836,9 @@ proc declaration(self: Compiler, node: Declaration, compile: bool = true) = # We compile this immediately because we # need to keep the stack in the right state # at runtime - self.varDecl(VarDecl(node), compile) + self.varDecl(VarDecl(node)) else: - self.statement(Statement(node), compile) + self.statement(Statement(node)) proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tuple[start, stop: int]], source: string, chunk: Chunk = nil, diff --git a/src/frontend/meta/ast.nim b/src/frontend/meta/ast.nim index 461f8da..89f6758 100644 --- a/src/frontend/meta/ast.nim +++ b/src/frontend/meta/ast.nim @@ -702,6 +702,9 @@ proc `$`*(self: ASTNode): string = of blockStmt: var self = BlockStmt(self) result &= &"""Block([{self.code.join(", ")}])""" + of namedBlockStmt: + var self = NamedBlockStmt(self) + result &= &"""Block(name={self.name}, [{self.code.join(", ")}])""" of whileStmt: var self = WhileStmt(self) result &= &"While(condition={self.condition}, body={self.body})" diff --git a/tests/auto.pn b/tests/auto.pn index 36e7e7b..8d1d5de 100644 --- a/tests/auto.pn +++ b/tests/auto.pn @@ -11,4 +11,4 @@ print(x == 1); print(sum(1, 2) == 3); print(sum(1'i32, 2'i32) == 3'i32); print(sum(1.0, 2.0) == 3.0); -#print(sum(1'i32, 2'i16)); # Will fail to compile +#print(sum(1'i32, 2'i16)); # Will fail to compile \ No newline at end of file