@@ -25,34 +25,31 @@ def write_header(buf, version):
2525 buf .extend (struct .pack ('<I' , version ))
2626
2727
28+ def write_input (buf , biases , weights ):
29+ buf .extend (biases .flatten ().numpy ().tobytes ())
30+ buf .extend (weights .transpose (0 , 1 ).flatten ().numpy ().tobytes ())
31+
32+
2833def write_layer (buf , biases , weights ):
29- buf .extend (biases .numpy ().tobytes ())
30- buf .extend (weights .numpy ().tobytes ())
34+ buf .extend (biases .flatten (). numpy ().tobytes ())
35+ buf .extend (weights .flatten (). numpy ().tobytes ())
3136
3237
3338def quant_input (biases , weights ):
34- biases = biases .mul (HALFKX_BIAS_SCALE ).round ().to (torch .int16 )
35- weights = weights .mul (HALFKX_WEIGHT_SCALE ).round ().to (torch .int16 )
39+ biases = biases .data . mul (HALFKX_BIAS_SCALE ).round ().to (torch .int16 )
40+ weights = weights .data . mul (HALFKX_WEIGHT_SCALE ).round ().to (torch .int16 )
3641 return (biases , weights )
3742
3843
3944def quant_linear (biases , weights ):
40- biases = biases .mul (HIDDEN_BIAS_SCALE ).round ().to (torch .int32 )
41- weights = weights .clamp (- MAX_HIDDEN_WEIGHT , MAX_HIDDEN_WEIGHT ).mul (HIDDEN_WEIGHT_SCALE ).round ().to (torch .int8 )
45+ biases = biases .data . mul (HIDDEN_BIAS_SCALE ).round ().to (torch .int32 )
46+ weights = weights .data . clamp (- MAX_HIDDEN_WEIGHT , MAX_HIDDEN_WEIGHT ).mul (HIDDEN_WEIGHT_SCALE ).round ().to (torch .int8 )
4247 return (biases , weights )
4348
4449
4550def quant_output (biases , weights ):
46- biases = biases .mul (OUTPUT_BIAS_SCALE ).round ().to (torch .int32 )
47- weights = weights .clamp (- MAX_OUTPUT_WEIGHT , MAX_OUTPUT_WEIGHT ).mul (OUTPUT_WEIGHT_SCALE ).round ().to (torch .int8 )
48- return (biases , weights )
49-
50-
51- def extract_layer (file , ninputs , size ):
52- buf = numpy .fromfile (file , numpy .float32 , size )
53- biases = torch .from_numpy (buf .astype (numpy .float32 ))
54- buf = numpy .fromfile (file , numpy .float32 , size * ninputs )
55- weights = torch .from_numpy (buf .astype (numpy .float32 ))
51+ biases = biases .data .mul (OUTPUT_BIAS_SCALE ).round ().to (torch .int32 )
52+ weights = weights .data .clamp (- MAX_OUTPUT_WEIGHT , MAX_OUTPUT_WEIGHT ).mul (OUTPUT_WEIGHT_SCALE ).round ().to (torch .int8 )
5653 return (biases , weights )
5754
5855
@@ -65,15 +62,15 @@ def quantization(source, target):
6562 nnue .eval ()
6663
6764 # Perform quantization
68- input = quant_input (nnue .input .weight , nnue .input .bias )
69- linear1 = quant_linear (nnue .l1 .weight , nnue .l1 .bias )
70- linear2 = quant_linear (nnue .l2 .weight , nnue .l2 .bias )
71- output = quant_output (nnue .output .weight , nnue .output .bias )
65+ input = quant_input (nnue .input .bias , nnue .input .weight )
66+ linear1 = quant_linear (nnue .l1 .bias , nnue .l1 .weight )
67+ linear2 = quant_linear (nnue .l2 .bias , nnue .l2 .weight )
68+ output = quant_output (nnue .output .bias , nnue .output .weight )
7269
7370 # Write quantized layers
7471 outbuffer = bytearray ()
7572 write_header (outbuffer , NNUE_FORMAT_VERSION )
76- write_layer (outbuffer , input [0 ], input [1 ])
73+ write_input (outbuffer , input [0 ], input [1 ])
7774 write_layer (outbuffer , linear1 [0 ], linear1 [1 ])
7875 write_layer (outbuffer , linear2 [0 ], linear2 [1 ])
7976 write_layer (outbuffer , output [0 ], output [1 ])
0 commit comments