Skip to content

Commit e289b87

Browse files
committed
8 PSQT values per feature and 8 layer stacks.
1 parent f5fcca9 commit e289b87

File tree

5 files changed

+140
-11
lines changed

5 files changed

+140
-11
lines changed

src/nnue/evaluate_nnue.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace Stockfish::Eval::NNUE {
3535
LargePagePtr<FeatureTransformer> featureTransformer;
3636

3737
// Evaluation function
38-
AlignedPtr<Network> network;
38+
AlignedPtr<Network> network[LayerStacks];
3939

4040
// Evaluation function file name
4141
std::string fileName;
@@ -74,7 +74,8 @@ namespace Stockfish::Eval::NNUE {
7474
void initialize() {
7575

7676
Detail::initialize(featureTransformer);
77-
Detail::initialize(network);
77+
for (std::size_t i = 0; i < LayerStacks; ++i)
78+
Detail::initialize(network[i]);
7879
}
7980

8081
// Read network header
@@ -83,7 +84,7 @@ namespace Stockfish::Eval::NNUE {
8384
std::uint32_t version, size;
8485

8586
version = read_little_endian<std::uint32_t>(stream);
86-
*hashValue = read_little_endian<std::uint32_t>(stream);
87+
*hashValue = read_little_endian<std::uint32_t>(stream);
8788
size = read_little_endian<std::uint32_t>(stream);
8889
if (!stream || version != Version) return false;
8990
architecture->resize(size);
@@ -99,7 +100,8 @@ namespace Stockfish::Eval::NNUE {
99100
if (!read_header(stream, &hashValue, &architecture)) return false;
100101
if (hashValue != HashValue) return false;
101102
if (!Detail::read_parameters(stream, *featureTransformer)) return false;
102-
if (!Detail::read_parameters(stream, *network)) return false;
103+
for (std::size_t i = 0; i < LayerStacks; ++i)
104+
if (!Detail::read_parameters(stream, *(network[i]))) return false;
103105
return stream && stream.peek() == std::ios::traits_type::eof();
104106
}
105107

@@ -127,10 +129,12 @@ namespace Stockfish::Eval::NNUE {
127129
ASSERT_ALIGNED(transformedFeatures, alignment);
128130
ASSERT_ALIGNED(buffer, alignment);
129131

130-
featureTransformer->transform(pos, transformedFeatures);
131-
const auto output = network->propagate(transformedFeatures, buffer);
132+
const std::size_t bucket = (popcount(pos.pieces()) - 1) / 4;
132133

133-
return static_cast<Value>(output[0] / OutputScale);
134+
const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket);
135+
const auto output = network[bucket]->propagate(transformedFeatures, buffer);
136+
137+
return static_cast<Value>((output[0] + psqt) / OutputScale);
134138
}
135139

136140
// Load eval, from a file stream or a memory stream

src/nnue/nnue_accumulator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ namespace Stockfish::Eval::NNUE {
3030

3131
// Class that holds the result of affine transformation of input features
3232
struct alignas(CacheLineSize) Accumulator {
33-
std::int16_t
34-
accumulation[2][TransformedFeatureDimensions];
33+
std::int16_t accumulation[2][TransformedFeatureDimensions];
34+
std::int32_t psqtAccumulation[2][PSQTBuckets];
3535
AccumulatorState state[2];
3636
};
3737

src/nnue/nnue_architecture.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ namespace Stockfish::Eval::NNUE {
3636

3737
// Number of input feature dimensions after conversion
3838
constexpr IndexType TransformedFeatureDimensions = 512;
39+
constexpr IndexType PSQTBuckets = 8;
40+
constexpr IndexType LayerStacks = 8;
3941

4042
namespace Layers {
4143

src/nnue/nnue_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
namespace Stockfish::Eval::NNUE {
4747

4848
// Version of the evaluation file
49-
constexpr std::uint32_t Version = 0x7AF32F16u;
49+
constexpr std::uint32_t Version = 0x7AF32F20u;
5050

5151
// Constant used in evaluation value calculation
5252
constexpr int OutputScale = 16;

src/nnue/nnue_feature_transformer.h

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,82 @@ namespace Stockfish::Eval::NNUE {
3535
// vector registers.
3636
#define VECTOR
3737

38+
static_assert(PSQTBuckets == 8, "Assumed by the current choice of constants.");
39+
3840
#ifdef USE_AVX512
3941
typedef __m512i vec_t;
42+
typedef __m256i psqt_vec_t;
4043
#define vec_load(a) _mm512_load_si512(a)
4144
#define vec_store(a,b) _mm512_store_si512(a,b)
4245
#define vec_add_16(a,b) _mm512_add_epi16(a,b)
4346
#define vec_sub_16(a,b) _mm512_sub_epi16(a,b)
47+
#define vec_load_psqt(a) _mm256_load_si256(a)
48+
#define vec_store_psqt(a,b) _mm256_store_si256(a,b)
49+
#define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
50+
#define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
51+
#define vec_zero_psqt() _mm256_setzero_si256()
4452
static constexpr IndexType NumRegs = 8; // only 8 are needed
53+
static constexpr IndexType NumPsqtRegs = 1;
4554

4655
#elif USE_AVX2
4756
typedef __m256i vec_t;
57+
typedef __m256i psqt_vec_t;
4858
#define vec_load(a) _mm256_load_si256(a)
4959
#define vec_store(a,b) _mm256_store_si256(a,b)
5060
#define vec_add_16(a,b) _mm256_add_epi16(a,b)
5161
#define vec_sub_16(a,b) _mm256_sub_epi16(a,b)
62+
#define vec_load_psqt(a) _mm256_load_si256(a)
63+
#define vec_store_psqt(a,b) _mm256_store_si256(a,b)
64+
#define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
65+
#define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
66+
#define vec_zero_psqt() _mm256_setzero_si256()
5267
static constexpr IndexType NumRegs = 16;
68+
static constexpr IndexType NumPsqtRegs = 1;
5369

5470
#elif USE_SSE2
5571
typedef __m128i vec_t;
72+
typedef __m128i psqt_vec_t;
5673
#define vec_load(a) (*(a))
5774
#define vec_store(a,b) *(a)=(b)
5875
#define vec_add_16(a,b) _mm_add_epi16(a,b)
5976
#define vec_sub_16(a,b) _mm_sub_epi16(a,b)
77+
#define vec_load_psqt(a) (*(a))
78+
#define vec_store_psqt(a,b) *(a)=(b)
79+
#define vec_add_psqt_32(a,b) _mm_add_epi32(a,b)
80+
#define vec_sub_psqt_32(a,b) _mm_sub_epi32(a,b)
81+
#define vec_zero_psqt() _mm_setzero_si128()
6082
static constexpr IndexType NumRegs = Is64Bit ? 16 : 8;
83+
static constexpr IndexType NumPsqtRegs = 2;
6184

6285
#elif USE_MMX
6386
typedef __m64 vec_t;
87+
typedef std::int32_t psqt_vec_t;
6488
#define vec_load(a) (*(a))
6589
#define vec_store(a,b) *(a)=(b)
6690
#define vec_add_16(a,b) _mm_add_pi16(a,b)
6791
#define vec_sub_16(a,b) _mm_sub_pi16(a,b)
92+
#define vec_load_psqt(a) (*(a))
93+
#define vec_store_psqt(a,b) *(a)=(b)
94+
#define vec_add_psqt_32(a,b) a+b
95+
#define vec_sub_psqt_32(a,b) a-b
96+
#define vec_zero_psqt() 0
6897
static constexpr IndexType NumRegs = 8;
98+
static constexpr IndexType NumPsqtRegs = 8;
6999

70100
#elif USE_NEON
71101
typedef int16x8_t vec_t;
102+
typedef int32x4_t psqt_vec_t;
72103
#define vec_load(a) (*(a))
73104
#define vec_store(a,b) *(a)=(b)
74105
#define vec_add_16(a,b) vaddq_s16(a,b)
75106
#define vec_sub_16(a,b) vsubq_s16(a,b)
107+
#define vec_load_psqt(a) (*(a))
108+
#define vec_store_psqt(a,b) *(a)=(b)
109+
#define vec_add_psqt_32(a,b) vaddq_s32(a,b)
110+
#define vec_sub_psqt_32(a,b) vsubq_s32(a,b)
111+
#define vec_zero_psqt() psqt_vec_t{0}
76112
static constexpr IndexType NumRegs = 16;
113+
static constexpr IndexType NumPsqtRegs = 2;
77114

78115
#else
79116
#undef VECTOR
@@ -89,7 +126,9 @@ namespace Stockfish::Eval::NNUE {
89126

90127
#ifdef VECTOR
91128
static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
129+
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
92130
static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
131+
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
93132
#endif
94133

95134
public:
@@ -115,15 +154,18 @@ namespace Stockfish::Eval::NNUE {
115154
biases[i] = read_little_endian<BiasType>(stream);
116155
for (std::size_t i = 0; i < HalfDimensions * InputDimensions; ++i)
117156
weights[i] = read_little_endian<WeightType>(stream);
157+
for (std::size_t i = 0; i < PSQTBuckets * InputDimensions; ++i)
158+
psqtWeights[i] = read_little_endian<PSQTWeightType>(stream);
118159
return !stream.fail();
119160
}
120161

121162
// Convert input features
122-
void transform(const Position& pos, OutputType* output) const {
163+
std::int32_t transform(const Position& pos, OutputType* output, int bucket) const {
123164
update_accumulator(pos, WHITE);
124165
update_accumulator(pos, BLACK);
125166

126167
const auto& accumulation = pos.state()->accumulator.accumulation;
168+
const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation;
127169

128170
#if defined(USE_AVX512)
129171
constexpr IndexType NumChunks = HalfDimensions / (SimdWidth * 2);
@@ -231,6 +273,12 @@ namespace Stockfish::Eval::NNUE {
231273
#if defined(USE_MMX)
232274
_mm_empty();
233275
#endif
276+
277+
const auto psqt = (
278+
psqtAccumulation[static_cast<int>(perspectives[0])][bucket]
279+
- psqtAccumulation[static_cast<int>(perspectives[1])][bucket]
280+
) / 2;
281+
return psqt;
234282
}
235283

236284
private:
@@ -246,6 +294,7 @@ namespace Stockfish::Eval::NNUE {
246294
// Gcc-10.2 unnecessarily spills AVX2 registers if this array
247295
// is defined in the VECTOR code below, once in each branch
248296
vec_t acc[NumRegs];
297+
psqt_vec_t psqt[NumPsqtRegs];
249298
#endif
250299

251300
// Look for a usable accumulator of an earlier position. We keep track
@@ -324,12 +373,52 @@ namespace Stockfish::Eval::NNUE {
324373
}
325374
}
326375

376+
for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
377+
{
378+
// Load accumulator
379+
auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
380+
&st->accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
381+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
382+
psqt[k] = vec_load_psqt(&accTilePsqt[k]);
383+
384+
for (IndexType i = 0; states_to_update[i]; ++i)
385+
{
386+
// Difference calculation for the deactivated features
387+
for (const auto index : removed[i])
388+
{
389+
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
390+
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
391+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
392+
psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
393+
}
394+
395+
// Difference calculation for the activated features
396+
for (const auto index : added[i])
397+
{
398+
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
399+
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
400+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
401+
psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
402+
}
403+
404+
// Store accumulator
405+
accTilePsqt = reinterpret_cast<psqt_vec_t*>(
406+
&states_to_update[i]->accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
407+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
408+
vec_store_psqt(&accTilePsqt[k], psqt[k]);
409+
}
410+
}
411+
327412
#else
328413
for (IndexType i = 0; states_to_update[i]; ++i)
329414
{
330415
std::memcpy(states_to_update[i]->accumulator.accumulation[perspective],
331416
st->accumulator.accumulation[perspective],
332417
HalfDimensions * sizeof(BiasType));
418+
419+
for (std::size_t k = 0; k < PSQTBuckets; ++k)
420+
states_to_update[i]->accumulator.psqtAccumulation[perspective][k] = st->accumulator.psqtAccumulation[perspective][k];
421+
333422
st = states_to_update[i];
334423

335424
// Difference calculation for the deactivated features
@@ -339,6 +428,9 @@ namespace Stockfish::Eval::NNUE {
339428

340429
for (IndexType j = 0; j < HalfDimensions; ++j)
341430
st->accumulator.accumulation[perspective][j] -= weights[offset + j];
431+
432+
for (std::size_t k = 0; k < PSQTBuckets; ++k)
433+
st->accumulator.psqtAccumulation[perspective][k] -= psqtWeights[index * PSQTBuckets + k];
342434
}
343435

344436
// Difference calculation for the activated features
@@ -348,6 +440,9 @@ namespace Stockfish::Eval::NNUE {
348440

349441
for (IndexType j = 0; j < HalfDimensions; ++j)
350442
st->accumulator.accumulation[perspective][j] += weights[offset + j];
443+
444+
for (std::size_t k = 0; k < PSQTBuckets; ++k)
445+
st->accumulator.psqtAccumulation[perspective][k] += psqtWeights[index * PSQTBuckets + k];
351446
}
352447
}
353448
#endif
@@ -383,16 +478,42 @@ namespace Stockfish::Eval::NNUE {
383478
vec_store(&accTile[k], acc[k]);
384479
}
385480

481+
for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
482+
{
483+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
484+
psqt[k] = vec_zero_psqt();
485+
486+
for (const auto index : active)
487+
{
488+
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
489+
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
490+
491+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
492+
psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
493+
}
494+
495+
auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
496+
&accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
497+
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
498+
vec_store_psqt(&accTilePsqt[k], psqt[k]);
499+
}
500+
386501
#else
387502
std::memcpy(accumulator.accumulation[perspective], biases,
388503
HalfDimensions * sizeof(BiasType));
389504

505+
for (std::size_t k = 0; k < PSQTBuckets; ++k)
506+
accumulator.psqtAccumulation[perspective][k] = 0;
507+
390508
for (const auto index : active)
391509
{
392510
const IndexType offset = HalfDimensions * index;
393511

394512
for (IndexType j = 0; j < HalfDimensions; ++j)
395513
accumulator.accumulation[perspective][j] += weights[offset + j];
514+
515+
for (std::size_t k = 0; k < PSQTBuckets; ++k)
516+
accumulator.psqtAccumulation[perspective][k] += psqtWeights[index * PSQTBuckets + k];
396517
}
397518
#endif
398519
}
@@ -404,9 +525,11 @@ namespace Stockfish::Eval::NNUE {
404525

405526
using BiasType = std::int16_t;
406527
using WeightType = std::int16_t;
528+
using PSQTWeightType = std::int32_t;
407529

408530
alignas(CacheLineSize) BiasType biases[HalfDimensions];
409531
alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];
532+
alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets];
410533
};
411534

412535
} // namespace Stockfish::Eval::NNUE

0 commit comments

Comments
 (0)