1- use std:: arch:: x86_64:: * ;
2-
31use crate :: {
42 nnue:: {
53 accumulator:: { PstAccumulator , ThreatAccumulator } ,
@@ -11,7 +9,7 @@ use crate::{
119pub unsafe fn activate_ft ( pst : & PstAccumulator , threat : & ThreatAccumulator , stm : Color ) -> Aligned < [ u8 ; L1_SIZE ] > {
1210 let mut output = Aligned :: new ( [ 0 ; L1_SIZE ] ) ;
1311
14- let zero = simd:: zeroed ( ) ;
12+ let zero = simd:: splat_i16 ( 0 ) ;
1513 let one = simd:: splat_i16 ( FT_QUANT as i16 ) ;
1614
1715 for flip in [ 0 , 1 ] {
@@ -53,33 +51,6 @@ pub unsafe fn activate_ft(pst: &PstAccumulator, threat: &ThreatAccumulator, stm:
5351 output
5452}
5553
56- pub unsafe fn find_nnz (
57- ft_out : & Aligned < [ u8 ; L1_SIZE ] > , nnz_table : & [ SparseEntry ] ,
58- ) -> ( Aligned < [ u16 ; L1_SIZE / 4 ] > , usize ) {
59- let mut indexes = Aligned :: new ( [ 0 ; L1_SIZE / 4 ] ) ;
60- let mut count = 0 ;
61-
62- let increment = _mm_set1_epi16 ( 8 ) ;
63- let mut base = _mm_setzero_si128 ( ) ;
64-
65- for i in ( 0 ..L1_SIZE ) . step_by ( 2 * simd:: I16_LANES ) {
66- let mask = simd:: nnz_bitmask ( * ft_out. as_ptr ( ) . add ( i) . cast ( ) ) ;
67-
68- for offset in ( 0 ..simd:: I32_LANES ) . step_by ( 8 ) {
69- let slice = ( mask >> offset) & 0xFF ;
70- let entry = nnz_table. get_unchecked ( slice as usize ) ;
71-
72- let store = indexes. as_mut_ptr ( ) . add ( count) . cast ( ) ;
73- _mm_storeu_si128 ( store, _mm_add_epi16 ( base, * entry. indexes . as_ptr ( ) . cast ( ) ) ) ;
74-
75- count += entry. count ;
76- base = _mm_add_epi16 ( base, increment) ;
77- }
78- }
79-
80- ( indexes, count)
81- }
82-
8354pub unsafe fn propagate_l1 ( ft_out : Aligned < [ u8 ; L1_SIZE ] > , nnz : & [ u16 ] ) -> Aligned < [ f32 ; L2_SIZE ] > {
8455 const CHUNKS : usize = 4 ;
8556
@@ -100,23 +71,23 @@ pub unsafe fn propagate_l1(ft_out: Aligned<[u8; L1_SIZE]>, nnz: &[u16]) -> Align
10071 let weights2 = PARAMETERS . l1_weights . as_ptr ( ) . add ( index2 * L2_SIZE * CHUNKS ) ;
10172
10273 for j in ( 0 ..L2_SIZE ) . step_by ( simd:: F32_LANES ) {
103- let weights1 = weights1. add ( j * CHUNKS ) . cast ( ) ;
104- let weights2 = weights2. add ( j * CHUNKS ) . cast ( ) ;
74+ let weights1 = * weights1. add ( j * CHUNKS ) . cast ( ) ;
75+ let weights2 = * weights2. add ( j * CHUNKS ) . cast ( ) ;
10576
10677 let vector = & mut pre_activations[ j / simd:: F32_LANES ] ;
107- * vector = simd:: double_dpbusd ( * vector, input1, * weights1, input2, * weights2) ;
78+ * vector = simd:: double_dpbusd ( * vector, input1, weights1, input2, weights2) ;
10879 }
10980 }
11081
11182 if let Some ( last) = pairs. remainder ( ) . first ( ) {
11283 let index = * last as usize ;
113- let pst_input = simd:: splat_i32 ( * packed. get_unchecked ( index) ) ;
84+ let input = simd:: splat_i32 ( * packed. get_unchecked ( index) ) ;
11485 let weights = PARAMETERS . l1_weights . as_ptr ( ) . add ( index * L2_SIZE * CHUNKS ) ;
11586
11687 for j in ( 0 ..L2_SIZE ) . step_by ( simd:: F32_LANES ) {
117- let weights = weights. add ( j * CHUNKS ) . cast ( ) ;
88+ let weights = * weights. add ( j * CHUNKS ) . cast ( ) ;
11889 let vector = & mut pre_activations[ j / simd:: F32_LANES ] ;
119- * vector = simd:: dpbusd ( * vector, pst_input , * weights) ;
90+ * vector = simd:: dpbusd ( * vector, input , weights) ;
12091 }
12192 }
12293
@@ -139,13 +110,13 @@ pub unsafe fn propagate_l2(l1_out: Aligned<[f32; L2_SIZE]>) -> Aligned<[f32; L3_
139110 let mut output = PARAMETERS . l2_biases . clone ( ) ;
140111
141112 for i in 0 ..L2_SIZE {
142- let pst_input = simd:: splat_f32 ( l1_out[ i] ) ;
113+ let input = simd:: splat_f32 ( l1_out[ i] ) ;
143114 let weights = PARAMETERS . l2_weights [ i] . as_ptr ( ) ;
144115
145116 for j in ( 0 ..L3_SIZE ) . step_by ( simd:: F32_LANES ) {
146- let weights = weights. add ( j) . cast ( ) ;
117+ let weights = * weights. add ( j) . cast ( ) ;
147118 let vector = output. as_mut_ptr ( ) . add ( j) . cast ( ) ;
148- * vector = simd:: mul_add_f32 ( * weights, pst_input , * vector) ;
119+ * vector = simd:: mul_add_f32 ( weights, input , * vector) ;
149120 }
150121 }
151122
@@ -163,19 +134,80 @@ pub unsafe fn propagate_l2(l1_out: Aligned<[f32; L2_SIZE]>) -> Aligned<[f32; L3_
163134pub unsafe fn propagate_l3 ( l2_out : Aligned < [ f32 ; L3_SIZE ] > ) -> f32 {
164135 const LANES : usize = 16 / simd:: F32_LANES ;
165136
166- let pst_input = l2_out. as_ptr ( ) ;
137+ let input = l2_out. as_ptr ( ) ;
167138 let weights = PARAMETERS . l3_weights . as_ptr ( ) ;
168139
169140 let mut output = [ simd:: zero_f32 ( ) ; LANES ] ;
170141
171142 for ( lane, result) in output. iter_mut ( ) . enumerate ( ) {
172143 for i in ( 0 ..L3_SIZE ) . step_by ( LANES * simd:: F32_LANES ) {
173- let a = weights. add ( i + lane * simd:: F32_LANES ) . cast ( ) ;
174- let b = pst_input . add ( i + lane * simd:: F32_LANES ) . cast ( ) ;
144+ let a = * weights. add ( i + lane * simd:: F32_LANES ) . cast ( ) ;
145+ let b = * input . add ( i + lane * simd:: F32_LANES ) . cast ( ) ;
175146
176- * result = simd:: mul_add_f32 ( * a, * b, * result) ;
147+ * result = simd:: mul_add_f32 ( a, b, * result) ;
177148 }
178149 }
179150
180151 simd:: horizontal_sum ( output) + PARAMETERS . l3_biases
181152}
153+
154+ #[ cfg( not( target_feature = "neon" ) ) ]
155+ pub unsafe fn find_nnz (
156+ ft_out : & Aligned < [ u8 ; L1_SIZE ] > , nnz_table : & [ SparseEntry ] ,
157+ ) -> ( Aligned < [ u16 ; L1_SIZE / 4 ] > , usize ) {
158+ use std:: arch:: x86_64:: * ;
159+
160+ let mut indexes = Aligned :: new ( [ 0 ; L1_SIZE / 4 ] ) ;
161+ let mut count = 0 ;
162+
163+ let increment = _mm_set1_epi16 ( 8 ) ;
164+ let mut base = _mm_setzero_si128 ( ) ;
165+
166+ for i in ( 0 ..L1_SIZE ) . step_by ( 2 * simd:: I16_LANES ) {
167+ let mask = simd:: nnz_bitmask ( * ft_out. as_ptr ( ) . add ( i) . cast ( ) ) ;
168+
169+ for offset in ( 0 ..simd:: I32_LANES ) . step_by ( 8 ) {
170+ let slice = ( mask >> offset) & 0xFF ;
171+ let entry = nnz_table. get_unchecked ( slice as usize ) ;
172+
173+ let store = indexes. as_mut_ptr ( ) . add ( count) . cast ( ) ;
174+ _mm_storeu_si128 ( store, _mm_add_epi16 ( base, * entry. indexes . as_ptr ( ) . cast ( ) ) ) ;
175+
176+ count += entry. count ;
177+ base = _mm_add_epi16 ( base, increment) ;
178+ }
179+ }
180+
181+ ( indexes, count)
182+ }
183+
184+ #[ cfg( target_feature = "neon" ) ]
185+ pub unsafe fn find_nnz (
186+ ft_out : & Aligned < [ u8 ; L1_SIZE ] > , nnz_table : & [ SparseEntry ] ,
187+ ) -> ( Aligned < [ u16 ; L1_SIZE / 4 ] > , usize ) {
188+ use std:: arch:: aarch64:: * ;
189+
190+ let mut indexes = Aligned :: new ( [ 0 ; L1_SIZE / 4 ] ) ;
191+ let mut count = 0 ;
192+
193+ let increment = vdupq_n_s16 ( 8 ) ;
194+ let mut base = vdupq_n_s16 ( 0 ) ;
195+
196+ for i in ( 0 ..L1_SIZE ) . step_by ( 32 ) {
197+ let v0 = * ft_out. as_ptr ( ) . add ( i) . cast ( ) ;
198+ let v1 = * ft_out. as_ptr ( ) . add ( i + 16 ) . cast ( ) ;
199+
200+ let mask = ( simd:: nnz_bitmask ( v0) | ( simd:: nnz_bitmask ( v1) << 4 ) ) as usize ;
201+ let entry = nnz_table. get_unchecked ( mask) ;
202+
203+ let store = indexes. as_mut_ptr ( ) . add ( count) . cast ( ) ;
204+ let indexed = vaddq_s16 ( base, vld1q_s16 ( entry. indexes . as_ptr ( ) . cast ( ) ) ) ;
205+
206+ vst1q_s16 ( store, indexed) ;
207+
208+ count += entry. count ;
209+ base = vaddq_s16 ( base, increment) ;
210+ }
211+
212+ ( indexes, count)
213+ }
0 commit comments