Skip to content

Commit fc79013

Browse files
committed
Use AVX-512 for smaller affine transforms and in the feature transformer.
For the feature transformer the code is analogical to AVX2 since there was room for easy adaptation of wider simd registers. For the smaller affine transforms that have 32 byte stride we keep 2 columns in one zmm register. We also unroll more aggressively so that in the end we have to do 16 parallel horizontal additions on ymm slices each consisting of 4 32-bit integers. The slices are embedded in 8 zmm registers. These changes provide about 1.5% speedup for AVX-512 builds. Closes #3218 No functional change.
1 parent 3f6451e commit fc79013

File tree

2 files changed

+148
-8
lines changed

2 files changed

+148
-8
lines changed

src/nnue/layers/affine_transform.h

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,21 @@ namespace Eval::NNUE::Layers {
8383
return _mm512_reduce_add_epi32(sum) + bias;
8484
};
8585

86-
[[maybe_unused]] auto m512_haddx4 = [](__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
86+
// This function takes
87+
// sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
88+
// sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
89+
// sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
90+
// sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
91+
// and returns
92+
// ret = [
93+
// reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
94+
// reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
95+
// reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
96+
// reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
97+
// ]
98+
[[maybe_unused]] auto m512_hadd128x16_interleave = [](
99+
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
100+
87101
__m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
88102
__m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
89103

@@ -96,7 +110,13 @@ namespace Eval::NNUE::Layers {
96110
__m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
97111
__m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
98112

99-
__m512i sum = _mm512_add_epi32(sum0123a, sum0123b);
113+
return _mm512_add_epi32(sum0123a, sum0123b);
114+
};
115+
116+
[[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
117+
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
118+
119+
__m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
100120

101121
__m256i sum256lo = _mm512_castsi512_si256(sum);
102122
__m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
@@ -109,6 +129,58 @@ namespace Eval::NNUE::Layers {
109129
return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
110130
};
111131

132+
[[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
133+
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
134+
__m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
135+
136+
__m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
137+
__m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
138+
139+
__m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
140+
__m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
141+
__m512i x = _mm512_add_epi32(
142+
_mm512_permutex2var_epi64(suma, indices0, sumb),
143+
_mm512_permutex2var_epi64(suma, indices1, sumb));
144+
145+
__m256i sum256lo = _mm512_castsi512_si256(x);
146+
__m256i sum256hi = _mm512_extracti64x4_epi64(x, 1);
147+
148+
return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias);
149+
};
150+
151+
[[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
152+
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
153+
154+
__m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
155+
156+
__m512i indices = _mm512_setr_epi32(
157+
0, 4, 8, 12, 2, 6, 10, 14,
158+
1, 5, 9, 13, 3, 7, 11, 15);
159+
sum = _mm512_permutexvar_epi32(indices, sum);
160+
161+
__m256i sum256lo = _mm512_castsi512_si256(sum);
162+
__m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
163+
164+
return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias);
165+
};
166+
167+
[[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
168+
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
169+
__m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
170+
171+
__m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
172+
__m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
173+
174+
__m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
175+
__m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
176+
__m512i x = _mm512_add_epi32(
177+
_mm512_permutex2var_epi64(suma, indices0, sumb),
178+
_mm512_permutex2var_epi64(suma, indices1, sumb));
179+
180+
__m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
181+
return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
182+
};
183+
112184
[[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
113185
#if defined (USE_VNNI)
114186
acc = _mm512_dpbusd_epi32(acc, a, b);
@@ -205,7 +277,58 @@ namespace Eval::NNUE::Layers {
205277

206278
// kOutputDimensions is either 1 or a multiple of kSimdWidth
207279
// because then it is also an input dimension.
208-
if constexpr (kOutputDimensions % 4 == 0)
280+
if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
281+
{
282+
for (IndexType i = 0; i < kOutputDimensions; i += 16)
283+
{
284+
const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
285+
const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
286+
const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
287+
const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
288+
const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
289+
const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
290+
const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
291+
const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
292+
293+
const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
294+
__m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
295+
296+
__m512i sum01a = _mm512_setzero_si512();
297+
__m512i sum23a = _mm512_setzero_si512();
298+
__m512i sum45a = _mm512_setzero_si512();
299+
__m512i sum67a = _mm512_setzero_si512();
300+
__m512i sum01b = _mm512_setzero_si512();
301+
__m512i sum23b = _mm512_setzero_si512();
302+
__m512i sum45b = _mm512_setzero_si512();
303+
__m512i sum67b = _mm512_setzero_si512();
304+
305+
const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
306+
const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
307+
const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
308+
const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
309+
const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
310+
const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
311+
const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
312+
const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
313+
314+
const __m256i in256 = input_vector256[0];
315+
const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
316+
317+
m512_add_dpbusd_epi32(sum01a, in, row01a);
318+
m512_add_dpbusd_epi32(sum23a, in, row23a);
319+
m512_add_dpbusd_epi32(sum45a, in, row45a);
320+
m512_add_dpbusd_epi32(sum67a, in, row67a);
321+
m512_add_dpbusd_epi32(sum01b, in, row01b);
322+
m512_add_dpbusd_epi32(sum23b, in, row23b);
323+
m512_add_dpbusd_epi32(sum45b, in, row45b);
324+
m512_add_dpbusd_epi32(sum67b, in, row67b);
325+
326+
*outptr = m512_hadd256x16(
327+
sum01a, sum23a, sum45a, sum67a,
328+
sum01b, sum23b, sum45b, sum67b, bias);
329+
}
330+
}
331+
else if constexpr (kOutputDimensions % 4 == 0)
209332
{
210333
for (IndexType i = 0; i < kOutputDimensions; i += 4)
211334
{

src/nnue/nnue_feature_transformer.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ namespace Eval::NNUE {
127127

128128
const auto& accumulation = pos.state()->accumulator.accumulation;
129129

130-
#if defined(USE_AVX2)
130+
#if defined(USE_AVX512)
131+
constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth * 2);
132+
static_assert(kHalfDimensions % (kSimdWidth * 2) == 0);
133+
const __m512i kControl = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
134+
const __m512i kZero = _mm512_setzero_si512();
135+
136+
#elif defined(USE_AVX2)
131137
constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
132138
constexpr int kControl = 0b11011000;
133139
const __m256i kZero = _mm256_setzero_si256();
@@ -154,13 +160,24 @@ namespace Eval::NNUE {
154160
for (IndexType p = 0; p < 2; ++p) {
155161
const IndexType offset = kHalfDimensions * p;
156162

157-
#if defined(USE_AVX2)
163+
#if defined(USE_AVX512)
164+
auto out = reinterpret_cast<__m512i*>(&output[offset]);
165+
for (IndexType j = 0; j < kNumChunks; ++j) {
166+
__m512i sum0 = _mm512_load_si512(
167+
&reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
168+
__m512i sum1 = _mm512_load_si512(
169+
&reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
170+
_mm512_store_si512(&out[j], _mm512_permutexvar_epi64(kControl,
171+
_mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), kZero)));
172+
}
173+
174+
#elif defined(USE_AVX2)
158175
auto out = reinterpret_cast<__m256i*>(&output[offset]);
159176
for (IndexType j = 0; j < kNumChunks; ++j) {
160177
__m256i sum0 = _mm256_load_si256(
161178
&reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
162179
__m256i sum1 = _mm256_load_si256(
163-
&reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
180+
&reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
164181
_mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8(
165182
_mm256_packs_epi16(sum0, sum1), kZero), kControl));
166183
}
@@ -177,9 +194,9 @@ namespace Eval::NNUE {
177194
_mm_store_si128(&out[j],
178195

179196
#ifdef USE_SSE41
180-
_mm_max_epi8(packedbytes, kZero)
197+
_mm_max_epi8(packedbytes, kZero)
181198
#else
182-
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
199+
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
183200
#endif
184201

185202
);

0 commit comments

Comments
 (0)