diff --git a/examples/test.nds b/examples/test.nds index 4f8b210..32e0377 100644 --- a/examples/test.nds +++ b/examples/test.nds @@ -1,17 +1,5 @@ -outer = funct() { - var a = 1; - var b = 2; - var result; - var middle = funct() { - var c = 3; - var d = 4; - var inner = funct() { - print a + c + b + d; - }; - result = inner; - }; - :result = result; -}; +var argcap = funct(n) + :result = funct() print n +; -//expect:10.0 -outer()(); \ No newline at end of file +argcap(5)(); diff --git a/src/ndspkg/compiler.nim b/src/ndspkg/compiler.nim index e64e1c3..6c03ffb 100644 --- a/src/ndspkg/compiler.nim +++ b/src/ndspkg/compiler.nim @@ -414,7 +414,8 @@ proc expString(comp: Compiler) = tkString.genRule(expString, nop, pcNone) -proc addUpvalue(comp: Compiler, index: int): int = +proc addUpvalue(comp: Compiler, local: Local): int = + ## argument: local ## This proc takes an index to a local in the locals table ## and creates an upvalue in every function up until the one ## including this local so that all of the function scopes in between @@ -428,9 +429,11 @@ proc addUpvalue(comp: Compiler, index: int): int = comp.error("Too many closure variables in function.") blk - var scopeIndex = comp.locals[index].depth - var isLocal = true - var upvalIndex: int = index + var scopeIndex = local.depth + 1 + # +1 must be here, because if scope the local is in is a function + # then the upvalues should only be in child functions + var isLocal = true # local means that it's the outermost function that's closing it + var upvalIndex: int = local.index while scopeIndex < comp.scopes.len(): let scope = comp.scopes[scopeIndex] if scope.function: @@ -443,9 +446,10 @@ proc addUpvalue(comp: Compiler, index: int): int = upvalIndex = i break ensure scope.upvalues.add(Upvalue(index: upvalIndex, isLocal: isLocal)) + upvalIndex = scope.upvalues.high() isLocal = false - upvalIndex = scope.upvalues.high() scopeIndex.inc + return upvalIndex proc resolveLocal(comp: Compiler, name: string): tuple[index: int, upvalue: bool] = ## the bool arg specifies whether it found an upvalue @@ -465,7 +469,7 @@ proc resolveLocal(comp: Compiler, name: string): tuple[index: int, upvalue: bool else: # resolveUpvalue local.captured = true - return (comp.addUpvalue(i), true) + return (comp.addUpvalue(local), true) i.dec return (index: -1, upvalue: false) @@ -756,7 +760,7 @@ proc parseList(comp: Compiler) = while comp.current.tokenType != tkRightBracket: comp.expression() count.inc() - if comp.current.tokenType != tkRightBracket or comp.current.tokenType == tkComma: + if comp.current.tokenType != tkRightBracket: comp.consume(tkComma, "Comma expected after list member.") comp.consume(tkRightBracket, "Right bracket expected after list members.") if count > argMax: @@ -776,9 +780,10 @@ proc parseTable(comp: Compiler) = comp.consume(tkEqual, "Equal sign expected after key.") comp.expression() count.inc() - if comp.current.tokenType != tkRightBrace or comp.current.tokenType == tkComma: + if comp.current.tokenType != tkRightBrace: comp.consume(tkComma, "Comma expected after key-value pair.") - comp.consume(tkRightBrace, "Right brace expected after list members.") + + comp.consume(tkRightBrace, "Right brace expected after table members.") if count > argMax: comp.error("Maximum table length exceeded.") comp.writeChunk(1 - 2 * count, opCreateTable) diff --git a/src/ndspkg/types/closure.nim b/src/ndspkg/types/closure.nim index 54ab8d8..7911725 100644 --- a/src/ndspkg/types/closure.nim +++ b/src/ndspkg/types/closure.nim @@ -38,9 +38,10 @@ proc debugStr*[T](clos: Closure[T]): string = let upvalCountStr: string = $clos.upvalueCount result = &"Closure(start: {addrStr}, length: {upvalCountStr}, upvalues: " mixin `$` - for i in 0 .. clos.upvalueCount: + for i in 0 .. clos.upvalueCount-1: if clos.upvalues[i] != nil: - result &= &"{$(clos.upvalues[i].location[])}, " + let upvalStr = $(clos.upvalues[i].location[]) + result &= &"{upvalStr}, " else: result &= ", " result &= ")" diff --git a/src/ndspkg/vm.nim b/src/ndspkg/vm.nim index 6049ecc..2a657ba 100644 --- a/src/ndspkg/vm.nim +++ b/src/ndspkg/vm.nim @@ -29,6 +29,7 @@ type Frame = object stackBottom: int # the absolute index of where 0 inside the frame is returnIp: ptr uint8 + closure: Closure[NdValue] InterpretResult* = enum irOK, irRuntimeError @@ -41,7 +42,6 @@ proc run*(chunk: Chunk): InterpretResult = hadError: bool globals: Table[NdValue, NdValue] frames: Stack[Frame] = newStack[Frame](4) - closures: Stack[Closure[NdValue]] = newStack[Closure[NdValue]](4) openUpvalues: Upvalue[NdValue] = nil proc runtimeError(msg: string) = @@ -76,13 +76,14 @@ proc run*(chunk: Chunk): InterpretResult = frames.add(Frame(stackBottom: stack.high - argcount, returnIp: ip)) ip = funct.asFunct() # jump to the entry point elif funct.isClosure(): - frames.add(Frame(stackBottom: stack.high - argcount, returnIp: ip)) - closures.add(funct.asClosure()) + frames.add(Frame(stackBottom: stack.high - argcount, returnIp: ip, closure: funct.asClosure())) ip = funct.asClosure().getIp() else: error proc captureUpvalue(location: ptr NdValue): Upvalue[NdValue] = + when debugClosures: + write stdout, "CLOSURES - captureUpvalue: " var prev: Upvalue[NdValue] var upvalue = openUpvalues while upvalue != nil and upvalue.location.pgreater(location): @@ -91,9 +92,13 @@ proc run*(chunk: Chunk): InterpretResult = # existing upvalue if upvalue != nil and upvalue.location == location: + when debugClosures: + write stdout, "found existing, returning that.\n" return upvalue # new upvalue + when debugClosures: + write stdout, "creating new.\n" result = newUpvalue(location) result.next = upvalue @@ -130,12 +135,12 @@ proc run*(chunk: Chunk): InterpretResult = echo msg when debugClosures: - if closures.len() > 0: - msg = " Closures: [ " - for i in 0 .. closures.high(): - msg &= debugStr(closures[i]) & " " - msg &= "]" - echo msg + msg = " Closures: [ " + for i in 0 .. frames.high(): + if frames[i].closure != nil: + msg &= debugStr(frames[i].closure) & " " + msg &= "]" + echo msg var ii = ip.pdiff(chunk.code[0].unsafeAddr) - 1 @@ -248,15 +253,15 @@ proc run*(chunk: Chunk): InterpretResult = stack[slot + frameBottom] = stack.peek() of opGetUpvalue: let slot = ip.readDU8() - let val = closures.peek().get(slot).read() + let val = frames.peek().closure.get(slot).read() when debugClosures: echo &"CLOSURES - getupvalue got {val} from slot {slot}" stack.push(val) of opSetUpvalue: let slot = ip.readDU8() when debugClosures: - echo &"CLOSURES - setupvalue is setting {$stack.peek} to slot {slot}, number of slots: {closures.peek().upvalueCount}" - closures.peek().get(slot).write(stack.peek()) + echo &"CLOSURES - setupvalue is setting {$stack.peek} to slot {slot}, number of slots: {frames.peek().closure.upvalueCount}" + frames.peek().closure.get(slot).write(stack.peek()) of opCloseUpvalue: let slot = ip.readDU8() stack[slot + frameBottom].addr.closeUpvalues() @@ -293,7 +298,7 @@ proc run*(chunk: Chunk): InterpretResult = echo &"CLOSURES - opClosure: local upvalue {loc[]} from local slot {slot} to slot {i}" closure.set(i, loc.captureUpvalue()) else: - let val = closures.peek().get(slot) + let val = frames.peek().closure.get(slot) when debugClosures: echo &"CLOSURES - opClosure: non local upvalue {val.location[]} from slot {slot} to slot {i}" closure.set(i, val) @@ -396,7 +401,6 @@ proc run*(chunk: Chunk): InterpretResult = stack.free() frames.free() globals.free() - closures.free() if hadError: irRuntimeError diff --git a/tests/closures.nds b/tests/closures.nds index 93ac35f..c2d7872 100644 --- a/tests/closures.nds +++ b/tests/closures.nds @@ -42,25 +42,43 @@ var f = funct() { :result = x; }; -//expect:5.0 -f()[0](); -f()[1](); -//expect:6.0 -f()[0](); +var inst = f(); -// capturing the result of a function: +//expect:5.0 +inst[0](); +inst[1](); +//expect:6.0 +inst[0](); + +// multiple different labels var f2 = funct() { - var x = { @ftwo - :result = funct() { - print :ftwo; + var x = { @a @b + // this captures the internal value, not whatever it returns is assigned to + :result = @{ + "get" = funct() print :a, + "set" = funct(n) :b = n, }; }; x = 5; }; -//expect:5.0 -f2()(); +var inst2 = f2(); +inst2["get"](); +//expect:nil +inst2["set"](5.2); +inst2["get"](); +//expect:5.2 + +// capturing args + +var argcap = funct(n) + :result = funct() print n +; + +//expect:8.1 +argcap(8.1)(); + // oop: constructors, getters, setters @@ -151,11 +169,12 @@ outer = funct() { }; result = inner; }; - :result = result; + middle(); + :result = funct() :result = result; }; //expect:10.0 -outer()(); +outer()()(); // 4: manipulation of vals from closures @@ -209,12 +228,12 @@ globalGet(); print c; }; - print f(); - print g(); - print h(); - //expect:1 - //expect:2 - //expect:3 + f(); + g(); + h(); + //expect:1.0 + //expect:2.0 + //expect:3.0 }; // bonus: the last one with a list twist @@ -240,6 +259,6 @@ bonus = bonus(); bonus[2](); bonus[0](); bonus[1](); -//expect:3 -//expect:1 -//expect:2 +//expect:3.0 +//expect:1.0 +//expect:2.0 diff --git a/tests/test.nim b/tests/test.nim index 1d397b9..463dec1 100644 --- a/tests/test.nim +++ b/tests/test.nim @@ -5,6 +5,7 @@ import os import re import strutils import osproc +import terminal testHashtables() @@ -24,10 +25,24 @@ proc runTest(path: string) = let success = output == expoutput if not success: echo "Nds test failed: " & path - echo "expected output:" - echo expoutput - echo "got output:" - echo output + + let oupLines = output.split('\n') + let expLines = expoutput.split('\n') + for i in 0 .. oupLines.high(): + let oupLine = oupLines[i] + var expLine = "" + if expLines.len() > i: + expLine = expLines[i] + if oupLine == expLine: + setForegroundColor(fgGreen) + echo oupLine + else: + setForegroundColor(fgRed) + write stdout, oupLine + setForegroundColor(fgDefault) + write stdout, " (expected: " & expLine & ")\n" + + setForegroundColor(fgDefault) else: echo "Test success: " & path