Skip to content

Commit b1b13a5

Browse files
authored
Merge HalfKP relu with the L1 affine transform [+1.626%] (#212)
Remove intermediate stores between activating the accumulator and L1 affine transform. Strictly a speedup. Elo | 7.41 +- 4.86 (95%) SPRT | 10.0+0.10s Threads=1 Hash=8MB LLR | 3.08 (-2.94, 2.94) [0.00, 3.00] Games | N: 9242 W: 2280 L: 2083 D: 4879 Penta | [30, 907, 2569, 1066, 49] http://chess.grantnet.us/test/36685/ Bench 2492187
1 parent a05cb91 commit b1b13a5

File tree

2 files changed

+34
-58
lines changed

2 files changed

+34
-58
lines changed

src/nnue/nnue.c

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -138,63 +138,31 @@ static void abort_nnue(const char *reason) {
138138
fflush(stdout); exit(EXIT_FAILURE);
139139
}
140140

141+
INLINE vepi8 vepi16_relu_packu(vepi16 in0, vepi16 in1) {
142+
vepi16 shiftA = vepi16_srai(in0, SHIFT_L0);
143+
vepi16 shiftB = vepi16_srai(in1, SHIFT_L0);
144+
return vepi16_packu(shiftA, shiftB);
145+
}
141146

142-
INLINE void maddubs_x4(vepi32 *acc, const vepi8 *inp, const vepi8 *wgt, int i, int j, int k) {
147+
INLINE void relu_maddubs_x4(vepi32 *acc, const vepi16 *inp, const vepi8 *wgt, int i, int j, int k) {
143148

144149
static const int InChunks = L1SIZE / vepi8_cnt;
145150

146-
vepi16 sum0 = vepi16_maubs(inp[j+0], wgt[InChunks * (i * 8 + k) + j + 0]);
147-
vepi16 sum1 = vepi16_maubs(inp[j+1], wgt[InChunks * (i * 8 + k) + j + 1]);
148-
vepi16 sum2 = vepi16_maubs(inp[j+2], wgt[InChunks * (i * 8 + k) + j + 2]);
149-
vepi16 sum3 = vepi16_maubs(inp[j+3], wgt[InChunks * (i * 8 + k) + j + 3]);
151+
vepi16 sum0 = vepi16_maubs(vepi16_relu_packu(inp[0], inp[1]), wgt[InChunks * (i * 8 + k) + j + 0]);
152+
vepi16 sum1 = vepi16_maubs(vepi16_relu_packu(inp[2], inp[3]), wgt[InChunks * (i * 8 + k) + j + 1]);
153+
vepi16 sum2 = vepi16_maubs(vepi16_relu_packu(inp[4], inp[5]), wgt[InChunks * (i * 8 + k) + j + 2]);
154+
vepi16 sum3 = vepi16_maubs(vepi16_relu_packu(inp[6], inp[7]), wgt[InChunks * (i * 8 + k) + j + 3]);
150155

151156
vepi16 sumX = vepi16_add(sum0, vepi16_add(sum1, vepi16_add(sum2, sum3)));
152157
*acc = vepi32_add(*acc, vepi16_madd(vepi16_one, sumX));
153158
}
154159

155-
156-
INLINE void halfkp_relu(NNUEAccumulator *accum, uint8_t *outputs, int turn) {
157-
158-
// The accumulation of king-piece values has already been computed.
159-
// Perform the ReLU operation on each accumuatlor, and place them
160-
// such that the side-to-move is first, then the non-side-to-move
161-
162-
assert(KPSIZE % 64 == 0);
163-
164-
vepi16 *in_white = (vepi16 *) &accum->values[WHITE];
165-
vepi16 *in_black = (vepi16 *) &accum->values[BLACK];
166-
167-
vepi8 *out_white = (vepi8 *) (turn == WHITE ? outputs : &outputs[KPSIZE]);
168-
vepi8 *out_black = (vepi8 *) (turn == BLACK ? outputs : &outputs[KPSIZE]);
169-
170-
for (int i = 0; i < KPSIZE / vepi8_cnt; i += 2) {
171-
172-
vepi16 shift0A = vepi16_srai(in_white[(i + 0) * 2 + 0], SHIFT_L0);
173-
vepi16 shift0B = vepi16_srai(in_white[(i + 0) * 2 + 1], SHIFT_L0);
174-
vepi16 shift1A = vepi16_srai(in_white[(i + 1) * 2 + 0], SHIFT_L0);
175-
vepi16 shift1B = vepi16_srai(in_white[(i + 1) * 2 + 1], SHIFT_L0);
176-
177-
out_white[i+0] = vepi16_packu(shift0A, shift0B);
178-
out_white[i+1] = vepi16_packu(shift1A, shift1B);
179-
}
180-
181-
for (int i = 0; i < KPSIZE / vepi8_cnt; i += 2) {
182-
183-
vepi16 shift0A = vepi16_srai(in_black[(i + 0) * 2 + 0], SHIFT_L0);
184-
vepi16 shift0B = vepi16_srai(in_black[(i + 0) * 2 + 1], SHIFT_L0);
185-
vepi16 shift1A = vepi16_srai(in_black[(i + 1) * 2 + 0], SHIFT_L0);
186-
vepi16 shift1B = vepi16_srai(in_black[(i + 1) * 2 + 1], SHIFT_L0);
187-
188-
out_black[i+0] = vepi16_packu(shift0A, shift0B);
189-
out_black[i+1] = vepi16_packu(shift1A, shift1B);
190-
}
191-
}
192-
193-
INLINE void quant_affine_relu(int8_t *weights, int32_t *biases, uint8_t *inputs, float *outputs) {
160+
INLINE void halfkp_relu_quant_affine_relu(int8_t *weights, int32_t *biases, int16_t *us_accum, int16_t *opp_accum, float *outputs) {
194161

195162
assert(L1SIZE % 64 == 0 && L2SIZE % 8 == 0);
163+
assert(L1SIZE == KPSIZE * 2);
196164

197-
const int InChunks = L1SIZE / vepi8_cnt;
165+
const int InChunks = KPSIZE / vepi8_cnt;
198166
const int OutChunks = L2SIZE / 8;
199167

200168
#if defined(USE_AVX2) || defined(USE_AVX)
@@ -203,7 +171,8 @@ INLINE void quant_affine_relu(int8_t *weights, int32_t *biases, uint8_t *inputs,
203171
const vps32 zero = vps32_zero();
204172
#endif
205173

206-
const vepi8 *inp = (vepi8 *) inputs;
174+
const vepi8 *us = (vepi8 *) us_accum;
175+
const vepi8 *opp = (vepi8 *) opp_accum;
207176
const vepi8 *wgt = (vepi8 *) weights;
208177
const vepi32 *bia = (vepi32 *) biases;
209178
vps32 *const out = (vps32 *) outputs;
@@ -220,14 +189,23 @@ INLINE void quant_affine_relu(int8_t *weights, int32_t *biases, uint8_t *inputs,
220189
vepi32 acc7 = vepi32_zero();
221190

222191
for (int j = 0; j < InChunks; j += 4) {
223-
maddubs_x4(&acc0, inp, wgt, i, j, 0);
224-
maddubs_x4(&acc1, inp, wgt, i, j, 1);
225-
maddubs_x4(&acc2, inp, wgt, i, j, 2);
226-
maddubs_x4(&acc3, inp, wgt, i, j, 3);
227-
maddubs_x4(&acc4, inp, wgt, i, j, 4);
228-
maddubs_x4(&acc5, inp, wgt, i, j, 5);
229-
maddubs_x4(&acc6, inp, wgt, i, j, 6);
230-
maddubs_x4(&acc7, inp, wgt, i, j, 7);
192+
relu_maddubs_x4(&acc0, &us [j * 2], wgt, i, j, 0);
193+
relu_maddubs_x4(&acc1, &us [j * 2], wgt, i, j, 1);
194+
relu_maddubs_x4(&acc2, &us [j * 2], wgt, i, j, 2);
195+
relu_maddubs_x4(&acc3, &us [j * 2], wgt, i, j, 3);
196+
relu_maddubs_x4(&acc4, &us [j * 2], wgt, i, j, 4);
197+
relu_maddubs_x4(&acc5, &us [j * 2], wgt, i, j, 5);
198+
relu_maddubs_x4(&acc6, &us [j * 2], wgt, i, j, 6);
199+
relu_maddubs_x4(&acc7, &us [j * 2], wgt, i, j, 7);
200+
201+
relu_maddubs_x4(&acc0, &opp[j * 2], wgt + InChunks, i, j, 0);
202+
relu_maddubs_x4(&acc1, &opp[j * 2], wgt + InChunks, i, j, 1);
203+
relu_maddubs_x4(&acc2, &opp[j * 2], wgt + InChunks, i, j, 2);
204+
relu_maddubs_x4(&acc3, &opp[j * 2], wgt + InChunks, i, j, 3);
205+
relu_maddubs_x4(&acc4, &opp[j * 2], wgt + InChunks, i, j, 4);
206+
relu_maddubs_x4(&acc5, &opp[j * 2], wgt + InChunks, i, j, 5);
207+
relu_maddubs_x4(&acc6, &opp[j * 2], wgt + InChunks, i, j, 6);
208+
relu_maddubs_x4(&acc7, &opp[j * 2], wgt + InChunks, i, j, 7);
231209
}
232210

233211
acc0 = vepi32_hadd(acc0, acc1);
@@ -478,7 +456,6 @@ int nnue_evaluate(Thread *thread, Board *board) {
478456

479457
NNUEAccumulator *accum = thread->nnue->current;
480458

481-
ALIGN64 uint8_t out8[L1SIZE];
482459
ALIGN64 float outN1[L1SIZE];
483460
ALIGN64 float outN2[L1SIZE];
484461

@@ -505,8 +482,7 @@ int nnue_evaluate(Thread *thread, Board *board) {
505482
}
506483

507484
// Feed-forward the entire evaluation function
508-
halfkp_relu(accum, out8, board->turn);
509-
quant_affine_relu(l1_weights, l1_biases, out8, outN1);
485+
halfkp_relu_quant_affine_relu(l1_weights, l1_biases, accum->values[board->turn], accum->values[!board->turn], outN1);
510486
float_affine_relu(l2_weights, l2_biases, outN1, outN2);
511487
output_transform (l3_weights, l3_biases, outN2, outN1);
512488

src/uci.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
#include "types.h"
2525

26-
#define VERSION_ID "14.38"
26+
#define VERSION_ID "14.39"
2727

2828
#ifndef LICENSE_OWNER
2929
#define LICENSE_OWNER "Unlicensed"

0 commit comments

Comments
 (0)