Added support for automatic types

This commit is contained in:
Mattia Giambirtone 2022-12-05 12:06:24 +01:00
parent c0f358e956
commit 572443a988
5 changed files with 233 additions and 120 deletions

View File

@ -46,7 +46,7 @@ type
UInt32, Int64, UInt64, Float32, Float64, UInt32, Int64, UInt64, Float32, Float64,
Char, Byte, String, Function, CustomType, Char, Byte, String, Function, CustomType,
Nil, Nan, Bool, Inf, Typevar, Generic, Nil, Nan, Bool, Inf, Typevar, Generic,
Reference, Pointer, Any, All, Union Reference, Pointer, Any, All, Union, Auto
Type = ref object Type = ref object
## A wrapper around ## A wrapper around
## compile-time types ## compile-time types
@ -55,6 +55,7 @@ type
isLambda: bool isLambda: bool
isGenerator: bool isGenerator: bool
isCoroutine: bool isCoroutine: bool
isAuto: bool
args: seq[tuple[name: string, kind: Type, default: Expression]] args: seq[tuple[name: string, kind: Type, default: Expression]]
returnType: Type returnType: Type
builtinOp: string builtinOp: string
@ -278,11 +279,11 @@ proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tu
incremental: bool = false, isMainModule: bool = true, disabledWarnings: seq[WarningKind] = @[], showMismatches: bool = false, incremental: bool = false, isMainModule: bool = true, disabledWarnings: seq[WarningKind] = @[], showMismatches: bool = false,
mode: CompileMode = Debug): Chunk mode: CompileMode = Debug): Chunk
proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.} proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.discardable.}
proc statement(self: Compiler, node: Statement) proc statement(self: Compiler, node: Statement, compile: bool = true)
proc declaration(self: Compiler, node: Declaration) proc declaration(self: Compiler, node: Declaration, compile: bool = true)
proc peek(self: Compiler, distance: int = 0): ASTNode proc peek(self: Compiler, distance: int = 0): ASTNode
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.} proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.}
proc varDecl(self: Compiler, node: VarDecl) proc varDecl(self: Compiler, node: VarDecl, compile: bool = true)
proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name
proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discardable.} proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discardable.}
proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.} proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discardable.}
@ -302,7 +303,7 @@ proc handlePurePragma(self: Compiler, pragma: Pragma, name: Name)
proc handleErrorPragma(self: Compiler, pragma: Pragma, name: Name) proc handleErrorPragma(self: Compiler, pragma: Pragma, name: Name)
proc dispatchPragmas(self: Compiler, name: Name) proc dispatchPragmas(self: Compiler, name: Name)
proc dispatchDelayedPragmas(self: Compiler, name: Name) proc dispatchDelayedPragmas(self: Compiler, name: Name)
proc funDecl(self: Compiler, node: FunDecl, name: Name) proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true)
proc compileModule(self: Compiler, module: Name) proc compileModule(self: Compiler, module: Name)
proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int)
proc prepareFunction(self: Compiler, fn: Name) proc prepareFunction(self: Compiler, fn: Name)
@ -703,7 +704,7 @@ proc resolve(self: Compiler, name: string): Name =
## resolved, it is also compiled on-the-fly ## resolved, it is also compiled on-the-fly
for obj in reversed(self.names): for obj in reversed(self.names):
if obj.ident.token.lexeme == name: if obj.ident.token.lexeme == name:
if obj.owner != self.currentModule: if obj.owner.path != self.currentModule.path:
# We don't own this name, but we # We don't own this name, but we
# may still have access to it # may still have access to it
if obj.isPrivate: if obj.isPrivate:
@ -725,8 +726,6 @@ proc resolve(self: Compiler, name: string): Name =
# might not want to also have access to C's and D's # might not want to also have access to C's and D's
# names as they might clash with its own stuff) # names as they might clash with its own stuff)
continue continue
if obj.kind == Argument and obj.belongsTo != self.currentFunction:
continue
result = obj result = obj
result.resolved = true result.resolved = true
break break
@ -888,6 +887,8 @@ proc toIntrinsic(name: string): Type =
## otherwise ## otherwise
if name == "any": if name == "any":
return Type(kind: Any) return Type(kind: Any)
elif name == "auto":
return Type(kind: Auto)
elif name in ["int", "int64", "i64"]: elif name in ["int", "int64", "i64"]:
return Type(kind: Int64) return Type(kind: Int64)
elif name in ["uint64", "u64", "uint"]: elif name in ["uint64", "u64", "uint"]:
@ -966,7 +967,7 @@ proc infer(self: Compiler, node: Expression): Type =
return nil return nil
case node.kind: case node.kind:
of NodeKind.identExpr: of NodeKind.identExpr:
result = self.identifier(IdentExpr(node), compile=false) result = self.identifier(IdentExpr(node), compile=false, strict=false)
of NodeKind.unaryExpr: of NodeKind.unaryExpr:
result = self.unary(UnaryExpr(node), compile=false) result = self.unary(UnaryExpr(node), compile=false)
of NodeKind.binaryExpr: of NodeKind.binaryExpr:
@ -1009,7 +1010,7 @@ proc stringify(self: Compiler, typ: Type): string =
of Int8, UInt8, Int16, UInt16, Int32, of Int8, UInt8, Int16, UInt16, Int32,
UInt32, Int64, UInt64, Float32, Float64, UInt32, Int64, UInt64, Float32, Float64,
Char, Byte, String, Nil, TypeKind.Nan, Bool, Char, Byte, String, Nil, TypeKind.Nan, Bool,
TypeKind.Inf: TypeKind.Inf, Auto:
result &= ($typ.kind).toLowerAscii() result &= ($typ.kind).toLowerAscii()
of Pointer: of Pointer:
result &= &"ptr {self.stringify(typ.value)}" result &= &"ptr {self.stringify(typ.value)}"
@ -1518,7 +1519,8 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} =
returnType: nil, # We check it later returnType: nil, # We check it later
args: @[], args: @[],
fun: node, fun: node,
forwarded: node.body.isNil()), forwarded: node.body.isNil(),
isAuto: false),
ident: node.name, ident: node.name,
node: node, node: node,
isLet: false, isLet: false,
@ -1526,10 +1528,24 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} =
kind: NameKind.Function, kind: NameKind.Function,
belongsTo: self.currentFunction, belongsTo: self.currentFunction,
isReal: true) isReal: true)
self.names.add(fn)
n = fn
if node.generics.len() > 0: if node.generics.len() > 0:
fn.isGeneric = true fn.isGeneric = true
var typ: Type
for argument in node.arguments:
typ = self.infer(argument.valueType)
if not typ.isNil() and typ.kind == Auto:
fn.valueType.isAuto = true
if fn.isGeneric:
self.error("automatic types cannot be used within generics", argument.valueType)
break
typ = self.infer(node.returnType)
if not typ.isNil() and typ.kind == Auto:
fn.valueType.isAuto = true
if fn.isGeneric:
self.error("automatic types cannot be used within generics", node.returnType)
self.names.add(fn)
self.prepareFunction(fn)
n = fn
of NodeKind.importStmt: of NodeKind.importStmt:
var node = ImportStmt(node) var node = ImportStmt(node)
# We change the name of the module internally so that # We change the name of the module internally so that
@ -1583,11 +1599,6 @@ proc declare(self: Compiler, node: ASTNode): Name {.discardable.} =
discard # TODO: enums discard # TODO: enums
if not n.isNil(): if not n.isNil():
self.dispatchPragmas(n) self.dispatchPragmas(n)
case n.kind:
of NameKind.Function:
self.prepareFunction(n)
else:
discard
for name in self.findByName(declaredName): for name in self.findByName(declaredName):
if name == n: if name == n:
continue continue
@ -1914,16 +1925,24 @@ proc binary(self: Compiler, node: BinaryExpr, compile: bool = true): Type {.disc
var default: Expression var default: Expression
let fn = Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)]) let fn = Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)])
let impl = self.match(node.token.lexeme, fn, node) let impl = self.match(node.token.lexeme, fn, node)
result = impl.valueType.returnType result = impl.valueType
if impl.isGeneric:
result = self.specialize(result, @[node.a, node.b])
result = result.returnType
if compile: if compile:
self.generateCall(impl, @[node.a, node.b], impl.line) self.generateCall(impl, @[node.a, node.b], impl.line)
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true): Type {.discardable.} = proc identifier(self: Compiler, node: IdentExpr, name: Name = nil, compile: bool = true, strict: bool = true): Type {.discardable.} =
## Compiles access to identifiers ## Compiles access to identifiers
var s = name var s = name
if s.isNil(): if s.isNil():
s = self.resolveOrError(node) if strict:
s = self.resolveOrError(node)
else:
s = self.resolve(node)
if s.isNil() and not strict:
return nil
result = s.valueType result = s.valueType
if not compile: if not compile:
return return
@ -1991,7 +2010,7 @@ proc assignment(self: Compiler, node: ASTNode, compile: bool = true): Type {.dis
self.error(&"invalid AST node of kind {node.kind} at assignment(): {node} (This is an internal error and most likely a bug)") self.error(&"invalid AST node of kind {node.kind} at assignment(): {node} (This is an internal error and most likely a bug)")
proc blockStmt(self: Compiler, node: BlockStmt) = proc blockStmt(self: Compiler, node: BlockStmt, compile: bool = true) =
## Compiles block statements, which create ## Compiles block statements, which create
## a new local scope ## a new local scope
self.beginScope() self.beginScope()
@ -2003,35 +2022,43 @@ proc blockStmt(self: Compiler, node: BlockStmt) =
self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last)
else: else:
discard discard
self.declaration(decl) self.declaration(decl, compile)
last = decl last = decl
self.endScope() self.endScope()
proc ifStmt(self: Compiler, node: IfStmt) = proc ifStmt(self: Compiler, node: IfStmt, compile: bool = true) =
## Compiles if/else statements for conditional ## Compiles if/else statements for conditional
## execution of code ## execution of code
self.check(node.condition, Type(kind: Bool)) self.check(node.condition, Type(kind: Bool))
self.expression(node.condition) self.expression(node.condition)
let jump = self.emitJump(JumpIfFalsePop, node.token.line) var jump: int
self.statement(node.thenBranch) var jump2: int
let jump2 = self.emitJump(JumpForwards, node.token.line) if compile:
self.patchJump(jump) jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.statement(node.thenBranch, compile)
if compile:
jump2 = self.emitJump(JumpForwards, node.token.line)
self.patchJump(jump)
if not node.elseBranch.isNil(): if not node.elseBranch.isNil():
self.statement(node.elseBranch) self.statement(node.elseBranch, compile)
self.patchJump(jump2) if compile:
self.patchJump(jump2)
proc whileStmt(self: Compiler, node: WhileStmt) = proc whileStmt(self: Compiler, node: WhileStmt, compile: bool = true) =
## Compiles C-style while loops and ## Compiles C-style while loops and
## desugared C-style for loops ## desugared C-style for loops
self.check(node.condition, Type(kind: Bool)) self.check(node.condition, Type(kind: Bool))
let start = self.chunk.code.high() let start = self.chunk.code.high()
self.expression(node.condition) self.expression(node.condition)
let jump = self.emitJump(JumpIfFalsePop, node.token.line) var jump: int
self.statement(node.body) if compile:
self.emitLoop(start, node.token.line) jump = self.emitJump(JumpIfFalsePop, node.token.line)
self.patchJump(jump) self.statement(node.body, compile)
if compile:
self.emitLoop(start, node.token.line)
self.patchJump(jump)
proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} = proc generateCall(self: Compiler, fn: Type, args: seq[Expression], line: int) {.used.} =
@ -2082,8 +2109,8 @@ proc prepareFunction(self: Compiler, fn: Name) =
let idx = self.stackIndex let idx = self.stackIndex
self.stackIndex = 1 self.stackIndex = 1
var default: Expression var default: Expression
var i = 0
var node = FunDecl(fn.node) var node = FunDecl(fn.node)
var i = 0
for argument in node.arguments: for argument in node.arguments:
if self.names.high() > 16777215: if self.names.high() > 16777215:
self.error("cannot declare more than 16777215 variables at a time") self.error("cannot declare more than 16777215 variables at a time")
@ -2094,7 +2121,7 @@ proc prepareFunction(self: Compiler, fn: Name) =
file: fn.file, file: fn.file,
isConst: false, isConst: false,
ident: argument.name, ident: argument.name,
valueType: self.inferOrError(argument.valueType), valueType: if not fn.valueType.isAuto: self.inferOrError(argument.valueType) else: Type(kind: Any),
codePos: 0, codePos: 0,
isLet: false, isLet: false,
line: argument.name.token.line, line: argument.name.token.line,
@ -2117,6 +2144,47 @@ proc prepareFunction(self: Compiler, fn: Name) =
self.stackIndex = idx self.stackIndex = idx
proc prepareAutoFunction(self: Compiler, fn: Name, args: seq[tuple[name: string, kind: Type, default: Expression]]): Name =
## "Prepares" an automatic function declaration
## by declaring a concrete version of it along
## with its arguments
# First we declare the function's generics, if it has any.
# This is because the function's return type may in itself
# be a generic, so it needs to exist first
# We now declare and typecheck the function's
# arguments
let idx = self.stackIndex
self.stackIndex = 1
var default: Expression
var node = FunDecl(fn.node)
var fn = deepCopy(fn)
self.names.add(fn)
for (argument, val) in zip(node.arguments, args):
if self.names.high() > 16777215:
self.error("cannot declare more than 16777215 variables at a time")
inc(self.stackIndex)
self.names.add(Name(depth: fn.depth + 1,
isPrivate: true,
owner: fn.owner,
file: fn.file,
isConst: false,
ident: argument.name,
valueType: val.kind,
codePos: 0,
isLet: false,
line: argument.name.token.line,
belongsTo: fn,
kind: NameKind.Argument,
node: argument.name,
position: self.stackIndex
))
fn.valueType.args = args
fn.position = self.stackIndex
self.stackIndex = idx
return fn
proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) = proc generateCall(self: Compiler, fn: Name, args: seq[Expression], line: int) =
## Small wrapper that abstracts emitting a call instruction ## Small wrapper that abstracts emitting a call instruction
## for a given function ## for a given function
@ -2163,7 +2231,7 @@ proc specialize(self: Compiler, typ: Type, args: seq[Expression]): Type {.discar
continue continue
kind = self.inferOrError(args[i]) kind = self.inferOrError(args[i])
if typ.name in mapping and not self.compare(kind, mapping[typ.name]): if typ.name in mapping and not self.compare(kind, mapping[typ.name]):
self.error(&"expecting generic argument '{typ.name}' to be of type {self.stringify(mapping[typ.name])}, got {self.stringify(kind)}") self.error(&"expecting generic argument '{typ.name}' to be of type {self.stringify(mapping[typ.name])}, got {self.stringify(kind)}", args[i])
mapping[typ.name] = kind mapping[typ.name] = kind
result.args[i].kind = kind result.args[i].kind = kind
if not result.returnType.isNil() and result.returnType.kind == Generic: if not result.returnType.isNil() and result.returnType.kind == Generic:
@ -2200,10 +2268,14 @@ proc call(self: Compiler, node: CallExpr, compile: bool = true): Type {.discarda
case node.callee.kind: case node.callee.kind:
of NodeKind.identExpr: of NodeKind.identExpr:
# Calls like hi() # Calls like hi()
let impl = self.match(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node) var impl = self.match(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node)
result = impl.valueType result = impl.valueType
if impl.isGeneric: if impl.isGeneric:
result = self.specialize(result, argExpr) result = self.specialize(result, argExpr)
if impl.valueType.isAuto:
impl = self.prepareAutoFunction(impl, args)
result = impl.valueType
self.funDecl(FunDecl(result.fun), impl, compile=compile)
result = result.returnType result = result.returnType
if compile: if compile:
# Now we call it # Now we call it
@ -2371,45 +2443,53 @@ proc expression(self: Compiler, node: Expression, compile: bool = true): Type {.
# TODO # TODO
proc awaitStmt(self: Compiler, node: AwaitStmt) = proc awaitStmt(self: Compiler, node: AwaitStmt, compile: bool = true) =
## Compiles await statements ## Compiles await statements
# TODO # TODO
proc deferStmt(self: Compiler, node: DeferStmt) = proc deferStmt(self: Compiler, node: DeferStmt, compile: bool = true) =
## Compiles defer statements ## Compiles defer statements
# TODO # TODO
proc yieldStmt(self: Compiler, node: YieldStmt) = proc yieldStmt(self: Compiler, node: YieldStmt, compile: bool = true) =
## Compiles yield statements ## Compiles yield statements
# TODO # TODO
proc raiseStmt(self: Compiler, node: RaiseStmt) = proc raiseStmt(self: Compiler, node: RaiseStmt, compile: bool = true) =
## Compiles raise statements ## Compiles raise statements
# TODO # TODO
proc assertStmt(self: Compiler, node: AssertStmt) = proc assertStmt(self: Compiler, node: AssertStmt, compile: bool = true) =
## Compiles assert statements ## Compiles assert statements
# TODO # TODO
# TODO # TODO
proc forEachStmt(self: Compiler, node: ForEachStmt) = proc forEachStmt(self: Compiler, node: ForEachStmt, compile: bool = true) =
## Compiles foreach loops ## Compiles foreach loops
proc returnStmt(self: Compiler, node: ReturnStmt) = proc returnStmt(self: Compiler, node: ReturnStmt, compile: bool = true) =
## Compiles return statements ## Compiles return statements
if self.currentFunction.valueType.returnType.isNil() and not node.value.isNil(): if self.currentFunction.valueType.returnType.isNil() and not node.value.isNil():
self.error("cannot return a value from a void function", node.value) self.error("cannot return a value from a void function", node.value)
elif not self.currentFunction.valueType.returnType.isNil() and node.value.isNil(): elif not self.currentFunction.valueType.returnType.isNil() and node.value.isNil():
self.error("bare return statement is only allowed in void functions", node) self.error("bare return statement is only allowed in void functions", node)
if not node.value.isNil(): if not node.value.isNil():
self.expression(node.value) if not self.currentFunction.valueType.isAuto:
self.emitByte(OpCode.SetResult, node.token.line) self.check(node.value, self.currentFunction.valueType.returnType)
else:
if self.currentFunction.valueType.returnType.kind == Auto:
self.currentFunction.valueType.returnType = self.inferOrError(node.value)
else:
self.check(node.value, self.currentFunction.valueType.returnType)
self.expression(node.value, compile)
if compile:
self.emitByte(OpCode.SetResult, node.token.line)
# Since the "set result" part and "exit the function" part # Since the "set result" part and "exit the function" part
# of our return mechanism are already decoupled into two # of our return mechanism are already decoupled into two
# separate opcodes, we perform the former and then jump to # separate opcodes, we perform the former and then jump to
@ -2421,17 +2501,19 @@ proc returnStmt(self: Compiler, node: ReturnStmt) =
# the function has any local variables or not, this jump might be # the function has any local variables or not, this jump might be
# patched to jump to the function's PopN/PopC instruction(s) rather # patched to jump to the function's PopN/PopC instruction(s) rather
# than straight to the return statement # than straight to the return statement
self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line)) if compile:
self.currentFunction.valueType.retJumps.add(self.emitJump(JumpForwards, node.token.line))
proc continueStmt(self: Compiler, node: ContinueStmt) = proc continueStmt(self: Compiler, node: ContinueStmt, compile: bool = true) =
## Compiles continue statements. A continue statement ## Compiles continue statements. A continue statement
## jumps to the next iteration in a loop ## jumps to the next iteration in a loop
if node.label.isNil(): if node.label.isNil():
if self.currentLoop.start > 16777215: if self.currentLoop.start > 16777215:
self.error("too much code to jump over in continue statement") self.error("too much code to jump over in continue statement")
self.emitByte(Jump, node.token.line) if compile:
self.emitBytes(self.currentLoop.start.toTriple(), node.token.line) self.emitByte(Jump, node.token.line)
self.emitBytes(self.currentLoop.start.toTriple(), node.token.line)
else: else:
var blocks: seq[NamedBlock] = @[] var blocks: seq[NamedBlock] = @[]
var found: bool = false var found: bool = false
@ -2442,27 +2524,29 @@ proc continueStmt(self: Compiler, node: ContinueStmt) =
break break
if not found: if not found:
self.error(&"unknown block name '{node.label.token.lexeme}'", node.label) self.error(&"unknown block name '{node.label.token.lexeme}'", node.label)
self.emitByte(Jump, node.token.line) if compile:
self.emitBytes(blocks[^1].start.toTriple(), node.token.line) self.emitByte(Jump, node.token.line)
self.emitBytes(blocks[^1].start.toTriple(), node.token.line)
proc importStmt(self: Compiler, node: ImportStmt) = proc importStmt(self: Compiler, node: ImportStmt, compile: bool = true) =
## Imports a module at compile time ## Imports a module at compile time
self.declare(node) self.declare(node)
var module = self.names[^1] var module = self.names[^1]
try: try:
self.compileModule(module) if compile:
# Importing a module automatically exports self.compileModule(module)
# its public names to us # Importing a module automatically exports
for name in self.findInModule("", module): # its public names to us
name.exportedTo.incl(self.currentModule) for name in self.findInModule("", module):
name.exportedTo.incl(self.currentModule)
except IOError: except IOError:
self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()}") self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()}")
except OSError: except OSError:
self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()} [errno {osLastError()}]") self.error(&"could not import '{module.ident.token.lexeme}': {getCurrentExceptionMsg()} [errno {osLastError()}]")
proc exportStmt(self: Compiler, node: ExportStmt) = proc exportStmt(self: Compiler, node: ExportStmt, compile: bool = true) =
## Exports a name at compile time to ## Exports a name at compile time to
## all modules importing us ## all modules importing us
var name = self.resolveOrError(node.name) var name = self.resolveOrError(node.name)
@ -2484,10 +2568,12 @@ proc exportStmt(self: Compiler, node: ExportStmt) =
discard discard
proc breakStmt(self: Compiler, node: BreakStmt) = proc breakStmt(self: Compiler, node: BreakStmt, compile: bool = true) =
## Compiles break statements. A break statement ## Compiles break statements. A break statement
## jumps to the end of the loop ## jumps to the end of the loop
if node.label.isNil(): if node.label.isNil():
if not compile:
return
self.currentLoop.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) self.currentLoop.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line))
if self.currentLoop.depth > self.depth: if self.currentLoop.depth > self.depth:
# Breaking out of a loop closes its scope # Breaking out of a loop closes its scope
@ -2506,7 +2592,7 @@ proc breakStmt(self: Compiler, node: BreakStmt) =
self.error(&"unknown block name '{node.label.token.lexeme}'", node.label) self.error(&"unknown block name '{node.label.token.lexeme}'", node.label)
proc namedBlock(self: Compiler, node: NamedBlockStmt) = proc namedBlock(self: Compiler, node: NamedBlockStmt, compile: bool = true) =
## Compiles named blocks ## Compiles named blocks
self.beginScope() self.beginScope()
var blk = self.namedBlocks[^1] var blk = self.namedBlocks[^1]
@ -2518,79 +2604,80 @@ proc namedBlock(self: Compiler, node: NamedBlockStmt) =
self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last) self.warning(UnreachableCode, &"code after '{last.token.lexeme}' statement is unreachable", nil, last)
else: else:
discard discard
if blk.broken: if blk.broken and compile:
blk.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line)) blk.breakJumps.add(self.emitJump(OpCode.JumpForwards, node.token.line))
self.declaration(decl) self.declaration(decl, compile)
last = decl last = decl
self.patchBreaks() if compile:
self.patchBreaks()
self.endScope() self.endScope()
proc statement(self: Compiler, node: Statement) = proc statement(self: Compiler, node: Statement, compile: bool = true) =
## Compiles all statements ## Compiles all statements
case node.kind: case node.kind:
of exprStmt: of exprStmt:
let expression = ExprStmt(node).expression let expression = ExprStmt(node).expression
let kind = self.infer(expression) let kind = self.infer(expression)
self.expression(expression) self.expression(expression, compile)
if kind.isNil(): if kind.isNil():
# The expression has no type and produces no value, # The expression has no type and produces no value,
# so we don't have to pop anything # so we don't have to pop anything
discard discard
elif self.replMode: elif self.replMode and compile:
self.printRepl(kind, expression) self.printRepl(kind, expression)
else: elif compile:
self.emitByte(Pop, node.token.line) self.emitByte(Pop, node.token.line)
of NodeKind.namedBlockStmt: of NodeKind.namedBlockStmt:
self.namedBlocks.add(NamedBlock(start: self.chunk.code.len(), self.namedBlocks.add(NamedBlock(start: self.chunk.code.len(),
depth: self.depth, depth: self.depth,
breakJumps: @[], breakJumps: @[],
name: NamedBlockStmt(node).name.token.lexeme)) name: NamedBlockStmt(node).name.token.lexeme))
self.namedBlock(NamedBlockStmt(node)) self.namedBlock(NamedBlockStmt(node), compile)
#self.patchBreaks()
discard self.namedBlocks.pop() discard self.namedBlocks.pop()
of NodeKind.ifStmt: of NodeKind.ifStmt:
self.ifStmt(IfStmt(node)) self.ifStmt(IfStmt(node), compile)
of NodeKind.assertStmt: of NodeKind.assertStmt:
self.assertStmt(AssertStmt(node)) self.assertStmt(AssertStmt(node), compile)
of NodeKind.raiseStmt: of NodeKind.raiseStmt:
self.raiseStmt(RaiseStmt(node)) self.raiseStmt(RaiseStmt(node), compile)
of NodeKind.breakStmt: of NodeKind.breakStmt:
self.breakStmt(BreakStmt(node)) self.breakStmt(BreakStmt(node), compile)
of NodeKind.continueStmt: of NodeKind.continueStmt:
self.continueStmt(ContinueStmt(node)) self.continueStmt(ContinueStmt(node), compile)
of NodeKind.returnStmt: of NodeKind.returnStmt:
self.returnStmt(ReturnStmt(node)) self.returnStmt(ReturnStmt(node), compile)
of NodeKind.importStmt: of NodeKind.importStmt:
self.importStmt(ImportStmt(node)) self.importStmt(ImportStmt(node), compile)
of NodeKind.exportStmt: of NodeKind.exportStmt:
self.exportStmt(ExportStmt(node)) self.exportStmt(ExportStmt(node), compile)
of NodeKind.whileStmt: of NodeKind.whileStmt:
# Note: Our parser already desugars # Note: Our parser already desugars
# for loops to while loops # for loops to while loops
let loop = self.currentLoop let loop = self.currentLoop
self.currentLoop = Loop(start: self.chunk.code.len(), self.currentLoop = Loop(start: self.chunk.code.len(),
depth: self.depth, breakJumps: @[]) depth: self.depth, breakJumps: @[])
self.whileStmt(WhileStmt(node)) self.whileStmt(WhileStmt(node), compile)
self.patchBreaks() if compile:
self.patchBreaks()
self.currentLoop = loop self.currentLoop = loop
of NodeKind.forEachStmt: of NodeKind.forEachStmt:
self.forEachStmt(ForEachStmt(node)) self.forEachStmt(ForEachStmt(node), compile)
of NodeKind.blockStmt: of NodeKind.blockStmt:
self.blockStmt(BlockStmt(node)) self.blockStmt(BlockStmt(node), compile)
of NodeKind.yieldStmt: of NodeKind.yieldStmt:
self.yieldStmt(YieldStmt(node)) self.yieldStmt(YieldStmt(node), compile)
of NodeKind.awaitStmt: of NodeKind.awaitStmt:
self.awaitStmt(AwaitStmt(node)) self.awaitStmt(AwaitStmt(node), compile)
of NodeKind.deferStmt: of NodeKind.deferStmt:
self.deferStmt(DeferStmt(node)) self.deferStmt(DeferStmt(node), compile)
of NodeKind.tryStmt: of NodeKind.tryStmt:
discard discard
else: else:
self.expression(Expression(node)) self.expression(Expression(node))
proc varDecl(self: Compiler, node: VarDecl) = proc varDecl(self: Compiler, node: VarDecl, compile: bool = true) =
## Compiles variable declarations ## Compiles variable declarations
# Our parser guarantees that the variable declaration # Our parser guarantees that the variable declaration
@ -2600,6 +2687,8 @@ proc varDecl(self: Compiler, node: VarDecl) =
# Variable has no value: the type declaration # Variable has no value: the type declaration
# takes over # takes over
typ = self.inferOrError(node.valueType) typ = self.inferOrError(node.valueType)
if typ.kind == Auto:
self.error("automatic types require initialization", node)
elif node.valueType.isNil: elif node.valueType.isNil:
# Variable has no type declaration: the type # Variable has no type declaration: the type
# of its value takes over # of its value takes over
@ -2609,9 +2698,12 @@ proc varDecl(self: Compiler, node: VarDecl) =
# a value: the value's type must match the # a value: the value's type must match the
# type declaration # type declaration
let expected = self.inferOrError(node.valueType) let expected = self.inferOrError(node.valueType)
self.check(node.value, expected) if expected.kind != Auto:
# If this doesn't fail, then we're good self.check(node.value, expected)
typ = expected # If this doesn't fail, then we're good
typ = expected
else:
typ = self.infer(node.value)
self.expression(node.value) self.expression(node.value)
self.emitByte(AddVar, node.token.line) self.emitByte(AddVar, node.token.line)
self.declare(node) self.declare(node)
@ -2621,7 +2713,7 @@ proc varDecl(self: Compiler, node: VarDecl) =
name.valueType = typ name.valueType = typ
proc funDecl(self: Compiler, node: FunDecl, name: Name) = proc funDecl(self: Compiler, node: FunDecl, name: Name, compile: bool = true) =
## Compiles function declarations ## Compiles function declarations
if node.token.kind == Operator and node.name.token.lexeme in [".", "="]: if node.token.kind == Operator and node.name.token.lexeme in [".", "="]:
self.error(&"Due to compiler limitations, the '{node.name.token.lexeme}' operator cannot be currently overridden", node.name) self.error(&"Due to compiler limitations, the '{node.name.token.lexeme}' operator cannot be currently overridden", node.name)
@ -2724,41 +2816,42 @@ proc funDecl(self: Compiler, node: FunDecl, name: Name) =
self.stackIndex = stackIdx self.stackIndex = stackIdx
proc declaration(self: Compiler, node: Declaration) = proc declaration(self: Compiler, node: Declaration, compile: bool = true) =
## Compiles declarations, statements and expressions ## Compiles declarations, statements and expressions
## recursively ## recursively
case node.kind: case node.kind:
of NodeKind.funDecl: of NodeKind.funDecl:
var name = self.declare(node) var name = self.declare(node)
self.funDecl(FunDecl(node), name) if not name.valueType.isAuto:
if name.isGeneric: self.funDecl(FunDecl(node), name, compile=compile)
# After we're done compiling a generic if name.isGeneric:
# function, we pull a magic trick: since, # After we're done compiling a generic
# from here on, the user will be able to # function, we pull a magic trick: since,
# call this with any of the types in the # from here on, the user will be able to
# generic constraint, we switch every generic # call this with any of the types in the
# to a type union (which, conveniently, have an # generic constraint, we switch every generic
# identical layout) so that the compiler will # to a type union (which, conveniently, have an
# typecheck the function as if its arguments # identical layout) so that the compiler will
# were all types of the constraint at once, # typecheck the function as if its arguments
# while still allowing the user to call it with # were all types of the constraint at once,
# any type in said constraint # while still allowing the user to call it with
for i, argument in name.valueType.args: # any type in said constraint
if argument.kind.kind != Generic: for i, argument in name.valueType.args:
continue if argument.kind.kind != Generic:
else: continue
argument.kind.asUnion = true else:
if not name.valueType.returnType.isNil() and name.valueType.returnType.kind == Generic: argument.kind.asUnion = true
name.valueType.returnType.asUnion = true if not name.valueType.returnType.isNil() and name.valueType.returnType.kind == Generic:
name.valueType.returnType.asUnion = true
of NodeKind.typeDecl: of NodeKind.typeDecl:
self.declare(node) self.declare(node)
of NodeKind.varDecl: of NodeKind.varDecl:
# We compile this immediately because we # We compile this immediately because we
# need to keep the stack in the right state # need to keep the stack in the right state
# at runtime # at runtime
self.varDecl(VarDecl(node)) self.varDecl(VarDecl(node), compile)
else: else:
self.statement(Statement(node)) self.statement(Statement(node), compile)
proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tuple[start, stop: int]], source: string, chunk: Chunk = nil, proc compile*(self: Compiler, ast: seq[Declaration], file: string, lines: seq[tuple[start, stop: int]], source: string, chunk: Chunk = nil,

View File

@ -13,12 +13,12 @@ operator `+`*[T: Integer](a, b: T): T {
} }
operator `+`(a, b: float): float { operator `+`*(a, b: float): float {
#pragma[magic: "AddFloat64", pure] #pragma[magic: "AddFloat64", pure]
} }
operator `+`(a, b: float32): float32 { operator `+`*(a, b: float32): float32 {
#pragma[magic: "AddFloat32", pure] #pragma[magic: "AddFloat32", pure]
} }

