@@ -88,8 +88,9 @@ void find_nnz(const std::int32_t* RESTRICT input,
8888 constexpr IndexType SimdWidthOut = 32 ; // 512 bits / 16 bits
8989 constexpr IndexType NumChunks = InputDimensions / SimdWidthOut;
9090 const __m512i increment = _mm512_set1_epi16 (SimdWidthOut);
91- __m512i base = _mm512_set_epi16 (31 , 30 , 29 , 28 , 27 , 26 , 25 , 24 , 23 , 22 , 21 , 20 , 19 , 18 , 17 , 16 ,
92- 15 , 14 , 13 , 12 , 11 , 10 , 9 , 8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 );
91+ __m512i base = _mm512_set_epi16 ( // Same permute order as _mm512_packus_epi32()
92+ 31 , 30 , 29 , 28 , 15 , 14 , 13 , 12 , 27 , 26 , 25 , 24 , 11 , 10 , 9 , 8 , 23 , 22 , 21 , 20 , 7 , 6 , 5 , 4 , 19 ,
93+ 18 , 17 , 16 , 3 , 2 , 1 , 0 );
9394
9495 IndexType count = 0 ;
9596 for (IndexType i = 0 ; i < NumChunks; ++i)
@@ -98,12 +99,12 @@ void find_nnz(const std::int32_t* RESTRICT input,
9899 const __m512i inputV1 = _mm512_load_si512 (input + i * 2 * SimdWidthIn + SimdWidthIn);
99100
100101 // Get a bitmask and gather non zero indices
101- const __mmask32 nnzMask = _mm512_kunpackw ( _mm512_test_epi32_mask (inputV1 , inputV1),
102- _mm512_test_epi32_mask (inputV0, inputV0) );
102+ const __m512i inputV01 = _mm512_packus_epi32 (inputV0 , inputV1);
103+ const __mmask32 nnzMask = _mm512_test_epi16_mask (inputV01, inputV01 );
103104
104105 // Avoid _mm512_mask_compressstoreu_epi16() as it's 256 uOps on Zen4
105106 __m512i nnz = _mm512_maskz_compress_epi16 (nnzMask, base);
106- _mm512_storeu_epi16 (out + count, nnz);
107+ _mm512_storeu_si512 (out + count, nnz);
107108
108109 count += popcount (nnzMask);
109110 base = _mm512_add_epi16 (base, increment);
0 commit comments