Skip to content

Commit 45f00f5

Browse files
committed
Don't do unnecessary work in the affine transform when input is a multiple of 4 but not a multiple of 32
1 parent a6a1a3f commit 45f00f5

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/nnue/layers/affine_transform.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ namespace Stockfish::Eval::NNUE::Layers {
267267
#endif
268268

269269
#if defined (USE_SSSE3)
270+
// Different layout, we process 4 inputs at a time, always.
271+
static_assert(InputDimensions % 4 == 0);
270272

271273
const auto output = reinterpret_cast<OutputType*>(buffer);
272274
const auto inputVector = reinterpret_cast<const vec_t*>(input);
@@ -277,7 +279,7 @@ namespace Stockfish::Eval::NNUE::Layers {
277279
// because then it is also an input dimension.
278280
if constexpr (OutputDimensions % OutputSimdWidth == 0)
279281
{
280-
constexpr IndexType NumChunks = PaddedInputDimensions / 4;
282+
constexpr IndexType NumChunks = InputDimensions / 4;
281283

282284
const auto input32 = reinterpret_cast<const std::int32_t*>(input);
283285
vec_t* outptr = reinterpret_cast<vec_t*>(output);
@@ -344,17 +346,21 @@ namespace Stockfish::Eval::NNUE::Layers {
344346
auto output = reinterpret_cast<OutputType*>(buffer);
345347

346348
#if defined(USE_SSE2)
347-
constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
349+
// At least a multiple of 16, with SSE2.
350+
static_assert(InputDimensions % SimdWidth == 0);
351+
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
348352
const __m128i Zeros = _mm_setzero_si128();
349353
const auto inputVector = reinterpret_cast<const __m128i*>(input);
350354

351355
#elif defined(USE_MMX)
352-
constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
356+
static_assert(InputDimensions % SimdWidth == 0);
357+
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
353358
const __m64 Zeros = _mm_setzero_si64();
354359
const auto inputVector = reinterpret_cast<const __m64*>(input);
355360

356361
#elif defined(USE_NEON)
357-
constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
362+
static_assert(InputDimensions % SimdWidth == 0);
363+
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
358364
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
359365
#endif
360366

0 commit comments

Comments
 (0)