Skip to content

Commit f96602a

Browse files
authored
Merge pull request #251 from Sopel97/leb
Support for LEB128 compression of feature transformer parameters.
2 parents 9883b2c + 0e5b220 commit f96602a

File tree

1 file changed

+103
-18
lines changed

1 file changed

+103
-18
lines changed

serialize.py

Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
import features
33
import math
44
import model as M
5-
import numpy
65
import struct
76
import torch
7+
import io
88
from torch import nn
99
import pytorch_lightning as pl
1010
from torch.utils.data import DataLoader
1111
from functools import reduce
1212
import operator
13+
import numpy as np
14+
from numba import njit
1315

1416
def ascii_hist(name, x, bins=6):
15-
N,X = numpy.histogram(x, bins=bins)
17+
N,X = np.histogram(x, bins=bins)
1618
total = 1.0*len(x)
1719
width = 50
1820
nmax = N.max()
@@ -23,6 +25,36 @@ def ascii_hist(name, x, bins=6):
2325
xi = '{0: <8.4g}'.format(xi).ljust(10)
2426
print('{0}| {1}'.format(xi,bar))
2527

28+
@njit
29+
def encode_leb_128_array(arr):
30+
res = []
31+
for v in arr:
32+
while True:
33+
byte = v & 0x7f
34+
v = v >> 7
35+
if (v == 0 and byte & 0x40 == 0) or (v == -1 and byte & 0x40 != 0):
36+
res.append(byte)
37+
break
38+
res.append(byte | 0x80)
39+
return res
40+
41+
@njit
42+
def decode_leb_128_array(arr, n):
43+
ints = np.zeros(n)
44+
k = 0
45+
for i in range(n):
46+
r = 0
47+
shift = 0
48+
while True:
49+
byte = arr[k]
50+
k = k + 1
51+
r |= (byte & 0x7f) << shift
52+
shift += 7
53+
if (byte & 0x80) == 0:
54+
ints[i] = r if (byte & 0x40) == 0 else r | ~((1 << shift) - 1)
55+
break
56+
return ints
57+
2658
# hardcoded for now
2759
VERSION = 0x7AF32F20
2860
DEFAULT_DESCRIPTION = "Network trained with the https://github.com/glinscott/nnue-pytorch trainer."
@@ -31,7 +63,7 @@ class NNUEWriter():
3163
"""
3264
All values are stored in little endian.
3365
"""
34-
def __init__(self, model, description=None):
66+
def __init__(self, model, description=None, ft_compression='none'):
3567
if description is None:
3668
description = DEFAULT_DESCRIPTION
3769

@@ -43,7 +75,7 @@ def __init__(self, model, description=None):
4375
fc_hash = self.fc_hash(model)
4476
self.write_header(model, fc_hash, description)
4577
self.int32(model.feature_set.hash ^ (M.L1*2)) # Feature transformer hash
46-
self.write_feature_transformer(model)
78+
self.write_feature_transformer(model, ft_compression)
4779
for l1, l2, output in model.layer_stacks.get_coalesced_layer_stacks():
4880
self.int32(fc_hash) # FC layers hash
4981
self.write_fc_layer(model, l1)
@@ -76,7 +108,21 @@ def write_header(self, model, fc_hash, description):
76108
self.int32(len(encoded_description)) # Network definition
77109
self.buf.extend(encoded_description)
78110

79-
def write_feature_transformer(self, model):
111+
def write_leb_128_array(self, arr):
112+
buf = encode_leb_128_array(arr)
113+
self.int32(len(buf))
114+
self.buf.extend(buf)
115+
116+
def write_tensor(self, arr, compression='none'):
117+
if compression == 'none':
118+
self.buf.extend(arr.tobytes())
119+
elif compression == 'leb128':
120+
self.buf.extend('COMPRESSED_LEB128'.encode('utf-8'))
121+
self.write_leb_128_array(arr)
122+
else:
123+
raise Exception('Invalid compression method.')
124+
125+
def write_feature_transformer(self, model, ft_compression):
80126
layer = model.input
81127

82128
bias = layer.bias.data[:M.L1]
@@ -93,10 +139,11 @@ def write_feature_transformer(self, model):
93139
ascii_hist('ft weight:', weight.numpy())
94140
ascii_hist('ft psqt weight:', psqt_weight.numpy())
95141

96-
self.buf.extend(bias.flatten().numpy().tobytes())
97142
# Weights stored as [num_features][outputs]
98-
self.buf.extend(weight.flatten().numpy().tobytes())
99-
self.buf.extend(psqt_weight.flatten().numpy().tobytes())
143+
144+
self.write_tensor(bias.flatten().numpy(), ft_compression)
145+
self.write_tensor(weight.flatten().numpy(), ft_compression)
146+
self.write_tensor(psqt_weight.flatten().numpy(), ft_compression)
100147

101148
def write_fc_layer(self, model, layer, is_output=False):
102149
# FC layers are stored as int8 weights, and int32 biases
@@ -170,20 +217,51 @@ def read_header(self, feature_set, fc_hash):
170217
desc_len = self.read_int32()
171218
description = self.f.read(desc_len)
172219

