@@ -138,63 +138,31 @@ static void abort_nnue(const char *reason) {
138138 fflush (stdout ); exit (EXIT_FAILURE );
139139}
140140
141+ INLINE vepi8 vepi16_relu_packu (vepi16 in0 , vepi16 in1 ) {
142+ vepi16 shiftA = vepi16_srai (in0 , SHIFT_L0 );
143+ vepi16 shiftB = vepi16_srai (in1 , SHIFT_L0 );
144+ return vepi16_packu (shiftA , shiftB );
145+ }
141146
142- INLINE void maddubs_x4 (vepi32 * acc , const vepi8 * inp , const vepi8 * wgt , int i , int j , int k ) {
147+ INLINE void relu_maddubs_x4 (vepi32 * acc , const vepi16 * inp , const vepi8 * wgt , int i , int j , int k ) {
143148
144149 static const int InChunks = L1SIZE / vepi8_cnt ;
145150
146- vepi16 sum0 = vepi16_maubs (inp [j + 0 ] , wgt [InChunks * (i * 8 + k ) + j + 0 ]);
147- vepi16 sum1 = vepi16_maubs (inp [j + 1 ] , wgt [InChunks * (i * 8 + k ) + j + 1 ]);
148- vepi16 sum2 = vepi16_maubs (inp [j + 2 ] , wgt [InChunks * (i * 8 + k ) + j + 2 ]);
149- vepi16 sum3 = vepi16_maubs (inp [j + 3 ] , wgt [InChunks * (i * 8 + k ) + j + 3 ]);
151+ vepi16 sum0 = vepi16_maubs (vepi16_relu_packu ( inp [0 ], inp [ 1 ]) , wgt [InChunks * (i * 8 + k ) + j + 0 ]);
152+ vepi16 sum1 = vepi16_maubs (vepi16_relu_packu ( inp [2 ], inp [ 3 ]) , wgt [InChunks * (i * 8 + k ) + j + 1 ]);
153+ vepi16 sum2 = vepi16_maubs (vepi16_relu_packu ( inp [4 ], inp [ 5 ]) , wgt [InChunks * (i * 8 + k ) + j + 2 ]);
154+ vepi16 sum3 = vepi16_maubs (vepi16_relu_packu ( inp [6 ], inp [ 7 ]) , wgt [InChunks * (i * 8 + k ) + j + 3 ]);
150155
151156 vepi16 sumX = vepi16_add (sum0 , vepi16_add (sum1 , vepi16_add (sum2 , sum3 )));
152157 * acc = vepi32_add (* acc , vepi16_madd (vepi16_one , sumX ));
153158}
154159
155-
156- INLINE void halfkp_relu (NNUEAccumulator * accum , uint8_t * outputs , int turn ) {
157-
158- // The accumulation of king-piece values has already been computed.
159- // Perform the ReLU operation on each accumuatlor, and place them
160- // such that the side-to-move is first, then the non-side-to-move
161-
162- assert (KPSIZE % 64 == 0 );
163-
164- vepi16 * in_white = (vepi16 * ) & accum -> values [WHITE ];
165- vepi16 * in_black = (vepi16 * ) & accum -> values [BLACK ];
166-
167- vepi8 * out_white = (vepi8 * ) (turn == WHITE ? outputs : & outputs [KPSIZE ]);
168- vepi8 * out_black = (vepi8 * ) (turn == BLACK ? outputs : & outputs [KPSIZE ]);
169-
170- for (int i = 0 ; i < KPSIZE / vepi8_cnt ; i += 2 ) {
171-
172- vepi16 shift0A = vepi16_srai (in_white [(i + 0 ) * 2 + 0 ], SHIFT_L0 );
173- vepi16 shift0B = vepi16_srai (in_white [(i + 0 ) * 2 + 1 ], SHIFT_L0 );
174- vepi16 shift1A = vepi16_srai (in_white [(i + 1 ) * 2 + 0 ], SHIFT_L0 );
175- vepi16 shift1B = vepi16_srai (in_white [(i + 1 ) * 2 + 1 ], SHIFT_L0 );
176-
177- out_white [i + 0 ] = vepi16_packu (shift0A , shift0B );
178- out_white [i + 1 ] = vepi16_packu (shift1A , shift1B );
179- }
180-
181- for (int i = 0 ; i < KPSIZE / vepi8_cnt ; i += 2 ) {
182-
183- vepi16 shift0A = vepi16_srai (in_black [(i + 0 ) * 2 + 0 ], SHIFT_L0 );
184- vepi16 shift0B = vepi16_srai (in_black [(i + 0 ) * 2 + 1 ], SHIFT_L0 );
185- vepi16 shift1A = vepi16_srai (in_black [(i + 1 ) * 2 + 0 ], SHIFT_L0 );
186- vepi16 shift1B = vepi16_srai (in_black [(i + 1 ) * 2 + 1 ], SHIFT_L0 );
187-
188- out_black [i + 0 ] = vepi16_packu (shift0A , shift0B );
189- out_black [i + 1 ] = vepi16_packu (shift1A , shift1B );
190- }
191- }
192-
193- INLINE void quant_affine_relu (int8_t * weights , int32_t * biases , uint8_t * inputs , float * outputs ) {
160+ INLINE void halfkp_relu_quant_affine_relu (int8_t * weights , int32_t * biases , int16_t * us_accum , int16_t * opp_accum , float * outputs ) {
194161
195162 assert (L1SIZE % 64 == 0 && L2SIZE % 8 == 0 );
163+ assert (L1SIZE == KPSIZE * 2 );
196164
197- const int InChunks = L1SIZE / vepi8_cnt ;
165+ const int InChunks = KPSIZE / vepi8_cnt ;
198166 const int OutChunks = L2SIZE / 8 ;
199167
200168 #if defined(USE_AVX2 ) || defined(USE_AVX )
@@ -203,7 +171,8 @@ INLINE void quant_affine_relu(int8_t *weights, int32_t *biases, uint8_t *inputs,
203171 const vps32 zero = vps32_zero ();
204172 #endif
205173
206- const vepi8 * inp = (vepi8 * ) inputs ;
174+ const vepi8 * us = (vepi8 * ) us_accum ;
175+ const vepi8 * opp = (vepi8 * ) opp_accum ;
207176 const vepi8 * wgt = (vepi8 * ) weights ;
208177 const vepi32 * bia = (vepi32 * ) biases ;
209178 vps32 * const out = (vps32 * ) outputs ;
@@ -220,14 +189,23 @@ INLINE void quant_affine_relu(int8_t *weights, int32_t *biases, uint8_t *inputs,
220189 vepi32 acc7 = vepi32_zero ();
221190
222191 for (int j = 0 ; j < InChunks ; j += 4 ) {
223- maddubs_x4 (& acc0 , inp , wgt , i , j , 0 );
224- maddubs_x4 (& acc1 , inp , wgt , i , j , 1 );
225- maddubs_x4 (& acc2 , inp , wgt , i , j , 2 );
226- maddubs_x4 (& acc3 , inp , wgt , i , j , 3 );
227- maddubs_x4 (& acc4 , inp , wgt , i , j , 4 );
228- maddubs_x4 (& acc5 , inp , wgt , i , j , 5 );
229- maddubs_x4 (& acc6 , inp , wgt , i , j , 6 );
230- maddubs_x4 (& acc7 , inp , wgt , i , j , 7 );
192+ relu_maddubs_x4 (& acc0 , & us [j * 2 ], wgt , i , j , 0 );
193+ relu_maddubs_x4 (& acc1 , & us [j * 2 ], wgt , i , j , 1 );
194+ relu_maddubs_x4 (& acc2 , & us [j * 2 ], wgt , i , j , 2 );
195+ relu_maddubs_x4 (& acc3 , & us [j * 2 ], wgt , i , j , 3 );
196+ relu_maddubs_x4 (& acc4 , & us [j * 2 ], wgt , i , j , 4 );
197+ relu_maddubs_x4 (& acc5 , & us [j * 2 ], wgt , i , j , 5 );
198+ relu_maddubs_x4 (& acc6 , & us [j * 2 ], wgt , i , j , 6 );
199+ relu_maddubs_x4 (& acc7 , & us [j * 2 ], wgt , i , j , 7 );
200+
201+ relu_maddubs_x4 (& acc0 , & opp [j * 2 ], wgt + InChunks , i , j , 0 );
202+ relu_maddubs_x4 (& acc1 , & opp [j * 2 ], wgt + InChunks , i , j , 1 );
203+ relu_maddubs_x4 (& acc2 , & opp [j * 2 ], wgt + InChunks , i , j , 2 );
204+ relu_maddubs_x4 (& acc3 , & opp [j * 2 ], wgt + InChunks , i , j , 3 );
205+ relu_maddubs_x4 (& acc4 , & opp [j * 2 ], wgt + InChunks , i , j , 4 );
206+ relu_maddubs_x4 (& acc5 , & opp [j * 2 ], wgt + InChunks , i , j , 5 );
207+ relu_maddubs_x4 (& acc6 , & opp [j * 2 ], wgt + InChunks , i , j , 6 );
208+ relu_maddubs_x4 (& acc7 , & opp [j * 2 ], wgt + InChunks , i , j , 7 );
231209 }
232210
233211 acc0 = vepi32_hadd (acc0 , acc1 );
@@ -478,7 +456,6 @@ int nnue_evaluate(Thread *thread, Board *board) {
478456
479457 NNUEAccumulator * accum = thread -> nnue -> current ;
480458
481- ALIGN64 uint8_t out8 [L1SIZE ];
482459 ALIGN64 float outN1 [L1SIZE ];
483460 ALIGN64 float outN2 [L1SIZE ];
484461
@@ -505,8 +482,7 @@ int nnue_evaluate(Thread *thread, Board *board) {
505482 }
506483
507484 // Feed-forward the entire evaluation function
508- halfkp_relu (accum , out8 , board -> turn );
509- quant_affine_relu (l1_weights , l1_biases , out8 , outN1 );
485+ halfkp_relu_quant_affine_relu (l1_weights , l1_biases , accum -> values [board -> turn ], accum -> values [!board -> turn ], outN1 );
510486 float_affine_relu (l2_weights , l2_biases , outN1 , outN2 );
511487 output_transform (l3_weights , l3_biases , outN2 , outN1 );
512488
0 commit comments