Rework generic replacement mechanism

This commit is contained in:
Mattia Giambirtone 2024-02-09 16:55:48 +01:00
parent 13e92ef014
commit 647512094b
3 changed files with 33 additions and 34 deletions

View File

@ -1069,6 +1069,25 @@ proc match(self: TypeChecker, name: string, sig: TypeSignature, node: ASTNode =
self.error(msg, node)
proc replaceGenerics(self: TypeChecker, typ: Type, generics: TableRef[string, Type]) =
## Recursively replaces all occurrences of the generics in
## the given mapping with their concrete counterpart
case typ.kind:
of TypeKind.Structure:
if typ.isEnum:
self.error("generic enums are currently not supported")
for fieldName in typ.fields.keys():
var fieldType = typ.fields[fieldName]
case fieldType.kind:
of TypeKind.Generic:
typ.fields[fieldName] = generics[fieldName]
else:
discard
else:
self.error(&"unable to perform generic instantiation for object of type {self.stringify(typ)}")
proc specialize(self: TypeChecker, name: Name, args: seq[TypedExpr], node: ASTNode = nil): Type =
## Instantiates a generic type
let
@ -1080,43 +1099,19 @@ proc specialize(self: TypeChecker, name: Name, args: seq[TypedExpr], node: ASTNo
self.error(&"invalid number of arguments supplied for generic instantiation (expecting exactly {expectedCount}, got {len(args)} instead)", node=node)
# Construct a concrete copy of the original generic type
result = typ.deepCopy()
# Create a new hidden scope to declare fresh type variables in
self.beginScope()
var replaced = newTable[string, Type]()
var i = 0
for key in typ.genericTypes.keys():
replaced[key] = self.check(args[i].kind, typ.genericTypes[key], args[i].node)
self.addName(Name(depth: self.scopeDepth,
ident: name.node.genericTypes[key].ident,
isPrivate: true,
module: self.currentModule,
file: self.file,
valueType: replaced[key],
line: node.token.line,
owner: self.currentFunction,
kind: NameKind.Default,
node: name.node,
))
inc(i)
result.genericTypes.clear()
# Note how we do not reset i!
for key in typ.genericValues.keys():
replaced[key] = self.check(args[i].kind, typ.genericValues[key], args[i].node)
self.addName(Name(depth: self.scopeDepth,
ident: name.node.genericValues[key].ident,
isPrivate: true,
module: self.currentModule,
file: self.file,
valueType: replaced[key],
line: node.token.line,
owner: self.currentFunction,
kind: NameKind.Default,
node: name.node,
))
inc(i)
result.genericValues.clear()
# Close the hidden scope once we're done
self.endScope()
# Now replaced contains a mapping from the names of the type variables to
# their respective (concrete) type. All we have to do is recursively replace
# every occurrence of them
self.replaceGenerics(typ, replaced)
@ -1555,7 +1550,7 @@ proc declareGenerics(self: TypeChecker, name: Name) =
file: self.currentModule.file,
depth: self.scopeDepth,
isPrivate: true,
valueType: if constraints.len() > 1: Type(kind: Union, types: constraints) else: constraints[0].kind,
valueType: if constraints.len() > 1: Type(kind: Generic, types: constraints) else: constraints[0].kind,
line: gen.ident.token.line,
)
self.addName(generic)
@ -1574,7 +1569,7 @@ proc declareGenerics(self: TypeChecker, name: Name) =
file: self.currentModule.file,
depth: self.scopeDepth,
isPrivate: true,
valueType: (if constraints.len() > 1: Type(kind: Union, types: constraints) else: constraints[0].kind).unwrapType(),
valueType: (if constraints.len() > 1: Type(kind: Generic, types: constraints) else: constraints[0].kind).unwrapType(),
line: gen.ident.token.line,
)
self.addName(generic)

View File

@ -42,7 +42,8 @@ type
Union,
Function,
Lent,
Const
Const,
Generic
Type* = ref object
## A compile-time type
@ -84,7 +85,7 @@ type
isEnum*: bool
of Reference, Pointer, Lent, Const:
value*: Type
of Union:
of Generic, Union:
types*: seq[tuple[match: bool, kind: Type, value: Expression]]
else:
discard

View File

@ -216,7 +216,7 @@ proc step(self: Parser): Token {.inline.} =
proc error(self: Parser, message: string, token: Token = nil) {.raises: [ParseError].} =
## Raises a ParseError exception
var token = if token.isNil(): self.getCurrentToken() else: token
var token = if token.isNil(): self.peek() else: token
if token.kind == EndOfFile:
token = self.peek(-1)
raise ParseError(msg: message, token: token, line: token.line, file: self.file, parser: self)
@ -1284,8 +1284,11 @@ proc dispatch(self: Parser): ASTNode =
TokenType.Foreach, TokenType.Break, TokenType.Continue, TokenType.Return,
TokenType.Import, TokenType.Export, TokenType.LeftBrace, TokenType.Block:
return self.statement()
of TokenType.Comment:
discard self.step() # TODO
else:
return self.expression()
result = self.expression()
self.expect(Semicolon, "expecting semicolon after expression")
proc findOperators(self: Parser, tokens: seq[Token]) =