Hopefully fix pondering bugs

This commit is contained in:
Mattia Giambirtone 2024-05-13 17:11:24 +02:00
parent 6adae84bca
commit d9e183caf3
Signed by: nocturn9x
GPG Key ID: 37B83AB6C3BE6514
4 changed files with 63 additions and 71 deletions

View File

@ -110,12 +110,7 @@ func getOccupancy*(self: Position): Bitboard {.inline.} =
proc getPawnAttacks*(self: Position, square: Square, attacker: PieceColor): Bitboard {.inline.} = proc getPawnAttacks*(self: Position, square: Square, attacker: PieceColor): Bitboard {.inline.} =
## Returns the locations of the pawns attacking the given square ## Returns the locations of the pawns attacking the given square
try: return self.getBitboard(Pawn, attacker) and getPawnAttacks(attacker, square)
return self.getBitboard(Pawn, attacker) and getPawnAttacks(attacker, square)
except IndexDefect:
echo square
echo square.int
echo square.toBitboard()
proc getKingAttacks*(self: Position, square: Square, attacker: PieceColor): Bitboard {.inline.} = proc getKingAttacks*(self: Position, square: Square, attacker: PieceColor): Bitboard {.inline.} =

View File

@ -107,9 +107,10 @@ type
SearchManager* = object SearchManager* = object
## A simple state storage ## A simple state storage
## for our search ## for our search
searchFlag: ptr Atomic[bool] searching: Atomic[bool]
stopFlag: ptr Atomic[bool] stop: Atomic[bool]
board: Chessboard pondering: Atomic[bool]
board*: Chessboard
bestRootScore: Score bestRootScore: Score
searchStart: MonoTime searchStart: MonoTime
hardLimit: MonoTime hardLimit: MonoTime
@ -126,29 +127,29 @@ type
selectiveDepth: int selectiveDepth: int
proc newSearchManager*(position: Position, positions: seq[Position], transpositions: ptr TTable, stopFlag, searchFlag: ptr Atomic[bool], proc newSearchManager*(position: Position, positions: seq[Position], transpositions: ptr TTable,
history: ptr HistoryTable, killers: ptr KillersTable): SearchManager = history: ptr HistoryTable, killers: ptr KillersTable): SearchManager =
var board = newChessboard() result = SearchManager(board: newChessboard(), transpositionTable: transpositions, stop: Atomic[bool](),
board.position = position searching: Atomic[bool](), pondering: Atomic[bool](), history: history,
board.positions = positions killers: killers)
result = SearchManager(board: board, transpositionTable: transpositions, stopFlag: stopFlag, result.board.position = position
searchFlag: searchFlag, history: history, killers: killers) result.board.positions = positions
for i in 0..MAX_DEPTH: for i in 0..MAX_DEPTH:
for j in 0..MAX_DEPTH: for j in 0..MAX_DEPTH:
result.pvMoves[i][j] = nullMove() result.pvMoves[i][j] = nullMove()
proc isSearching*(self: SearchManager): bool = proc isSearching*(self: var SearchManager): bool =
## Returns whether a search for the best ## Returns whether a search for the best
## move is in progress ## move is in progress
result = self.searchFlag[].load() result = self.searching.load()
proc stop*(self: SearchManager) = proc stop*(self: var SearchManager) =
## Stops the search if it is ## Stops the search if it is
## running ## running
if self.isSearching(): if self.isSearching():
self.stopFlag[].store(true) self.stop.store(true)
proc isKillerMove(self: SearchManager, move: Move, ply: int): bool = proc isKillerMove(self: SearchManager, move: Move, ply: int): bool =
@ -246,7 +247,9 @@ proc reorderMoves(self: SearchManager, moves: var MoveList, ply: int) =
proc timedOut(self: SearchManager): bool = getMonoTime() >= self.hardLimit proc timedOut(self: SearchManager): bool = getMonoTime() >= self.hardLimit
proc cancelled(self: SearchManager): bool = self.stopFlag[].load() proc isPondering*(self: var SearchManager): bool = self.pondering.load()
proc stopPondering*(self: var SearchManager) = self.pondering.store(false)
proc cancelled(self: var SearchManager): bool = self.stop.load()
proc elapsedTime(self: SearchManager): int64 = (getMonoTime() - self.searchStart).inMilliseconds() proc elapsedTime(self: SearchManager): int64 = (getMonoTime() - self.searchStart).inMilliseconds()
@ -272,13 +275,13 @@ proc log(self: SearchManager, depth: int) =
echo logMsg echo logMsg
proc shouldStop(self: SearchManager): bool = proc shouldStop(self: var SearchManager): bool =
## Returns whether searching should ## Returns whether searching should
## stop ## stop
if self.cancelled(): if self.cancelled():
# Search has been cancelled! # Search has been cancelled!
return true return true
if self.timedOut(): if self.timedOut() and not self.isPondering():
# We ran out of time! # We ran out of time!
return true return true
if self.maxNodes > 0 and self.nodeCount >= self.maxNodes: if self.maxNodes > 0 and self.nodeCount >= self.maxNodes:
@ -644,6 +647,7 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
result = @[] result = @[]
var pv: array[256, Move] var pv: array[256, Move]
self.maxNodes = maxNodes self.maxNodes = maxNodes
self.pondering.store(ponder)
self.searchMoves = searchMoves self.searchMoves = searchMoves
self.searchStart = getMonoTime() self.searchStart = getMonoTime()
self.hardLimit = self.searchStart + initDuration(milliseconds=maxSearchTime) self.hardLimit = self.searchStart + initDuration(milliseconds=maxSearchTime)
@ -651,7 +655,7 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
var maxDepth = maxDepth var maxDepth = maxDepth
if maxDepth == -1: if maxDepth == -1:
maxDepth = 30 maxDepth = 30
self.searchFlag[].store(true) self.searching.store(true)
# Iterative deepening loop # Iterative deepening loop
var score = Score(0) var score = Score(0)
for depth in 1..min(MAX_DEPTH, maxDepth): for depth in 1..min(MAX_DEPTH, maxDepth):
@ -668,10 +672,10 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
# Soft time management: don't start a new search iteration # Soft time management: don't start a new search iteration
# if the soft limit has expired, as it is unlikely to complete # if the soft limit has expired, as it is unlikely to complete
# anyway # anyway
if getMonoTime() >= self.softLimit: if getMonoTime() >= self.softLimit and not self.isPondering():
break break
self.searchFlag[].store(false) self.searching.store(false)
self.stopFlag[].store(false) self.stop.store(false)
for move in pv: for move in pv:
if move == nullMove(): if move == nullMove():
break break

