Switch to static move list and print nps after perft completion

This commit is contained in:
Mattia Giambirtone 2024-04-16 09:05:35 +02:00
parent aeaa57aba6
commit 48e2adddc6
2 changed files with 59 additions and 27 deletions

View File

@ -16,6 +16,7 @@ import std/strutils
import std/strformat
import std/times
import std/math
from std/lenientops import `/` # Only needed for perft
type
@ -73,6 +74,11 @@ type
targetSquare*: Location
flags*: uint16
MoveList* = object
## A list of moves
data: array[218, Move]
len: int8
Position* = ref object
## A chess position
@ -112,10 +118,31 @@ type
positions: seq[Position]
iterator items(self: MoveList): Move =
var i = 0
while self.len > i:
yield self.data[i]
inc(i)
func add(self: var MoveList, move: Move) {.inline.} =
self.data[self.len] = move
inc(self.len)
func contains(self: MoveList, move: Move): bool {.inline.} =
for item in self:
if move == item:
return true
return false
func len(self: MoveList): int {.inline.} = self.len
# A bunch of simple utility functions and forward declarations
func emptyPiece*: Piece {.inline.} = Piece(kind: Empty, color: None)
func emptyLocation*: Location {.inline.} = (-1 , -1)
func emptyLocation*: Location {.inline.} = (-1, -1)
func opposite*(c: PieceColor): PieceColor {.inline.} = (if c == White: Black else: White)
proc algebraicToLocation*(s: string): Location {.inline.}
proc makeMove*(self: ChessBoard, move: Move): Move {.discardable.}
@ -124,7 +151,7 @@ func `+`*(a, b: Location): Location = (a.row + b.row, a.col + b.col)
func `-`*(a: Location): Location = (-a.row, -a.col)
func `-`*(a, b: Location): Location = (a.row - b.row, a.col - b.col)
func isValid*(a: Location): bool {.inline.} = a.row in 0..7 and a.col in 0..7
proc generateMoves(self: ChessBoard, location: Location): seq[Move]
proc generateMoves(self: ChessBoard, location: Location, moves: var MoveList)
proc getAttackers*(self: ChessBoard, loc: Location, color: PieceColor): seq[Location]
proc getAttackFor*(self: ChessBoard, source, target: Location): tuple[source, target, direction: Location]
proc isAttacked*(self: ChessBoard, loc: Location, color: PieceColor = None): bool
@ -829,7 +856,7 @@ proc getCheckResolutions(self: ChessBoard, color: PieceColor): seq[Location] =
result.add(location)
proc generatePawnMoves(self: ChessBoard, location: Location): seq[Move] =
proc generatePawnMoves(self: ChessBoard, location: Location, moveList: var MoveList) =
## Generates the possible moves for the pawn in the given
## location
var
@ -911,12 +938,12 @@ proc generatePawnMoves(self: ChessBoard, location: Location): seq[Move] =
if target.row == piece.color.getLastRow():
# Pawn reached the other side of the board: generate all potential piece promotions
for promotionType in [PromoteToKnight, PromoteToBishop, PromoteToRook, PromoteToQueen]:
result.add(Move(startSquare: location, targetSquare: target, flags: promotionType.uint16 or flags))
moveList.add(Move(startSquare: location, targetSquare: target, flags: promotionType.uint16 or flags))
continue
result.add(Move(startSquare: location, targetSquare: target, flags: flags))
moveList.add(Move(startSquare: location, targetSquare: target, flags: flags))
proc generateSlidingMoves(self: ChessBoard, location: Location): seq[Move] =
proc generateSlidingMoves(self: ChessBoard, location: Location, moves: var MoveList) =
## Generates moves for the sliding piece in the given location
let piece = self.grid[location.row, location.col]
assert piece.kind in [Bishop, Rook, Queen], &"generateSlidingMoves called on a {piece.kind}"
@ -972,13 +999,13 @@ proc generateSlidingMoves(self: ChessBoard, location: Location): seq[Move] =
# it and stop going any further
if otherPiece.kind != King:
# Can't capture the king
result.add(Move(startSquare: location, targetSquare: square, flags: Capture.uint16))
moves.add(Move(startSquare: location, targetSquare: square, flags: Capture.uint16))
break
# Target square is empty, keep going
result.add(Move(startSquare: location, targetSquare: square))
moves.add(Move(startSquare: location, targetSquare: square))
proc generateKingMoves(self: ChessBoard, location: Location): seq[Move] =
proc generateKingMoves(self: ChessBoard, location: Location, moves: var MoveList) =
## Generates moves for the king in the given location
var
piece = self.grid[location.row, location.col]
@ -1020,10 +1047,10 @@ proc generateKingMoves(self: ChessBoard, location: Location): seq[Move] =
continue
# Target square is empty or contains an enemy piece:
# All good for us!
result.add(Move(startSquare: location, targetSquare: square, flags: flag.uint16))
moves.add(Move(startSquare: location, targetSquare: square, flags: flag.uint16))
proc generateKnightMoves(self: ChessBoard, location: Location): seq[Move] =
proc generateKnightMoves(self: ChessBoard, location: Location, moves: var MoveList) =
## Generates moves for the knight in the given location
var
piece = self.grid[location.row, location.col]
@ -1039,7 +1066,7 @@ proc generateKnightMoves(self: ChessBoard, location: Location): seq[Move] =
let pinned = self.getPinnedDirections(location)
if pinned.len() > 0:
# Knight is pinned: can't move!
return @[]
return
let checked = self.inCheck()
let resolutions = if not checked: @[] else: self.getCheckResolutions(piece.color)
for direction in directions:
@ -1057,37 +1084,38 @@ proc generateKnightMoves(self: ChessBoard, location: Location): seq[Move] =
if otherPiece.color != None:
# Target square contains an enemy piece: capture
# it
result.add(Move(startSquare: location, targetSquare: square, flags: Capture.uint16))
moves.add(Move(startSquare: location, targetSquare: square, flags: Capture.uint16))
else:
# Target square is empty
result.add(Move(startSquare: location, targetSquare: square))
moves.add(Move(startSquare: location, targetSquare: square))
proc generateMoves(self: ChessBoard, location: Location): seq[Move] =
proc generateMoves(self: ChessBoard, location: Location, moves: var MoveList) =
## Returns the list of possible legal chess moves for the
## piece in the given location
let piece = self.grid[location.row, location.col]
case piece.kind:
of Queen, Bishop, Rook:
return self.generateSlidingMoves(location)
self.generateSlidingMoves(location, moves)
of Pawn:
return self.generatePawnMoves(location)
self.generatePawnMoves(location, moves)
of King:
return self.generateKingMoves(location)
self.generateKingMoves(location, moves)
of Knight:
return self.generateKnightMoves(location)
self.generateKnightMoves(location, moves)
else:
return @[]
discard
proc generateAllMoves*(self: ChessBoard): seq[Move] =
proc generateAllMoves*(self: ChessBoard): MoveList =
## Returns the list of all possible legal moves
## in the current position
var data: array[218, Move]
result = MoveList(len: 0, data: data)
for i in 0..7:
for j in 0..7:
if self.grid[i, j].color == self.getActiveColor():
for move in self.generateMoves((int8(i), int8(j))):
result.add(move)
self.generateMoves((int8(i), int8(j)), result)
proc isAttacked*(self: ChessBoard, loc: Location, color: PieceColor = None): bool =
@ -1716,7 +1744,9 @@ proc undoLastMove*(self: ChessBoard) =
proc isLegal(self: ChessBoard, move: Move): bool {.inline.} =
## Returns whether the given move is legal
return move in self.generateMoves(move.startSquare)
var moves = MoveList()
self.generateMoves(move.startSquare, moves)
return move in moves
proc makeMove*(self: ChessBoard, move: Move): Move {.discardable.} =
@ -1999,11 +2029,13 @@ proc handleGoCommand(board: ChessBoard, command: seq[string]) =
if bulk:
let t = cpuTime()
let nodes = board.perft(ply, divide=true, bulk=true, verbose=verbose).nodes
let tot = cpuTime() - t
echo &"\nNodes searched (bulk-counting: on): {nodes}"
echo &"Time taken: {round(cpuTime() - t, 3)} seconds\n"
echo &"Time taken: {round(tot, 3)} seconds\nNodes per second: {round(nodes / tot).uint64}"
else:
let t = cpuTime()
let data = board.perft(ply, divide=true, verbose=verbose)
let tot = cpuTime() - t
echo &"\nNodes searched (bulk-counting: off): {data.nodes}"
echo &" - Captures: {data.captures}"
echo &" - Checks: {data.checks}"
@ -2012,7 +2044,7 @@ proc handleGoCommand(board: ChessBoard, command: seq[string]) =
echo &" - Castles: {data.castles}"
echo &" - Promotions: {data.promotions}"
echo ""
echo &"Time taken: {round(cpuTime() - t, 3)} seconds"
echo &"Time taken: {tot} seconds\nNodes per second: {round(data.nodes / tot).uint64}"
except ValueError:
echo "Error: go: perft: invalid depth"
else:

View File

@ -26,7 +26,7 @@ def main(args: Namespace) -> int:
stop = timeit.default_timer()
print(f"\r[S] Ran {len(positions)} tests at depth {args.ply} in {stop - start:.2f} seconds ({len(successful)} successful, {len(failed)} failed)\033[K")
if failed and args.show_failures:
print("[S] The following FENs failed to pass the test:", end="")
print("[S] The following FENs failed to pass the test:", end="\n\t")
print("\n\t".join(failed))