Skip to content

Commit 10b7bdc

Browse files
authored
3.3.0: support avx512 and vnni512 instructions (#272)
bench: 2703171
1 parent eb425a9 commit 10b7bdc

File tree

9 files changed

+97
-90
lines changed

9 files changed

+97
-90
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ ENDIF()
2424
IF (DEFINED USE_AVX2)
2525
ADD_DEFINITIONS(-DUSE_AVX2=${USE_AVX2})
2626
ENDIF()
27+
IF (DEFINED USE_AVX512)
28+
ADD_DEFINITIONS(-DUSE_AVX512=${USE_AVX512})
29+
ENDIF()
30+
IF (DEFINED USE_VNNI)
31+
ADD_DEFINITIONS(-DUSE_VNNI=${USE_VNNI})
32+
ENDIF()
2733

2834
IF (DEFINED _BTYPE)
2935
ADD_DEFINITIONS(-D_BTYPE=${_BTYPE})

src/evaluate.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,12 @@ std::pair<std::int32_t, bool> Transformer::transform(Position & pos, std::uint8_
231231
if (abs(pt) > PSQT_THRESHOLD * WEIGHTS_SCALE)
232232
return { pt, true };
233233

234-
#if defined(USE_AVX2)
234+
#if defined(USE_AVX512)
235+
std::uint32_t chunks = HalfDimensions / (SIMD_WIDTH * 2);
236+
static_assert(HalfDimensions % (SIMD_WIDTH * 2) == 0);
237+
const __m512i control = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
238+
const __m512i zero = _mm512_setzero_si512();
239+
#elif defined(USE_AVX2)
235240
std::uint32_t chunks = HalfDimensions / SIMD_WIDTH;
236241
constexpr int control = 0b11011000;
237242
const __m256i zero = _mm256_setzero_si256();
@@ -240,7 +245,15 @@ std::pair<std::int32_t, bool> Transformer::transform(Position & pos, std::uint8_
240245
for (auto side : sides) {
241246
std::uint32_t offset = HalfDimensions * side;
242247

243-
#if defined(USE_AVX2)
248+
#if defined(USE_AVX512)
249+
auto out = reinterpret_cast<__m512i*>(&outBuffer[offset]);
250+
for (std::uint32_t j = 0; j < chunks; ++j) {
251+
__m512i sum0 = _mm512_load_si512(&reinterpret_cast<const __m512i*>(acc[sides[side]])[j * 2 + 0]);
252+
__m512i sum1 = _mm512_load_si512(&reinterpret_cast<const __m512i*>(acc[sides[side]])[j * 2 + 1]);
253+
_mm512_store_si512(&out[j], _mm512_permutexvar_epi64(control,_mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), zero)));
254+
}
255+
256+
#elif defined(USE_AVX2)
244257
auto out = reinterpret_cast<__m256i*>(&outBuffer[offset]);
245258
for (std::uint32_t j = 0; j < chunks; ++j) {
246259
__m256i sum0 = _mm256_loadA_si256(&reinterpret_cast<const __m256i*>(acc[sides[side]][0])[j * 2 + 0]);
@@ -339,7 +352,13 @@ inline void Transformer::refresh(Position & pos) {
339352
for (std::size_t k = 0; k < PSQT_BUCKETS; ++k)
340353
accumulator.psqtAccumulation[c][k] += psqts[indexes[index] * PSQT_BUCKETS + k];
341354

342-
#if defined(USE_AVX2)
355+
#if defined(USE_AVX512)
356+
auto accumulation = reinterpret_cast<__m512i*>(&accumulator.accumulation[c][0][0]);
357+
auto column = reinterpret_cast<const __m512i*>(&weights[offset]);
358+
constexpr std::uint32_t chunks = HalfDimensions / SIMD_WIDTH;
359+
for (std::uint32_t j = 0; j < chunks; ++j)
360+
_mm512_storeA_si512(&accumulation[j], _mm512_add_epi16(_mm512_loadA_si512(&accumulation[j]), column[j]));
361+
#elif defined(USE_AVX2)
343362
auto accumulation = reinterpret_cast<__m256i*>(&accumulator.accumulation[c][0][0]);
344363
auto column = reinterpret_cast<const __m256i*>(&weights[offset]);
345364

@@ -362,7 +381,13 @@ inline std::int32_t * Layer<OutputDimensions, InputDimensions>::propagate(std::u
362381

363382
auto output = reinterpret_cast<std::int32_t*>(outBuffer);
364383

365-
#if defined(USE_AVX2)
384+
#if defined(USE_AVX512)
385+
std::uint32_t chunks = InputDimensions / (SIMD_WIDTH * 2);
386+
const auto input_vector = reinterpret_cast<const __m512i*>(features);
387+
#if !defined(USE_VNNI)
388+
const __m512i ones = _mm512_set1_epi16(1);
389+
#endif
390+
#elif defined(USE_AVX2)
366391
std::uint32_t chunks = InputDimensions / SIMD_WIDTH;
367392
const __m256i ones = _mm256_set1_epi16(1);
368393
const auto input_vector = reinterpret_cast<const __m256i*>(features);
@@ -371,7 +396,34 @@ inline std::int32_t * Layer<OutputDimensions, InputDimensions>::propagate(std::u
371396
for (std::uint32_t i = 0; i < OutputDimensions; ++i) {
372397
const std::uint32_t offset = i * InputDimensions;
373398

374-
#if defined(USE_AVX2)
399+
#if defined(USE_AVX512)
400+
__m512i sum = _mm512_setzero_si512();
401+
const auto row = reinterpret_cast<const __m512i*>(&weights[offset]);
402+
for (std::uint32_t j = 0; j < chunks; ++j) {
403+
#if defined(USE_VNNI)
404+
sum = _mm512_dpbusd_epi32(sum, _mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
405+
#else
406+
__m512i product = _mm512_maddubs_epi16(_mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
407+
product = _mm512_madd_epi16(product, ones);
408+
sum = _mm512_add_epi32(sum, product);
409+
#endif
410+
}
411+
412+
if (InputDimensions != chunks * SIMD_WIDTH * 2) {
413+
const auto iv256 = reinterpret_cast<const __m256i*>(&input_vector[chunks]);
414+
const auto row256 = reinterpret_cast<const __m256i*>(&row[chunks]);
415+
#if defined(USE_VNNI)
416+
__m256i product256 = _mm256_dpbusd_epi32(_mm512_castsi512_si256(sum), _mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
417+
sum = _mm512_inserti32x8(sum, product256, 0);
418+
#else
419+
__m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
420+
sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(product256));
421+
#endif
422+
}
423+
424+
output[i] = _mm512_reduce_add_epi32(sum) + biases[i];
425+
426+
#elif defined(USE_AVX2)
375427
__m256i sum = _mm256_setzero_si256();
376428
const auto row = reinterpret_cast<const __m256i*>(&weights[offset]);
377429
for (std::uint32_t j = 0; j < chunks; ++j) {

src/evaluate.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ const EVAL VAL_K = 20000;
4242
#define LAYERED_NETWORKS 8
4343
#define WEIGHTS_SCALE 16
4444
#define PSQT_THRESHOLD 1400
45-
46-
#if defined(USE_AVX2)
4745
#define SIMD_WIDTH 32
48-
#endif
4946

5047
class Transformer
5148
{

src/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#if !defined(UNIT_TEST)
2525
int main(int argc, const char* argv[])
2626
{
27-
static_assert(USE_AVX2 == 1, "AVX2 is currently the only supported build type");
27+
static_assert(USE_AVX2 == 1, "AVX2 is the minimum supported build type");
2828

2929
//
3030
// initialize igel

src/makefile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ NNFLAGS = -DEVALFILE=\"$(EVALFILE)\"
1515
LIBS = -std=c++17 -mpopcnt -pthread
1616
WARN = -Wall
1717
OPTIM = -O3 -march=native -flto
18-
DEFS = -DNDEBUG -DEVAL_NNUE=1 -D_BTYPE=0 -DSYZYGY_SUPPORT=TRUE
18+
DEFS = -DNDEBUG -D_BTYPE=0 -DSYZYGY_SUPPORT=TRUE
1919

2020
ifneq ($(findstring __AVX2__, $(GCCDEFINES)),)
2121
LIBS += -mavx2
2222
DEFS += -DUSE_AVX2=1
2323
endif
2424

25-
ifneq ($(findstring __AVX512__, $(GCCDEFINES)),)
26-
LIBS += -mavx512bw
27-
DEFS += -DUSE_AVX512=1
25+
ifneq ($(findstring __AVX512VNNI__, $(GCCDEFINES)),)
26+
LIBS += -mavx512vnni
27+
DEFS += -DUSE_AVX512=1 -DUSE_VNNI=1
2828
endif
2929

3030
CFLAGS = $(WARN) $(LIBS) $(OPTIM) $(NNFLAGS)

src/misc.cpp

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/misc.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

src/types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ typedef I64 NODES;
7777
#endif
7878
#endif
7979

80+
#if defined(USE_AVX512)
81+
#if defined(__GNUC__ ) && (__GNUC__ < 9) && defined(_WIN32)
82+
#define _mm512_loadA_si512 _mm512_loadu_si512
83+
#define _mm512_storeA_si512 _mm512_storeu_si512
84+
#else
85+
#define _mm512_loadA_si512 _mm512_load_si512
86+
#define _mm512_storeA_si512 _mm512_store_si512
87+
#endif
88+
#endif
89+
8090
enum
8191
{
8292
NOPIECE = 0,

src/uci.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,25 @@
3232
#include <sstream>
3333

3434
const std::string VERSION = "3.3.0";
35+
const std::string ARCHITECTURE = " 64 "
3536

37+
#if _BTYPE==0
38+
"POPCNT "
39+
#else
40+
"BMI2 "
41+
#endif
42+
43+
#if defined(USE_AVX512)
44+
"AVX512"
45+
#if defined(USE_VNNI)
46+
" VNNI512"
47+
#endif
48+
#elif defined(USE_AVX2)
49+
"AVX2"
50+
#endif
51+
;
52+
53+
/*
3654
#if defined(ENV64BIT)
3755
#if defined(_BTYPE)
3856
#if _BTYPE==0
@@ -46,6 +64,7 @@ const std::string VERSION = "3.3.0";
4664
#else
4765
const std::string ARCHITECTURE = " CUSTOM";
4866
#endif
67+
*/
4968

5069
#if defined(__linux__) && !defined(__ANDROID__)
5170
const int MIN_HASH_SIZE = 2;

0 commit comments

Comments
 (0)