From af3c7234be3984fc05030d5e623a962c028584d7 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Tue, 6 Dec 2022 12:55:05 +0100 Subject: [PATCH] Added template support and related test --- src/frontend/compiler.nim | 201 +++++++++++++++++------------ src/frontend/meta/ast.nim | 3 +- src/frontend/meta/token.nim | 2 +- src/frontend/parser.nim | 28 ++-- src/peon/stdlib/builtins/values.pn | 1 - src/util/symbols.nim | 1 + tests/templates.pn | 11 ++ 7 files changed, 147 insertions(+), 100 deletions(-) create mode 100644 tests/templates.pn diff --git a/src/frontend/compiler.nim b/src/frontend/compiler.nim index 636f6c1..1290c69 100644 --- a/src/frontend/compiler.nim +++ b/src/frontend/compiler.nim @@ -1383,6 +1383,8 @@ proc endScope(self: Compiler) = # Automatic functions do not materialize # at runtime, so their arguments don't either continue + if not name.isReal: + continue inc(popCount) if not name.resolved: case name.kind: @@ -1391,7 +1393,7 @@ proc endScope(self: Compiler) = self.warning(UnusedName, &"'{name.ident.token.lexeme}' is declared but not used (add '_' prefix to silence warning)", name) of NameKind.Argument: if not name.ident.token.lexeme.startsWith("_") and name.isPrivate: - if not name.belongsTo.isNil() and not name.belongsTo.isBuiltin and name.belongsTo.isReal: + if not name.belongsTo.isNil() and not name.belongsTo.isBuiltin and name.belongsTo.isReal and name.belongsTo.resolved: # Builtin functions never use their arguments. We also don't emit this # warning if the function was generated internally by the compiler (for # example as a result of generic specialization) because such objects do @@ -1527,6 +1529,8 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} = kind: NameKind.Function, belongsTo: self.currentFunction, isReal: true) + if node.isTemplate: + fn.valueType.compiled = true if node.generics.len() > 0: fn.isGeneric = true var typ: Type @@ -1968,10 +1972,13 @@ proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool else: discard # Unreachable else: - # Loads a regular variable from the current frame - self.emitByte(LoadVar, s.ident.token.line) - # No need to check for -1 here: we already did a nil check above! - self.emitBytes(s.position.toTriple(), s.ident.token.line) + if not s.belongsTo.isNil() and s.belongsTo.valueType.fun.kind == funDecl and FunDecl(s.belongsTo.valueType.fun).isTemplate: + discard + else: + # Loads a regular variable from the current frame + self.emitByte(LoadVar, s.ident.token.line) + # No need to check for -1 here: we already did a nil check above! + self.emitBytes(s.position.toTriple(), s.ident.token.line) proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Type {.discardable.} = @@ -2119,7 +2126,8 @@ proc prepareFunction(self: Compiler, fn: Name) = belongsTo: fn, kind: NameKind.Argument, node: argument.name, - position: self.stackIndex + position: self.stackIndex, + isReal: not node.isTemplate )) if node.arguments.high() - node.defaults.high() <= node.arguments.high(): # There's a default argument! @@ -2133,6 +2141,8 @@ proc prepareFunction(self: Compiler, fn: Name) = fn.valueType.returnType = self.inferOrError(FunDecl(fn.node).returnType) fn.position = self.stackIndex self.stackIndex = idx + if node.isTemplate: + fn.valueType.compiled = true proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string, kind: Type, default: Expression]]): Name = @@ -2170,14 +2180,18 @@ proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string, belongsTo: fn, kind: NameKind.Argument, node: argument.name, - position: self.stackIndex + position: self.stackIndex, + isReal: not node.isTemplate )) + if node.isTemplate: + fn.valueType.compiled = true fn.valueType.args = args fn.position = self.stackIndex self.stackIndex = idx return fn + proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) = ## Small wrapper that abstracts emitting a call instruction ## for a given function @@ -2272,7 +2286,23 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda self.funDecl(FunDecl(result.fun), impl) result = result.returnType if compile: - self.generateCall(impl, argExpr, node.token.line) + if impl.valueType.fun.kind == funDecl and FunDecl(impl.valueType.fun).isTemplate: + for arg in reversed(argExpr): + self.expression(arg) + let code = BlockStmt(FunDecl(impl.valueType.fun).body).code + for i, decl in code: + if i < code.high(): + self.declaration(decl) + else: + # The last expression in a template + # is its return type, so we compute + # it, but don't pop it off the stack + if decl.kind == exprStmt: + self.expression(ExprStmt(decl).expression) + else: + self.declaration(decl) + else: + self.generateCall(impl, argExpr, node.token.line) of NodeKind.callExpr: # Calling a call expression, like hello()() var node: Expression = node @@ -2718,84 +2748,85 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name) = return let stackIdx = self.stackIndex self.stackIndex = name.position - # A function's code is just compiled linearly - # and then jumped over - name.valueType.compiled = true - jmp = self.emitJump(JumpForwards, node.token.line) - name.codePos = self.chunk.code.len() - name.valueType.location = name.codePos - # We let our debugger know this function's boundaries - self.chunk.functions.add(self.chunk.code.len().toTriple()) - self.functions.add((start: self.chunk.code.len(), stop: 0, pos: self.chunk.functions.len() - 3, fn: name)) - var offset = self.functions[^1] - let idx = self.chunk.functions.len() - self.chunk.functions.add(0.toTriple()) # Patched it later - self.chunk.functions.add(uint8(node.arguments.len())) - if not node.name.isNil(): - self.chunk.functions.add(name.ident.token.lexeme.len().toDouble()) - var s = name.ident.token.lexeme - if s.len() >= uint16.high().int: - s = node.name.token.lexeme[0..uint16.high()] - self.chunk.functions.add(s.toBytes()) - else: - self.chunk.functions.add(0.toDouble()) - if BlockStmt(node.body).code.len() == 0: - self.error("cannot declare function with empty body") - # Since the deferred array is a linear - # sequence of instructions and we want - # to keep track to whose function's each - # set of deferred instruction belongs, - # we record the length of the deferred - # array before compiling the function - # and use this info later to compile - # the try/finally block with the deferred - # code - var deferStart = self.deferred.len() - var last: Declaration - self.beginScope() - for decl in BlockStmt(node.body).code: - if not last.isNil(): - if last.kind == returnStmt: - self.warning(UnreachableCode, "code after 'return' statement is unreachable") - self.declaration(decl) - last = decl - let typ = self.currentFunction.valueType.returnType - var hasVal: bool = false - case self.currentFunction.valueType.fun.kind: - of NodeKind.funDecl: - hasVal = FunDecl(self.currentFunction.valueType.fun).hasExplicitReturn - of NodeKind.lambdaExpr: - hasVal = LambdaExpr(self.currentFunction.valueType.fun).hasExplicitReturn + if not node.isTemplate: + # A function's code is just compiled linearly + # and then jumped over + name.valueType.compiled = true + jmp = self.emitJump(JumpForwards, node.token.line) + name.codePos = self.chunk.code.len() + name.valueType.location = name.codePos + # We let our debugger know this function's boundaries + self.chunk.functions.add(self.chunk.code.len().toTriple()) + self.functions.add((start: self.chunk.code.len(), stop: 0, pos: self.chunk.functions.len() - 3, fn: name)) + var offset = self.functions[^1] + var idx = self.chunk.functions.len() + self.chunk.functions.add(0.toTriple()) # Patched it later + self.chunk.functions.add(uint8(node.arguments.len())) + if not node.name.isNil(): + self.chunk.functions.add(name.ident.token.lexeme.len().toDouble()) + var s = name.ident.token.lexeme + if s.len() >= uint16.high().int: + s = node.name.token.lexeme[0..uint16.high()] + self.chunk.functions.add(s.toBytes()) else: - discard # Unreachable - if not hasVal and not typ.isNil(): - # There is no explicit return statement anywhere in the function's - # body: while this is not a tremendously useful piece of information - # (since the presence of at least one doesn't mean all control flow - # cases are covered), it definitely is an error worth reporting - self.error("function has an explicit return type, but no return statement was found", node) - hasVal = hasVal and not typ.isNil() - for jump in self.currentFunction.valueType.retJumps: - self.patchJump(jump) - self.endScope() - # Terminates the function's context - self.emitByte(OpCode.Return, self.peek().token.line) - if hasVal: - self.emitByte(1, self.peek().token.line) - else: - self.emitByte(0, self.peek().token.line) - # Currently defer is not functional, so we - # just pop the instructions - for _ in deferStart..self.deferred.high(): - discard self.deferred.pop() - let stop = self.chunk.code.len().toTriple() - self.chunk.functions[idx] = stop[0] - self.chunk.functions[idx + 1] = stop[1] - self.chunk.functions[idx + 2] = stop[2] - offset.stop = self.chunk.code.len() - # Well, we've compiled everything: time to patch - # the jump offset - self.patchJump(jmp) + self.chunk.functions.add(0.toDouble()) + if BlockStmt(node.body).code.len() == 0: + self.error("cannot declare function with empty body") + # Since the deferred array is a linear + # sequence of instructions and we want + # to keep track to whose function's each + # set of deferred instruction belongs, + # we record the length of the deferred + # array before compiling the function + # and use this info later to compile + # the try/finally block with the deferred + # code + var deferStart = self.deferred.len() + var last: Declaration + self.beginScope() + for decl in BlockStmt(node.body).code: + if not last.isNil(): + if last.kind == returnStmt: + self.warning(UnreachableCode, "code after 'return' statement is unreachable") + self.declaration(decl) + last = decl + let typ = self.currentFunction.valueType.returnType + var hasVal: bool = false + case self.currentFunction.valueType.fun.kind: + of NodeKind.funDecl: + hasVal = FunDecl(self.currentFunction.valueType.fun).hasExplicitReturn + of NodeKind.lambdaExpr: + hasVal = LambdaExpr(self.currentFunction.valueType.fun).hasExplicitReturn + else: + discard # Unreachable + if not hasVal and not typ.isNil(): + # There is no explicit return statement anywhere in the function's + # body: while this is not a tremendously useful piece of information + # (since the presence of at least one doesn't mean all control flow + # cases are covered), it definitely is an error worth reporting + self.error("function has an explicit return type, but no return statement was found", node) + hasVal = hasVal and not typ.isNil() + for jump in self.currentFunction.valueType.retJumps: + self.patchJump(jump) + self.endScope() + # Terminates the function's context + self.emitByte(OpCode.Return, self.peek().token.line) + if hasVal: + self.emitByte(1, self.peek().token.line) + else: + self.emitByte(0, self.peek().token.line) + # Currently defer is not functional, so we + # just pop the instructions + for _ in deferStart..self.deferred.high(): + discard self.deferred.pop() + let stop = self.chunk.code.len().toTriple() + self.chunk.functions[idx] = stop[0] + self.chunk.functions[idx + 1] = stop[1] + self.chunk.functions[idx + 2] = stop[2] + offset.stop = self.chunk.code.len() + # Well, we've compiled everything: time to patch + # the jump offset + self.patchJump(jmp) # Restores the enclosing function (if any). # Makes nested calls work (including recursion) self.currentFunction = function diff --git a/src/frontend/meta/ast.nim b/src/frontend/meta/ast.nim index 89f6758..9f4057e 100644 --- a/src/frontend/meta/ast.nim +++ b/src/frontend/meta/ast.nim @@ -264,6 +264,7 @@ type defaults*: seq[Expression] isAsync*: bool isGenerator*: bool + isTemplate*: bool isPure*: bool returnType*: Expression hasExplicitReturn*: bool @@ -737,7 +738,7 @@ proc `$`*(self: ASTNode): string = result &= &"Var(name={self.name}, value={self.value}, const={self.isConst}, private={self.isPrivate}, type={self.valueType}, pragmas={self.pragmas})" of funDecl: var self = FunDecl(self) - result &= &"""FunDecl(name={self.name}, body={self.body}, type={self.returnType}, arguments=[{self.arguments.join(", ")}], defaults=[{self.defaults.join(", ")}], generics=[{self.generics.join(", ")}], async={self.isAsync}, generator={self.isGenerator}, private={self.isPrivate}, pragmas={self.pragmas})""" + result &= &"""FunDecl(name={self.name}, body={self.body}, type={self.returnType}, arguments=[{self.arguments.join(", ")}], defaults=[{self.defaults.join(", ")}], generics=[{self.generics.join(", ")}], async={self.isAsync}, generator={self.isGenerator}, template={self.isTemplate}, private={self.isPrivate}, pragmas={self.pragmas})""" of typeDecl: var self = TypeDecl(self) result &= &"""TypeDecl(name={self.name}, fields={self.fields}, defaults={self.defaults}, private={self.isPrivate}, pragmas={self.pragmas}, generics={self.generics}, parent={self.parent}, ref={self.isRef}, enum={self.isEnum})""" diff --git a/src/frontend/meta/token.nim b/src/frontend/meta/token.nim index 75cf79d..7ead699 100644 --- a/src/frontend/meta/token.nim +++ b/src/frontend/meta/token.nim @@ -38,7 +38,7 @@ type Yield, Defer, Try, Except, Finally, Type, Operator, Case, Enum, From, Ptr, Ref, Object, - Export, Block + Export, Block, Template # Literal types Integer, Float, String, Identifier, diff --git a/src/frontend/parser.nim b/src/frontend/parser.nim index f08f38e..75d4031 100644 --- a/src/frontend/parser.nim +++ b/src/frontend/parser.nim @@ -301,7 +301,7 @@ proc varDecl(self: Parser, isLet: bool = false, isConst: bool = false): Declaration proc parseFunExpr(self: Parser): LambdaExpr proc funDecl(self: Parser, isAsync: bool = false, isGenerator: bool = false, - isLambda: bool = false, isOperator: bool = false): Declaration + isLambda: bool = false, isOperator: bool = false, isTemplate: bool = false): Declaration proc declaration(self: Parser): Declaration proc parse*(self: Parser, tokens: seq[Token], file: string, lines: seq[tuple[start, stop: int]], source: string, persist: bool = false): seq[Declaration] # End of forward declarations @@ -335,7 +335,7 @@ proc primary(self: Parser): Expression = self.expect(RightParen, "unterminated parenthesized expression") of Yield: let tok = self.step() - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'yield' cannot be used outside functions", tok) elif self.currentFunction.token.kind != Generator: # It's easier than doing conversions for lambda/funDecl @@ -349,7 +349,7 @@ proc primary(self: Parser): Expression = result.file = self.file of Await: let tok = self.step() - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'await' cannot be used outside functions", tok) if self.currentFunction.token.kind != Coroutine: self.error("'await' can only be used inside coroutines", tok) @@ -659,7 +659,7 @@ proc breakStmt(self: Parser): Statement = proc deferStmt(self: Parser): Statement = ## Parses defer statements let tok = self.peek(-1) - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'defer' cannot be used outside functions") endOfLine("missing semicolon after 'defer'") result = newDeferStmt(self.expression(), tok) @@ -683,7 +683,7 @@ proc continueStmt(self: Parser): Statement = proc returnStmt(self: Parser): Statement = ## Parses return statements let tok = self.peek(-1) - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'return' cannot be used outside functions") var value: Expression if not self.check(Semicolon): @@ -704,7 +704,7 @@ proc returnStmt(self: Parser): Statement = proc yieldStmt(self: Parser): Statement = ## Parses yield statements let tok = self.peek(-1) - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'yield' cannot be outside functions") elif self.currentFunction.token.kind != Generator: self.error("'yield' can only be used inside generators") @@ -719,7 +719,7 @@ proc yieldStmt(self: Parser): Statement = proc awaitStmt(self: Parser): Statement = ## Parses await statements let tok = self.peek(-1) - if self.currentFunction.isNil(): + if self.currentFunction.isNil() or (self.currentFunction.kind == funDecl and FunDecl(self.currentFunction).isTemplate): self.error("'await' cannot be used outside functions") if self.currentFunction.token.kind != Coroutine: self.error("'await' can only be used inside coroutines") @@ -1033,13 +1033,13 @@ proc parseFunExpr(self: Parser): LambdaExpr = var arguments: seq[tuple[name: IdentExpr, valueType: Expression]] = @[] var defaults: seq[Expression] = @[] result = newLambdaExpr(arguments, defaults, nil, isGenerator=self.peek(-1).kind == Generator, - isAsync=self.peek(-1).kind == Coroutine, token=self.peek(-1), + isAsync=self.peek(-1).kind == Coroutine, token=self.peek(-1), returnType=nil, depth=self.scopeDepth) var parameter: tuple[name: IdentExpr, valueType: Expression] if self.match(LeftParen): self.parseDeclArguments(arguments, parameter, defaults) if self.match(":"): - if self.match([Function, Coroutine, Generator]): + if self.match([Function, Coroutine, Generator, Template]): result.returnType = self.parseFunExpr() else: result.returnType = self.expression() @@ -1082,7 +1082,7 @@ proc parseGenerics(self: Parser, decl: Declaration) = proc funDecl(self: Parser, isAsync: bool = false, isGenerator: bool = false, - isLambda: bool = false, isOperator: bool = false): Declaration = # Can't use just FunDecl because it can also return LambdaExpr! + isLambda: bool = false, isOperator: bool = false, isTemplate: bool = false): Declaration = # Can't use just FunDecl because it can also return LambdaExpr! ## Parses all types of functions, coroutines, generators and operators ## (with or without a name, where applicable) let tok = self.peek(-1) @@ -1128,7 +1128,7 @@ proc funDecl(self: Parser, isAsync: bool = false, isGenerator: bool = false, returnType=nil, depth=self.scopeDepth) if self.match(":"): # Function has explicit return type - if self.match([Function, Coroutine, Generator]): + if self.match([Function, Coroutine, Generator, Template]): # The function's return type is another # function. We specialize this case because # the type declaration for a function lacks @@ -1142,7 +1142,7 @@ proc funDecl(self: Parser, isAsync: bool = false, isGenerator: bool = false, self.parseDeclArguments(arguments, parameter, defaults) if self.match(":"): # Function's return type - if self.match([Function, Coroutine, Generator]): + if self.match([Function, Coroutine, Generator, Template]): returnType = self.parseFunExpr() else: returnType = self.expression() @@ -1154,6 +1154,7 @@ proc funDecl(self: Parser, isAsync: bool = false, isGenerator: bool = false, if self.match(TokenType.Pragma): for pragma in self.parsePragmas(): pragmas.add(pragma) + FunDecl(self.currentFunction).isTemplate = isTemplate FunDecl(self.currentFunction).body = self.blockStmt() else: # This is a forward declaration, so we explicitly @@ -1344,6 +1345,9 @@ proc declaration(self: Parser): Declaration = of Function: discard self.step() result = self.funDecl() + of Template: + discard self.step() + result = self.funDecl(isTemplate=true) of Coroutine: discard self.step() result = self.funDecl(isAsync=true) diff --git a/src/peon/stdlib/builtins/values.pn b/src/peon/stdlib/builtins/values.pn index 6dbc0f0..3b38465 100644 --- a/src/peon/stdlib/builtins/values.pn +++ b/src/peon/stdlib/builtins/values.pn @@ -1,6 +1,5 @@ # Stub type declarations for peon's intrinsic types - type int64* = object { #pragma[magic: "int64"] } diff --git a/src/util/symbols.nim b/src/util/symbols.nim index 69b6d3b..8fcef1b 100644 --- a/src/util/symbols.nim +++ b/src/util/symbols.nim @@ -56,6 +56,7 @@ proc fillSymbolTable*(tokenizer: Lexer) = tokenizer.symbols.addKeyword("ref", TokenType.Ref) tokenizer.symbols.addKeyword("ptr", TokenType.Ptr) tokenizer.symbols.addKeyword("block", TokenType.Block) + tokenizer.symbols.addKeyword("template", TokenType.Template) for sym in [">", "<", "=", "~", "/", "+", "-", "_", "*", "?", "@", ":", "==", "!=", ">=", "<=", "+=", "-=", "/=", "*=", "**=", "!", "%", "&", "|", "^", ">>", "<<"]: diff --git a/tests/templates.pn b/tests/templates.pn new file mode 100644 index 0000000..26e1530 --- /dev/null +++ b/tests/templates.pn @@ -0,0 +1,11 @@ +# A test for templates +import std; + + +template sum[T: Integer](a, b: T): T { + a + b; +} + + +print(sum(1, 2) == 3); # true +print(sum(1'i32, 2'i32) == 3'i32); # true