1919MAX_HIDDEN_WEIGHT = MAX_QUANTIZED_ACTIVATION / HIDDEN_WEIGHT_SCALE
2020MAX_OUTPUT_WEIGHT = MAX_QUANTIZED_ACTIVATION / OUTPUT_WEIGHT_SCALE
2121
22+ NNUE_FORMAT_VERSION = 0x00000008
2223
2324def write_header (buf , version ):
2425 buf .extend (struct .pack ('<I' , version ))
@@ -29,7 +30,7 @@ def write_layer(buf, biases, weights):
2930 buf .extend (weights .numpy ().tobytes ())
3031
3132
32- def quant_halfkx (biases , weights ):
33+ def quant_input (biases , weights ):
3334 biases = biases .mul (HALFKX_BIAS_SCALE ).round ().to (torch .int16 )
3435 weights = weights .mul (HALFKX_WEIGHT_SCALE ).round ().to (torch .int16 )
3536 return (biases , weights )
@@ -47,14 +48,7 @@ def quant_output(biases, weights):
4748 return (biases , weights )
4849
4950
50- def read_version (file ):
51- version = struct .unpack ('<I' , file .read (4 ))[0 ]
52- if version != model .EXPORT_FORMAT_VERSION :
53- raise Exception ('Model format mismatch' )
54- return version
55-
56-
57- def read_layer (file , ninputs , size ):
51+ def extract_layer (file , ninputs , size ):
5852 buf = numpy .fromfile (file , numpy .float32 , size )
5953 biases = torch .from_numpy (buf .astype (numpy .float32 ))
6054 buf = numpy .fromfile (file , numpy .float32 , size * ninputs )
@@ -65,24 +59,21 @@ def read_layer(file, ninputs, size):
6559def quantization (source , target ):
6660 print ('Performing quantization ...' )
6761
68- # Read all layers
69- with open (source , 'rb' ) as f :
70- version = read_version (f )
71- halfkx = read_layer (f , model .NUM_INPUTS , model .L1 )
72- linear1 = read_layer (f , model .L1 * 2 , model .L2 )
73- linear2 = read_layer (f , model .L2 , model .L3 )
74- output = read_layer (f , model .L3 , 1 )
62+ # Load model
63+ nnue = model .NNUE ()
64+ nnue .load_state_dict (torch .load (source , map_location = torch .device ('cpu' )))
65+ nnue .eval ()
7566
7667 # Perform quantization
77- halfkx = quant_halfkx ( halfkx [ 0 ], halfkx [ 1 ] )
78- linear1 = quant_linear (linear1 [ 0 ], linear1 [ 1 ] )
79- linear2 = quant_linear (linear2 [ 0 ], linear2 [ 1 ] )
80- output = quant_output (output [ 0 ], output [ 1 ] )
68+ input = quant_input ( nnue . input . weight , nnue . input . bias )
69+ linear1 = quant_linear (nnue . l1 . weight , nnue . l1 . bias )
70+ linear2 = quant_linear (nnue . l2 . weight , nnue . l2 . bias )
71+ output = quant_output (nnue . output . weight , nnue . output . bias )
8172
8273 # Write quantized layers
8374 outbuffer = bytearray ()
84- write_header (outbuffer , version )
85- write_layer (outbuffer , halfkx [0 ], halfkx [1 ])
75+ write_header (outbuffer , NNUE_FORMAT_VERSION )
76+ write_layer (outbuffer , input [0 ], input [1 ])
8677 write_layer (outbuffer , linear1 [0 ], linear1 [1 ])
8778 write_layer (outbuffer , linear2 [0 ], linear2 [1 ])
8879 write_layer (outbuffer , output [0 ], output [1 ])
0 commit comments