Skip to content

Commit 1a71fe4

Browse files
committed
Fix bug with Accum resets on weight loading
No functional change Bench 3271193
1 parent 73e12e2 commit 1a71fe4

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

src/nnue/accumulator.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,33 @@ extern ALIGN64 int16_t in_weights[INSIZE * KPSIZE];
3131
extern ALIGN64 int16_t in_biases[KPSIZE];
3232

3333
INLINE NNUEEvaluator* nnue_create_evaluator() {
34+
return align_malloc(sizeof(NNUEEvaluator));
35+
}
3436

35-
NNUEEvaluator* nnue = align_malloc(sizeof(NNUEEvaluator));
37+
INLINE void nnue_reset_evaluator(NNUEEvaluator* ptr) {
3638

3739
#if USE_NNUE
3840

39-
for (size_t i = 0; i < SQUARE_NB; i++) {
40-
memset(nnue->table[i].occupancy, 0, sizeof(nnue->table[i].occupancy));
41-
memcpy(nnue->table[i].accumulator.values[WHITE], in_biases, sizeof(int16_t) * KPSIZE);
42-
memcpy(nnue->table[i].accumulator.values[BLACK], in_biases, sizeof(int16_t) * KPSIZE);
43-
}
41+
// Reset the Finny table Accumulators
42+
for (size_t i = 0; i < SQUARE_NB; i++) {
43+
memset(ptr->table[i].occupancy, 0, sizeof(ptr->table[i].occupancy));
44+
memcpy(ptr->table[i].accumulator.values[WHITE], in_biases, sizeof(int16_t) * KPSIZE);
45+
memcpy(ptr->table[i].accumulator.values[BLACK], in_biases, sizeof(int16_t) * KPSIZE);
46+
}
4447

45-
#endif
48+
// Reset the base of the Accumulator stack
49+
ptr->current = &ptr->stack[0];
50+
ptr->current->accurate[WHITE] = 0;
51+
ptr->current->accurate[BLACK] = 0;
4652

47-
return nnue;
53+
#endif
4854
}
4955

50-
INLINE void nnue_delete_accumulators(NNUEEvaluator* ptr) {
56+
INLINE void nnue_delete_evaluator(NNUEEvaluator* ptr) {
5157
align_free(ptr);
5258
}
5359

60+
5461
INLINE void nnue_pop(Board *board) {
5562
if (USE_NNUE && board->thread != NULL)
5663
--board->thread->nnue->current;

src/thread.c

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Thread* createThreadPool(int nthreads) {
5858
void deleteThreadPool(Thread *threads) {
5959

6060
for (int i = 0; i < threads->nthreads; i++)
61-
nnue_delete_accumulators(threads[i].nnue);
61+
nnue_delete_evaluator(threads[i].nnue);
6262

6363
free(threads);
6464
}
@@ -100,12 +100,8 @@ void newSearchThreadPool(Thread *threads, Board *board, Limits *limits, TimeMana
100100
memcpy(&threads[i].board, board, sizeof(Board));
101101
threads[i].board.thread = &threads[i];
102102

103-
// Reset the accumulator stack. The table can remain
104-
threads[i].nnue->current = &threads[i].nnue->stack[0];
105-
threads[i].nnue->current->accurate[WHITE] = 0;
106-
threads[i].nnue->current->accurate[BLACK] = 0;
107-
108103
memset(threads[i].nodeStates, 0, sizeof(NodeState) * STACK_SIZE);
104+
nnue_reset_evaluator(threads[i].nnue);
109105
}
110106
}
111107

0 commit comments

Comments
 (0)