View File

@ -12,11 +12,12 @@ operator `<`*[T: UnsignedInteger](a, b: T): bool {
} }
operator `==`*[T: Number](a, b: T): bool { operator `==`*[T: Number | inf](a, b: T): bool {
#pragma[magic: "Equal", pure] #pragma[magic: "Equal", pure]
} }
operator `!=`*[T: Number](a, b: T): bool {
operator `!=`*[T: Number | inf](a, b: T): bool {
#pragma[magic: "NotEqual", pure] #pragma[magic: "NotEqual", pure]
} }
@ -90,3 +91,4 @@ operator `<=`*(a, b: float32): bool {
#pragma[magic: "Float32LessOrEqual", pure] #pragma[magic: "Float32LessOrEqual", pure]
} }

View File

@ -65,6 +65,10 @@ type nan* = object {
#pragma[magic: "nan"] #pragma[magic: "nan"]
} }
type auto* = object {
#pragma[magic: "auto"]
}
# Some convenience aliases # Some convenience aliases
type int* = int64; type int* = int64;
type float* = float64; type float* = float64;

14
tests/auto.pn Normal file
View File

@ -0,0 +1,14 @@
# A test for automatic types
import std;
fn sum(a, b: auto): auto {
return a + b;
}
var x: auto = 1;
print(x == 1);
print(sum(1, 2) == 3);
print(sum(1'i32, 2'i32) == 3'i32);
print(sum(1.0, 2.0) == 3.0);
#print(sum(1'i32, 2'i16)); # Will fail to compile