diff --git a/src/frontend/compiler.nim b/src/frontend/compiler.nim index f30b570..e42dff8 100644 --- a/src/frontend/compiler.nim +++ b/src/frontend/compiler.nim @@ -46,7 +46,7 @@ type UInt32, Int64, UInt64, Float32, Float64, Char, Byte, String, Function, CustomType, Nil, Nan, Bool, Inf, Typevar, Generic, - Reference, Pointer, Any, All, Union + Reference, Pointer, Any, All, Union, Auto Type = ref object ## A wrapper around ## compile-time types @@ -55,6 +55,7 @@ type isLambda: bool isGenerator: bool isCoroutine: bool + isAuto: bool args: seq[tuple[name: string, kind: Type, default: Expression]] returnType: Type builtinOp: string @@ -278,11 +279,11 @@ proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tu incremental: bool = false, isMainModule: bool = true, disabledWarnings: seq[WarningKind] = @[], showMismatches: bool = false, mode: CompileMode = Debug): Chunk proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.} -proc statement(self: Compiler, node: Statement) -proc declaration(self: Compiler, node: Declaration) +proc statement(self: Compiler, node: Statement, compile: bool = true) +proc declaration(self: Compiler, node: Declaration, compile: bool = true) proc peek(self: Compiler, distance: int = 0): ASTNode -proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.} -proc varDecl(self: Compiler, node: VarDecl) +proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.} +proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discardable.} proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.} @@ -302,7 +303,7 @@ proc handlePurePragma(self: Compiler, pragma: Pragma, name: Name) proc handleErrorPragma(self: Compiler, pragma: Pragma, name: Name) proc dispatchPragmas(self: Compiler, name: Name) proc dispatchDelayedPragmas(self: Compiler, name: Name) -proc funDecl(self: Compiler, node: FunDecl, name: Name) +proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) proc compileModule(self: Compiler, module: Name) proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) proc prepareFunction(self: Compiler, fn: Name) @@ -703,7 +704,7 @@ proc resolve(self: Compiler, name: string): Name = ## resolved, it is also compiled on-the-fly for obj in reversed(self.names): if obj.ident.token.lexeme == name: - if obj.owner != self.currentModule: + if obj.owner.path != self.currentModule.path: # We don't own this name, but we # may still have access to it if obj.isPrivate: @@ -725,8 +726,6 @@ proc resolve(self: Compiler, name: string): Name = # might not want to also have access to C's and D's # names as they might clash with its own stuff) continue - if obj.kind == Argument and obj.belongsTo != self.currentFunction: - continue result = obj result.resolved = true break @@ -888,6 +887,8 @@ proc toIntrinsic(name: string): Type = ## otherwise if name == "any": return Type(kind: Any) + elif name == "auto": + return Type(kind: Auto) elif name in ["int", "int64", "i64"]: return Type(kind: Int64) elif name in ["uint64", "u64", "uint"]: @@ -966,7 +967,7 @@ proc infer(self: Compiler, node: Expression): Type = return nil case node.kind: of NodeKind.identExpr: - result = self.identifier(IdentExpr(node), compile=false) + result = self.identifier(IdentExpr(node), compile=false, strict=false) of NodeKind.unaryExpr: result = self.unary(UnaryExpr(node), compile=false) of NodeKind.binaryExpr: @@ -1009,7 +1010,7 @@ proc stringify(self: Compiler, typ: Type): string = of Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float32, Float64, Char, Byte, String, Nil, TypeKind.Nan, Bool, - TypeKind.Inf: + TypeKind.Inf, Auto: result &= ($typ.kind).toLowerAscii() of Pointer: result &= &"ptr {self.stringify(typ.value)}" @@ -1518,7 +1519,8 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} = returnType: nil, # We check it later args: @[], fun: node, - forwarded: node.body.isNil()), + forwarded: node.body.isNil(), + isAuto: false), ident: node.name, node: node, isLet: false, @@ -1526,10 +1528,24 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} = kind: NameKind.Function, belongsTo: self.currentFunction, isReal: true) - self.names.add(fn) - n = fn if node.generics.len() > 0: fn.isGeneric = true + var typ: Type + for argument in node.arguments: + typ = self.infer(argument.valueType) + if not typ.isNil() and typ.kind == Auto: + fn.valueType.isAuto = true + if fn.isGeneric: + self.error("automatic types cannot be used within generics", argument.valueType) + break + typ = self.infer(node.returnType) + if not typ.isNil() and typ.kind == Auto: + fn.valueType.isAuto = true + if fn.isGeneric: + self.error("automatic types cannot be used within generics", node.returnType) + self.names.add(fn) + self.prepareFunction(fn) + n = fn of NodeKind.importStmt: var node = ImportStmt(node) # We change the name of the module internally so that @@ -1583,11 +1599,6 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} = discard # TODO: enums if not n.isNil(): self.dispatchPragmas(n) - case n.kind: - of NameKind.Function: - self.prepareFunction(n) - else: - discard for name in self.findByName(declaredName): if name == n: continue @@ -1914,16 +1925,24 @@ proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Type {.disc var default: Expression let fn = Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)]) let impl = self.match(node.token.lexeme, fn, node) - result = impl.valueType.returnType + result = impl.valueType + if impl.isGeneric: + result = self.specialize(result, @[node.a, node.b]) + result = result.returnType if compile: self.generateCall(impl, @[node.a, node.b], impl.line) -proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.} = +proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.} = ## Compiles access to identifiers var s = name if s.isNil(): - s = self.resolveOrError(node) + if strict: + s = self.resolveOrError(node) + else: + s = self.resolve(node) + if s.isNil() and not strict: + return nil result = s.valueType if not compile: return @@ -1991,7 +2010,7 @@ proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Type {.dis self.error(&"invalid AST node of kind {node.kind} at assignment(): {node} (This is an internal error and most likely a bug)") -proc blockStmt(self: Compiler, node: BlockStmt) = +proc blockStmt(self: Compiler, node: BlockStmt, compile: bool = true) = ## Compiles block statements, which create ## a new local scope self.beginScope() @@ -2003,35 +2022,43 @@ proc blockStmt(self: Compiler, node: BlockStmt) = self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) else: discard - self.declaration(decl) + self.declaration(decl, compile) last = decl self.endScope() -proc ifStmt(self: Compiler, node: IfStmt) = +proc ifStmt(self: Compiler, node: IfStmt, compile: bool = true) = ## Compiles if/else statements for conditional ## execution of code self.check(node.condition, Type(kind: Bool)) self.expression(node.condition) - let jump = self.emitJump(JumpIfFalsePop, node.token.line) - self.statement(node.thenBranch) - let jump2 = self.emitJump(JumpForwards, node.token.line) - self.patchJump(jump) + var jump: int + var jump2: int + if compile: + jump = self.emitJump(JumpIfFalsePop, node.token.line) + self.statement(node.thenBranch, compile) + if compile: + jump2 = self.emitJump(JumpForwards, node.token.line) + self.patchJump(jump) if not node.elseBranch.isNil(): - self.statement(node.elseBranch) - self.patchJump(jump2) + self.statement(node.elseBranch, compile) + if compile: + self.patchJump(jump2) -proc whileStmt(self: Compiler, node: WhileStmt) = +proc whileStmt(self: Compiler, node: WhileStmt, compile: bool = true) = ## Compiles C-style while loops and ## desugared C-style for loops self.check(node.condition, Type(kind: Bool)) let start = self.chunk.code.high() self.expression(node.condition) - let jump = self.emitJump(JumpIfFalsePop, node.token.line) - self.statement(node.body) - self.emitLoop(start, node.token.line) - self.patchJump(jump) + var jump: int + if compile: + jump = self.emitJump(JumpIfFalsePop, node.token.line) + self.statement(node.body, compile) + if compile: + self.emitLoop(start, node.token.line) + self.patchJump(jump) proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} = @@ -2082,8 +2109,8 @@ proc prepareFunction(self: Compiler, fn: Name) = let idx = self.stackIndex self.stackIndex = 1 var default: Expression - var i = 0 var node = FunDecl(fn.node) + var i = 0 for argument in node.arguments: if self.names.high() > 16777215: self.error("cannot declare more than 16777215 variables at a time") @@ -2094,7 +2121,7 @@ proc prepareFunction(self: Compiler, fn: Name) = file: fn.file, isConst: false, ident: argument.name, - valueType: self.inferOrError(argument.valueType), + valueType: if not fn.valueType.isAuto: self.inferOrError(argument.valueType) else: Type(kind: Any), codePos: 0, isLet: false, line: argument.name.token.line, @@ -2117,6 +2144,47 @@ proc prepareFunction(self: Compiler, fn: Name) = self.stackIndex = idx +proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string, kind: Type, default: Expression]]): Name = + ## "Prepares" an automatic function declaration + ## by declaring a concrete version of it along + ## with its arguments + + # First we declare the function's generics, if it has any. + # This is because the function's return type may in itself + # be a generic, so it needs to exist first + # We now declare and typecheck the function's + # arguments + let idx = self.stackIndex + self.stackIndex = 1 + var default: Expression + var node = FunDecl(fn.node) + var fn = deepCopy(fn) + self.names.add(fn) + for (argument, val) in zip(node.arguments, args): + if self.names.high() > 16777215: + self.error("cannot declare more than 16777215 variables at a time") + inc(self.stackIndex) + self.names.add(Name(depth: fn.depth + 1, + isPrivate: true, + owner: fn.owner, + file: fn.file, + isConst: false, + ident: argument.name, + valueType: val.kind, + codePos: 0, + isLet: false, + line: argument.name.token.line, + belongsTo: fn, + kind: NameKind.Argument, + node: argument.name, + position: self.stackIndex + )) + fn.valueType.args = args + fn.position = self.stackIndex + self.stackIndex = idx + return fn + + proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) = ## Small wrapper that abstracts emitting a call instruction ## for a given function @@ -2163,7 +2231,7 @@ proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discar continue kind = self.inferOrError(args[i]) if typ.name in mapping and not self.compare(kind, mapping[typ.name]): - self.error(&"expecting generic argument '{typ.name}' to be of type {self.stringify(mapping[typ.name])}, got {self.stringify(kind)}") + self.error(&"expecting generic argument '{typ.name}' to be of type {self.stringify(mapping[typ.name])}, got {self.stringify(kind)}", args[i]) mapping[typ.name] = kind result.args[i].kind = kind if not result.returnType.isNil() and result.returnType.kind == Generic: @@ -2200,10 +2268,14 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda case node.callee.kind: of NodeKind.identExpr: # Calls like hi() - let impl = self.match(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node) + var impl = self.match(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node) result = impl.valueType if impl.isGeneric: result = self.specialize(result, argExpr) + if impl.valueType.isAuto: + impl = self.prepareAutoFunction(impl, args) + result = impl.valueType + self.funDecl(FunDecl(result.fun), impl, compile=compile) result = result.returnType if compile: # Now we call it @@ -2371,45 +2443,53 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {. # TODO -proc awaitStmt(self: Compiler, node: AwaitStmt) = +proc awaitStmt(self: Compiler, node: AwaitStmt, compile: bool = true) = ## Compiles await statements # TODO -proc deferStmt(self: Compiler, node: DeferStmt) = +proc deferStmt(self: Compiler, node: DeferStmt, compile: bool = true) = ## Compiles defer statements # TODO -proc yieldStmt(self: Compiler, node: YieldStmt) = +proc yieldStmt(self: Compiler, node: YieldStmt, compile: bool = true) = ## Compiles yield statements # TODO -proc raiseStmt(self: Compiler, node: RaiseStmt) = +proc raiseStmt(self: Compiler, node: RaiseStmt, compile: bool = true) = ## Compiles raise statements # TODO -proc assertStmt(self: Compiler, node: AssertStmt) = +proc assertStmt(self: Compiler, node: AssertStmt, compile: bool = true) = ## Compiles assert statements # TODO # TODO -proc forEachStmt(self: Compiler, node: ForEachStmt) = +proc forEachStmt(self: Compiler, node: ForEachStmt, compile: bool = true) = ## Compiles foreach loops -proc returnStmt(self: Compiler, node: ReturnStmt) = +proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) = ## Compiles return statements if self.currentFunction.valueType.returnType.isNil() and not node.value.isNil(): self.error("cannot return a value from a void function", node.value) elif not self.currentFunction.valueType.returnType.isNil() and node.value.isNil(): self.error("bare return statement is only allowed in void functions", node) if not node.value.isNil(): - self.expression(node.value) - self.emitByte(OpCode.SetResult, node.token.line) + if not self.currentFunction.valueType.isAuto: + self.check(node.value, self.currentFunction.valueType.returnType) + else: + if self.currentFunction.valueType.returnType.kind == Auto: + self.currentFunction.valueType.returnType = self.inferOrError(node.value) + else: + self.check(node.value, self.currentFunction.valueType.returnType) + self.expression(node.value, compile) + if compile: + self.emitByte(OpCode.SetResult, node.token.line) # Since the "set result" part and "exit the function" part # of our return mechanism are already decoupled into two # separate opcodes, we perform the former and then jump to @@ -2421,17 +2501,19 @@ proc returnStmt(self: Compiler, node: ReturnStmt) = # the function has any local variables or not, this jump might be # patched to jump to the function's PopN/PopC instruction(s) rather # than straight to the return statement - self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) + if compile: + self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) -proc continueStmt(self: Compiler, node: ContinueStmt) = +proc continueStmt(self: Compiler, node: ContinueStmt, compile: bool = true) = ## Compiles continue statements. A continue statement ## jumps to the next iteration in a loop if node.label.isNil(): if self.currentLoop.start > 16777215: self.error("too much code to jump over in continue statement") - self.emitByte(Jump, node.token.line) - self.emitBytes(self.currentLoop.start.toTriple(), node.token.line) + if compile: + self.emitByte(Jump, node.token.line) + self.emitBytes(self.currentLoop.start.toTriple(), node.token.line) else: var blocks: seq[NamedBlock] = @[] var found: bool = false @@ -2442,27 +2524,29 @@ proc continueStmt(self: Compiler, node: ContinueStmt) = break if not found: self.error(&"unknown block name '{node.label.token.lexeme}'", node.label) - self.emitByte(Jump, node.token.line) - self.emitBytes(blocks[^1].start.toTriple(), node.token.line) + if compile: + self.emitByte(Jump, node.token.line) + self.emitBytes(blocks[^1].start.toTriple(), node.token.line) -proc importStmt(self: Compiler, node: ImportStmt) = +proc importStmt(self: Compiler, node: ImportStmt, compile: bool = true) = ## Imports a module at compile time self.declare(node) var module = self.names[^1] try: - self.compileModule(module) - # Importing a module automatically exports - # its public names to us - for name in self.findInModule("", module): - name.exportedTo.incl(self.currentModule) + if compile: + self.compileModule(module) + # Importing a module automatically exports + # its public names to us + for name in self.findInModule("", module): + name.exportedTo.incl(self.currentModule) except IOError: self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()}") except OSError: self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()} [errno {osLastError()}]") -proc exportStmt(self: Compiler, node: ExportStmt) = +proc exportStmt(self: Compiler, node: ExportStmt, compile: bool = true) = ## Exports a name at compile time to ## all modules importing us var name = self.resolveOrError(node.name) @@ -2484,10 +2568,12 @@ proc exportStmt(self: Compiler, node: ExportStmt) = discard -proc breakStmt(self: Compiler, node: BreakStmt) = +proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) = ## Compiles break statements. A break statement ## jumps to the end of the loop if node.label.isNil(): + if not compile: + return self.currentLoop.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) if self.currentLoop.depth > self.depth: # Breaking out of a loop closes its scope @@ -2506,7 +2592,7 @@ proc breakStmt(self: Compiler, node: BreakStmt) = self.error(&"unknown block name '{node.label.token.lexeme}'", node.label) -proc namedBlock(self: Compiler, node: NamedBlockStmt) = +proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) = ## Compiles named blocks self.beginScope() var blk = self.namedBlocks[^1] @@ -2518,79 +2604,80 @@ proc namedBlock(self: Compiler, node: NamedBlockStmt) = self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) else: discard - if blk.broken: + if blk.broken and compile: blk.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) - self.declaration(decl) + self.declaration(decl, compile) last = decl - self.patchBreaks() + if compile: + self.patchBreaks() self.endScope() -proc statement(self: Compiler, node: Statement) = +proc statement(self: Compiler, node: Statement, compile: bool = true) = ## Compiles all statements case node.kind: of exprStmt: let expression = ExprStmt(node).expression let kind = self.infer(expression) - self.expression(expression) + self.expression(expression, compile) if kind.isNil(): # The expression has no type and produces no value, # so we don't have to pop anything discard - elif self.replMode: + elif self.replMode and compile: self.printRepl(kind, expression) - else: + elif compile: self.emitByte(Pop, node.token.line) of NodeKind.namedBlockStmt: self.namedBlocks.add(NamedBlock(start: self.chunk.code.len(), depth: self.depth, breakJumps: @[], name: NamedBlockStmt(node).name.token.lexeme)) - self.namedBlock(NamedBlockStmt(node)) - #self.patchBreaks() + self.namedBlock(NamedBlockStmt(node), compile) discard self.namedBlocks.pop() of NodeKind.ifStmt: - self.ifStmt(IfStmt(node)) + self.ifStmt(IfStmt(node), compile) of NodeKind.assertStmt: - self.assertStmt(AssertStmt(node)) + self.assertStmt(AssertStmt(node), compile) of NodeKind.raiseStmt: - self.raiseStmt(RaiseStmt(node)) + self.raiseStmt(RaiseStmt(node), compile) of NodeKind.breakStmt: - self.breakStmt(BreakStmt(node)) + self.breakStmt(BreakStmt(node), compile) of NodeKind.continueStmt: - self.continueStmt(ContinueStmt(node)) + self.continueStmt(ContinueStmt(node), compile) of NodeKind.returnStmt: - self.returnStmt(ReturnStmt(node)) + self.returnStmt(ReturnStmt(node), compile) of NodeKind.importStmt: - self.importStmt(ImportStmt(node)) + self.importStmt(ImportStmt(node), compile) of NodeKind.exportStmt: - self.exportStmt(ExportStmt(node)) + self.exportStmt(ExportStmt(node), compile) of NodeKind.whileStmt: # Note: Our parser already desugars # for loops to while loops let loop = self.currentLoop self.currentLoop = Loop(start: self.chunk.code.len(), depth: self.depth, breakJumps: @[]) - self.whileStmt(WhileStmt(node)) - self.patchBreaks() + self.whileStmt(WhileStmt(node), compile) + if compile: + self.patchBreaks() self.currentLoop = loop of NodeKind.forEachStmt: - self.forEachStmt(ForEachStmt(node)) + self.forEachStmt(ForEachStmt(node), compile) of NodeKind.blockStmt: - self.blockStmt(BlockStmt(node)) + self.blockStmt(BlockStmt(node), compile) of NodeKind.yieldStmt: - self.yieldStmt(YieldStmt(node)) + self.yieldStmt(YieldStmt(node), compile) of NodeKind.awaitStmt: - self.awaitStmt(AwaitStmt(node)) + self.awaitStmt(AwaitStmt(node), compile) of NodeKind.deferStmt: - self.deferStmt(DeferStmt(node)) + self.deferStmt(DeferStmt(node), compile) of NodeKind.tryStmt: discard else: self.expression(Expression(node)) -proc varDecl(self: Compiler, node: VarDecl) = +proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) = ## Compiles variable declarations # Our parser guarantees that the variable declaration @@ -2600,6 +2687,8 @@ proc varDecl(self: Compiler, node: VarDecl) = # Variable has no value: the type declaration # takes over typ = self.inferOrError(node.valueType) + if typ.kind == Auto: + self.error("automatic types require initialization", node) elif node.valueType.isNil: # Variable has no type declaration: the type # of its value takes over @@ -2609,9 +2698,12 @@ proc varDecl(self: Compiler, node: VarDecl) = # a value: the value's type must match the # type declaration let expected = self.inferOrError(node.valueType) - self.check(node.value, expected) - # If this doesn't fail, then we're good - typ = expected + if expected.kind != Auto: + self.check(node.value, expected) + # If this doesn't fail, then we're good + typ = expected + else: + typ = self.infer(node.value) self.expression(node.value) self.emitByte(AddVar, node.token.line) self.declare(node) @@ -2621,7 +2713,7 @@ proc varDecl(self: Compiler, node: VarDecl) = name.valueType = typ -proc funDecl(self: Compiler, node: FunDecl, name: Name) = +proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) = ## Compiles function declarations if node.token.kind == Operator and node.name.token.lexeme in [".", "="]: self.error(&"Due to compiler limitations, the '{node.name.token.lexeme}' operator cannot be currently overridden", node.name) @@ -2724,41 +2816,42 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name) = self.stackIndex = stackIdx -proc declaration(self: Compiler, node: Declaration) = +proc declaration(self: Compiler, node: Declaration, compile: bool = true) = ## Compiles declarations, statements and expressions ## recursively case node.kind: of NodeKind.funDecl: var name = self.declare(node) - self.funDecl(FunDecl(node), name) - if name.isGeneric: - # After we're done compiling a generic - # function, we pull a magic trick: since, - # from here on, the user will be able to - # call this with any of the types in the - # generic constraint, we switch every generic - # to a type union (which, conveniently, have an - # identical layout) so that the compiler will - # typecheck the function as if its arguments - # were all types of the constraint at once, - # while still allowing the user to call it with - # any type in said constraint - for i, argument in name.valueType.args: - if argument.kind.kind != Generic: - continue - else: - argument.kind.asUnion = true - if not name.valueType.returnType.isNil() and name.valueType.returnType.kind == Generic: - name.valueType.returnType.asUnion = true + if not name.valueType.isAuto: + self.funDecl(FunDecl(node), name, compile=compile) + if name.isGeneric: + # After we're done compiling a generic + # function, we pull a magic trick: since, + # from here on, the user will be able to + # call this with any of the types in the + # generic constraint, we switch every generic + # to a type union (which, conveniently, have an + # identical layout) so that the compiler will + # typecheck the function as if its arguments + # were all types of the constraint at once, + # while still allowing the user to call it with + # any type in said constraint + for i, argument in name.valueType.args: + if argument.kind.kind != Generic: + continue + else: + argument.kind.asUnion = true + if not name.valueType.returnType.isNil() and name.valueType.returnType.kind == Generic: + name.valueType.returnType.asUnion = true of NodeKind.typeDecl: self.declare(node) of NodeKind.varDecl: # We compile this immediately because we # need to keep the stack in the right state # at runtime - self.varDecl(VarDecl(node)) + self.varDecl(VarDecl(node), compile) else: - self.statement(Statement(node)) + self.statement(Statement(node), compile) proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tuple[start, stop: int]], source: string, chunk: Chunk = nil, diff --git a/src/peon/stdlib/builtins/arithmetics.pn b/src/peon/stdlib/builtins/arithmetics.pn index 9619ce8..0a1d9ed 100644 --- a/src/peon/stdlib/builtins/arithmetics.pn +++ b/src/peon/stdlib/builtins/arithmetics.pn @@ -13,12 +13,12 @@ operator `+`*[T: Integer](a, b: T): T { } -operator `+`(a, b: float): float { +operator `+`*(a, b: float): float { #pragma[magic: "AddFloat64", pure] } -operator `+`(a, b: float32): float32 { +operator `+`*(a, b: float32): float32 { #pragma[magic: "AddFloat32", pure] } diff --git a/src/peon/stdlib/builtins/comparisons.pn b/src/peon/stdlib/builtins/comparisons.pn index dca4fd7..505df0b 100644 --- a/src/peon/stdlib/builtins/comparisons.pn +++ b/src/peon/stdlib/builtins/comparisons.pn @@ -12,11 +12,12 @@ operator `<`*[T: UnsignedInteger](a, b: T): bool { } -operator `==`*[T: Number](a, b: T): bool { +operator `==`*[T: Number | inf](a, b: T): bool { #pragma[magic: "Equal", pure] } -operator `!=`*[T: Number](a, b: T): bool { + +operator `!=`*[T: Number | inf](a, b: T): bool { #pragma[magic: "NotEqual", pure] } @@ -90,3 +91,4 @@ operator `<=`*(a, b: float32): bool { #pragma[magic: "Float32LessOrEqual", pure] } + diff --git a/src/peon/stdlib/builtins/values.pn b/src/peon/stdlib/builtins/values.pn index a267d18..6dbc0f0 100644 --- a/src/peon/stdlib/builtins/values.pn +++ b/src/peon/stdlib/builtins/values.pn @@ -65,6 +65,10 @@ type nan* = object { #pragma[magic: "nan"] } +type auto* = object { + #pragma[magic: "auto"] +} + # Some convenience aliases type int* = int64; type float* = float64; diff --git a/tests/auto.pn b/tests/auto.pn new file mode 100644 index 0000000..36e7e7b --- /dev/null +++ b/tests/auto.pn @@ -0,0 +1,14 @@ +# A test for automatic types +import std; + + +fn sum(a, b: auto): auto { + return a + b; +} + +var x: auto = 1; +print(x == 1); +print(sum(1, 2) == 3); +print(sum(1'i32, 2'i32) == 3'i32); +print(sum(1.0, 2.0) == 3.0); +#print(sum(1'i32, 2'i16)); # Will fail to compile