From e046981f4b456b130923ea3b6d01cca6f3034a22 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Wed, 2 Nov 2022 12:03:14 +0100 Subject: [PATCH] Fix for closures --- src/frontend/compiler.nim | 470 +++++++++++++++++++------------------- tests/closures.pn | 22 +- 2 files changed, 253 insertions(+), 239 deletions(-) diff --git a/src/frontend/compiler.nim b/src/frontend/compiler.nim index e53677e..8aa58d4 100644 --- a/src/frontend/compiler.nim +++ b/src/frontend/compiler.nim @@ -192,7 +192,7 @@ type # be empty) deferred: seq[uint8] # List of closed-over variables - closedOver: seq[Name] + closures: seq[Name] # Compiler procedures called by pragmas compilerProcs: TableRef[string, proc (self: Compiler, pragma: Pragma, node: ASTNode, name: Name)] # Stores line data for error reporting @@ -229,8 +229,9 @@ proc identifier(self: Compiler, node: IdentExpr) proc varDecl(self: Compiler, node: VarDecl, name: Name) proc specialize(self: Compiler, name: Name, args: seq[Expression]): Name proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil): Name -proc infer(self: Compiler, node: LiteralExpr): Type -proc infer(self: Compiler, node: Expression): Type +proc infer(self: Compiler, node: LiteralExpr, allowGeneric: bool = false): Type +proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type +proc inferOrError[T: LiteralExpr | Expression](self: Compiler, node: T, allowGeneric: bool = false): Type proc findByName(self: Compiler, name: string): seq[Name] proc findByModule(self: Compiler, name: string): seq[Name] proc findByType(self: Compiler, name: string, kind: Type, depth: int = -1): seq[Name] @@ -268,6 +269,7 @@ proc newCompiler*(replMode: bool = false): Compiler = result.lexer.fillSymbolTable() result.parser = newParser() result.isMainModule = false + result.closures = @[] ## Public getters for nicer error formatting @@ -485,73 +487,77 @@ proc fixJumps(self: Compiler, oldLen: int, modifiedAt: int) = self.setJump(jump.offset, self.chunk.code[jump.offset + 1..jump.offset + 3]) -proc resolve(self: Compiler, name: IdentExpr, - depth: int = self.scopeDepth): Name = - ## Traverses self.names backwards and returns the - ## first name object with the given name. Returns - ## nil when the name can't be found. This function - ## has no concept of scope depth, because getStackPos - ## does that job. Note that private names declared in - ## other modules will not be resolved! Another important - ## thing to remember is that if a name is being resolved - ## for the first time, calling this function will also - ## cause its declaration to be compiled - for obj in reversed(self.names): - if obj.ident.token.lexeme == name.token.lexeme: - if obj.owner != self.currentModule: - if obj.isPrivate: - continue - elif obj.exported: - return obj - if not obj.resolved: - obj.resolved = true - obj.codePos = self.chunk.code.len() - case obj.kind: - of NameKind.Var: - self.varDecl(VarDecl(obj.node), obj) - of NameKind.CustomType: - self.typeDecl(TypeDecl(obj.node), obj) - of NameKind.Function: - if not obj.valueType.isGeneric: - self.funDecl(FunDecl(obj.node), obj) - # Generic functions need to be compiled at - # the call site, but regular functions can - # be precompiled as soon as we resolve them - else: - discard - return obj - return nil - - -proc resolve(self: Compiler, name: string, - depth: int = self.scopeDepth): Name = - ## Version of resolve that takes strings instead - ## of AST nodes +proc resolve(self: Compiler, name: string): Name = + ## Traverses all existing namespaces and returns + ## the first object with the given name. Returns + ## nil when the name can't be found. Note that, + ## when a declaration is first resolved, it is + ## also compiled on-the-fly for obj in reversed(self.names): if obj.ident.token.lexeme == name: if obj.owner != self.currentModule: + # We don't own this name, but we + # may still have access to it if obj.isPrivate: + # Name is private in its owner + # module, so we definitely can't + # use it continue elif obj.exported: - return obj - if not obj.resolved: - obj.resolved = true - obj.codePos = self.chunk.code.len() - case obj.kind: - of NameKind.Var: - self.varDecl(VarDecl(obj.node), obj) - of NameKind.CustomType: - self.typeDecl(TypeDecl(obj.node), obj) - of NameKind.Function: - if not obj.valueType.isGeneric: - self.funDecl(FunDecl(obj.node), obj) - # Generic functions need to be compiled at - # the call site, but regular functions can - # be precompiled as soon as we resolve them - else: - discard - return obj - return nil + # The name is public in its owner + # module and said module has explicitly + # exported it to us: we can use it + result = obj + break + # If the name is public but not exported in + # its owner module, then we act as if it's + # private. This is to avoid namespace pollution + # from imports (i.e. if module A imports modules + # C and D and module B imports module A, then B + # might not want to also have access to C's and D's + # names as they might clash with its own stuff) + result = obj + break + if not result.isNil() and not result.resolved: + # There's no reason to compile a declaration + # unless it is used at least once: this way + # not only do we save space if a name is declared + # but never used, but it also makes it easier to + # implement generics. Yay! + result.resolved = true + # Now we just dispatch to one of our functions to + # compile the declaration + case result.kind: + of NameKind.Var: + self.varDecl(VarDecl(result.node), result) + of NameKind.CustomType: + self.typeDecl(TypeDecl(result.node), result) + of NameKind.Function: + # Generic functions need to be compiled at + # the call site because we need to know the + # type of the arguments, but regular functions + # can be precompiled as soon as we resolve them + if not result.valueType.isGeneric: + self.funDecl(FunDecl(result.node), result) + else: + discard + + +proc resolve(self: Compiler, name: IdentExpr): Name = + ## Version of resolve that takes Identifier + ## AST nodes instead of strings + return self.resolve(name.token.lexeme) + + +proc resolveOrError[T: IdentExpr | string](self: Compiler, name: T): Name = + ## Calls self.resolve() and errors out with an appropriate + ## message if it returns nil + result = self.resolve(name) + if result.isNil(): + when T is IdentExpr: + self.error(&"reference to undefined name '{name.token.lexeme}'", name) + when T is string: + self.error(&"reference to undefined name '{name}'") proc getStackPos(self: Compiler, name: Name): int = @@ -560,7 +566,7 @@ proc getStackPos(self: Compiler, name: Name): int = var found = false result = 2 for variable in self.names: - if variable.kind in [NameKind.Module, NameKind.CustomType, NameKind.Enum, NameKind.Function]: + if variable.kind in [NameKind.Module, NameKind.CustomType, NameKind.Enum, NameKind.Function, NameKind.None]: continue elif variable.kind == NameKind.Argument and variable.depth > self.scopeDepth: continue @@ -583,16 +589,12 @@ proc getClosurePos(self: Compiler, name: Name): int = ## environment if not self.currentFunction.valueType.isClosure: return -1 - result = self.currentFunction.valueType.envLen - 1 - var i = result - while i <= result and i >= 0: - if name == self.closedOver[i]: + for i, e in self.closures: + if e == name: return i - dec(result) return -1 - proc compare(self: Compiler, a, b: Type): bool = ## Compares two type objects ## for equality (works with nil!) @@ -728,7 +730,7 @@ proc toIntrinsic(name: string): Type = return nil -proc infer(self: Compiler, node: LiteralExpr): Type = +proc infer(self: Compiler, node: LiteralExpr, allowGeneric: bool = false): Type = ## Infers the type of a given literal expression ## (if the expression is nil, nil is returned) if node.isNil(): @@ -772,10 +774,11 @@ proc infer(self: Compiler, node: LiteralExpr): Type = discard # TODO -proc infer(self: Compiler, node: Expression): Type = +proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type = ## Infers the type of a given expression and ## returns it (if the node is nil, nil is ## returned). Always returns a concrete type + ## unless allowGeneric is set to true if node.isNil(): return nil case node.kind: @@ -784,7 +787,7 @@ proc infer(self: Compiler, node: Expression): Type = var name = self.resolve(node) if not name.isNil(): result = name.valueType - if not result.isNil() and result.kind == Generic: + if not result.isNil() and result.kind == Generic and not allowGeneric: if name.belongsTo.isNil(): name = self.resolve(result.name) if not name.isNil(): @@ -799,13 +802,13 @@ proc infer(self: Compiler, node: Expression): Type = let node = UnaryExpr(node) let impl = self.matchImpl(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a))]), node) result = impl.valueType.returnType - if result.kind == Generic: + if result.kind == Generic and not allowGeneric: result = self.specialize(impl, @[node.a]).valueType.returnType of binaryExpr: let node = BinaryExpr(node) let impl = self.matchImpl(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a)), ("", self.infer(node.b))]), node) result = impl.valueType.returnType - if result.kind == Generic: + if result.kind == Generic and not allowGeneric: result = self.specialize(impl, @[node.a, node.b]).valueType.returnType of {intExpr, hexExpr, binExpr, octExpr, strExpr, falseExpr, trueExpr, infExpr, @@ -851,7 +854,26 @@ proc infer(self: Compiler, node: Expression): Type = result = self.infer(GroupingExpr(node).expression) else: discard # Unreachable - + + +proc inferOrError[T: LiteralExpr | Expression](self: Compiler, node: T, allowGeneric: bool = false): Type = + ## Attempts to infer the type of + ## the given expression and raises an + ## error with an appropriate message if + ## it fails + result = self.infer(node, allowGeneric) + if result.isNil(): + case node.kind: + of identExpr: + self.error(&"reference to undefined name '{IdentExpr(node).token.lexeme}'", node) + of callExpr: + let node = CallExpr(node) + if node.callee.kind == identExpr: + self.error(&"call to undefined function '{IdentExpr(node.callee).token.lexeme}'", node) + else: + self.error("expression has no type", node) + else: + self.error("expression has no type", node) proc typeToStr(self: Compiler, typ: Type): string = @@ -966,18 +988,13 @@ proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil): N result = impl[0] + proc check(self: Compiler, term: Expression, kind: Type, allowAny: bool = false) = ## Checks the type of term against a known type. ## Raises an error if appropriate and returns ## otherwise - let k = self.infer(term) - if k.isNil(): - if term.kind == identExpr and self.resolve(IdentExpr(term)).isNil(): - self.error(&"reference to undeclared name '{term.token.lexeme}'", term) - elif term.kind == callExpr and CallExpr(term).callee.kind == identExpr: - self.error(&"call to undeclared function '{CallExpr(term).callee.token.lexeme}'", term) - self.error(&"expecting value of type '{self.typeToStr(kind)}', but expression has no type", term) - elif k.kind == Any and not allowAny: + let k = self.inferOrError(term) + if k.kind == Any and not allowAny: # Any should only be used internally: error! self.error("'all' is not a valid type in this context", term) elif not self.compare(k, kind): @@ -1095,21 +1112,28 @@ proc endScope(self: Compiler) = dec(self.scopeDepth) var names: seq[Name] = @[] var popCount = 0 + if self.scopeDepth == -1 and not self.isMainModule: + # When we're compiling another module, we don't + # close its global scope because self.compileModule() + # needs access to it + return for name in self.names: - if self.scopeDepth == -1 and not self.isMainModule: - continue if name.depth > self.scopeDepth: + if not name.belongsTo.isNil() and not name.belongsTo.resolved: + continue names.add(name) #[if not name.resolved: # TODO: Emit a warning? continue]# if name.owner != self.currentModule and self.scopeDepth > -1: + # Names coming from other modules only go out of scope + # when the global scope is closed (i.e. at the end of + # the module) continue - if name.kind in [NameKind.Var, NameKind.Argument]: - # We don't increase the pop count for some kinds of objects - # because they're not stored the same way as regular variables - # (for types, generics and function declarations) - if name.belongsTo.isNil() or not name.belongsTo.valueType.isBuiltinFunction: + if name.kind == NameKind.Var: + inc(popCount) + elif name.kind == NameKind.Argument: + if not name.belongsTo.valueType.isBuiltinFunction and name.belongsTo.resolved: # We don't pop arguments to builtin functions because those don't # actually have scopes: their arguments are temporaries on the stack inc(popCount) @@ -1118,7 +1142,8 @@ proc endScope(self: Compiler) = # This includes the environments of every other closure that may # have been contained within it, too var i = 0 - var all = flatten(name.valueType) + var envLen = 0 + var lastEnvLen = 0 # Why this? Well, it's simple: if a function returns # a closure, that function becomes a closure too. The # environments of closures are aligned one after the @@ -1130,31 +1155,23 @@ proc endScope(self: Compiler) = # environment is larger than the contained one, which will # guarantee there actually is some value that the contained # function is closing over - var envLen = 0 - var lastEnvLen = 0 - for fn in all: + for fn in flatten(name.valueType): if fn.isClosure and fn.envLen > lastEnvLen: envLen += fn.envLen lastEnvLen = fn.envLen for y in 0.. 1: - # If we're popping less than 65535 variables, then - # we can emit a PopN instruction. This is true for - # 99.99999% of the use cases of the language (who the - # hell is going to use 65 THOUSAND variables?), but - # if you'll ever use more then Peon will emit a PopN instruction - # for the first 65 thousand and change local variables and then - # emit another batch of plain ol' Pop instructions for the rest - self.emitByte(PopN, self.peek().token.line) - self.emitBytes(popCount.toDouble(), self.peek().token.line) - if popCount > uint16.high().int(): - for i in countdown(self.names.high(), popCount - uint16.high().int()): - if self.names[i].depth > self.scopeDepth: - self.emitByte(PopC, self.peek().token.line) + # If we're popping more than one variable, + # we emit a bunch of PopN instructions until + # the pop count is greater than zero + while popCount > 0: + self.emitByte(PopN, self.peek().token.line) + self.emitBytes(popCount.toDouble(), self.peek().token.line) + popCount -= popCount.toDouble().fromDouble().int elif popCount == 1: # We only emit PopN if we're popping more than one value self.emitByte(PopC, self.peek().token.line) @@ -1175,10 +1192,7 @@ proc unpackGenerics(self: Compiler, condition: Expression, list: var seq[tuple[m ## Recursively unpacks a type constraint in a generic type case condition.kind: of identExpr: - let name = self.infer(condition) - if name.isNil(): - self.error(&"cannot infer type of '{IdentExpr(condition).token.lexeme}' in generic declaration", condition) - list.add((accept, name)) + list.add((accept, self.inferOrError(condition))) of binaryExpr: let condition = BinaryExpr(condition) case condition.operator.lexeme: @@ -1232,76 +1246,70 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false): Name = result = self.names[^1] of NodeKind.funDecl: var node = FunDecl(node) - var generics: seq[Name] = @[] - if node.generics.len() > 0: - # We declare the generics before the function so we - # can refer to them later - var constraints: seq[tuple[match: bool, kind: Type]] - for gen in node.generics: - constraints = @[] - self.unpackGenerics(gen.cond, constraints) - self.names.add(Name(depth: self.scopeDepth + 1, - isPrivate: true, - isConst: false, - owner: self.currentModule, - line: node.token.line, - valueType: Type(kind: Generic, name: gen.name.token.lexeme, mutable: false, cond: constraints), - ident: gen.name, - node: node)) - generics.add(self.names[^1]) - let fn = Name(depth: self.scopeDepth, - isPrivate: node.isPrivate, - isConst: false, - owner: self.currentModule, - valueType: Type(kind: Function, - returnType: self.infer(node.returnType), - args: @[], - fun: node, - children: @[]), - ident: node.name, - node: node, - isLet: false, - line: node.token.line, - kind: NameKind.Function, - belongsTo: self.currentFunction) - self.names.add(fn) - var name: Name - for argument in node.arguments: + result = Name(depth: self.scopeDepth, + isPrivate: node.isPrivate, + isConst: false, + owner: self.currentModule, + valueType: Type(kind: Function, + returnType: nil, # We check it later + args: @[], + fun: node, + children: @[]), + ident: node.name, + node: node, + isLet: false, + line: node.token.line, + kind: NameKind.Function, + belongsTo: self.currentFunction) + # First we declare the function's generics, if it has any. + # This is because the function's return type may in itself + # be a generic, so it needs to exist first + var constraints: seq[tuple[match: bool, kind: Type]] = @[] + for gen in node.generics: + self.unpackGenerics(gen.cond, constraints) + self.names.add(Name(depth: result.depth + 1, + isPrivate: true, + valueType: Type(kind: Generic, name: gen.name.token.lexeme, mutable: false, cond: constraints), + codePos: 0, + isLet: false, + line: result.node.token.line, + belongsTo: result, + ident: gen.name, + owner: self.currentModule)) + constraints = @[] + if not node.returnType.isNil(): + result.valueType.returnType = self.inferOrError(node.returnType, allowGeneric=true) + self.names.add(result) + # We now declare and typecheck the function's + # arguments + for argument in FunDecl(result.node).arguments: if self.names.high() > 16777215: self.error("cannot declare more than 16777215 variables at a time") - # wait, no LoadVar? Yes! That's because when calling functions, - # arguments will already be on the stack, so there's no need to - # load them here - name = Name(depth: fn.depth + 1, - isPrivate: true, - owner: self.currentModule, - isConst: false, - ident: argument.name, - valueType: self.infer(argument.valueType), - codePos: 0, - isLet: false, - line: argument.name.token.line, - belongsTo: fn, - kind: NameKind.Argument - ) - self.names.add(name) - # If it's nil, it's an error! - if name.valueType.isNil(): - self.error(&"cannot determine the type of argument '{argument.name.token.lexeme}'", argument.name) - fn.valueType.args.add((argument.name.token.lexeme, name.valueType)) - if generics.len() > 0: - fn.valueType.isGeneric = true - result = fn + self.names.add(Name(depth: result.depth + 1, + isPrivate: true, + owner: self.currentModule, + isConst: false, + ident: argument.name, + valueType: self.inferOrError(argument.valueType, allowGeneric=true), + codePos: 0, + isLet: false, + line: argument.name.token.line, + belongsTo: result, + kind: NameKind.Argument + )) + result.valueType.args.add((self.names[^1].ident.token.lexeme, self.names[^1].valueType)) + if node.generics.len() > 0: + result.valueType.isGeneric = true of NodeKind.importStmt: var node = ImportStmt(node) var name = node.moduleName.token.lexeme.extractFilename().replace(".pn", "") declaredName = name self.names.add(Name(depth: self.scopeDepth, - owner: self.currentModule, - ident: newIdentExpr(Token(kind: Identifier, lexeme: name, line: node.moduleName.token.line)), - line: node.moduleName.token.line, - kind: NameKind.Module, - isPrivate: false + owner: self.currentModule, + ident: newIdentExpr(Token(kind: Identifier, lexeme: name, line: node.moduleName.token.line)), + line: node.moduleName.token.line, + kind: NameKind.Module, + isPrivate: false )) result = self.names[^1] else: @@ -1555,21 +1563,20 @@ proc binary(self: Compiler, node: BinaryExpr) = proc identifier(self: Compiler, node: IdentExpr) = ## Compiles access to identifiers - var s = self.resolve(node) - if s.isNil(): - self.error(&"reference to undeclared name '{node.token.lexeme}'") - elif s.isConst: + var s = self.resolveOrError(node) + if s.isConst: # Constants are always emitted as Load* instructions # no matter the scope depth self.emitConstant(node, self.infer(node)) else: - if s.valueType.kind == Function and s.kind == NameKind.Function: - # Functions have no runtime - # representation: they're just - # a location to jump to + if s.kind == NameKind.Function: + # Functions have no runtime representation, they're just + # a location to jump to, but we pretend they aren't and + # resolve them to their address into our bytecode when + # they're referenced self.emitByte(LoadUInt64, node.token.line) self.emitBytes(self.chunk.writeConstant(s.codePos.toLong()), node.token.line) - elif self.scopeDepth > 0 and not self.currentFunction.isNil() and s.depth != self.scopeDepth and self.scopeOwners[s.depth].owner != self.currentFunction: + elif s.depth > 0 and self.scopeDepth > 0 and not self.currentFunction.isNil() and s.depth != self.scopeDepth and self.scopeOwners[s.depth].owner != self.currentFunction: # Loads a closure variable. Stored in a separate "closure array" in the VM that does not # align its semantics with the call stack. This makes closures work as expected and is # not much slower than indexing our stack (since they're both dynamic arrays at runtime anyway) @@ -1580,11 +1587,11 @@ proc identifier(self: Compiler, node: IdentExpr) = fn.envLen += 1 if fn.parent.isNil(): break - fn = fn.parent + fn = fn.parent s.isClosedOver = true - self.closedOver.add(s) + self.closures.add(s) let stackIdx = self.getStackPos(s).toTriple() - let closeIdx = self.getClosurePos(s).toTriple() + let closeIdx = self.closures.high().toTriple() let oldLen = self.chunk.code.len() # This madness makes it so that we can insert bytecode # at arbitrary offsets into our alredy compiled code and @@ -1599,12 +1606,15 @@ proc identifier(self: Compiler, node: IdentExpr) = self.chunk.lines[self.chunk.getIdx(self.chunk.getLine(s.belongsTo.codePos)) + 1] += 7 self.fixJumps(oldLen, s.belongsTo.codePos) self.fixCFIOffsets(oldLen, s.belongsTo.codePos) + let pos = self.getClosurePos(s) + if pos == -1: + self.error(&"cannot compute closure offset for '{s.ident.token.lexeme}'", s.ident) self.emitByte(LoadClosure, node.token.line) - self.emitBytes(self.getClosurePos(s).toTriple(), node.token.line) + self.emitBytes(pos.toTriple(), node.token.line) else: # Static name resolution, loads value at index in the stack. Very fast. Much wow. self.emitByte(LoadVar, node.token.line) - # No need to check for -1 here: we already did a nil-check above! + # No need to check for -1 here: we already did a nil check above!รน self.emitBytes(self.getStackPos(s).toTriple(), node.token.line) @@ -1614,13 +1624,11 @@ proc assignment(self: Compiler, node: ASTNode) = of assignExpr: let node = AssignExpr(node) let name = IdentExpr(node.name) - var r = self.resolve(name) - if r.isNil(): - self.error(&"assignment to undeclared name '{name.token.lexeme}'", name) - elif r.isConst: - self.error(&"cannot assign to '{name.token.lexeme}' (constant)", name) + var r = self.resolveOrError(name) + if r.isConst: + self.error(&"cannot assign to '{name.token.lexeme}' (value is a constant)", name) elif r.isLet: - self.error(&"cannot reassign '{name.token.lexeme}'", name) + self.error(&"cannot reassign '{name.token.lexeme}' (value is immutable)", name) self.expression(node.value) if not r.isClosedOver: self.emitByte(StoreVar, node.token.line) @@ -1710,17 +1718,8 @@ proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) = return case fn.kind: of NameKind.Var: - # We're trying to call a function assigned to a variable, - # so we resolve it if it's an identifier (lambdas coming soon!) - case VarDecl(fn.node).value.kind: - of identExpr: - let fn = self.matchImpl(IdentExpr(VarDecl(fn.node).value).token.lexeme, fn.valueType) - self.identifier(IdentExpr(VarDecl(fn.node).value)) - else: - discard # TODO + self.identifier(VarDecl(fn.node).name) of NameKind.Function: - # Just a regular function declaration: we can load its address - # normally self.emitByte(LoadUInt64, line) self.emitBytes(self.chunk.writeConstant(fn.codePos.toLong()), line) else: @@ -1794,8 +1793,6 @@ proc callExpr(self: Compiler, node: CallExpr): Name {.discardable.} = dec(i) kind = self.infer(argument) if kind.isNil(): - if argument.kind == identExpr and self.resolve(IdentExpr(argument)).isNil(): - self.error(&"reference to undeclared name '{IdentExpr(argument).name.lexeme}'") if node.callee.kind != identExpr: self.error(&"cannot infer the type of argument {i + 1} in call") else: @@ -1807,13 +1804,12 @@ proc callExpr(self: Compiler, node: CallExpr): Name {.discardable.} = # Calls like hi() result = self.matchImpl(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: args), node) if result.valueType.isGeneric: - result = self.specialize(result, argExpr) # We can't instantiate a concrete version # of a generic function without the types # of its arguments, so we wait until the # very last moment to compile it, once # that info is available to us - self.funDecl(FunDecl(result.node), result) + result = self.specialize(result, argExpr) # Now we call it self.generateCall(result, argExpr, node.token.line) of NodeKind.callExpr: @@ -1914,8 +1910,7 @@ proc deferStmt(self: Compiler, node: DeferStmt) = proc returnStmt(self: Compiler, node: ReturnStmt) = ## Compiles return statements - var expected = self.currentFunction.valueType.returnType - self.check(node.value, expected) + self.check(node.value, self.currentFunction.valueType.returnType) if not node.value.isNil(): self.expression(node.value) self.emitByte(OpCode.SetResult, node.token.line) @@ -1924,9 +1919,12 @@ proc returnStmt(self: Compiler, node: ReturnStmt) = # separate opcodes, we perform the former and then jump to # the function's last return statement, which is always emitted # by funDecl() at the end of the function's lifecycle, greatly - # simplifying the whole thing since there's just one return + # simplifying the design, since now there's just one return # instruction to jump to instead of many potential points - # where the function returns from + # where the function returns from. Note that depending on whether + # the function has any local variables or not, this jump might be + # patched to jump to the function's PopN/PopC instruction(s) rather + # than straight to the return statement self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) @@ -1987,9 +1985,7 @@ proc importStmt(self: Compiler, node: ImportStmt) = proc exportStmt(self: Compiler, node: ExportStmt) = ## Exports a name at compile time to ## all modules importing us - var name = self.resolve(node.name) - if name.isNil(): - self.error(&"reference to undefined name '{node.name.token.lexeme}'") + var name = self.resolveOrError(node.name) if name.isPrivate: self.error("cannot export private names") name.exported = true @@ -2102,23 +2098,27 @@ proc statement(self: Compiler, node: Statement) = proc varDecl(self: Compiler, node: VarDecl, name: Name) = ## Compiles variable declarations - let expected = self.infer(node.valueType) - let actual = self.infer(node.value) - if expected.isNil() and actual.isNil(): - if node.value.kind == identExpr or node.value.kind == callExpr and CallExpr(node.value).callee.kind == identExpr: - var name = node.value.token.lexeme - if node.value.kind == callExpr and CallExpr(node.value).callee.kind == identExpr: - name = CallExpr(node.value).callee.token.lexeme - if self.resolve(name).isNil(): - self.error(&"reference to undeclared name '{name}'") - self.error(&"'{node.name.token.lexeme}' has no type") - if not expected.isNil() and expected.mutable: # I mean, variables *are* already mutable (some of them anyway) - self.error(&"invalid type '{self.typeToStr(expected)}' for var") - elif not self.compare(expected, actual): - if not expected.isNil(): - self.error(&"expected value of type '{self.typeToStr(expected)}', but '{node.name.token.lexeme}' is of type '{self.typeToStr(actual)}'") - if expected.isNil(): - name.valueType = actual + + # Our parser guarantees that the variable declaration + # will have a type declaration or a value (or both) + var typ: Type + if node.value.isNil(): + # Variable has no value: the type declaration + # takes over + typ = self.inferOrError(node.valueType) + elif node.valueType.isNil: + # Variable has no type declaration: the type + # of its value takes over + typ = self.inferOrError(node.value) + else: + # Variable has both a type declaration and + # a value: the value's type must match the + # type declaration + let expected = self.inferOrError(node.valueType) + self.check(node.value, expected) + # If this doesn't fail, then we're good + typ = expected + name.valueType = typ self.expression(node.value) self.emitByte(StoreVar, node.token.line) self.emitBytes(self.getStackPos(name).toTriple(), node.token.line) diff --git a/tests/closures.pn b/tests/closures.pn index c9f72c4..12c77f7 100644 --- a/tests/closures.pn +++ b/tests/closures.pn @@ -1,11 +1,25 @@ # Tests closures -fn makeClosure(n: int): fn: int { +import std; + + +fn makeClosure(x: int): fn: int { fn inner: int { - return n; + return x; } return inner; } -var closure = makeClosure(38); -closure(); \ No newline at end of file +fn makeClosureTwo(y: int): fn: int { + fn inner: int { + return y; + } + return inner; +} + + +var closure = makeClosure(42); +print(closure()); # 42 +print(makeClosureTwo(38)()); # 38 +var closureTwo = makeClosureTwo(420); +print(closureTwo()); # 420 \ No newline at end of file