Fix issues with joinThread and fix nps scaling issues by not using atomic node counters

This commit is contained in:
Mattia Giambirtone 2024-05-14 11:20:03 +02:00
parent e54fc56925
commit 5b3c244206
Signed by: nocturn9x
GPG Key ID: B6025DD9B4458B69
2 changed files with 94 additions and 67 deletions

View File

@ -115,8 +115,8 @@ type
searchStart: MonoTime searchStart: MonoTime
hardLimit: MonoTime hardLimit: MonoTime
softLimit: MonoTime softLimit: MonoTime
nodeCount: ptr Atomic[uint64] nodeCount: uint64
maxNodes: ptr Atomic[uint64] maxNodes: uint64
searchMoves: seq[Move] searchMoves: seq[Move]
transpositionTable: ptr TTable transpositionTable: ptr TTable
history: ptr HistoryTable history: ptr HistoryTable
@ -125,8 +125,14 @@ type
# We keep one extra entry so we don't need any special casing # We keep one extra entry so we don't need any special casing
# inside the search function when constructing pv lines # inside the search function when constructing pv lines
pvMoves: array[MAX_DEPTH + 1, array[MAX_DEPTH + 1, Move]] pvMoves: array[MAX_DEPTH + 1, array[MAX_DEPTH + 1, Move]]
# The highest depth we explored to, including extensions
selectiveDepth: int selectiveDepth: int
# Are we the main worker?
isMainWorker: bool isMainWorker: bool
# We keep track of all the worker
# threads' respective search states
# to collct statistics efficiently
children: seq[ptr SearchManager]
proc newSearchManager*(position: Position, positions: seq[Position], transpositions: ptr TTable, proc newSearchManager*(position: Position, positions: seq[Position], transpositions: ptr TTable,
@ -138,19 +144,15 @@ proc newSearchManager*(position: Position, positions: seq[Position], transpositi
searchFlag: ptr Atomic[bool] searchFlag: ptr Atomic[bool]
stopFlag: ptr Atomic[bool] stopFlag: ptr Atomic[bool]
ponderFlag: ptr Atomic[bool] ponderFlag: ptr Atomic[bool]
nodeCounter: ptr Atomic[uint64]
maxNodes: ptr Atomic[uint64]
if mainWorker: if mainWorker:
searchFlag = create(Atomic[bool], sizeof(Atomic[bool])) searchFlag = create(Atomic[bool], sizeof(Atomic[bool]))
stopFlag = create(Atomic[bool], sizeof(Atomic[bool])) stopFlag = create(Atomic[bool], sizeof(Atomic[bool]))
ponderFlag = create(Atomic[bool], sizeof(Atomic[bool])) ponderFlag = create(Atomic[bool], sizeof(Atomic[bool]))
nodeCounter = create(Atomic[uint64], sizeof(Atomic[uint64]))
maxNodes = create(Atomic[uint64], sizeof(Atomic[uint64]))
# If we're not the main worker, we expect the shared atomic metadata to be filled in by the # If we're not the main worker, we expect the shared atomic metadata to be filled in by the
# main worker # main worker
result = SearchManager(board: newChessboard(), transpositionTable: transpositions, stop: stopFlag, result = SearchManager(board: newChessboard(), transpositionTable: transpositions, stop: stopFlag,
searching: searchFlag, pondering: ponderFlag, history: history, nodeCount: nodeCounter, searching: searchFlag, pondering: ponderFlag, history: history,
maxNodes: maxNodes, killers: killers, isMainWorker: mainWorker) killers: killers, isMainWorker: mainWorker)
result.board.position = position result.board.position = position
result.board.positions = positions result.board.positions = positions
for i in 0..MAX_DEPTH: for i in 0..MAX_DEPTH:
@ -168,8 +170,6 @@ proc `destroy=`*(self: var SearchManager) =
dealloc(self.stop) dealloc(self.stop)
dealloc(self.searching) dealloc(self.searching)
dealloc(self.pondering) dealloc(self.pondering)
dealloc(self.maxNodes)
dealloc(self.nodeCount)
else: else:
# This state is thread-local and is fine to # This state is thread-local and is fine to
# destroy *unless* we're the main worker. This # destroy *unless* we're the main worker. This
@ -307,16 +307,23 @@ proc stopPondering*(self: var SearchManager) =
proc log(self: var SearchManager, depth: int) = proc log(self: var SearchManager, depth: int) =
if not self.isMainWorker: if not self.isMainWorker:
# We restrict logging to the main worker. Since # We restrict logging to the main worker to reduce
# all important state is shared across threads using # noise
# atomics, the statistics will still be correct (maybe
# out of date, but correct)
return return
# Using an atomic for such frequently updated counters kills
# performance and cripples nps scaling, so instead we let each
# thread have its own local counters and then aggregate the results
# here
var
nodeCount = self.nodeCount
selDepth = self.selectiveDepth
for child in self.children:
nodeCount += child.nodeCount
selDepth = max(selDepth, child.selectiveDepth)
let let
nodeCount = self.nodeCount[].load()
elapsedMsec = self.elapsedTime().uint64 elapsedMsec = self.elapsedTime().uint64
nps = 1000 * (nodeCount div max(elapsedMsec, 1)) nps = 1000 * (nodeCount div max(elapsedMsec, 1))
var logMsg = &"info depth {depth} seldepth {self.selectiveDepth} time {elapsedMsec} nodes {nodeCount} nps {nps}" var logMsg = &"info depth {depth} seldepth {selDepth} time {elapsedMsec} nodes {nodeCount} nps {nps}"
logMsg &= &" hashfull {self.transpositionTable[].getFillEstimate()}" logMsg &= &" hashfull {self.transpositionTable[].getFillEstimate()}"
if abs(self.bestRootScore) >= mateScore() - MAX_DEPTH: if abs(self.bestRootScore) >= mateScore() - MAX_DEPTH:
if self.bestRootScore > 0: if self.bestRootScore > 0:
@ -343,10 +350,7 @@ proc shouldStop(self: var SearchManager): bool =
if self.timedOut() and not self.isPondering(): if self.timedOut() and not self.isPondering():
# We ran out of time! # We ran out of time!
return true return true
let if self.maxNodes > 0 and self.nodeCount >= self.maxNodes:
nodeCount = self.nodeCount[].load()
maxNodes = self.maxNodes[].load()
if maxNodes > 0 and nodeCount >= maxNodes:
# Ran out of nodes # Ran out of nodes
return true return true
@ -407,7 +411,7 @@ proc qsearch(self: var SearchManager, ply: int, alpha, beta: Score): Score =
if self.board.position.see(move) < 0: if self.board.position.see(move) < 0:
continue continue
self.board.doMove(move) self.board.doMove(move)
self.nodeCount[].atomicInc() inc(self.nodeCount)
let score = -self.qsearch(ply + 1, -beta, -alpha) let score = -self.qsearch(ply + 1, -beta, -alpha)
self.board.unmakeMove() self.board.unmakeMove()
bestScore = max(score, bestScore) bestScore = max(score, bestScore)
@ -565,7 +569,7 @@ proc search(self: var SearchManager, depth, ply: int, alpha, beta: Score, isPV:
let let
extension = self.getSearchExtension(move) extension = self.getSearchExtension(move)
reduction = self.getReduction(move, depth, ply, i, isPV) reduction = self.getReduction(move, depth, ply, i, isPV)
self.nodeCount[].atomicInc() inc(self.nodeCount)
# Find the best move for us (worst move # Find the best move for us (worst move
# for our opponent, hence the negative sign) # for our opponent, hence the negative sign)
var score: Score var score: Score
@ -695,25 +699,9 @@ proc aspirationWindow(self: var SearchManager, score: Score, depth: int): Score
delta = highestEval() delta = highestEval()
proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move], proc findBestLine(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove=false, ponder=false): seq[Move] = timePerMove=false, ponder=false): seq[Move] =
## Finds the principal variation in the current position ## Internal, singl-threaded search for the principal variation
## and returns it, limiting search time according the
## the remaining time and increment values provided (in
## milliseconds) and only up to maxDepth ply (if maxDepth
## is -1, a reasonable limit is picked). If maxNodes is supplied
## and is nonzero, search will stop once it has analyzed maxNodes
## nodes. If searchMoves is provided and is not empty, search will
## be restricted to the moves in the list. Note that regardless of
## any time limitations or explicit cancellations, the search will
## not stop until it has at least cleared depth one. Search depth
## is always constrained to at most MAX_DEPTH ply from the root. If
## timePerMove is true, the increment is assumed to be zero and the
## remaining time is considered the time limit for the entire search
## (note that soft time management is disabled in that case). If ponder
## is true, the search is performed in pondering mode (i.e. no explicit
## time limit) and can be switched to a regular search by calling the
## stopPondering() procedure
# Apparently negative remaining time is a thing. Welp # Apparently negative remaining time is a thing. Welp
self.maxSearchTime = if not timePerMove: max(1, (timeRemaining div 10) + ((increment div 3) * 2)) else: timeRemaining self.maxSearchTime = if not timePerMove: max(1, (timeRemaining div 10) + ((increment div 3) * 2)) else: timeRemaining
@ -721,8 +709,8 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
result = @[] result = @[]
var pv: array[256, Move] var pv: array[256, Move]
if self.isMainWorker: if self.isMainWorker:
self.maxNodes[].store(maxNodes)
self.pondering[].store(ponder) self.pondering[].store(ponder)
self.maxNodes = maxNodes
self.searchMoves = searchMoves self.searchMoves = searchMoves
self.searchStart = getMonoTime() self.searchStart = getMonoTime()
self.hardLimit = self.searchStart + initDuration(milliseconds=self.maxSearchTime) self.hardLimit = self.searchStart + initDuration(milliseconds=self.maxSearchTime)
@ -730,8 +718,6 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
var maxDepth = maxDepth var maxDepth = maxDepth
if maxDepth == -1: if maxDepth == -1:
maxDepth = 60 maxDepth = 60
if self.isMainWorker:
self.searching[].store(true)
# Iterative deepening loop # Iterative deepening loop
var score = Score(0) var score = Score(0)
for depth in 1..min(MAX_DEPTH, maxDepth): for depth in 1..min(MAX_DEPTH, maxDepth):
@ -750,53 +736,94 @@ proc findBestLine*(self: var SearchManager, timeRemaining, increment: int64, max
# anyway # anyway
if getMonoTime() >= self.softLimit and not self.isPondering(): if getMonoTime() >= self.softLimit and not self.isPondering():
break break
if self.isMainWorker:
self.searching[].store(false)
self.stop[].store(false)
for move in pv: for move in pv:
if move == nullMove(): if move == nullMove():
break break
result.add(move) result.add(move)
proc workerFunc(args: tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move], proc workerFunc(args: tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]) {.thread.} = timePerMove, ponder: bool]) {.thread.} =
## Worker that calls findBestLine in a new thread ## Worker that calls findBestLine in a new thread
# Gotta lie to nim's thread analyzer lest it shout at us that we're not # Gotta lie to nim's thread analyzer lest it shout at us that we're not
# GC safe! # GC safe!
{.cast(gcsafe).}: {.cast(gcsafe).}:
var self = args.self discard args.self[].findBestLine(args.timeRemaining, args.increment, args.maxDepth, args.maxNodes, args.searchMoves, args.timePerMove, args.ponder)
discard self.findBestLine(args.timeRemaining, args.increment, args.maxDepth, args.maxNodes, args.searchMoves, args.timePerMove, args.ponder)
# Creating threads is expensive, so there's no need to make new ones for every call # Creating threads is expensive, so there's no need to make new ones for every call
# to our parallel search. Also, nim leaks thread vars: this keeps the resource leaks # to our parallel search. Also, nim leaks thread vars: this keeps the resource leaks
# to a minimum # to a minimum
var workers: seq[ref Thread[tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move], var workers: seq[ref Thread[tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]]] = @[] timePerMove, ponder: bool]]] = @[]
proc parallelSearch*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move], proc search*(self: var SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove=false, ponder=false, numWorkers: int): seq[Move] = timePerMove=false, ponder=false, numWorkers: int): seq[Move] =
## Parallel version of findBestLine(): the search is performed ## Finds the principal variation in the current position
## using the provided number of worker threads using a shared ## and returns it, limiting search time according the
## transposition table. ## the remaining time and increment values provided (in
## milliseconds) and only up to maxDepth ply (if maxDepth
## is -1, a reasonable limit is picked). If maxNodes is supplied
## and is nonzero, search will stop once it has analyzed maxNodes
## nodes. If searchMoves is provided and is not empty, search will
## be restricted to the moves in the list. Note that regardless of
## any time limitations or explicit cancellations, the search will
## not stop until it has at least cleared depth one. Search depth
## is always constrained to at most MAX_DEPTH ply from the root. If
## timePerMove is true, the increment is assumed to be zero and the
## remaining time is considered the time limit for the entire search
## (note that soft time management is disabled in that case). If ponder
## is true, the search is performed in pondering mode (i.e. no explicit
## time limit) and can be switched to a regular search by calling the
## stopPondering() procedure. If numWorkers is > 1, the search is performed
## in parallel using numWorkers threads
while workers.len() + 1 < numWorkers: while workers.len() + 1 < numWorkers:
# We create n - 1 workers because we'll also be searching # We create n - 1 workers because we'll also be searching
# ourselves. We use the lazy SMP approach, so we'll exploit the # ourselves. We use the lazy SMP approach, so we'll exploit the
# other threads just to fill up our transposition table and # other threads just to fill up our transposition table and
# not much else (for now) # not much else (for now)
workers.add(new Thread[tuple[self: SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move], workers.add(new Thread[tuple[self: ptr SearchManager, timeRemaining, increment: int64, maxDepth: int, maxNodes: uint64, searchMoves: seq[Move],
timePerMove, ponder: bool]]) timePerMove, ponder: bool]])
self.searching[].store(true)
for i in 0..<numWorkers - 1: for i in 0..<numWorkers - 1:
# Create a new search manager to send off to a worker thread # Copy the history and killers table, as those are meant to be thread-local
var localSearcher = newSearchManager(self.board.position, self.board.positions, self.transpositionTable, self.history, self.killers, false) var
history = create(HistoryTable, sizeof(HistoryTable))
killers = create(KillersTable, sizeof(KillersTable))
# Copy in the data
for color in PieceColor.White..PieceColor.Black:
for i in Square(0)..Square(63):
for j in Square(0)..Square(63):
history[color][i][j] = self.history[color][i][j]
for i in 0..<MAX_DEPTH:
for j in 0..<NUM_KILLERS:
killers[i][j] = self.killers[i][j]
# Create a new search manager to send off to a worker thread. We store it
# on the heap because we need to access its state from elsewhere for collecting
# statistics
self.children.add(create(SearchManager, sizeof(SearchManager)))
self.children[i][] = newSearchManager(self.board.position, self.board.positions, self.transpositionTable, history, killers, false)
# Fill in our shared atomic metadata # Fill in our shared atomic metadata
localSearcher.stop = self.stop self.children[i].stop = self.stop
localSearcher.pondering = self.pondering self.children[i].pondering = self.pondering
localSearcher.searching = self.searching self.children[i].searching = self.searching
localSearcher.nodeCount = self.nodeCount
localSearcher.maxNodes = self.maxNodes
# Off you go, you little search minion! # Off you go, you little search minion!
createThread(workers[i][], workerFunc, (localSearcher, timeRemaining, increment, maxDepth, maxNodes, searchMoves, timePerMove, ponder)) createThread(workers[i][], workerFunc, (self.children[i], timeRemaining, increment, maxDepth, maxNodes div numWorkers.uint64, searchMoves, timePerMove, ponder))
result = self.findBestLine(timeRemaining, increment, maxDepth, maxNodes, searchMoves, timePerMove, ponder) # We divide maxNodes by the number of workers so that even when searching in parallel, no more than maxNodes nodes
# No need to wait for the threads, they'll finish alongside us anyway # are searched
result = self.findBestLine(timeRemaining, increment, maxDepth, maxNodes div numWorkers.uint64, searchMoves, timePerMove, ponder)
# Wait for all search threads to finish. This isn't technically
# necessary, but it's good practice and will catch bugs in our
# "atomic stop" system
for i in 0..<numWorkers - 1:
if workers[i][].running:
joinThread(workers[i][])
# If we set the atomics any earlier than this, our
# search threads would never stop!
self.searching[].store(false)
self.stop[].store(false)
# Ensure local searchers get destroyed
for child in self.children:
child[].`destroy=`()
dealloc(child)
self.children.setLen(0)

View File

@ -331,8 +331,8 @@ proc bestMove(args: tuple[session: UCISession, command: UCICommand]) {.thread.}
increment = 0 increment = 0
elif timeRemaining == 0: elif timeRemaining == 0:
timeRemaining = int32.high() timeRemaining = int32.high()
var line = session.searchState[].parallelSearch(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove, var line = session.searchState[].search(timeRemaining, increment, command.depth, command.nodes, command.searchmoves, timePerMove,
command.ponder, session.workers) command.ponder, session.workers)
if session.printMove[]: if session.printMove[]:
if line.len() == 1: if line.len() == 1:
echo &"bestmove {line[0].toAlgebraic()}" echo &"bestmove {line[0].toAlgebraic()}"