Compare commits
49 Commits
master
...
chess-mail
Author | SHA1 | Date |
---|---|---|
Mattia Giambirtone | 48e2adddc6 | |
Mattia Giambirtone | aeaa57aba6 | |
Mattia Giambirtone | c9644213fe | |
Mattia Giambirtone | 4d4b12a603 | |
Mattia Giambirtone | 6153112c21 | |
Mattia Giambirtone | 2ada052460 | |
Mattia Giambirtone | f75f7533f5 | |
Mattia Giambirtone | 54a6217bd3 | |
Mattia Giambirtone | 89a96eaf52 | |
Mattia Giambirtone | f65d426ccf | |
Mattia Giambirtone | c1ac5ea5c3 | |
Mattia Giambirtone | 57353c0994 | |
Mattia Giambirtone | 77129855df | |
Mattia Giambirtone | a4954a971b | |
Mattia Giambirtone | 9047e3a53d | |
Mattia Giambirtone | 6e10cbe925 | |
Mattia Giambirtone | 3dca208123 | |
Mattia Giambirtone | 0b9b24b8e1 | |
Mattia Giambirtone | 2c58488c61 | |
Mattia Giambirtone | 75869357cc | |
Mattia Giambirtone | a9a9b917c6 | |
Mattia Giambirtone | c79af07638 | |
Mattia Giambirtone | 29a554d5da | |
Mattia Giambirtone | b9dcde1563 | |
Mattia Giambirtone | c6cc98a296 | |
Mattia Giambirtone | 60c4f28ec0 | |
Mattia Giambirtone | 82a203c98b | |
Mattia Giambirtone | 31f77fa22d | |
Mattia Giambirtone | afff1db88f | |
Mattia Giambirtone | b4ef8b4a2e | |
Mattia Giambirtone | 79477fe077 | |
Mattia Giambirtone | ca498ebc42 | |
Mattia Giambirtone | e782935fd7 | |
Mattia Giambirtone | 17f15e682c | |
Mattia Giambirtone | 942f195ddc | |
Mattia Giambirtone | eb77cf4b89 | |
Mattia Giambirtone | 1610e7b4a6 | |
Mattia Giambirtone | de0864c066 | |
Mattia Giambirtone | 25ebe7f409 | |
Mattia Giambirtone | f1c09e302e | |
Mattia Giambirtone | 56628cac27 | |
Mattia Giambirtone | 9634787746 | |
Mattia Giambirtone | ce57d06f79 | |
Mattia Giambirtone | 3da384f8fc | |
Mattia Giambirtone | 1f940a6e60 | |
Mattia Giambirtone | 85648d883c | |
Mattia Giambirtone | 24d1cd0c82 | |
Mattia Giambirtone | 7737e47f11 | |
Mattia Giambirtone | 9c191adedd |
|
@ -2,3 +2,7 @@
|
|||
nimcache/
|
||||
nimblecache/
|
||||
htmldocs/
|
||||
nim.cfg
|
||||
bin
|
||||
# Python
|
||||
__pycache__
|
||||
|
|
|
@ -7,9 +7,9 @@ in semiconductor technology smart enough to play tic tac toe.
|
|||
## Plans
|
||||
|
||||
- Tic Tac Toe (optimal) -> Done
|
||||
- Connect 4 (optinal) -> WIP
|
||||
- Connect 4 (optinal)
|
||||
- Checkers (optimal?)
|
||||
- Chess
|
||||
- Chess -> WIP
|
||||
|
||||
|
||||
All of these games will be played using decision trees searched using the minimax algorithm (maybe a bit of neural networks too, who knows).
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,60 @@
|
|||
import board as chess
|
||||
import std/strformat
|
||||
import std/strutils
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
when isMainModule:
|
||||
setControlCHook(proc () {.noconv.} = echo ""; quit(0))
|
||||
const fen = "rnbqkbnr/2p/8/8/8/8/P7/RNBQKBNR w KQkq - 0 1"
|
||||
var
|
||||
board = newChessboardFromFEN(fen)
|
||||
canCastle: tuple[queen, king: bool]
|
||||
data: string
|
||||
move: Move
|
||||
|
||||
echo "\x1Bc"
|
||||
while true:
|
||||
canCastle = board.canCastle()
|
||||
echo &"{board.pretty()}"
|
||||
echo &"Turn: {board.getActiveColor()}"
|
||||
echo &"Moves: {board.getMoveCount()} full, {board.getHalfMoveCount()} half"
|
||||
echo &"Can castle:\n - King side: {(if canCastle.king: \"yes\" else: \"no\")}\n - Queen side: {(if canCastle.queen: \"yes\" else: \"no\")}"
|
||||
stdout.write(&"En passant target: ")
|
||||
if board.getEnPassantTarget() != emptyLocation():
|
||||
echo board.getEnPassantTarget().locationToAlgebraic()
|
||||
else:
|
||||
echo "None"
|
||||
stdout.write(&"Check: ")
|
||||
if board.inCheck():
|
||||
echo &"Yes"
|
||||
else:
|
||||
echo "No"
|
||||
stdout.write("\nMove(s) -> ")
|
||||
try:
|
||||
data = readLine(stdin).strip(chars={'\0', ' '})
|
||||
except IOError:
|
||||
echo ""
|
||||
break
|
||||
if data == "undo":
|
||||
echo &"\x1BcUndo: {board.undoLastMove()}"
|
||||
continue
|
||||
if data == "reset":
|
||||
echo &"\x1BcBoard reset"
|
||||
board = newChessboardFromFEN(fen)
|
||||
continue
|
||||
for moveChars in data.split(" "):
|
||||
if len(moveChars) != 4:
|
||||
echo "\x1BcError: invalid move"
|
||||
break
|
||||
try:
|
||||
move = board.makeMove(moveChars[0..1], moveChars[2..3])
|
||||
except ValueError:
|
||||
echo &"\x1BcError: {getCurrentExceptionMsg()}"
|
||||
if move == emptyMove():
|
||||
echo &"\x1BcError: move '{moveChars}' is illegal"
|
||||
break
|
||||
else:
|
||||
echo "\x1Bc"
|
|
@ -0,0 +1,167 @@
|
|||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
from shutil import which
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
|
||||
|
||||
def main(args: Namespace) -> int:
|
||||
if args.silent:
|
||||
print = lambda *_: ...
|
||||
print("Nimfish move validator v0.0.1 by nocturn9x")
|
||||
try:
|
||||
STOCKFISH = (args.stockfish or Path(which("stockfish"))).resolve(strict=True)
|
||||
except Exception as e:
|
||||
print(f"Could not locate stockfish executable -> {type(e).__name__}: {e}")
|
||||
return 2
|
||||
try:
|
||||
NIMFISH = (args.nimfish or (Path.cwd() / "bin" / "nimfish")).resolve(strict=True)
|
||||
except Exception as e:
|
||||
print(f"Could not locate nimfish executable -> {type(e).__name__}: {e}")
|
||||
return 2
|
||||
print(f"Starting Stockfish engine at {STOCKFISH.as_posix()!r}")
|
||||
stockfish_process = subprocess.Popen(STOCKFISH,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE,
|
||||
encoding="u8",
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
print(f"Starting Nimfish engine at {NIMFISH.as_posix()!r}")
|
||||
nimfish_process = subprocess.Popen(NIMFISH,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE,
|
||||
encoding="u8",
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
print(f"Setting position to {(args.fen if args.fen else 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1')!r}")
|
||||
if args.fen:
|
||||
nimfish_process.stdin.write(f"position fen {args.fen}\n")
|
||||
stockfish_process.stdin.write(f"position fen {args.fen}\n")
|
||||
else:
|
||||
nimfish_process.stdin.write("position startpos\n")
|
||||
stockfish_process.stdin.write("position startpos\n")
|
||||
print(f"Engines started, beginning search to depth {args.ply}")
|
||||
nimfish_process.stdin.write(f"go perft {args.ply} {'bulk' if args.bulk else ''}\n")
|
||||
stockfish_process.stdin.write(f"go perft {args.ply}\n")
|
||||
stockfish_output, stockfish_error = stockfish_process.communicate()
|
||||
nimfish_output, nimfish_error = nimfish_process.communicate()
|
||||
if nimfish_process.returncode != 0:
|
||||
print(f"Nimfish crashed, stderr output below:\n{nimfish_error}")
|
||||
if stockfish_process.returncode != 0:
|
||||
print(f"Stockfish crashed, stderr below:\n{stockfish_error}")
|
||||
if not all([stockfish_process.returncode == 0, nimfish_process.returncode == 0]):
|
||||
return 3
|
||||
positions = {
|
||||
"all": {},
|
||||
"stockfish": {},
|
||||
"nimfish": {}
|
||||
}
|
||||
pattern = re.compile(r"(?P<source>[a-h][1-8])(?P<target>[a-h][1-8])(?P<promotion>b|n|q|r)?:\s(?P<nodes>[0-9]+)", re.MULTILINE)
|
||||
for (source, target, promotion, nodes) in pattern.findall(stockfish_output):
|
||||
move = f"{source}{target}{promotion}"
|
||||
positions["all"][move] = [int(nodes)]
|
||||
positions["stockfish"][move] = int(nodes)
|
||||
for (source, target, promotion, nodes) in pattern.findall(nimfish_output):
|
||||
move = f"{source}{target}{promotion}"
|
||||
if move in positions["all"]:
|
||||
positions["all"][move].append(int(nodes))
|
||||
else:
|
||||
positions["all"][move] = [int(nodes)]
|
||||
positions["nimfish"][move] = int(nodes)
|
||||
|
||||
missing = {
|
||||
# Are in nimfish but not in stockfish
|
||||
"nimfish": [],
|
||||
# Are in stockfish but not in nimfish
|
||||
"stockfish": []
|
||||
}
|
||||
# What mistakes did Nimfish do?
|
||||
mistakes = set()
|
||||
for move, nodes in positions["all"].items():
|
||||
if move not in positions["stockfish"]:
|
||||
missing["nimfish"].append(move)
|
||||
continue
|
||||
elif move not in positions["nimfish"]:
|
||||
missing["stockfish"].append(move)
|
||||
continue
|
||||
if nodes[0] != nodes[1]:
|
||||
mistakes.add(move)
|
||||
mistakes = sorted(list(mistakes))
|
||||
total_nodes = {"stockfish": sum(positions["stockfish"][move] for move in positions["stockfish"]),
|
||||
"nimfish": sum(positions["nimfish"][move] for move in positions["nimfish"])}
|
||||
total_difference = total_nodes["stockfish"] - total_nodes["nimfish"]
|
||||
print(f"Stockfish searched {total_nodes['stockfish']} node{'' if total_nodes['stockfish'] == 1 else 's'}")
|
||||
print(f"Nimfish searched {total_nodes['nimfish']} node{'' if total_nodes['nimfish'] == 1 else 's'}")
|
||||
|
||||
if total_difference > 0:
|
||||
print(f"Nimfish searched {total_difference} fewer node{'' if total_difference == 1 else 's'} than Stockfish")
|
||||
elif total_difference < 0:
|
||||
total_difference = abs(total_difference)
|
||||
print(f"Nimfish searched {total_difference} more node{'' if total_difference == 1 else 's'} than Stockfish")
|
||||
else:
|
||||
print("Node count is identical")
|
||||
pattern = re.compile(r"(?:\s\s-\sCaptures:\s(?P<captures>[0-9]+))\n"
|
||||
r"(?:\s\s-\sChecks:\s(?P<checks>[0-9]+))\n"
|
||||
r"(?:\s\s-\sE\.P:\s(?P<enPassant>[0-9]+))\n"
|
||||
r"(?:\s\s-\sCheckmates:\s(?P<checkmates>[0-9]+))\n"
|
||||
r"(?:\s\s-\sCastles:\s(?P<castles>[0-9]+))\n"
|
||||
r"(?:\s\s-\sPromotions:\s(?P<promotions>[0-9]+))",
|
||||
re.MULTILINE)
|
||||
extra: re.Match | None = None
|
||||
if not args.bulk:
|
||||
extra = pattern.search(nimfish_output)
|
||||
missed_total = len(missing['stockfish']) + len(missing['nimfish'])
|
||||
if missing["stockfish"] or missing["nimfish"] or mistakes:
|
||||
print(f"Found {missed_total} missed move{'' if missed_total == 1 else 's'} and {len(mistakes)} counting mistake{'' if len(mistakes) == 1 else 's'}, more info below: ")
|
||||
if args.bulk:
|
||||
print("Note: Nimfish was run in bulk-counting mode, so a detailed breakdown of each move type is not available. "
|
||||
"To fix this, re-run the program without the --bulk option")
|
||||
if extra:
|
||||
print(f" Breakdown by move type:")
|
||||
print(f" - Captures: {extra.group('captures')}")
|
||||
print(f" - Checks: {extra.group('checks')}")
|
||||
print(f" - En Passant: {extra.group('enPassant')}")
|
||||
print(f" - Checkmates: {extra.group('checkmates')}")
|
||||
print(f" - Castles: {extra.group('castles')}")
|
||||
print(f" - Promotions: {extra.group('promotions')}")
|
||||
|
||||
elif not args.bulk:
|
||||
print("Unable to locate move breakdown in Nimfish output")
|
||||
if missing["stockfish"] or missing["nimfish"]:
|
||||
print("\n Move count breakdown:")
|
||||
if missing["stockfish"]:
|
||||
print(" Legal moves missed: ")
|
||||
for move in missing["stockfish"]:
|
||||
print(f" - {move}: {positions['stockfish'][move]}")
|
||||
if missing["nimfish"]:
|
||||
print("\n Illegal moves generated: ")
|
||||
for move in missing["nimfish"]:
|
||||
print(f" - {move}: {positions['nimfish'][move]}")
|
||||
if mistakes:
|
||||
print("\n Counting mistakes made:")
|
||||
for move in mistakes:
|
||||
missed = positions["stockfish"][move] - positions["nimfish"][move]
|
||||
print(f" - {move}: expected {positions['stockfish'][move]}, got {positions['nimfish'][move]} ({'-' if missed > 0 else '+'}{abs(missed)})")
|
||||
return 1
|
||||
else:
|
||||
print("No discrepancies detected")
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Automatically compare perft results between Nimfish and Stockfish")
|
||||
parser.add_argument("--fen", "-f", type=str, default="", help="The FEN string of the position to start from (empty string means the initial one). Defaults to ''")
|
||||
parser.add_argument("--ply", "-d", type=int, required=True, help="The depth to stop at, expressed in plys (half-moves)")
|
||||
parser.add_argument("--bulk", action="store_true", help="Enable bulk-counting for Nimfish (much faster)", default=False)
|
||||
parser.add_argument("--stockfish", type=Path, help="Path to the stockfish executable. Defaults to '' (detected automatically)", default=None)
|
||||
parser.add_argument("--nimfish", type=Path, help="Path to the nimfish executable. Defaults to '' (detected automatically)", default=None)
|
||||
parser.add_argument("--silent", action="store_true", help="Disable all output (a return code of 0 means the test was successful)", default=False)
|
||||
sys.exit(main(parser.parse_args()))
|
|
@ -0,0 +1,150 @@
|
|||
1r2k2r/8/8/8/8/8/8/R3K2R b KQk - 0 1
|
||||
1r2k2r/8/8/8/8/8/8/R3K2R w KQk - 0 1
|
||||
2K2r2/4P3/8/8/8/8/8/3k4 w - - 0 1
|
||||
2kr3r/p1ppqpb1/bn2Qnp1/3PN3/1p2P3/2N5/PPPBBPPP/R3K2R b KQ - 3 2
|
||||
2r1k2r/8/8/8/8/8/8/R3K2R b KQk - 0 1
|
||||
2r1k2r/8/8/8/8/8/8/R3K2R w KQk - 0 1
|
||||
2r5/3pk3/8/2P5/8/2K5/8/8 w - - 5 4
|
||||
3k4/3p4/8/K1P4r/8/8/8/8 b - - 0 1
|
||||
3k4/3pp3/8/8/8/8/3PP3/3K4 b - - 0 1
|
||||
3k4/3pp3/8/8/8/8/3PP3/3K4 w - - 0 1
|
||||
3k4/8/8/8/8/8/8/R3K3 w Q - 0 1
|
||||
4k2r/6K1/8/8/8/8/8/8 b k - 0 1
|
||||
4k2r/6K1/8/8/8/8/8/8 w k - 0 1
|
||||
4k2r/8/8/8/8/8/8/4K3 b k - 0 1
|
||||
4k2r/8/8/8/8/8/8/4K3 w k - 0 1
|
||||
4k3/1P6/8/8/8/8/K7/8 w - - 0 1
|
||||
4k3/4p3/4K3/8/8/8/8/8 b - - 0 1
|
||||
4k3/8/8/8/8/8/8/4K2R b K - 0 1
|
||||
4k3/8/8/8/8/8/8/4K2R w K - 0 1
|
||||
4k3/8/8/8/8/8/8/R3K2R b KQ - 0 1
|
||||
4k3/8/8/8/8/8/8/R3K2R w KQ - 0 1
|
||||
4k3/8/8/8/8/8/8/R3K3 b Q - 0 1
|
||||
4k3/8/8/8/8/8/8/R3K3 w Q - 0 1
|
||||
5k2/8/8/8/8/8/8/4K2R w K - 0 1
|
||||
6KQ/8/8/8/8/8/8/7k b - - 0 1
|
||||
6kq/8/8/8/8/8/8/7K w - - 0 1
|
||||
6qk/8/8/8/8/8/8/7K b - - 0 1
|
||||
7k/3p4/8/8/3P4/8/8/K7 b - - 0 1
|
||||
7k/3p4/8/8/3P4/8/8/K7 w - - 0 1
|
||||
7K/7p/7k/8/8/8/8/8 b - - 0 1
|
||||
7K/7p/7k/8/8/8/8/8 w - - 0 1
|
||||
7k/8/1p6/8/8/P7/8/7K b - - 0 1
|
||||
7k/8/1p6/8/8/P7/8/7K w - - 0 1
|
||||
7k/8/8/1p6/P7/8/8/7K b - - 0 1
|
||||
7k/8/8/1p6/P7/8/8/7K w - - 0 1
|
||||
7k/8/8/3p4/8/8/3P4/K7 b - - 0 1
|
||||
7k/8/8/3p4/8/8/3P4/K7 w - - 0 1
|
||||
7k/8/8/p7/1P6/8/8/7K b - - 0 1
|
||||
7k/8/8/p7/1P6/8/8/7K w - - 0 1
|
||||
7k/8/p7/8/8/1P6/8/7K b - - 0 1
|
||||
7k/8/p7/8/8/1P6/8/7K w - - 0 1
|
||||
7k/RR6/8/8/8/8/rr6/7K b - - 0 1
|
||||
7k/RR6/8/8/8/8/rr6/7K w - - 0 1
|
||||
8/1k6/8/5N2/8/4n3/8/2K5 b - - 0 1
|
||||
8/1k6/8/5N2/8/4n3/8/2K5 w - - 0 1
|
||||
8/1n4N1/2k5/8/8/5K2/1N4n1/8 b - - 0 1
|
||||
8/1n4N1/2k5/8/8/5K2/1N4n1/8 w - - 0 1
|
||||
8/2k1p3/3pP3/3P2K1/8/8/8/8 b - - 0 1
|
||||
8/2k1p3/3pP3/3P2K1/8/8/8/8 w - - 0 1
|
||||
8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1
|
||||
8/3k4/3p4/8/3P4/3K4/8/8 b - - 0 1
|
||||
8/3k4/3p4/8/3P4/3K4/8/8 w - - 0 1
|
||||
8/8/1B6/7b/7k/8/2B1b3/7K b - - 0 1
|
||||
8/8/1B6/7b/7k/8/2B1b3/7K w - - 0 1
|
||||
8/8/1k6/2b5/2pP4/8/5K2/8 b - d3 0 1
|
||||
8/8/1P2K3/8/2n5/1q6/8/5k2 b - - 0 1
|
||||
8/8/2k5/5q2/5n2/8/5K2/8 b - - 0 1
|
||||
8/8/3K4/3Nn3/3nN3/4k3/8/8 b - - 0 1
|
||||
8/8/3k4/3p4/3P4/3K4/8/8 b - - 0 1
|
||||
8/8/3k4/3p4/3P4/3K4/8/8 w - - 0 1
|
||||
8/8/3k4/3p4/8/3P4/3K4/8 b - - 0 1
|
||||
8/8/3k4/3p4/8/3P4/3K4/8 w - - 0 1
|
||||
8/8/4k3/3Nn3/3nN3/4K3/8/8 w - - 0 1
|
||||
8/8/4k3/8/2p5/8/B2P2K1/8 w - - 0 1
|
||||
8/8/7k/7p/7P/7K/8/8 b - - 0 1
|
||||
8/8/7k/7p/7P/7K/8/8 w - - 0 1
|
||||
8/8/8/2k5/2pP4/8/B7/4K3 b - d3 0 3
|
||||
8/8/8/8/8/4k3/4P3/4K3 w - - 0 1
|
||||
8/8/8/8/8/7K/7P/7k b - - 0 1
|
||||
8/8/8/8/8/7K/7P/7k w - - 0 1
|
||||
8/8/8/8/8/8/1k6/R3K3 b Q - 0 1
|
||||
8/8/8/8/8/8/1k6/R3K3 w Q - 0 1
|
||||
8/8/8/8/8/8/6k1/4K2R b K - 0 1
|
||||
8/8/8/8/8/8/6k1/4K2R w K - 0 1
|
||||
8/8/8/8/8/K7/P7/k7 b - - 0 1
|
||||
8/8/8/8/8/K7/P7/k7 w - - 0 1
|
||||
8/8/k7/p7/P7/K7/8/8 b - - 0 1
|
||||
8/8/k7/p7/P7/K7/8/8 w - - 0 1
|
||||
8/k1P5/8/1K6/8/8/8/8 w - - 0 1
|
||||
8/P1k5/K7/8/8/8/8/8 w - - 0 1
|
||||
8/Pk6/8/8/8/8/6Kp/8 b - - 0 1
|
||||
8/Pk6/8/8/8/8/6Kp/8 w - - 0 1
|
||||
8/PPPk4/8/8/8/8/4Kppp/8 b - - 0 1
|
||||
8/PPPk4/8/8/8/8/4Kppp/8 w - - 0 1
|
||||
B6b/8/8/8/2K5/4k3/8/b6B w - - 0 1
|
||||
B6b/8/8/8/2K5/5k2/8/b6B b - - 0 1
|
||||
K1k5/8/P7/8/8/8/8/8 w - - 0 1
|
||||
k7/6p1/8/8/8/8/7P/K7 b - - 0 1
|
||||
k7/6p1/8/8/8/8/7P/K7 w - - 0 1
|
||||
k7/7p/8/8/8/8/6P1/K7 b - - 0 1
|
||||
k7/7p/8/8/8/8/6P1/K7 w - - 0 1
|
||||
K7/8/2n5/1n6/8/8/8/k6N b - - 0 1
|
||||
k7/8/2N5/1N6/8/8/8/K6n b - - 0 1
|
||||
K7/8/2n5/1n6/8/8/8/k6N w - - 0 1
|
||||
k7/8/2N5/1N6/8/8/8/K6n w - - 0 1
|
||||
k7/8/3p4/8/3P4/8/8/7K b - - 0 1
|
||||
k7/8/3p4/8/3P4/8/8/7K w - - 0 1
|
||||
k7/8/3p4/8/8/4P3/8/7K b - - 0 1
|
||||
k7/8/3p4/8/8/4P3/8/7K w - - 0 1
|
||||
k7/8/6p1/8/8/7P/8/K7 b - - 0 1
|
||||
k7/8/6p1/8/8/7P/8/K7 w - - 0 1
|
||||
k7/8/7p/8/8/6P1/8/K7 b - - 0 1
|
||||
k7/8/7p/8/8/6P1/8/K7 w - - 0 1
|
||||
k7/8/8/3p4/4p3/8/8/7K b - - 0 1
|
||||
k7/8/8/3p4/4p3/8/8/7K w - - 0 1
|
||||
K7/8/8/3Q4/4q3/8/8/7k b - - 0 1
|
||||
K7/8/8/3Q4/4q3/8/8/7k w - - 0 1
|
||||
k7/8/8/6p1/7P/8/8/K7 b - - 0 1
|
||||
k7/8/8/6p1/7P/8/8/K7 w - - 0 1
|
||||
k7/8/8/7p/6P1/8/8/K7 b - - 0 1
|
||||
k7/8/8/7p/6P1/8/8/K7 w - - 0 1
|
||||
k7/B7/1B6/1B6/8/8/8/K6b b - - 0 1
|
||||
K7/b7/1b6/1b6/8/8/8/k6B b - - 0 1
|
||||
k7/B7/1B6/1B6/8/8/8/K6b w - - 0 1
|
||||
K7/b7/1b6/1b6/8/8/8/k6B w - - 0 1
|
||||
K7/p7/k7/8/8/8/8/8 b - - 0 1
|
||||
K7/p7/k7/8/8/8/8/8 w - - 0 1
|
||||
n1n5/1Pk5/8/8/8/8/5Kp1/5N1N b - - 0 1
|
||||
n1n5/1Pk5/8/8/8/8/5Kp1/5N1N w - - 0 1
|
||||
n1n5/PPPk4/8/8/8/8/4Kppp/5N1N b - - 0 1
|
||||
n1n5/PPPk4/8/8/8/8/4Kppp/5N1N w - - 0 1
|
||||
r1bqkbnr/pppppppp/n7/8/8/P7/1PPPPPPP/RNBQKBNR w KQkq - 2 2
|
||||
r3k1r1/8/8/8/8/8/8/R3K2R b KQq - 0 1
|
||||
r3k1r1/8/8/8/8/8/8/R3K2R w KQq - 0 1
|
||||
r3k2r/1b4bq/8/8/8/8/7B/R3K2R w KQkq - 0 1
|
||||
r3k2r/8/3Q4/8/8/5q2/8/R3K2R b KQkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/1R2K2R b Kkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/1R2K2R w Kkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/2R1K2R b Kkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/2R1K2R w Kkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/4K3 b kq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/4K3 w kq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/R3K1R1 b Qkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/R3K1R1 w Qkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/R3K2R b KQkq - 0 1
|
||||
r3k2r/8/8/8/8/8/8/R3K2R w KQkq - 0 1
|
||||
r3k2r/p1pp1pb1/bn2Qnp1/2qPN3/1p2P3/2N5/PPPBBPPP/R3K2R b KQkq - 3 2
|
||||
r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1
|
||||
r3k3/1K6/8/8/8/8/8/8 b q - 0 1
|
||||
r3k3/1K6/8/8/8/8/8/8 w q - 0 1
|
||||
r3k3/8/8/8/8/8/8/4K3 b q - 0 1
|
||||
r3k3/8/8/8/8/8/8/4K3 w q - 0 1
|
||||
r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10
|
||||
r6r/1b2k1bq/8/8/7B/8/8/R3K2R b KQ - 3 2
|
||||
R6r/8/8/2K5/5k2/8/8/r6R b - - 0 1
|
||||
R6r/8/8/2K5/5k2/8/8/r6R w - - 0 1
|
||||
rnb2k1r/pp1Pbppp/2p5/q7/2B5/8/PPPQNnPP/RNB1K2R w KQ - 3 9
|
||||
rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8
|
||||
rnbqkb1r/ppppp1pp/7n/4Pp2/8/8/PPPP1PPP/RNBQKBNR w KQkq f6 0 3
|
||||
rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
|
|
@ -0,0 +1,43 @@
|
|||
import sys
|
||||
import timeit
|
||||
from pathlib import Path
|
||||
from argparse import Namespace, ArgumentParser
|
||||
from compare_positions import main as test
|
||||
|
||||
|
||||
|
||||
def main(args: Namespace) -> int:
|
||||
print("[S] Starting test suite")
|
||||
successful = []
|
||||
failed = []
|
||||
positions = args.positions.read_text().splitlines()
|
||||
start = timeit.default_timer()
|
||||
longest_fen = max(sorted([len(fen) for fen in positions]))
|
||||
for i, fen in enumerate(positions):
|
||||
fen = fen.strip(" ")
|
||||
fen += " " * (longest_fen - len(fen))
|
||||
sys.stdout.write(f"\r[S] Testing {fen} ({i + 1}/{len(positions)})\033[K")
|
||||
args.fen = fen
|
||||
args.silent = not args.no_silent
|
||||
if test(args) == 0:
|
||||
successful.append(fen)
|
||||
else:
|
||||
failed.append(fen)
|
||||
stop = timeit.default_timer()
|
||||
print(f"\r[S] Ran {len(positions)} tests at depth {args.ply} in {stop - start:.2f} seconds ({len(successful)} successful, {len(failed)} failed)\033[K")
|
||||
if failed and args.show_failures:
|
||||
print("[S] The following FENs failed to pass the test:", end="\n\t")
|
||||
print("\n\t".join(failed))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Run a set of tests using compare_positions.py")
|
||||
parser.add_argument("--ply", "-d", type=int, required=True, help="The depth to stop at, expressed in plys (half-moves)")
|
||||
parser.add_argument("--bulk", action="store_true", help="Enable bulk-counting for Nimfish (much faster)", default=False)
|
||||
parser.add_argument("--stockfish", type=Path, help="Path to the stockfish executable. Defaults to '' (detected automatically)", default=None)
|
||||
parser.add_argument("--nimfish", type=Path, help="Path to the nimfish executable. Defaults to '' (detected automatically)", default=None)
|
||||
parser.add_argument("--positions", type=Path, help="Location of the file containing FENs to test, one per line. Defaults to 'tests/positions.txt'",
|
||||
default=Path("tests/positions.txt"))
|
||||
parser.add_argument("--no-silent", action="store_true", help="Do not suppress output from compare_positions.py (defaults)", default=False)
|
||||
parser.add_argument("--show-failures", action="store_true", help="Show which FENs failed to pass the test", default=False)
|
||||
sys.exit(main(parser.parse_args()))
|
|
@ -28,9 +28,9 @@ export multibyte
|
|||
type
|
||||
TileKind* = enum
|
||||
## A tile enumeration kind
|
||||
Empty = 0,
|
||||
Self,
|
||||
Enemy
|
||||
TileEmpty = 0,
|
||||
TileX,
|
||||
TileO
|
||||
GameStatus* = enum
|
||||
## A game status enumeration
|
||||
Playing,
|
||||
|
@ -185,18 +185,35 @@ proc get*(self: TicTacToeGame): GameStatus =
|
|||
return Playing
|
||||
|
||||
|
||||
proc winner*(self: TicTacToeGame): TileKind =
|
||||
## Returns the tile of the winner (TileEmpty
|
||||
## is returned if the game is still in progress
|
||||
## or ended in a draw)
|
||||
let status = self.get()
|
||||
if status in [Playing, Draw]:
|
||||
return TileEmpty
|
||||
if status == WinX:
|
||||
return TileX
|
||||
return TileO
|
||||
|
||||
|
||||
proc `$`*(self: TileKind): string =
|
||||
case self:
|
||||
of TileEmpty:
|
||||
return "_"
|
||||
of TileX:
|
||||
return "X"
|
||||
of TileO:
|
||||
return "O"
|
||||
|
||||
|
||||
proc `$`*(self: TicTacToeGame): string =
|
||||
## Stringifies self
|
||||
result &= "-----------\n"
|
||||
for i, row in self.map:
|
||||
result &= "| "
|
||||
for j, e in row:
|
||||
if e == 0:
|
||||
result &= "_"
|
||||
elif e == 1:
|
||||
result &= "X"
|
||||
else:
|
||||
result &= "O"
|
||||
result &= $TileKind(e)
|
||||
if j == 2:
|
||||
result &= " |"
|
||||
else:
|
||||
|
@ -262,12 +279,12 @@ proc generateMoves*(map: Matrix[uint8], turn: TileKind, depth: int = 0): Move =
|
|||
return
|
||||
for row in 0..<map.shape.rows:
|
||||
for col in 0..<map.shape.cols:
|
||||
if TileKind(result.state[row, col]) == Empty:
|
||||
if TileKind(result.state[row, col]) == TileEmpty:
|
||||
var copy = result.state.copy()
|
||||
copy[row, col] = turn.uint8()
|
||||
let index = row * map.shape.cols + col
|
||||
new(result.next[index])
|
||||
result.next[index] = generateMoves(copy, if turn == Self: Enemy else: Self, depth + 1)
|
||||
result.next[index] = generateMoves(copy, if turn == TileX: TileO else: TileX, depth + 1)
|
||||
|
||||
|
||||
# This variant is suboptimal, but useful to build a bot that isn't
|
||||
|
@ -309,22 +326,24 @@ proc findBest*(tree: Move, map: Matrix[int]): Move =
|
|||
]#
|
||||
|
||||
|
||||
proc findBest*(tree: Move, maximize: bool = true, skip: int = 0): Choice =
|
||||
## Finds the best possible move in the
|
||||
## given playing field using minimax
|
||||
## tree search. The first skip best
|
||||
## results (default 0) are skipped.
|
||||
if tree.outcome == WinX:
|
||||
return Choice(move: tree, weight: 10 - tree.depth)
|
||||
elif tree.outcome == WinO:
|
||||
return Choice(move: tree, weight: -10 + tree.depth)
|
||||
elif tree.outcome == Draw:
|
||||
proc findBest*(tree: Move, maximize: bool = true, skip: int = 0, turn: TileKind): Choice =
|
||||
## Finds the best possible move for the given
|
||||
## turn in the given playing field using minimax
|
||||
## tree search. The first skip best results
|
||||
## (default 0) are skipped.
|
||||
if tree.outcome == Draw:
|
||||
return Choice(move: tree, weight: 0)
|
||||
let winner = tree.state.asGame().winner()
|
||||
if winner == turn:
|
||||
return Choice(move: tree, weight: 10 + tree.depth)
|
||||
elif winner != TileEmpty:
|
||||
# Means the other side won
|
||||
return Choice(move: tree, weight: -10 + tree.depth)
|
||||
var choices: seq[Choice] = @[]
|
||||
for i in 0..8:
|
||||
if tree.next[i].isNil():
|
||||
continue
|
||||
choices.add(tree.next[i].findBest(maximize=not maximize))
|
||||
choices.add(tree.next[i].findBest(maximize=not maximize, turn=if turn == TileX: TileO else: TileX))
|
||||
choices[^1].move = tree.next[i]
|
||||
var best: Choice
|
||||
var bestWeight: int = 100
|
||||
|
|
|
@ -32,58 +32,71 @@ template clearScreen =
|
|||
setCursorPos(0, 0)
|
||||
|
||||
|
||||
proc play(treeA, treeB: Move) =
|
||||
proc play(moves: Move) =
|
||||
## Plays a game of tic tac toe
|
||||
## against the user
|
||||
clearScreen()
|
||||
var game = newTicTacToe()
|
||||
var moves = treeA
|
||||
var moves = moves
|
||||
var location: tuple[row, col: int]
|
||||
var index: int
|
||||
var self, enemy: TileKind
|
||||
stdout.styledWrite(fgGreen, styleBright, "Wanna start first? ", fgYellow ,"[Y/n] ")
|
||||
if readLine(stdin).strip(chars={'\n'}).toLowerAscii() in ["n", "no"]:
|
||||
moves = treeB
|
||||
location = where(moves.state, moves.state != game.map, 3).index(Self.uint8)
|
||||
game.place(Self, location.row, location.col)
|
||||
if readLine(stdin).strip(chars={'\n'}).toLowerAscii() notin ["y", "yes"]:
|
||||
self = TileO
|
||||
enemy = TileX
|
||||
else:
|
||||
self = TileX
|
||||
enemy = TileO
|
||||
location = where(moves.state, moves.state != game.map, 3).index(self.uint8)
|
||||
game.place(self, location.row, location.col)
|
||||
clearScreen()
|
||||
styledEcho fgCyan, styleBright, "Computer chose ", fgYellow, $game.map.getIndex(location.row, location.col)
|
||||
else:
|
||||
clearScreen()
|
||||
clearScreen()
|
||||
while game.get() == Playing:
|
||||
styledEcho fgBlue, styleBright, "Tic Tac Bot v1.0"
|
||||
echo game, "\n"
|
||||
styledEcho fgMagenta, styleBright, "You are ", fgBlue, "O"
|
||||
stdout.styledWrite(fgRed, styleBright, "Make your move ", fgBlue, "(", fgYellow, "0", fgGreen, "~", fgYellow, "8", fgBlue, ")", fgRed, ": ")
|
||||
styledEcho fgMagenta, styleBright, "You are ", fgBlue, $TileKind(enemy)
|
||||
stdout.styledWrite(fgRed, styleBright, "Make your move ", fgBlue, "(", fgYellow, "1", fgGreen, "~", fgYellow, "8", fgBlue, ")", fgRed, ": ")
|
||||
flushFile(stdout)
|
||||
try:
|
||||
index = int(parseBiggestInt(readLine(stdin).strip(chars={'\n'})))
|
||||
location = ind2sub(index, game.map.shape)
|
||||
dec(index)
|
||||
except ValueError:
|
||||
clearScreen()
|
||||
styledEcho fgRed, styleBright, "Invalid move"
|
||||
continue
|
||||
if index notin 0..8 or TileKind(game.map[location.row, location.col]) != Empty:
|
||||
if index notin 0..8 or TileKind(game.map[location.row, location.col]) != TileEmpty:
|
||||
clearScreen()
|
||||
styledEcho fgRed, styleBright, "Invalid move"
|
||||
continue
|
||||
game.place(Enemy, location.row, location.col)
|
||||
game.place(enemy, location.row, location.col)
|
||||
clearScreen()
|
||||
if game.get() == WinO:
|
||||
if game.winner() == enemy:
|
||||
echo game, "\n"
|
||||
styledEcho fgGreen, styleBright, "Human wins!"
|
||||
return
|
||||
elif game.get() == Draw:
|
||||
break
|
||||
moves = moves.next[game.map.getIndex(location.row, location.col)].findBest(true).move
|
||||
location = where(moves.state, moves.state != game.map, 3).index(Self.uint8)
|
||||
game.place(Self, location.row, location.col)
|
||||
clearScreen()
|
||||
if game.get() != Draw:
|
||||
styledEcho fgCyan, styleBright, "Computer chose ", fgYellow, $game.map.getIndex(location.row, location.col)
|
||||
if game.get() == WinX:
|
||||
if game.winner() == self:
|
||||
echo game, "\n"
|
||||
styledEcho fgRed, styleBright, "Computer wins!"
|
||||
return
|
||||
if game.get() == Draw:
|
||||
break
|
||||
# Find best move and advance move tree
|
||||
moves = moves.next[game.map.getIndex(location.row, location.col)].findBest(true, turn=self).move
|
||||
location = where(moves.state, moves.state != game.map, 3).index(self.uint8)
|
||||
game.place(self, location.row, location.col)
|
||||
clearScreen()
|
||||
if game.get() != Draw:
|
||||
styledEcho fgCyan, styleBright, "Computer chose ", fgYellow, $(game.map.getIndex(location.row, location.col) + 1)
|
||||
|
||||
if game.winner() == enemy:
|
||||
echo game, "\n"
|
||||
styledEcho fgGreen, styleBright, "Human wins!"
|
||||
return
|
||||
if game.get() == Draw:
|
||||
break
|
||||
echo game
|
||||
styledEcho fgYellow, styleBright, "It's a draw!"
|
||||
|
||||
|
@ -102,46 +115,32 @@ when isMainModule:
|
|||
var path = getCacheDir() / "ttb"
|
||||
path.createDir()
|
||||
path = path / "cache.bin"
|
||||
var movesA: Move
|
||||
var movesB: Move
|
||||
# Since generating two full trees is pretty expensive, we cache
|
||||
# them the first time we generate them (since it's not like they
|
||||
# change anyway) so that we don't have to regenerate them every time
|
||||
var moves: Move
|
||||
# Since generating the tree is pretty expensive, we cache
|
||||
# it the first time we generate it (since it's not like it changes
|
||||
# anyway) so that we don't have to rebuild it every time
|
||||
if not fileExists(path):
|
||||
styledEcho fgCyan, styleBright, "Generating move trees..."
|
||||
# Sadly we need to generate two trees for both cases where either we or
|
||||
# our opponent have the first turn, as the states between them are not
|
||||
# interchangeable at all (trust me, I tried)
|
||||
movesA = generateMoves(build(@[uint8(0), 0, 0, 0, 0, 0, 0, 0, 0]).map, Enemy)
|
||||
movesB = generateMoves(build(@[uint8(0), 0, 0, 0, 0, 0, 0, 0, 0]).map, Self)
|
||||
styledEcho fgCyan, styleBright, "Caching results to disk..."
|
||||
styledEcho fgCyan, styleBright, "Generating move tree"
|
||||
moves = generateMoves(build(@[uint8(0), 0, 0, 0, 0, 0, 0, 0, 0]).map, TileX)
|
||||
styledEcho fgCyan, styleBright, "Caching data to disk..."
|
||||
var fp = open(path, fmWrite)
|
||||
discard fp.writeBytes(movesA.dumpBytes(), 0, 8799135)
|
||||
discard fp.writeBytes(movesB.dumpBytes(), 0, 8799135)
|
||||
discard fp.writeBytes(moves.dumpBytes(), 0, 8799135)
|
||||
fp.close()
|
||||
else:
|
||||
styledEcho fgCyan, styleBright, "Loading previously cached move trees..."
|
||||
styledEcho fgCyan, styleBright, "Loading previously cached move tree"
|
||||
var fp = open(path, fmRead)
|
||||
var data: seq[byte] = @[]
|
||||
for _ in 0..<8799135:
|
||||
data.add(byte(0))
|
||||
discard fp.readBytes(data, 0, 8799135)
|
||||
movesA = data.loadBytes()
|
||||
discard fp.readBytes(data, 0, 8799135)
|
||||
movesB = data.loadBytes()
|
||||
moves = data.loadBytes()
|
||||
fp.close()
|
||||
# Here we pick one of the first 5 best moves so that the bot doesn't
|
||||
# always start with an X in the left corner when it's playing first
|
||||
var best: seq[Move] = @[]
|
||||
for i in 0..4:
|
||||
best.add(movesB.findBest(true, i).move)
|
||||
movesB = sample(best)
|
||||
while true:
|
||||
try:
|
||||
play(movesA, movesB)
|
||||
play(moves)
|
||||
stdout.styledWrite(fgGreen, styleBright, "Again? ", fgYellow ,"[Y/n] ")
|
||||
flushFile(stdout)
|
||||
if readLine(stdin).strip(chars={'\n'}).toLowerAscii() in ["no", "n"]:
|
||||
if readLine(stdin).strip(chars={'\n'}).toLowerAscii() notin ["no", "n"]:
|
||||
break
|
||||
except IOError:
|
||||
break
|
||||
|
|
|
@ -995,30 +995,34 @@ when isMainModule:
|
|||
|
||||
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
||||
var k = m.transpose()
|
||||
assert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
assert all(m.transpose() == k), "transpose mismatch"
|
||||
assert k.sum() == m.sum(), "element sum mismatch"
|
||||
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
doAssert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
doAssert all(m.transpose() == k), "transpose mismatch"
|
||||
doAssert k.sum() == m.sum(), "element sum mismatch"
|
||||
doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
var y = newMatrix[int](@[1, 2, 3, 4])
|
||||
assert y.sum() == 10, "element sum mismatch"
|
||||
assert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
assert all(m + m == m * 2), "m + m != m * 2"
|
||||
doAssert y.sum() == 10, "element sum mismatch"
|
||||
doAssert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
doAssert all(m + m == m * 2), "m + m != m * 2"
|
||||
var z = newMatrix[int](@[1, 2, 3])
|
||||
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
doAssert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
assert x.max() == 9, "max mismatch"
|
||||
assert x.argmax() == 10, "argmax mismatch"
|
||||
assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
assert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
assert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
assert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
assert m.diag(3).len() == 0
|
||||
doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
doAssert x.max() == 9, "max mismatch"
|
||||
doAssert x.argmax() == 10, "argmax mismatch"
|
||||
doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
doAssert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
doAssert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
doAssert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
doAssert m.diag(3).len() == 0
|
||||
var j = m.fliplr()
|
||||
assert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
assert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
assert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
doAssert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
doAssert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
# A little test for the softmax function
|
||||
var mat = newMatrix[float](@[123.0, 456.0, 789.0])
|
||||
mat = mat - mat.max()
|
||||
doAssert (mat.apply(math.exp, axis = -1) / sum(mat.apply(math.exp, axis = -1))).sum() == 1.0
|
Loading…
Reference in New Issue