220+
def read_leb_128_array(self, dtype, shape):
221+
l = self.read_int32()
222+
d = self.f.read(l)
223+
if len(d) != l:
224+
raise Exception('Unexpected end of file when reading compressed data.')
225+
226+
res = torch.FloatTensor(decode_leb_128_array(d, reduce(operator.mul, shape, 1)))
227+
res = res.reshape(shape)
228+
return res
229+
230+
def peek(self, length=1):
231+
pos = self.f.tell()
232+
data = self.f.read(length)
233+
self.f.seek(pos)
234+
return data
235+
236+
def determine_compression(self):
237+
leb128_magic = b'COMPRESSED_LEB128'
238+
if self.peek(len(leb128_magic)) == leb128_magic:
239+
self.f.read(len(leb128_magic)) # actually advance the file pointer
240+
return 'leb128'
241+
else:
242+
return 'none'
243+
173244
def tensor(self, dtype, shape):
174-
d = numpy.fromfile(self.f, dtype, reduce(operator.mul, shape, 1))
175-
d = torch.from_numpy(d.astype(numpy.float32))
176-
d = d.reshape(shape)
177-
return d
245+
compression = self.determine_compression()
246+
247+
if compression == 'none':
248+
d = np.fromfile(self.f, dtype, reduce(operator.mul, shape, 1))
249+
d = torch.from_numpy(d.astype(np.float32))
250+
d = d.reshape(shape)
251+
return d
252+
elif compression == 'leb128':
253+
return self.read_leb_128_array(dtype, shape)
254+
else:
255+
raise Exception('Invalid compression method.')
178256

179257
def read_feature_transformer(self, layer, num_psqt_buckets):
180258
shape = layer.weight.shape
181259

182-
bias = self.tensor(numpy.int16, [layer.bias.shape[0]-num_psqt_buckets]).divide(self.model.quantized_one)
260+
bias = self.tensor(np.int16, [layer.bias.shape[0]-num_psqt_buckets]).divide(self.model.quantized_one)
183261
# weights stored as [num_features][outputs]
184-
weights = self.tensor(numpy.int16, [shape[0], shape[1]-num_psqt_buckets])
262+
weights = self.tensor(np.int16, [shape[0], shape[1]-num_psqt_buckets])
185263
weights = weights.divide(self.model.quantized_one)
186-
psqt_weights = self.tensor(numpy.int32, [shape[0], num_psqt_buckets])
264+
psqt_weights = self.tensor(np.int32, [shape[0], num_psqt_buckets])
187265
psqt_weights = psqt_weights.divide(self.model.nnue2score * self.model.weight_scale_out)
188266

189267
layer.bias.data = torch.cat([bias, torch.tensor([0]*num_psqt_buckets)])
@@ -202,8 +280,8 @@ def read_fc_layer(self, layer, is_output=False):
202280
non_padded_shape = layer.weight.shape
203281
padded_shape = (non_padded_shape[0], ((non_padded_shape[1]+31)//32)*32)
204282

205-
layer.bias.data = self.tensor(numpy.int32, layer.bias.shape).divide(kBiasScale)
206-
layer.weight.data = self.tensor(numpy.int8, padded_shape).divide(kWeightScale)
283+
layer.bias.data = self.tensor(np.int32, layer.bias.shape).divide(kBiasScale)
284+
layer.weight.data = self.tensor(np.int8, padded_shape).divide(kWeightScale)
207285

208286
# Strip padding.
209287
layer.weight.data = layer.weight.data[:non_padded_shape[0], :non_padded_shape[1]]
@@ -219,6 +297,7 @@ def main():
219297
parser.add_argument("source", help="Source file (can be .ckpt, .pt or .nnue)")
220298
parser.add_argument("target", help="Target file (can be .pt or .nnue)")
221299
parser.add_argument("--description", default=None, type=str, dest='description', help="The description string to include in the network. Only works when serializing into a .nnue file.")
300+
parser.add_argument("--ft_compression", default='none', type=str, dest='ft_compression', help="Compression method to use for FT weights and biases. Either 'none' or 'leb128'. Only allowed if saving to .nnue.")
222301
features.add_argparse_args(parser)
223302
args = parser.parse_args()
224303

@@ -238,12 +317,18 @@ def main():
238317
else:
239318
raise Exception('Invalid network input format.')
240319

320+
if args.ft_compression != 'none' and not args.target.endswith('.nnue'):
321+
raise Exception('Compression only allowed for .nnue target.')
322+
323+
if args.ft_compression not in ['none', 'leb128']:
324+
raise Exception('Invalid compression method.')
325+
241326
if args.target.endswith('.ckpt'):
242327
raise Exception('Cannot convert into .ckpt')
243328
elif args.target.endswith('.pt'):
244329
torch.save(nnue, args.target)
245330
elif args.target.endswith('.nnue'):
246-
writer = NNUEWriter(nnue, args.description)
331+
writer = NNUEWriter(nnue, args.description, ft_compression=args.ft_compression)
247332
with open(args.target, 'wb') as f:
248333
f.write(writer.buf)
249334
else:

0 commit comments

Comments
 (0)