Skip to content

Commit fd98546

Browse files
committed
Add NEON support (#621)
Results of Reckless-neon vs Reckless-main (4+0.04, 1t, 16MB, UHO_Lichess_4852_v1.epd): Elo: 18.01 +/- 7.37, nElo: 35.87 +/- 14.65 LOS: 100.00 %, DrawRatio: 54.12 %, PairsRatio: 1.53 Games: 2162, Wins: 623, Losses: 511, Draws: 1028, Points: 1137.0 (52.59 %) Ptnml(0-2): [6, 190, 585, 286, 14], WL/DD Ratio: 1.12 LLR: 2.96 (100.4%) (-2.94, 2.94) [0.00, 5.00] Bench: 3016642
1 parent 0dd5b9a commit fd98546

File tree

3 files changed

+211
-49
lines changed

3 files changed

+211
-49
lines changed

src/nnue.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use crate::{
1313
use accumulator::{AccumulatorCache, PstAccumulator};
1414

1515
mod forward {
16-
#[cfg(target_feature = "avx2")]
16+
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
1717
mod vectorized;
18-
#[cfg(target_feature = "avx2")]
18+
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
1919
pub use vectorized::*;
2020

21-
#[cfg(not(target_feature = "avx2"))]
21+
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
2222
mod scalar;
23-
#[cfg(not(target_feature = "avx2"))]
23+
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
2424
pub use scalar::*;
2525
}
2626

@@ -35,9 +35,14 @@ mod simd {
3535
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
3636
pub use avx2::*;
3737

38-
#[cfg(all(not(target_feature = "avx2"), not(target_feature = "avx512f")))]
38+
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
39+
mod neon;
40+
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
41+
pub use neon::*;
42+
43+
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
3944
mod scalar;
40-
#[cfg(all(not(target_feature = "avx2"), not(target_feature = "avx512f")))]
45+
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
4146
pub use scalar::*;
4247
}
4348

src/nnue/forward/vectorized.rs

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::arch::x86_64::*;
2-
31
use crate::{
42
nnue::{
53
accumulator::{PstAccumulator, ThreatAccumulator},
@@ -11,7 +9,7 @@ use crate::{
119
pub unsafe fn activate_ft(pst: &PstAccumulator, threat: &ThreatAccumulator, stm: Color) -> Aligned<[u8; L1_SIZE]> {
1210
let mut output = Aligned::new([0; L1_SIZE]);
1311

14-
let zero = simd::zeroed();
12+
let zero = simd::splat_i16(0);
1513
let one = simd::splat_i16(FT_QUANT as i16);
1614

1715
for flip in [0, 1] {
@@ -53,33 +51,6 @@ pub unsafe fn activate_ft(pst: &PstAccumulator, threat: &ThreatAccumulator, stm:
5351
output
5452
}
5553

56-
pub unsafe fn find_nnz(
57-
ft_out: &Aligned<[u8; L1_SIZE]>, nnz_table: &[SparseEntry],
58-
) -> (Aligned<[u16; L1_SIZE / 4]>, usize) {
59-
let mut indexes = Aligned::new([0; L1_SIZE / 4]);
60-
let mut count = 0;
61-
62-
let increment = _mm_set1_epi16(8);
63-
let mut base = _mm_setzero_si128();
64-
65-
for i in (0..L1_SIZE).step_by(2 * simd::I16_LANES) {
66-
let mask = simd::nnz_bitmask(*ft_out.as_ptr().add(i).cast());
67-
68-
for offset in (0..simd::I32_LANES).step_by(8) {
69-
let slice = (mask >> offset) & 0xFF;
70-
let entry = nnz_table.get_unchecked(slice as usize);
71-
72-
let store = indexes.as_mut_ptr().add(count).cast();
73-
_mm_storeu_si128(store, _mm_add_epi16(base, *entry.indexes.as_ptr().cast()));
74-
75-
count += entry.count;
76-
base = _mm_add_epi16(base, increment);
77-
}
78-
}
79-
80-
(indexes, count)
81-
}
82-
8354
pub unsafe fn propagate_l1(ft_out: Aligned<[u8; L1_SIZE]>, nnz: &[u16]) -> Aligned<[f32; L2_SIZE]> {
8455
const CHUNKS: usize = 4;
8556

@@ -100,23 +71,23 @@ pub unsafe fn propagate_l1(ft_out: Aligned<[u8; L1_SIZE]>, nnz: &[u16]) -> Align
10071
let weights2 = PARAMETERS.l1_weights.as_ptr().add(index2 * L2_SIZE * CHUNKS);
10172

10273
for j in (0..L2_SIZE).step_by(simd::F32_LANES) {
103-
let weights1 = weights1.add(j * CHUNKS).cast();
104-
let weights2 = weights2.add(j * CHUNKS).cast();
74+
let weights1 = *weights1.add(j * CHUNKS).cast();
75+
let weights2 = *weights2.add(j * CHUNKS).cast();
10576

10677
let vector = &mut pre_activations[j / simd::F32_LANES];
107-
*vector = simd::double_dpbusd(*vector, input1, *weights1, input2, *weights2);
78+
*vector = simd::double_dpbusd(*vector, input1, weights1, input2, weights2);
10879
}
10980
}
11081

11182
if let Some(last) = pairs.remainder().first() {
11283
let index = *last as usize;
113-
let pst_input = simd::splat_i32(*packed.get_unchecked(index));
84+
let input = simd::splat_i32(*packed.get_unchecked(index));
11485
let weights = PARAMETERS.l1_weights.as_ptr().add(index * L2_SIZE * CHUNKS);
11586

11687
for j in (0..L2_SIZE).step_by(simd::F32_LANES) {
117-
let weights = weights.add(j * CHUNKS).cast();
88+
let weights = *weights.add(j * CHUNKS).cast();
11889
let vector = &mut pre_activations[j / simd::F32_LANES];
119-
*vector = simd::dpbusd(*vector, pst_input, *weights);
90+
*vector = simd::dpbusd(*vector, input, weights);
12091
}
12192
}
12293

@@ -139,13 +110,13 @@ pub unsafe fn propagate_l2(l1_out: Aligned<[f32; L2_SIZE]>) -> Aligned<[f32; L3_
139110
let mut output = PARAMETERS.l2_biases.clone();
140111

141112
for i in 0..L2_SIZE {
142-
let pst_input = simd::splat_f32(l1_out[i]);
113+
let input = simd::splat_f32(l1_out[i]);
143114
let weights = PARAMETERS.l2_weights[i].as_ptr();
144115

145116
for j in (0..L3_SIZE).step_by(simd::F32_LANES) {
146-
let weights = weights.add(j).cast();
117+
let weights = *weights.add(j).cast();
147118
let vector = output.as_mut_ptr().add(j).cast();
148-
*vector = simd::mul_add_f32(*weights, pst_input, *vector);
119+
*vector = simd::mul_add_f32(weights, input, *vector);
149120
}
150121
}
151122

@@ -163,19 +134,80 @@ pub unsafe fn propagate_l2(l1_out: Aligned<[f32; L2_SIZE]>) -> Aligned<[f32; L3_
163134
pub unsafe fn propagate_l3(l2_out: Aligned<[f32; L3_SIZE]>) -> f32 {
164135
const LANES: usize = 16 / simd::F32_LANES;
165136

166-
let pst_input = l2_out.as_ptr();
137+
let input = l2_out.as_ptr();
167138
let weights = PARAMETERS.l3_weights.as_ptr();
168139

169140
let mut output = [simd::zero_f32(); LANES];
170141

171142
for (lane, result) in output.iter_mut().enumerate() {
172143
for i in (0..L3_SIZE).step_by(LANES * simd::F32_LANES) {
173-
let a = weights.add(i + lane * simd::F32_LANES).cast();
174-
let b = pst_input.add(i + lane * simd::F32_LANES).cast();
144+
let a = *weights.add(i + lane * simd::F32_LANES).cast();
145+
let b = *input.add(i + lane * simd::F32_LANES).cast();
175146

176-
*result = simd::mul_add_f32(*a, *b, *result);
147+
*result = simd::mul_add_f32(a, b, *result);
177148
}
178149
}
179150

180151
simd::horizontal_sum(output) + PARAMETERS.l3_biases
181152
}
153+
154+
#[cfg(not(target_feature = "neon"))]
155+
pub unsafe fn find_nnz(
156+
ft_out: &Aligned<[u8; L1_SIZE]>, nnz_table: &[SparseEntry],
157+
) -> (Aligned<[u16; L1_SIZE / 4]>, usize) {
158+
use std::arch::x86_64::*;
159+
160+
let mut indexes = Aligned::new([0; L1_SIZE / 4]);
161+
let mut count = 0;
162+
163+
let increment = _mm_set1_epi16(8);
164+
let mut base = _mm_setzero_si128();
165+
166+
for i in (0..L1_SIZE).step_by(2 * simd::I16_LANES) {
167+
let mask = simd::nnz_bitmask(*ft_out.as_ptr().add(i).cast());
168+
169+
for offset in (0..simd::I32_LANES).step_by(8) {
170+
let slice = (mask >> offset) & 0xFF;
171+
let entry = nnz_table.get_unchecked(slice as usize);
172+
173+
let store = indexes.as_mut_ptr().add(count).cast();
174+
_mm_storeu_si128(store, _mm_add_epi16(base, *entry.indexes.as_ptr().cast()));
175+
176+
count += entry.count;
177+
base = _mm_add_epi16(base, increment);
178+
}
179+
}
180+
181+
(indexes, count)
182+
}
183+
184+
#[cfg(target_feature = "neon")]
185+
pub unsafe fn find_nnz(
186+
ft_out: &Aligned<[u8; L1_SIZE]>, nnz_table: &[SparseEntry],
187+
) -> (Aligned<[u16; L1_SIZE / 4]>, usize) {
188+
use std::arch::aarch64::*;
189+
190+
let mut indexes = Aligned::new([0; L1_SIZE / 4]);
191+
let mut count = 0;
192+
193+
let increment = vdupq_n_s16(8);
194+
let mut base = vdupq_n_s16(0);
195+
196+
for i in (0..L1_SIZE).step_by(32) {
197+
let v0 = *ft_out.as_ptr().add(i).cast();
198+
let v1 = *ft_out.as_ptr().add(i + 16).cast();
199+
200+
let mask = (simd::nnz_bitmask(v0) | (simd::nnz_bitmask(v1) << 4)) as usize;
201+
let entry = nnz_table.get_unchecked(mask);
202+
203+
let store = indexes.as_mut_ptr().add(count).cast();
204+
let indexed = vaddq_s16(base, vld1q_s16(entry.indexes.as_ptr().cast()));
205+
206+
vst1q_s16(store, indexed);
207+
208+
count += entry.count;
209+
base = vaddq_s16(base, increment);
210+
}
211+
212+
(indexes, count)
213+
}

src/nnue/simd/neon.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::{arch::aarch64::*, mem::size_of};
2+
3+
pub const F32_LANES: usize = size_of::<float32x4_t>() / size_of::<f32>();
4+
pub const I16_LANES: usize = size_of::<int16x8_t>() / size_of::<i16>();
5+
6+
pub fn add_i16(a: int16x8_t, b: int16x8_t) -> int16x8_t {
7+
unsafe { vaddq_s16(a, b) }
8+
}
9+
10+
pub fn sub_i16(a: int16x8_t, b: int16x8_t) -> int16x8_t {
11+
unsafe { vsubq_s16(a, b) }
12+
}
13+
14+
pub unsafe fn zeroed() -> int32x4_t {
15+
vdupq_n_s32(0)
16+
}
17+
18+
pub unsafe fn splat_i16(a: i16) -> int16x8_t {
19+
vdupq_n_s16(a)
20+
}
21+
22+
pub unsafe fn clamp_i16(x: int16x8_t, min: int16x8_t, max: int16x8_t) -> int16x8_t {
23+
vmaxq_s16(vminq_s16(x, max), min)
24+
}
25+
26+
pub unsafe fn min_i16(a: int16x8_t, b: int16x8_t) -> int16x8_t {
27+
vminq_s16(a, b)
28+
}
29+
30+
pub unsafe fn shift_left_i16<const SHIFT: i32>(a: int16x8_t) -> int16x8_t {
31+
vshlq_n_s16::<SHIFT>(a)
32+
}
33+
34+
pub unsafe fn mul_high_i16(a: int16x8_t, b: int16x8_t) -> int16x8_t {
35+
let low = vmull_s16(vget_low_s16(a), vget_low_s16(b));
36+
let high = vmull_s16(vget_high_s16(a), vget_high_s16(b));
37+
38+
let low_hi = vshrn_n_s32::<16>(low);
39+
let high_hi = vshrn_n_s32::<16>(high);
40+
41+
vcombine_s16(low_hi, high_hi)
42+
}
43+
44+
pub unsafe fn convert_i8_i16(a: int8x8_t) -> int16x8_t {
45+
vmovl_s8(a)
46+
}
47+
48+
pub unsafe fn packus(a: int16x8_t, b: int16x8_t) -> int8x16_t {
49+
let a_u8 = vqmovun_s16(a);
50+
let b_u8 = vqmovun_s16(b);
51+
vreinterpretq_s8_u8(vcombine_u8(a_u8, b_u8))
52+
}
53+
54+
pub unsafe fn permute(a: int8x16_t) -> int8x16_t {
55+
a
56+
}
57+
58+
pub unsafe fn splat_i32(a: i32) -> int32x4_t {
59+
vdupq_n_s32(a)
60+
}
61+
62+
pub unsafe fn zero_f32() -> float32x4_t {
63+
vdupq_n_f32(0.0)
64+
}
65+
66+
pub unsafe fn splat_f32(a: f32) -> float32x4_t {
67+
vdupq_n_f32(a)
68+
}
69+
70+
pub unsafe fn mul_add_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
71+
vfmaq_f32(c, a, b)
72+
}
73+
74+
pub unsafe fn convert_to_f32(a: int32x4_t) -> float32x4_t {
75+
vcvtq_f32_s32(a)
76+
}
77+
78+
pub unsafe fn clamp_f32(x: float32x4_t, min: float32x4_t, max: float32x4_t) -> float32x4_t {
79+
vmaxq_f32(vminq_f32(x, max), min)
80+
}
81+
82+
unsafe fn dot_bytes(u8s: int32x4_t, i8s: int8x16_t) -> int32x4_t {
83+
let u8s = vreinterpretq_u8_s32(u8s);
84+
85+
let products_low = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(u8s))), vmovl_s8(vget_low_s8(i8s)));
86+
let products_high = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(u8s))), vmovl_s8(vget_high_s8(i8s)));
87+
88+
let sums_low = vpaddlq_s16(products_low);
89+
let sums_high = vpaddlq_s16(products_high);
90+
91+
vpaddq_s32(sums_low, sums_high)
92+
}
93+
94+
pub unsafe fn dpbusd(i32s: int32x4_t, u8s: int32x4_t, i8s: int8x16_t) -> int32x4_t {
95+
vaddq_s32(i32s, dot_bytes(u8s, i8s))
96+
}
97+
98+
pub unsafe fn double_dpbusd(
99+
i32s: int32x4_t, u8s1: int32x4_t, i8s1: int8x16_t, u8s2: int32x4_t, i8s2: int8x16_t,
100+
) -> int32x4_t {
101+
let accum = vaddq_s32(dot_bytes(u8s1, i8s1), dot_bytes(u8s2, i8s2));
102+
vaddq_s32(i32s, accum)
103+
}
104+
105+
pub unsafe fn horizontal_sum(x: [float32x4_t; 4]) -> f32 {
106+
let sum01 = vaddq_f32(x[0], x[1]);
107+
let sum23 = vaddq_f32(x[2], x[3]);
108+
let sum = vaddq_f32(sum01, sum23);
109+
110+
let pair = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
111+
let final_sum = vpadd_f32(pair, pair);
112+
113+
vget_lane_f32::<0>(final_sum)
114+
}
115+
116+
pub unsafe fn nnz_bitmask(x: int32x4_t) -> u16 {
117+
let cmp = vcgtq_s32(x, vdupq_n_s32(0));
118+
119+
let mask0 = (vgetq_lane_u32::<0>(cmp) >> 31) & 1;
120+
let mask1 = ((vgetq_lane_u32::<1>(cmp) >> 31) & 1) << 1;
121+
let mask2 = ((vgetq_lane_u32::<2>(cmp) >> 31) & 1) << 2;
122+
let mask3 = ((vgetq_lane_u32::<3>(cmp) >> 31) & 1) << 3;
123+
124+
(mask0 | mask1 | mask2 | mask3) as u16
125+
}

0 commit comments

Comments
 (0)