Implemented ACTUAL parametric polymorphism

This commit is contained in:
Mattia Giambirtone 2022-11-28 18:21:38 +01:00
parent 62b8bae0fc
commit 540feb0c91
4 changed files with 116 additions and 48 deletions

View File

@ -144,6 +144,9 @@ type
# Who is this name exported to? (Only makes sense if isPrivate
# equals false)
exportedTo: HashSet[Name]
# Has the compiler generates this name internally or
# does it come from user code?
isReal: bool
Loop = object
## A "loop object" used
@ -232,7 +235,7 @@ type
# List of disabled warnings
disabledWarnings: seq[WarningKind]
# Whether to show detailed info about type
# mismatches when we dispatch with matchImpl()
# mismatches when we dispatch with match()
showMismatches: bool
# Are we compiling in debug mode?
mode: CompileMode
@ -261,7 +264,7 @@ proc peek(self: Compiler, distance: int = 0): ASTNode
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil)
proc varDecl(self: Compiler, node: VarDecl)
proc specialize(self: Compiler, name: Name, args: seq[Expression]): Name
proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name
proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, args: seq[Expression] = @[], allowFwd: bool = true): Name
proc getItemExpr(self: Compiler, node: GetItemExpr, compile: bool = true): Type {.discardable.}
proc infer(self: Compiler, node: LiteralExpr, allowGeneric: bool = false): Type
proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type
@ -637,6 +640,10 @@ proc compileDecl(self: Compiler, name: Name) =
if name.resolved:
return
name.resolved = true
if name.isGeneric:
# We typecheck generics at declaration time,
# so they're already compiled
return
# Now we just dispatch to one of our functions to
# compile the declaration
case name.kind:
@ -757,6 +764,7 @@ proc compare(self: Compiler, a, b: Type): bool =
# The nil code here is for void functions (when
# we compare their return types)
result = false
if a.isNil():
result = b.isNil() or b.kind == All
elif b.isNil():
@ -858,21 +866,40 @@ proc compare(self: Compiler, a, b: Type): bool =
if self.compare(c1.kind, c2.kind):
# Here return is fine, because there's
# no more checks after this one!
return c1.match == c2.match
if c1.match != c2.match:
return false
else:
return false
# We only return at the end because when matching
# generics we want to match *all* constraints at
# once, not just find the first match
return true
else:
for constraint in a.cond:
if self.compare(constraint.kind, b):
return constraint.match
if not constraint.match:
return false
else:
return false
return true
elif b.kind == Generic:
if a.kind == Generic:
for c1 in a.cond:
for c2 in b.cond:
if self.compare(c1.kind, c2.kind):
return c1.match == c2.match
if c1.match != c2.match:
return false
else:
return false
return true
else:
for constraint in b.cond:
if self.compare(constraint.kind, a):
return constraint.match
if not constraint.match:
return false
else:
return false
return true
# TODO: Is this ok?
else:
result = false
@ -882,7 +909,7 @@ proc toIntrinsic(name: string): Type =
## Converts a string to an intrinsic
## type if it is valid and returns nil
## otherwise
if name == "all":
if name == "any":
return Type(kind: Any)
elif name in ["int", "int64", "i64"]:
return Type(kind: Int64)
@ -975,6 +1002,8 @@ proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type =
## unless allowGeneric is set to true
if node.isNil():
return nil
if node.isLiteral():
return self.infer(LiteralExpr(node), allowGeneric)
case node.kind:
of identExpr:
let node = IdentExpr(node)
@ -995,14 +1024,14 @@ proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type =
of unaryExpr:
let node = UnaryExpr(node)
var default: Expression
let impl = self.matchImpl(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a), default)]), node)
let impl = self.match(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a), default)]), node)
result = impl.valueType.returnType
if result.kind == Generic and not allowGeneric:
result = self.specialize(impl, @[node.a]).valueType.returnType
of binaryExpr:
let node = BinaryExpr(node)
var default: Expression
let impl = self.matchImpl(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a), default), ("", self.infer(node.b), default)]), node)
let impl = self.match(node.operator.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.infer(node.a), default), ("", self.infer(node.b), default)]), node)
result = impl.valueType.returnType
if result.kind == Generic and not allowGeneric:
result = self.specialize(impl, @[node.a, node.b]).valueType.returnType
@ -1027,7 +1056,15 @@ proc infer(self: Compiler, node: Expression, allowGeneric: bool = false): Type =
if not resolved.isNil():
case resolved.valueType.kind:
of Function:
result = resolved.valueType.returnType
if resolved.isGeneric and not allowGeneric:
var args: seq[Expression] = @[]
for argument in node.arguments.positionals:
args.add(argument)
for argument in node.arguments.keyword:
args.add(argument.value)
result = self.specialize(resolved, args).valueType.returnType
else:
result = resolved.valueType.returnType
else:
result = resolved.valueType
else:
@ -1160,11 +1197,19 @@ proc findAtDepth(self: Compiler, name: string, depth: int): seq[Name] {.used.} =
result.add(obj)
proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil, allowFwd: bool = true): Name =
proc match(self: Compiler, name: string, kind: Type, node: ASTNode = nil, args: seq[Expression] = @[], allowFwd: bool = true): Name =
## Tries to find a matching function implementation
## compatible with the given type and returns its
## name object
var impl = self.findByType(name, kind)
var impl: seq[Name] = @[]
var temp: Name
for obj in self.findByName(name):
if obj.isGeneric:
temp = self.specialize(obj, args)
else:
temp = obj
if self.compare(kind, temp.valueType):
impl.add(temp)
if impl.len() == 0:
var msg = &"failed to find a suitable implementation for '{name}'"
let names = self.findByName(name)
@ -1173,7 +1218,7 @@ proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil, al
if names.len() > 1:
msg &= "s"
if self.showMismatches:
msg &= ": "
msg &= " :"
for name in names:
msg &= &"\n - in {relativePath(name.file, getCurrentDir())}:{name.ident.token.line}:{name.ident.token.relPos.start} -> {self.typeToStr(name.valueType)}"
if name.valueType.kind != Function:
@ -1193,16 +1238,16 @@ proc matchImpl(self: Compiler, name: string, kind: Type, node: ASTNode = nil, al
else:
msg = &"call to undefined function '{name}'"
self.error(msg, node)
if impl.len() > 1:
elif impl.len() > 1:
# Forward declarations don't count when looking for a function
impl = filterIt(impl, not it.valueType.forwarded)
if impl.len() > 1:
# If it's *still* more than one match, then it's an error
var msg = &"multiple matching implementations of '{name}' found\n"
var msg = &"multiple matching implementations of '{name}' found"
if self.showMismatches:
msg &= ":"
for fn in reversed(impl):
msg &= &"- in {relativePath(fn.file, getCurrentDir())}, line {fn.line} of type {self.typeToStr(fn.valueType)}\n"
msg &= &"\n- in {relativePath(fn.file, getCurrentDir())}, line {fn.line} of type {self.typeToStr(fn.valueType)}"
else:
msg &= " (compile with --showMismatches for more details)"
self.error(msg, node)
@ -1219,7 +1264,7 @@ proc check(self: Compiler, term: Expression, kind: Type, allowAny: bool = false)
let k = self.inferOrError(term)
if k.kind == Any and not allowAny:
# Any should only be used internally: error!
self.error("'all' is not a valid type in this context", term)
self.error("'any' is not a valid type in this context", term)
elif not self.compare(k, kind):
self.error(&"expecting value of type {self.typeToStr(kind)}, got {self.typeToStr(k)}", term)
@ -1332,7 +1377,7 @@ proc patchForwardDeclarations(self: Compiler) =
var impl: Name
var pos: array[8, uint8]
for (forwarded, position) in self.forwarded:
impl = self.matchImpl(forwarded.ident.token.lexeme, forwarded.valueType, allowFwd=false)
impl = self.match(forwarded.ident.token.lexeme, forwarded.valueType, allowFwd=false)
if position == 0:
continue
pos = impl.codePos.toLong()
@ -1378,7 +1423,7 @@ proc endScope(self: Compiler) =
# Arguments to builtin functions become temporaries on the
# stack and are popped automatically
continue
if not name.belongsTo.resolved:
if not name.belongsTo.resolved or not name.isReal:
# Function hasn't been compiled yet,
# so we can't get rid of its arguments
# (it may need them later)
@ -1391,8 +1436,11 @@ proc endScope(self: Compiler) =
self.warning(UnusedName, &"'{name.ident.token.lexeme}' is declared but not used (add '_' prefix to silence warning)", name)
of NameKind.Argument:
if not name.ident.token.lexeme.startsWith("_") and name.isPrivate:
if not name.belongsTo.valueType.isBuiltinFunction:
# Builtin functions never use their arguments
if not name.belongsTo.valueType.isBuiltinFunction and name.belongsTo.isReal:
# Builtin functions never use their arguments. We also don't emit this
# warning if the function was generated internally by the compiler (for
# example as a result of generic specialization) because such objects do
# not exist in the user's code and are likely duplicated anyway
self.warning(UnusedName, &"argument '{name.ident.token.lexeme}' is unused (add '_' prefix to silence warning)", name)
else:
discard
@ -1444,7 +1492,7 @@ proc unpackGenerics(self: Compiler, condition: Expression, list: var seq[tuple[m
self.error("invalid type constraint in generic declaration", condition)
proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
proc declareName(self: Compiler, node: ASTNode, mutable: bool = false): Name {.discardable.} =
## Statically declares a name into the current scope.
## "Declaring" a name only means updating our internal
## list of identifiers so that further calls to resolve()
@ -1473,7 +1521,8 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
line: node.token.line,
belongsTo: self.currentFunction,
kind: NameKind.Var,
node: node
node: node,
isReal: true
))
n = self.names[^1]
if mutable:
@ -1497,7 +1546,8 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
isLet: false,
line: node.token.line,
kind: NameKind.Function,
belongsTo: self.currentFunction)
belongsTo: self.currentFunction,
isReal: true)
self.names.add(fn)
n = fn
if node.generics.len() > 0:
@ -1519,7 +1569,8 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
ident: node.moduleName,
line: node.moduleName.token.line,
kind: NameKind.Module,
isPrivate: false
isPrivate: false,
isReal: true
))
n = self.names[^1]
declaredName = self.names[^1].ident.token.lexeme
@ -1534,7 +1585,7 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
for name in self.findByName(declaredName):
if name == n:
continue
# We don't check for name clashes with functions because matchImpl does that
# We don't check for name clashes with functions because match does that
elif name.kind in [NameKind.Var, NameKind.Module, NameKind.CustomType, NameKind.Enum]:
if name.owner != self.currentModule:
if name.isPrivate:
@ -1547,6 +1598,7 @@ proc declareName(self: Compiler, node: ASTNode, mutable: bool = false) =
self.error(&"re-declaration of {declaredName} is not allowed (previously declared in {name.owner.ident.token.lexeme}:{name.ident.token.line}:{name.ident.token.relPos.start})")
elif name.depth < self.depth:
self.warning(WarningKind.ShadowOuterScope, &"'{declaredName}' shadows a name from an outer scope")
return n
proc emitLoop(self: Compiler, begin: int, line: int) =
@ -1688,7 +1740,7 @@ proc beginProgram(self: Compiler): int =
codePos: 0,
ident: newIdentExpr(Token(lexeme: self.file, kind: Identifier)),
resolved: true,
line: -1)
line: 1)
self.names.add(mainModule)
self.currentModule = mainModule
# Every peon program has a hidden entry point in
@ -1715,7 +1767,7 @@ proc beginProgram(self: Compiler): int =
ident: newIdentExpr(Token(lexeme: "", kind: Identifier)),
kind: NameKind.Function,
resolved: true,
line: -1)
line: 1)
self.names.add(main)
self.emitByte(LoadUInt64, 1)
self.emitBytes(self.chunk.writeConstant(main.codePos.toLong()), 1)
@ -1817,29 +1869,22 @@ proc literal(self: Compiler, node: ASTNode) =
self.error(&"invalid AST node of kind {node.kind} at literal(): {node} (This is an internal error and most likely a bug!)")
proc callUnaryOp(self: Compiler, fn: Name, op: UnaryExpr) {.inline.} =
## Emits the code to call a unary operator
self.generateCall(fn, @[op.a], fn.line)
proc callBinaryOp(self: Compiler, fn: Name, op: BinaryExpr) {.inline.} =
## Emits the code to call a binary operator
self.generateCall(fn, @[op.a, op.b], fn.line)
proc unary(self: Compiler, node: UnaryExpr) {.inline.} =
## Compiles all unary expressions
var default: Expression
let fn = Type(kind: Function,
returnType: Type(kind: Any),
args: @[("", self.inferOrError(node.a), default)])
self.callUnaryOp(self.matchImpl(node.token.lexeme, fn), node)
let impl = self.match(node.token.lexeme, fn, node, @[node.a])
self.generateCall(impl, @[node.a], impl.line)
proc binary(self: Compiler, node: BinaryExpr) {.inline.} =
## Compiles all binary expressions
var default: Expression
self.callBinaryOp(self.matchImpl(node.token.lexeme, Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)]), node), node)
let fn = Type(kind: Function, returnType: Type(kind: Any), args: @[("", self.inferOrError(node.a), default), ("", self.inferOrError(node.b), default)])
let impl = self.match(node.token.lexeme, fn, node, @[node.a, node.b])
self.generateCall(impl, @[node.a, node.b], impl.line)
proc identifier(self: Compiler, node: IdentExpr, name: Name = nil) =
@ -2092,6 +2137,7 @@ proc specialize(self: Compiler, name: Name, args: seq[Expression]): Name =
var kind: Type
result = deepCopy(name)
result.isGeneric = false
result.isReal = false
case name.kind:
of NameKind.Function:
# This first loop checks if a user tries to reassign a generic's
@ -2112,7 +2158,7 @@ proc specialize(self: Compiler, name: Name, args: seq[Expression]): Name =
owner: self.currentModule,
file: self.file,
isConst: false,
ident: newIdentExpr(Token(lexeme: argTuple.name)),
ident: newIdentExpr(Token(kind: Identifier, line: argExpr.token.line, lexeme: argTuple.name)),
valueType: argTuple.kind,
codePos: 0,
isLet: false,
@ -2120,7 +2166,7 @@ proc specialize(self: Compiler, name: Name, args: seq[Expression]): Name =
belongsTo: result,
kind: NameKind.Argument
))
if result.valueType.returnType.kind == Generic:
if not result.valueType.returnType.isNil() and result.valueType.returnType.kind == Generic:
result.valueType.returnType = mapping[result.valueType.returnType.name]
else:
discard # TODO: Custom user-defined types
@ -2133,7 +2179,7 @@ proc callExpr(self: Compiler, node: CallExpr): Name {.discardable.} =
var default: Expression
var kind: Type
for i, argument in node.arguments.positionals:
kind = self.infer(argument) # we don't use inferOrError so that we can raise a more appropriate error message
kind = self.infer(argument, allowGeneric=false) # We don't use inferOrError so that we can raise a more appropriate error message
if kind.isNil():
if argument.kind == NodeKind.identExpr:
self.error(&"reference to undeclared name '{argument.token.lexeme}'", argument)
@ -2141,7 +2187,7 @@ proc callExpr(self: Compiler, node: CallExpr): Name {.discardable.} =
args.add(("", kind, default))
argExpr.add(argument)
for i, argument in node.arguments.keyword:
kind = self.infer(argument.value)
kind = self.infer(argument.value, allowGeneric=false)
if kind.isNil():
if argument.value.kind == NodeKind.identExpr:
self.error(&"reference to undeclared name '{argument.value.token.lexeme}'", argument.value)
@ -2151,7 +2197,7 @@ proc callExpr(self: Compiler, node: CallExpr): Name {.discardable.} =
case node.callee.kind:
of identExpr:
# Calls like hi()
result = self.matchImpl(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node)
result = self.match(IdentExpr(node.callee).name.lexeme, Type(kind: Function, returnType: Type(kind: All), args: args), node, argExpr)
# Now we call it
self.generateCall(result, argExpr, node.token.line)
of NodeKind.callExpr:
@ -2604,7 +2650,12 @@ proc declaration(self: Compiler, node: Declaration) =
## right away, but rather only when they're referenced
## the first time
case node.kind:
of NodeKind.funDecl, NodeKind.typeDecl:
of NodeKind.funDecl:
var name = self.declareName(node)
if name.isGeneric:
# We typecheck generics immediately
self.funDecl(FunDecl(node), name)
of NodeKind.typeDecl:
self.declareName(node)
of NodeKind.varDecl:
# We compile this immediately because we

View File

@ -12,6 +12,10 @@ fn print*(x: int) {
#pragma[magic: "PrintInt64"]
}
fn print*(x: int32) {
#pragma[magic: "PrintInt32"]
}
fn print*(x: uint64) {
#pragma[magic: "PrintUInt64"]

View File

@ -2,11 +2,11 @@
import std;
fn sum[T: int | int32](a, b: T): T {
fn sum[T: any](a, b: T): T {
return a + b;
}
print(sum(1, 2)); # Prints 3
print(sum(1'i32, 2'i32)); # Also prints 3!
# print(sum(1'i16, 2'i16)); # Will not work if uncommented!
# print(sum(1'i16, 2'i16)); # Will not work if uncommented: print is not defined for i16!

13
tests/generics2.pn Normal file
View File

@ -0,0 +1,13 @@
# Tests more stuff about generics. This test should fail to compile
fn identity(x: int32): int32 {
return x;
}
fn nope[T: int32 | int16](x: T) {
return identity(x);
}
nope(5'i16);