22import features
33import math
44import model as M
5- import numpy
65import struct
76import torch
7+ import io
88from torch import nn
99import pytorch_lightning as pl
1010from torch .utils .data import DataLoader
1111from functools import reduce
1212import operator
13+ import numpy as np
14+ from numba import njit
1315
1416def ascii_hist (name , x , bins = 6 ):
15- N ,X = numpy .histogram (x , bins = bins )
17+ N ,X = np .histogram (x , bins = bins )
1618 total = 1.0 * len (x )
1719 width = 50
1820 nmax = N .max ()
@@ -23,6 +25,36 @@ def ascii_hist(name, x, bins=6):
2325 xi = '{0: <8.4g}' .format (xi ).ljust (10 )
2426 print ('{0}| {1}' .format (xi ,bar ))
2527
28+ @njit
29+ def encode_leb_128_array (arr ):
30+ res = []
31+ for v in arr :
32+ while True :
33+ byte = v & 0x7f
34+ v = v >> 7
35+ if (v == 0 and byte & 0x40 == 0 ) or (v == - 1 and byte & 0x40 != 0 ):
36+ res .append (byte )
37+ break
38+ res .append (byte | 0x80 )
39+ return res
40+
41+ @njit
42+ def decode_leb_128_array (arr , n ):
43+ ints = np .zeros (n )
44+ k = 0
45+ for i in range (n ):
46+ r = 0
47+ shift = 0
48+ while True :
49+ byte = arr [k ]
50+ k = k + 1
51+ r |= (byte & 0x7f ) << shift
52+ shift += 7
53+ if (byte & 0x80 ) == 0 :
54+ ints [i ] = r if (byte & 0x40 ) == 0 else r | ~ ((1 << shift ) - 1 )
55+ break
56+ return ints
57+
2658# hardcoded for now
2759VERSION = 0x7AF32F20
2860DEFAULT_DESCRIPTION = "Network trained with the https://github.com/glinscott/nnue-pytorch trainer."
@@ -31,7 +63,7 @@ class NNUEWriter():
3163 """
3264 All values are stored in little endian.
3365 """
34- def __init__ (self , model , description = None ):
66+ def __init__ (self , model , description = None , ft_compression = 'none' ):
3567 if description is None :
3668 description = DEFAULT_DESCRIPTION
3769
@@ -43,7 +75,7 @@ def __init__(self, model, description=None):
4375 fc_hash = self .fc_hash (model )
4476 self .write_header (model , fc_hash , description )
4577 self .int32 (model .feature_set .hash ^ (M .L1 * 2 )) # Feature transformer hash
46- self .write_feature_transformer (model )
78+ self .write_feature_transformer (model , ft_compression )
4779 for l1 , l2 , output in model .layer_stacks .get_coalesced_layer_stacks ():
4880 self .int32 (fc_hash ) # FC layers hash
4981 self .write_fc_layer (model , l1 )
@@ -76,7 +108,21 @@ def write_header(self, model, fc_hash, description):
76108 self .int32 (len (encoded_description )) # Network definition
77109 self .buf .extend (encoded_description )
78110
79- def write_feature_transformer (self , model ):
111+ def write_leb_128_array (self , arr ):
112+ buf = encode_leb_128_array (arr )
113+ self .int32 (len (buf ))
114+ self .buf .extend (buf )
115+
116+ def write_tensor (self , arr , compression = 'none' ):
117+ if compression == 'none' :
118+ self .buf .extend (arr .tobytes ())
119+ elif compression == 'leb128' :
120+ self .buf .extend ('COMPRESSED_LEB128' .encode ('utf-8' ))
121+ self .write_leb_128_array (arr )
122+ else :
123+ raise Exception ('Invalid compression method.' )
124+
125+ def write_feature_transformer (self , model , ft_compression ):
80126 layer = model .input
81127
82128 bias = layer .bias .data [:M .L1 ]
@@ -93,10 +139,11 @@ def write_feature_transformer(self, model):
93139 ascii_hist ('ft weight:' , weight .numpy ())
94140 ascii_hist ('ft psqt weight:' , psqt_weight .numpy ())
95141
96- self .buf .extend (bias .flatten ().numpy ().tobytes ())
97142 # Weights stored as [num_features][outputs]
98- self .buf .extend (weight .flatten ().numpy ().tobytes ())
99- self .buf .extend (psqt_weight .flatten ().numpy ().tobytes ())
143+
144+ self .write_tensor (bias .flatten ().numpy (), ft_compression )
145+ self .write_tensor (weight .flatten ().numpy (), ft_compression )
146+ self .write_tensor (psqt_weight .flatten ().numpy (), ft_compression )
100147
101148 def write_fc_layer (self , model , layer , is_output = False ):
102149 # FC layers are stored as int8 weights, and int32 biases
@@ -170,20 +217,51 @@ def read_header(self, feature_set, fc_hash):
170217 desc_len = self .read_int32 ()
171218 description = self .f .read (desc_len )
172219
220+ def read_leb_128_array (self , dtype , shape ):
221+ l = self .read_int32 ()
222+ d = self .f .read (l )
223+ if len (d ) != l :
224+ raise Exception ('Unexpected end of file when reading compressed data.' )
225+
226+ res = torch .FloatTensor (decode_leb_128_array (d , reduce (operator .mul , shape , 1 )))
227+ res = res .reshape (shape )
228+ return res
229+
230+ def peek (self , length = 1 ):
231+ pos = self .f .tell ()
232+ data = self .f .read (length )
233+ self .f .seek (pos )
234+ return data
235+
236+ def determine_compression (self ):
237+ leb128_magic = b'COMPRESSED_LEB128'
238+ if self .peek (len (leb128_magic )) == leb128_magic :
239+ self .f .read (len (leb128_magic )) # actually advance the file pointer
240+ return 'leb128'
241+ else :
242+ return 'none'
243+
173244 def tensor (self , dtype , shape ):
174- d = numpy .fromfile (self .f , dtype , reduce (operator .mul , shape , 1 ))
175- d = torch .from_numpy (d .astype (numpy .float32 ))
176- d = d .reshape (shape )
177- return d
245+ compression = self .determine_compression ()
246+
247+ if compression == 'none' :
248+ d = np .fromfile (self .f , dtype , reduce (operator .mul , shape , 1 ))
249+ d = torch .from_numpy (d .astype (np .float32 ))
250+ d = d .reshape (shape )
251+ return d
252+ elif compression == 'leb128' :
253+ return self .read_leb_128_array (dtype , shape )
254+ else :
255+ raise Exception ('Invalid compression method.' )
178256
179257 def read_feature_transformer (self , layer , num_psqt_buckets ):
180258 shape = layer .weight .shape
181259
182- bias = self .tensor (numpy .int16 , [layer .bias .shape [0 ]- num_psqt_buckets ]).divide (self .model .quantized_one )
260+ bias = self .tensor (np .int16 , [layer .bias .shape [0 ]- num_psqt_buckets ]).divide (self .model .quantized_one )
183261 # weights stored as [num_features][outputs]
184- weights = self .tensor (numpy .int16 , [shape [0 ], shape [1 ]- num_psqt_buckets ])
262+ weights = self .tensor (np .int16 , [shape [0 ], shape [1 ]- num_psqt_buckets ])
185263 weights = weights .divide (self .model .quantized_one )
186- psqt_weights = self .tensor (numpy .int32 , [shape [0 ], num_psqt_buckets ])
264+ psqt_weights = self .tensor (np .int32 , [shape [0 ], num_psqt_buckets ])
187265 psqt_weights = psqt_weights .divide (self .model .nnue2score * self .model .weight_scale_out )
188266
189267 layer .bias .data = torch .cat ([bias , torch .tensor ([0 ]* num_psqt_buckets )])
@@ -202,8 +280,8 @@ def read_fc_layer(self, layer, is_output=False):
202280 non_padded_shape = layer .weight .shape
203281 padded_shape = (non_padded_shape [0 ], ((non_padded_shape [1 ]+ 31 )// 32 )* 32 )
204282
205- layer .bias .data = self .tensor (numpy .int32 , layer .bias .shape ).divide (kBiasScale )
206- layer .weight .data = self .tensor (numpy .int8 , padded_shape ).divide (kWeightScale )
283+ layer .bias .data = self .tensor (np .int32 , layer .bias .shape ).divide (kBiasScale )
284+ layer .weight .data = self .tensor (np .int8 , padded_shape ).divide (kWeightScale )
207285
208286 # Strip padding.
209287 layer .weight .data = layer .weight .data [:non_padded_shape [0 ], :non_padded_shape [1 ]]
@@ -219,6 +297,7 @@ def main():
219297 parser .add_argument ("source" , help = "Source file (can be .ckpt, .pt or .nnue)" )
220298 parser .add_argument ("target" , help = "Target file (can be .pt or .nnue)" )
221299 parser .add_argument ("--description" , default = None , type = str , dest = 'description' , help = "The description string to include in the network. Only works when serializing into a .nnue file." )
300+ parser .add_argument ("--ft_compression" , default = 'none' , type = str , dest = 'ft_compression' , help = "Compression method to use for FT weights and biases. Either 'none' or 'leb128'. Only allowed if saving to .nnue." )
222301 features .add_argparse_args (parser )
223302 args = parser .parse_args ()
224303
@@ -238,12 +317,18 @@ def main():
238317 else :
239318 raise Exception ('Invalid network input format.' )
240319
320+ if args .ft_compression != 'none' and not args .target .endswith ('.nnue' ):
321+ raise Exception ('Compression only allowed for .nnue target.' )
322+
323+ if args .ft_compression not in ['none' , 'leb128' ]:
324+ raise Exception ('Invalid compression method.' )
325+
241326 if args .target .endswith ('.ckpt' ):
242327 raise Exception ('Cannot convert into .ckpt' )
243328 elif args .target .endswith ('.pt' ):
244329 torch .save (nnue , args .target )
245330 elif args .target .endswith ('.nnue' ):
246- writer = NNUEWriter (nnue , args .description )
331+ writer = NNUEWriter (nnue , args .description , ft_compression = args . ft_compression )
247332 with open (args .target , 'wb' ) as f :
248333 f .write (writer .buf )
249334 else :
0 commit comments