View File

@ -81,7 +81,7 @@ proc newTranspositionTable*(size: uint64): TTable =
result.size = numEntries result.size = numEntries
func clear*(self: var TTable) = func clear*(self: var TTable) {.inline.} =
## Clears the transposition table ## Clears the transposition table
## without releasing the memory ## without releasing the memory
## associated with it ## associated with it
@ -97,13 +97,6 @@ func resize*(self: var TTable, newSize: uint64) =
self.size = numEntries self.size = numEntries
func destroy*(self: var TTable) =
## Permanently and irreversibly
## destroys the transposition table
self.data = @[]
self.size = 0
func getIndex(self: TTable, key: ZobristKey): uint64 = func getIndex(self: TTable, key: ZobristKey): uint64 =
## Retrieves the index of the given ## Retrieves the index of the given
## zobrist key in our transposition table ## zobrist key in our transposition table

View File

@ -15,7 +15,6 @@
## Implementation of a UCI compatible server ## Implementation of a UCI compatible server
import std/strutils import std/strutils
import std/strformat import std/strformat
import std/atomics
import board import board
@ -38,6 +37,7 @@ type
## doesn't like sharing references across thread (despite ## doesn't like sharing references across thread (despite
## the fact that it should be safe to do so) ## the fact that it should be safe to do so)
searchState: ptr SearchManager searchState: ptr SearchManager
printMove: ptr bool
# Size of the transposition table (in megabytes) # Size of the transposition table (in megabytes)
hashTableSize: uint64 hashTableSize: uint64
# # Atomic boolean flag to interrupt the search # # Atomic boolean flag to interrupt the search
@ -173,6 +173,8 @@ proc handleUCIGoCommand(session: UCISession, command: seq[string]): UCICommand =
of "infinite": of "infinite":
result.wtime = int32.high() result.wtime = int32.high()
result.btime = int32.high() result.btime = int32.high()
of "ponder":
result.ponder = true
of "wtime": of "wtime":
result.wtime = command[current].parseInt() result.wtime = command[current].parseInt()
of "btime": of "btime":
@ -329,7 +331,6 @@ proc bestMove(args: tuple[session: UCISession, command: UCICommand]) {.thread.}
board.position = session.position board.position = session.position
board.positions = session.history board.positions = session.history
let command = args.command let command = args.command
var searcher = session.searchState[]
var var
timeRemaining = (if session.position.sideToMove == White: command.wtime else: command.btime) timeRemaining = (if session.position.sideToMove == White: command.wtime else: command.btime)
increment = (if session.position.sideToMove == White: command.winc else: command.binc) increment = (if session.position.sideToMove == White: command.winc else: command.binc)
@ -339,11 +340,12 @@ proc bestMove(args: tuple[session: UCISession, command: UCICommand]) {.thread.}
increment = 0 increment = 0
elif timeRemaining == 0: elif timeRemaining == 0:
timeRemaining = int32.high() timeRemaining = int32.high()
var line = searcher.findBestLine(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove, command.ponder) var line = session.searchState[].findBestLine(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove, command.ponder)
if line.len() == 1: if session.printMove[]:
if line.len() == 1:
echo &"bestmove {line[0].toAlgebraic()}" echo &"bestmove {line[0].toAlgebraic()}"
else: else:
echo &"bestmove {line[0].toAlgebraic()} ponder {line[1].toAlgebraic()}" echo &"bestmove {line[0].toAlgebraic()} ponder {line[1].toAlgebraic()}"
proc startUCISession* = proc startUCISession* =
@ -359,14 +361,13 @@ proc startUCISession* =
# God forbid we try to use atomic ARC like it was intended. Raw pointers # God forbid we try to use atomic ARC like it was intended. Raw pointers
# it is then... sigh # it is then... sigh
var var
stopFlag = cast[ptr Atomic[bool]](alloc0(sizeof(Atomic[bool])))
searchFlag = cast[ptr Atomic[bool]](alloc0(sizeof(Atomic[bool])))
transpositionTable = cast[ptr TTable](alloc0(sizeof(TTable))) transpositionTable = cast[ptr TTable](alloc0(sizeof(TTable)))
historyTable = cast[ptr HistoryTable](alloc0(sizeof(HistoryTable))) historyTable = cast[ptr HistoryTable](alloc0(sizeof(HistoryTable)))
killerMoves = cast[ptr KillersTable](alloc0(sizeof(KillersTable))) killerMoves = cast[ptr KillersTable](alloc0(sizeof(KillersTable)))
transpositionTable[] = newTranspositionTable(session.hashTableSize * 1024 * 1024) transpositionTable[] = newTranspositionTable(session.hashTableSize * 1024 * 1024)
session.searchState = cast[ptr SearchManager](alloc0(sizeof(SearchManager))) session.searchState = cast[ptr SearchManager](alloc0(sizeof(SearchManager)))
session.searchState[] = newSearchManager(session.position, session.history, transpositionTable, stopFlag, searchFlag, historyTable, killerMoves) session.searchState[] = newSearchManager(session.position, session.history, transpositionTable, historyTable, killerMoves)
session.printMove = cast[ptr bool](alloc0(sizeof(bool)))
# Initialize history table # Initialize history table
for color in PieceColor.White..PieceColor.Black: for color in PieceColor.White..PieceColor.Black:
for i in Square(0)..Square(63): for i in Square(0)..Square(63):
@ -408,14 +409,9 @@ proc startUCISession* =
of Debug: of Debug:
session.debug = cmd.on session.debug = cmd.on
of NewGame: of NewGame:
if transpositionTable[].size() == 0: if session.debug:
if session.debug: echo &"info string clearing out TT of size {session.hashTableSize} MiB"
echo &"info string allocating new TT of size {session.hashTableSize} MiB" transpositionTable[].clear()
transpositionTable[] = newTranspositionTable(session.hashTableSize * 1024 * 1024)
else:
if session.debug:
echo &"info string clearing out TT of size {session.hashTableSize} MiB"
transpositionTable[].clear()
# Re-Initialize history table # Re-Initialize history table
for color in PieceColor.White..PieceColor.Black: for color in PieceColor.White..PieceColor.Black:
for i in Square(0)..Square(63): for i in Square(0)..Square(63):
@ -430,24 +426,25 @@ proc startUCISession* =
echo "info string ponder move has ben hit" echo "info string ponder move has ben hit"
if not session.searchState[].isSearching(): if not session.searchState[].isSearching():
continue continue
joinThread(searchThread) session.searchState[].stopPondering()
if session.debug: if session.debug:
echo "info string ponder search stopped" echo "info string switched to normal search"
createThread(searchThread, bestMove, (session, cmd))
if session.debug:
echo "info string search started"
of Go: of Go:
when not defined(historyPenalty): session.printMove[] = true
# Scale our history coefficients if not cmd.ponder and session.searchState[].isPondering():
for color in PieceColor.White..PieceColor.Black: session.searchState[].stopPondering()
for source in Square(0)..Square(63): else:
for target in Square(0)..Square(63): when not defined(historyPenalty):
historyTable[color][source][target] = historyTable[color][source][target] div 2 # Scale our history coefficients
if searchThread.running: for color in PieceColor.White..PieceColor.Black:
joinThread(searchThread) for source in Square(0)..Square(63):
createThread(searchThread, bestMove, (session, cmd)) for target in Square(0)..Square(63):
if session.debug: historyTable[color][source][target] = historyTable[color][source][target] div 2
echo "info string search started" if searchThread.running:
joinThread(searchThread)
createThread(searchThread, bestMove, (session, cmd))
if session.debug:
echo "info string search started"
of Stop: of Stop:
if not session.searchState[].isSearching(): if not session.searchState[].isSearching():
continue continue
@ -474,10 +471,13 @@ proc startUCISession* =
else: else:
discard discard
of Position: of Position:
# Due to the way the whole thing is designed, the if session.searchState[].isPondering():
# position is actually set when the command is parsed session.printMove[] = false
# rather than when it is processed here session.searchState[].stopPondering()
discard session.searchState[].stop()
joinThread(searchThread)
session.searchState[].board.position = session.position
session.searchState[].board.positions = session.history
else: else:
discard discard
except IOError: except IOError: