Skip to content

Commit 8b8a510

Browse files
syzygy1vondele
authored andcommitted
Use tiling to speed up accumulator refreshes and updates
Perform the update and refresh operations tile by tile in a local array of vectors. By selecting the array size carefully, we achieve that the compiler keeps the whole array in vector registers. Idea and original implementation by @sf-x. STC: https://tests.stockfishchess.org/tests/view/5f623eec912c15f19854b855 LLR: 2.94 (-2.94,2.94) {-0.25,1.25} Total: 4872 W: 623 L: 477 D: 3772 Ptnml(0-2): 14, 350, 1585, 450, 37 LTC: https://tests.stockfishchess.org/tests/view/5f62434e912c15f19854b860 LLR: 2.94 (-2.94,2.94) {0.25,1.25} Total: 25808 W: 1565 L: 1401 D: 22842 Ptnml(0-2): 23, 1186, 10332, 1330, 33 closes #3130 No functional change
1 parent 64a6346 commit 8b8a510

File tree

1 file changed

+127
-110
lines changed

1 file changed

+127
-110
lines changed

src/nnue/nnue_feature_transformer.h

Lines changed: 127 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,68 @@
2929

3030
namespace Eval::NNUE {
3131

32+
// If vector instructions are enabled, we update and refresh the
33+
// accumulator tile by tile such that each tile fits in the CPU's
34+
// vector registers.
35+
#define TILING
36+
37+
#ifdef USE_AVX512
38+
typedef __m512i vec_t;
39+
#define vec_load(a) _mm512_loadA_si512(a)
40+
#define vec_store(a,b) _mm512_storeA_si512(a,b)
41+
#define vec_add_16(a,b) _mm512_add_epi16(a,b)
42+
#define vec_sub_16(a,b) _mm512_sub_epi16(a,b)
43+
static constexpr IndexType kNumRegs = 8; // only 8 are needed
44+
45+
#elif USE_AVX2
46+
typedef __m256i vec_t;
47+
#define vec_load(a) _mm256_loadA_si256(a)
48+
#define vec_store(a,b) _mm256_storeA_si256(a,b)
49+
#define vec_add_16(a,b) _mm256_add_epi16(a,b)
50+
#define vec_sub_16(a,b) _mm256_sub_epi16(a,b)
51+
static constexpr IndexType kNumRegs = 16;
52+
53+
#elif USE_SSE2
54+
typedef __m128i vec_t;
55+
#define vec_load(a) (*(a))
56+
#define vec_store(a,b) *(a)=(b)
57+
#define vec_add_16(a,b) _mm_add_epi16(a,b)
58+
#define vec_sub_16(a,b) _mm_sub_epi16(a,b)
59+
static constexpr IndexType kNumRegs = Is64Bit ? 16 : 8;
60+
61+
#elif USE_MMX
62+
typedef __m64 vec_t;
63+
#define vec_load(a) (*(a))
64+
#define vec_store(a,b) *(a)=(b)
65+
#define vec_add_16(a,b) _mm_add_pi16(a,b)
66+
#define vec_sub_16(a,b) _mm_sub_pi16(a,b)
67+
static constexpr IndexType kNumRegs = 8;
68+
69+
#elif USE_NEON
70+
typedef int16x8_t vec_t;
71+
#define vec_load(a) (*(a))
72+
#define vec_store(a,b) *(a)=(b)
73+
#define vec_add_16(a,b) vaddq_s16(a,b)
74+
#define vec_sub_16(a,b) vsubq_s16(a,b)
75+
static constexpr IndexType kNumRegs = 16;
76+
77+
#else
78+
#undef TILING
79+
80+
#endif
81+
3282
// Input feature converter
3383
class FeatureTransformer {
3484

3585
private:
3686
// Number of output dimensions for one side
3787
static constexpr IndexType kHalfDimensions = kTransformedFeatureDimensions;
3888

89+
#ifdef TILING
90+
static constexpr IndexType kTileHeight = kNumRegs * sizeof(vec_t) / 2;
91+
static_assert(kHalfDimensions % kTileHeight == 0, "kTileHeight must divide kHalfDimensions");
92+
#endif
93+
3994
public:
4095
// Output type
4196
using OutputType = TransformedFeatureType;
@@ -189,57 +244,41 @@ namespace Eval::NNUE {
189244
RawFeatures::AppendActiveIndices(pos, kRefreshTriggers[i],
190245
active_indices);
191246
for (Color perspective : { WHITE, BLACK }) {
247+
#ifdef TILING
248+
for (unsigned j = 0; j < kHalfDimensions / kTileHeight; ++j) {
249+
auto biasesTile = reinterpret_cast<const vec_t*>(
250+
&biases_[j * kTileHeight]);
251+
auto accTile = reinterpret_cast<vec_t*>(
252+
&accumulator.accumulation[perspective][i][j * kTileHeight]);
253+
vec_t acc[kNumRegs];
254+
255+
for (unsigned k = 0; k < kNumRegs; ++k)
256+
acc[k] = biasesTile[k];
257+
258+
for (const auto index : active_indices[perspective]) {
259+
const IndexType offset = kHalfDimensions * index + j * kTileHeight;
260+
auto column = reinterpret_cast<const vec_t*>(&weights_[offset]);
261+
262+
for (unsigned k = 0; k < kNumRegs; ++k)
263+
acc[k] = vec_add_16(acc[k], column[k]);
264+
}
265+
266+
for (unsigned k = 0; k < kNumRegs; k++)
267+
vec_store(&accTile[k], acc[k]);
268+
}
269+
#else
192270
std::memcpy(accumulator.accumulation[perspective][i], biases_,
193-
kHalfDimensions * sizeof(BiasType));
271+
kHalfDimensions * sizeof(BiasType));
272+
194273
for (const auto index : active_indices[perspective]) {
195274
const IndexType offset = kHalfDimensions * index;
196-
#if defined(USE_AVX512)
197-
auto accumulation = reinterpret_cast<__m512i*>(
198-
&accumulator.accumulation[perspective][i][0]);
199-
auto column = reinterpret_cast<const __m512i*>(&weights_[offset]);
200-
constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
201-
for (IndexType j = 0; j < kNumChunks; ++j)
202-
_mm512_storeA_si512(&accumulation[j], _mm512_add_epi16(_mm512_loadA_si512(&accumulation[j]), column[j]));
203-
204-
#elif defined(USE_AVX2)
205-
auto accumulation = reinterpret_cast<__m256i*>(
206-
&accumulator.accumulation[perspective][i][0]);
207-
auto column = reinterpret_cast<const __m256i*>(&weights_[offset]);
208-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
209-
for (IndexType j = 0; j < kNumChunks; ++j)
210-
_mm256_storeA_si256(&accumulation[j], _mm256_add_epi16(_mm256_loadA_si256(&accumulation[j]), column[j]));
211-
212-
#elif defined(USE_SSE2)
213-
auto accumulation = reinterpret_cast<__m128i*>(
214-
&accumulator.accumulation[perspective][i][0]);
215-
auto column = reinterpret_cast<const __m128i*>(&weights_[offset]);
216-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
217-
for (IndexType j = 0; j < kNumChunks; ++j)
218-
accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
219-
220-
#elif defined(USE_MMX)
221-
auto accumulation = reinterpret_cast<__m64*>(
222-
&accumulator.accumulation[perspective][i][0]);
223-
auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
224-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
225-
for (IndexType j = 0; j < kNumChunks; ++j)
226-
accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
227275

228-
#elif defined(USE_NEON)
229-
auto accumulation = reinterpret_cast<int16x8_t*>(
230-
&accumulator.accumulation[perspective][i][0]);
231-
auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
232-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
233-
for (IndexType j = 0; j < kNumChunks; ++j)
234-
accumulation[j] = vaddq_s16(accumulation[j], column[j]);
235-
236-
#else
237276
for (IndexType j = 0; j < kHalfDimensions; ++j)
238277
accumulator.accumulation[perspective][i][j] += weights_[offset + j];
239-
#endif
240-
241278
}
279+
#endif
242280
}
281+
243282
#if defined(USE_MMX)
244283
_mm_empty();
245284
#endif
@@ -257,29 +296,55 @@ namespace Eval::NNUE {
257296
bool reset[2];
258297
RawFeatures::AppendChangedIndices(pos, kRefreshTriggers[i],
259298
removed_indices, added_indices, reset);
260-
for (Color perspective : { WHITE, BLACK }) {
261299

262-
#if defined(USE_AVX2)
263-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
264-
auto accumulation = reinterpret_cast<__m256i*>(
265-
&accumulator.accumulation[perspective][i][0]);
266-
267-
#elif defined(USE_SSE2)
268-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
269-
auto accumulation = reinterpret_cast<__m128i*>(
270-
&accumulator.accumulation[perspective][i][0]);
271-
272-
#elif defined(USE_MMX)
273-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
274-
auto accumulation = reinterpret_cast<__m64*>(
275-
&accumulator.accumulation[perspective][i][0]);
300+
#ifdef TILING
301+
for (IndexType j = 0; j < kHalfDimensions / kTileHeight; ++j) {
302+
for (Color perspective : { WHITE, BLACK }) {
303+
auto accTile = reinterpret_cast<vec_t*>(
304+
&accumulator.accumulation[perspective][i][j * kTileHeight]);
305+
vec_t acc[kNumRegs];
306+
307+
if (reset[perspective]) {
308+
auto biasesTile = reinterpret_cast<const vec_t*>(
309+
&biases_[j * kTileHeight]);
310+
for (unsigned k = 0; k < kNumRegs; ++k)
311+
acc[k] = biasesTile[k];
312+
} else {
313+
auto prevAccTile = reinterpret_cast<const vec_t*>(
314+
&prev_accumulator.accumulation[perspective][i][j * kTileHeight]);
315+
for (IndexType k = 0; k < kNumRegs; ++k)
316+
acc[k] = vec_load(&prevAccTile[k]);
317+
318+
// Difference calculation for the deactivated features
319+
for (const auto index : removed_indices[perspective]) {
320+
const IndexType offset = kHalfDimensions * index + j * kTileHeight;
321+
auto column = reinterpret_cast<const vec_t*>(&weights_[offset]);
322+
323+
for (IndexType k = 0; k < kNumRegs; ++k)
324+
acc[k] = vec_sub_16(acc[k], column[k]);
325+
}
326+
}
327+
{ // Difference calculation for the activated features
328+
for (const auto index : added_indices[perspective]) {
329+
const IndexType offset = kHalfDimensions * index + j * kTileHeight;
330+
auto column = reinterpret_cast<const vec_t*>(&weights_[offset]);
331+
332+
for (IndexType k = 0; k < kNumRegs; ++k)
333+
acc[k] = vec_add_16(acc[k], column[k]);
334+
}
335+
}
276336

277-
#elif defined(USE_NEON)
278-
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
279-
auto accumulation = reinterpret_cast<int16x8_t*>(
280-
&accumulator.accumulation[perspective][i][0]);
337+
for (IndexType k = 0; k < kNumRegs; ++k)
338+
vec_store(&accTile[k], acc[k]);
339+
}
340+
}
341+
#if defined(USE_MMX)
342+
_mm_empty();
281343
#endif
282344

345+
#else
346+
for (Color perspective : { WHITE, BLACK }) {
347+
283348
if (reset[perspective]) {
284349
std::memcpy(accumulator.accumulation[perspective][i], biases_,
285350
kHalfDimensions * sizeof(BiasType));
@@ -291,67 +356,19 @@ namespace Eval::NNUE {
291356
for (const auto index : removed_indices[perspective]) {
292357
const IndexType offset = kHalfDimensions * index;
293358

294-
#if defined(USE_AVX2)
295-
auto column = reinterpret_cast<const __m256i*>(&weights_[offset]);
296-
for (IndexType j = 0; j < kNumChunks; ++j)
297-
accumulation[j] = _mm256_sub_epi16(accumulation[j], column[j]);
298-
299-
#elif defined(USE_SSE2)
300-
auto column = reinterpret_cast<const __m128i*>(&weights_[offset]);
301-
for (IndexType j = 0; j < kNumChunks; ++j)
302-
accumulation[j] = _mm_sub_epi16(accumulation[j], column[j]);
303-
304-
#elif defined(USE_MMX)
305-
auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
306-
for (IndexType j = 0; j < kNumChunks; ++j)
307-
accumulation[j] = _mm_sub_pi16(accumulation[j], column[j]);
308-
309-
#elif defined(USE_NEON)
310-
auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
311-
for (IndexType j = 0; j < kNumChunks; ++j)
312-
accumulation[j] = vsubq_s16(accumulation[j], column[j]);
313-
314-
#else
315359
for (IndexType j = 0; j < kHalfDimensions; ++j)
316360
accumulator.accumulation[perspective][i][j] -= weights_[offset + j];
317-
#endif
318-
319361
}
320362
}
321363
{ // Difference calculation for the activated features
322364
for (const auto index : added_indices[perspective]) {
323365
const IndexType offset = kHalfDimensions * index;
324366

325-
#if defined(USE_AVX2)
326-
auto column = reinterpret_cast<const __m256i*>(&weights_[offset]);
327-
for (IndexType j = 0; j < kNumChunks; ++j)
328-
accumulation[j] = _mm256_add_epi16(accumulation[j], column[j]);
329-
330-
#elif defined(USE_SSE2)
331-
auto column = reinterpret_cast<const __m128i*>(&weights_[offset]);
332-
for (IndexType j = 0; j < kNumChunks; ++j)
333-
accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
334-
335-
#elif defined(USE_MMX)
336-
auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
337-
for (IndexType j = 0; j < kNumChunks; ++j)
338-
accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
339-
340-
#elif defined(USE_NEON)
341-
auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
342-
for (IndexType j = 0; j < kNumChunks; ++j)
343-
accumulation[j] = vaddq_s16(accumulation[j], column[j]);
344-
345-
#else
346367
for (IndexType j = 0; j < kHalfDimensions; ++j)
347368
accumulator.accumulation[perspective][i][j] += weights_[offset + j];
348-
#endif
349-
350369
}
351370
}
352371
}
353-
#if defined(USE_MMX)
354-
_mm_empty();
355372
#endif
356373

357374
accumulator.computed_accumulation = true;

0 commit comments

Comments
 (0)