Fix issues with joinThread and fix nps scaling issues by not using atomic node counters

This commit is contained in:
Mattia Giambirtone 2024-05-14 11:20:03 +02:00
parent e54fc56925
commit 5b3c244206
Signed by: nocturn9x
GPG Key ID: B6025DD9B4458B69
2 changed files with 94 additions and 67 deletions

View File

@ -115,8 +115,8 @@ type
searchStart: MonoTime
hardLimit: MonoTime
softLimit: MonoTime
nodeCount: ptr Atomic[uint64]
maxNodes: ptr Atomic[uint64]
nodeCount: uint64
maxNodes: uint64
searchMoves: seq[Move]
transpositionTable: ptr TTable
history: ptr HistoryTable
@ -125,8 +125,14 @@ type
# We keep one extra entry so we don't need any special casing
# inside the search function when constructing pv lines
pvMoves: array[MAX_DEPTH + 1, array[MAX_DEPTH + 1, Move]]
# The highest depth we explored to, including extensions
selectiveDepth: int
# Are we the main worker?
isMainWorker: bool
# We keep track of all the worker
# threads' respective search states
# to collct statistics efficiently
children: seq[ptr SearchManager]
proc newSearchManager*(position: Position, positions: seq[Position], transpositions: ptr TTable,
@ -138,19 +144,15 @@ proc newSearchManager*(position: Position, positions: seq[Position], transpositi
searchFlag: ptr Atomic[bool]
stopFlag: ptr Atomic[bool]
ponderFlag: ptr Atomic[bool]
nodeCounter: ptr Atomic[uint64]
maxNodes: ptr Atomic[uint64]
if mainWorker:
searchFlag = create(Atomic[bool], sizeof(Atomic[bool]))
stopFlag = create(Atomic[bool], sizeof(Atomic[bool]))
ponderFlag = create(Atomic[bool], sizeof(Atomic[bool]))
nodeCounter = create(Atomic[uint64], sizeof(Atomic[uint64]))
maxNodes = create(Atomic[uint64], sizeof(Atomic[uint64]))
# If we're not the main worker, we expect the shared atomic metadata to be filled in by the
# main worker
result = SearchManager(board: newChessboard(), transpositionTable: transpositions, stop: stopFlag,
searching: searchFlag, pondering: ponderFlag, history: history, nodeCount: nodeCounter,
maxNodes: maxNodes, killers: killers, isMainWorker: mainWorker)
searching: searchFlag, pondering: ponderFlag, history: history,
killers: killers, isMainWorker: mainWorker)
result.board.position = position
result.board.positions = positions
for i in 0..MAX_DEPTH:
@ -168,8 +170,6 @@ proc `destroy=`*(self: var SearchManager) =
dealloc(self.stop)
dealloc(self.searching)
dealloc(self.pondering)
dealloc(self.maxNodes)
dealloc(self.nodeCount)
else:
# This state is thread-local and is fine to
# destroy *unless* we're the main worker. This
@ -307,16 +307,23 @@ proc stopPondering*(self: var SearchManager) =
proc log(self: var SearchManager, depth: int) =
if not self.isMainWorker:
# We restrict logging to the main worker. Since
# all important state is shared across threads using
# atomics, the statistics will still be correct (maybe
# out of date, but correct)
# We restrict logging to the main worker to reduce
# noise
return
# Using an atomic for such frequently updated counters kills
# performance and cripples nps scaling, so instead we let each
# thread have its own local counters and then aggregate the results
# here
var
nodeCount = self.nodeCount
selDepth = self.selectiveDepth
for child in self.children:
nodeCount += child.nodeCount
selDepth = max(selDepth, child.selectiveDepth)
let
nodeCount = self.nodeCount[].load()
elapsedMsec = self.elapsedTime().uint64
nps = 1000 * (nodeCount div max(elapsedMsec, 1))
var logMsg = &"info depth {depth} seldepth {self.selectiveDepth} time {elapsedMsec} nodes {nodeCount} nps {nps}"
var logMsg = &"info depth {depth} seldepth {selDepth} time {elapsedMsec} nodes {nodeCount} nps {nps}"
logMsg &= &" hashfull {self.transpositionTable[].getFillEstimate()}"
if abs(self.bestRootScore) >= mateScore() - MAX_DEPTH:
if self.bestRootScore > 0:
@ -343,10 +350,7 @@ proc shouldStop(self: var SearchManager): bool =
if self.timedOut() and not self.isPondering():
# We ran out of time!
return true
let
nodeCount = self.nodeCount[].load()
maxNodes = self.maxNodes[].load()
if maxNodes > 0 and nodeCount >= maxNodes:
if self.maxNodes > 0 and self.nodeCount >= self.maxNodes:
# Ran out of nodes
return true
@ -407,7 +411,7 @@ proc qsearch(self: var SearchManager, ply: int, alpha, beta: Score): Score =
if self.board.position.see(move) < 0:
continue
self.board.doMove(move)
self.nodeCount[].atomicInc()
inc(self.nodeCount)
let score = -self.qsearch(ply + 1, -beta, -alpha)
self.board.unmakeMove()
bestScore = max(score, bestScore)
@ -565,7 +569,7 @@ proc search(self: var SearchManager, depth, ply: int, alpha, beta: Score, isPV:
let
extension = self.getSearchExtension(move)
reduction = self.getReduction(move, depth, ply, i, isPV)
self.nodeCount[].atomicInc()
inc(self.nodeCount)
# Find the best move for us (worst move
# for our opponent, hence the negative sign)
var score: Score
@ -695,25 +699,9 @@ proc aspirationWindow(self: var SearchManager, score: Score, depth: int): Score
delta = highestEval()
proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
proc findBestLine(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove=false, ponder=false): seq[Move] =
## Finds the principal variation in the current position
## and returns it, limiting search time according the
## the remaining time and increment values provided (in
## milliseconds) and only up to maxDepth ply (if maxDepth
## is -1, a reasonable limit is picked). If maxNodes is supplied
## and is nonzero, search will stop once it has analyzed maxNodes
## nodes. If searchMoves is provided and is not empty, search will
## be restricted to the moves in the list. Note that regardless of
## any time limitations or explicit cancellations, the search will
## not stop until it has at least cleared depth one. Search depth
## is always constrained to at most MAX_DEPTH ply from the root. If
## timePerMove is true, the increment is assumed to be zero and the
## remaining time is considered the time limit for the entire search
## (note that soft time management is disabled in that case). If ponder
## is true, the search is performed in pondering mode (i.e. no explicit
## time limit) and can be switched to a regular search by calling the
## stopPondering() procedure
## Internal, singl-threaded search for the principal variation
# Apparently negative remaining time is a thing. Welp
self.maxSearchTime = if not timePerMove: max(1, (timeRemaining div 10) + ((increment div 3) * 2)) else: timeRemaining
@ -721,8 +709,8 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
result = @[]
var pv: array[256, Move]
if self.isMainWorker:
self.maxNodes[].store(maxNodes)
self.pondering[].store(ponder)
self.maxNodes = maxNodes
self.searchMoves = searchMoves
self.searchStart = getMonoTime()
self.hardLimit = self.searchStart + initDuration(milliseconds=self.maxSearchTime)
@ -730,8 +718,6 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
var maxDepth = maxDepth
if maxDepth == -1:
maxDepth = 60
if self.isMainWorker:
self.searching[].store(true)
# Iterative deepening loop
var score = Score(0)
for depth in 1..min(MAX_DEPTH, maxDepth):
@ -750,53 +736,94 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
# anyway
if getMonoTime() >= self.softLimit and not self.isPondering():
break
if self.isMainWorker:
self.searching[].store(false)
self.stop[].store(false)
for move in pv:
if move == nullMove():
break
result.add(move)
proc workerFunc(args: tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
proc workerFunc(args: tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]) {.thread.} =
## Worker that calls findBestLine in a new thread
# Gotta lie to nim's thread analyzer lest it shout at us that we're not
# GC safe!
{.cast(gcsafe).}:
var self = args.self
discard self.findBestLine(args.timeRemaining, args.increment, args.maxDepth, args.maxNodes, args.searchMoves, args.timePerMove, args.ponder)
discard args.self[].findBestLine(args.timeRemaining, args.increment, args.maxDepth, args.maxNodes, args.searchMoves, args.timePerMove, args.ponder)
# Creating threads is expensive, so there's no need to make new ones for every call
# to our parallel search. Also, nim leaks thread vars: this keeps the resource leaks
# to a minimum
var workers: seq[ref Thread[tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
var workers: seq[ref Thread[tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]]] = @[]
proc parallelSearch*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
proc search*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove=false, ponder=false, numWorkers: int): seq[Move] =
## Parallel version of findBestLine(): the search is performed
## using the provided number of worker threads using a shared
## transposition table.
## Finds the principal variation in the current position
## and returns it, limiting search time according the
## the remaining time and increment values provided (in
## milliseconds) and only up to maxDepth ply (if maxDepth
## is -1, a reasonable limit is picked). If maxNodes is supplied
## and is nonzero, search will stop once it has analyzed maxNodes
## nodes. If searchMoves is provided and is not empty, search will
## be restricted to the moves in the list. Note that regardless of
## any time limitations or explicit cancellations, the search will
## not stop until it has at least cleared depth one. Search depth
## is always constrained to at most MAX_DEPTH ply from the root. If
## timePerMove is true, the increment is assumed to be zero and the
## remaining time is considered the time limit for the entire search
## (note that soft time management is disabled in that case). If ponder
## is true, the search is performed in pondering mode (i.e. no explicit
## time limit) and can be switched to a regular search by calling the
## stopPondering() procedure. If numWorkers is > 1, the search is performed
## in parallel using numWorkers threads
while workers.len() + 1 < numWorkers:
# We create n - 1 workers because we'll also be searching
# ourselves. We use the lazy SMP approach, so we'll exploit the
# other threads just to fill up our transposition table and
# not much else (for now)
workers.add(new Thread[tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
workers.add(new Thread[tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]])
self.searching[].store(true)
for i in 0..<numWorkers - 1:
# Create a new search manager to send off to a worker thread
var localSearcher = newSearchManager(self.board.position, self.board.positions, self.transpositionTable, self.history, self.killers, false)
# Copy the history and killers table, as those are meant to be thread-local
var
history = create(HistoryTable, sizeof(HistoryTable))
killers = create(KillersTable, sizeof(KillersTable))
# Copy in the data
for color in PieceColor.White..PieceColor.Black:
for i in Square(0)..Square(63):
for j in Square(0)..Square(63):
history[color][i][j] = self.history[color][i][j]
for i in 0..<MAX_DEPTH:
for j in 0..<NUM_KILLERS:
killers[i][j] = self.killers[i][j]
# Create a new search manager to send off to a worker thread. We store it
# on the heap because we need to access its state from elsewhere for collecting
# statistics
self.children.add(create(SearchManager, sizeof(SearchManager)))
self.children[i][] = newSearchManager(self.board.position, self.board.positions, self.transpositionTable, history, killers, false)
# Fill in our shared atomic metadata
localSearcher.stop = self.stop
localSearcher.pondering = self.pondering
localSearcher.searching = self.searching
localSearcher.nodeCount = self.nodeCount
localSearcher.maxNodes = self.maxNodes
self.children[i].stop = self.stop
self.children[i].pondering = self.pondering
self.children[i].searching = self.searching
# Off you go, you little search minion!
createThread(workers[i][], workerFunc, (localSearcher, timeRemaining, increment, maxDepth, maxNodes, searchMoves, timePerMove, ponder))
result = self.findBestLine(timeRemaining, increment, maxDepth, maxNodes, searchMoves, timePerMove, ponder)
# No need to wait for the threads, they'll finish alongside us anyway
createThread(workers[i][], workerFunc, (self.children[i], timeRemaining, increment, maxDepth, maxNodes div numWorkers.uint64, searchMoves, timePerMove, ponder))
# We divide maxNodes by the number of workers so that even when searching in parallel, no more than maxNodes nodes
# are searched
result = self.findBestLine(timeRemaining, increment, maxDepth, maxNodes div numWorkers.uint64, searchMoves, timePerMove, ponder)
# Wait for all search threads to finish. This isn't technically
# necessary, but it's good practice and will catch bugs in our
# "atomic stop" system
for i in 0..<numWorkers - 1:
if workers[i][].running:
joinThread(workers[i][])
# If we set the atomics any earlier than this, our
# search threads would never stop!
self.searching[].store(false)
self.stop[].store(false)
# Ensure local searchers get destroyed
for child in self.children:
child[].`destroy=`()
dealloc(child)
self.children.setLen(0)

View File

@ -331,8 +331,8 @@ proc bestMove(args: tuple[session: UCISession, command: UCICommand]) {.thread.}
increment = 0
elif timeRemaining == 0:
timeRemaining = int32.high()
var line = session.searchState[].parallelSearch(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove,
command.ponder, session.workers)
var line = session.searchState[].search(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove,
command.ponder, session.workers)
if session.printMove[]:
if line.len() == 1:
echo &"bestmove {line[0].toAlgebraic()}"