Skip to content

Commit 394c252

Browse files
maximlinrock
authored andcommitted
Compressed network parameters
Implemented LEB128 (de)compression for the feature transformer. Reduces embedded network size from 70 MiB to 39 Mib. The new nn-78bacfcee510.nnue corresponds to the master net compressed. closes official-stockfish#4617 No functional change
1 parent 91ff9ad commit 394c252

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

src/nnue/nnue_common.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ namespace Stockfish::Eval::NNUE {
5959
// Size of cache line (in bytes)
6060
constexpr std::size_t CacheLineSize = 64;
6161

62+
constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128";
63+
constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1;
64+
6265
// SIMD width (in bytes)
6366
#if defined(USE_AVX2)
6467
constexpr std::size_t SimdWidth = 32;
@@ -161,6 +164,80 @@ namespace Stockfish::Eval::NNUE {
161164
write_little_endian<IntType>(stream, values[i]);
162165
}
163166

167+
template <typename IntType>
168+
inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) {
169+
static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types");
170+
char leb128MagicString[Leb128MagicStringSize];
171+
stream.read(leb128MagicString, Leb128MagicStringSize);
172+
assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0);
173+
const std::uint32_t BUF_SIZE = 4096;
174+
std::uint8_t buf[BUF_SIZE];
175+
auto bytes_left = read_little_endian<std::uint32_t>(stream);
176+
std::uint32_t buf_pos = BUF_SIZE;
177+
for (std::size_t i = 0; i < count; ++i) {
178+
IntType result = 0;
179+
size_t shift = 0;
180+
do {
181+
if (buf_pos == BUF_SIZE) {
182+
stream.read(reinterpret_cast<char*>(buf), std::min(bytes_left, BUF_SIZE));
183+
buf_pos = 0;
184+
}
185+
std::uint8_t byte = buf[buf_pos++];
186+
--bytes_left;
187+
result |= (byte & 0x7f) << shift;
188+
shift += 7;
189+
if ((byte & 0x80) == 0) {
190+
out[i] = sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0 ? result : result | ~((1 << shift) - 1);
191+
break;
192+
}
193+
} while (shift < sizeof(IntType) * 8);
194+
}
195+
assert(bytes_left == 0);
196+
}
197+
198+
template <typename IntType>
199+
inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) {
200+
static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types");
201+
stream.write(Leb128MagicString, Leb128MagicStringSize);
202+
std::uint32_t byte_count = 0;
203+
for (std::size_t i = 0; i < count; ++i) {
204+
IntType value = values[i];
205+
std::uint8_t byte;
206+
do {
207+
byte = value & 0x7f;
208+
value >>= 7;
209+
++byte_count;
210+
} while ((byte & 0x40) == 0 ? value != 0 : value != -1);
211+
}
212+
write_little_endian(stream, byte_count);
213+
const std::uint32_t BUF_SIZE = 4096;
214+
std::uint8_t buf[BUF_SIZE];
215+
std::uint32_t buf_pos = 0;
216+
auto flush = [&]() {
217+
if (buf_pos > 0) {
218+
stream.write(reinterpret_cast<char*>(buf), buf_pos);
219+
buf_pos = 0;
220+
}
221+
};
222+
auto write = [&](std::uint8_t byte) {
223+
buf[buf_pos++] = byte;
224+
if (buf_pos == BUF_SIZE) flush();
225+
};
226+
for (std::size_t i = 0; i < count; ++i) {
227+
IntType value = values[i];
228+
while (true) {
229+
std::uint8_t byte = value & 0x7f;
230+
value >>= 7;
231+
if ((byte & 0x40) == 0 ? value == 0 : value == -1) {
232+
write(byte);
233+
break;
234+
}
235+
write(byte | 0x80);
236+
}
237+
}
238+
flush();
239+
}
240+
164241
} // namespace Stockfish::Eval::NNUE
165242

166243
#endif // #ifndef NNUE_COMMON_H_INCLUDED

src/nnue/nnue_feature_transformer.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,19 @@ namespace Stockfish::Eval::NNUE {
255255
// Read network parameters
256256
bool read_parameters(std::istream& stream) {
257257

258-
read_little_endian<BiasType >(stream, biases , HalfDimensions );
259-
read_little_endian<WeightType >(stream, weights , HalfDimensions * InputDimensions);
260-
read_little_endian<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);
258+
read_leb_128<BiasType >(stream, biases , HalfDimensions );
259+
read_leb_128<WeightType >(stream, weights , HalfDimensions * InputDimensions);
260+
read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);
261261

262262
return !stream.fail();
263263
}
264264

265265
// Write network parameters
266266
bool write_parameters(std::ostream& stream) const {
267267

268-
write_little_endian<BiasType >(stream, biases , HalfDimensions );
269-
write_little_endian<WeightType >(stream, weights , HalfDimensions * InputDimensions);
270-
write_little_endian<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);
268+
write_leb_128<BiasType >(stream, biases , HalfDimensions );
269+
write_leb_128<WeightType >(stream, weights , HalfDimensions * InputDimensions);
270+
write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);
271271

272272
return !stream.fail();
273273
}

0 commit comments

Comments
 (0)