Skip to content

Commit dc5acfd

Browse files
committed
quantize.py: Fix quantization
1 parent d1ede87 commit dc5acfd

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

quantize.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,31 @@ def write_header(buf, version):
2525
buf.extend(struct.pack('<I', version))
2626

2727

28+
def write_input(buf, biases, weights):
29+
buf.extend(biases.flatten().numpy().tobytes())
30+
buf.extend(weights.transpose(0, 1).flatten().numpy().tobytes())
31+
32+
2833
def write_layer(buf, biases, weights):
29-
buf.extend(biases.numpy().tobytes())
30-
buf.extend(weights.numpy().tobytes())
34+
buf.extend(biases.flatten().numpy().tobytes())
35+
buf.extend(weights.flatten().numpy().tobytes())
3136

3237

3338
def quant_input(biases, weights):
34-
biases = biases.mul(HALFKX_BIAS_SCALE).round().to(torch.int16)
35-
weights = weights.mul(HALFKX_WEIGHT_SCALE).round().to(torch.int16)
39+
biases = biases.data.mul(HALFKX_BIAS_SCALE).round().to(torch.int16)
40+
weights = weights.data.mul(HALFKX_WEIGHT_SCALE).round().to(torch.int16)
3641
return (biases, weights)
3742

3843

3944
def quant_linear(biases, weights):
40-
biases = biases.mul(HIDDEN_BIAS_SCALE).round().to(torch.int32)
41-
weights = weights.clamp(-MAX_HIDDEN_WEIGHT, MAX_HIDDEN_WEIGHT).mul(HIDDEN_WEIGHT_SCALE).round().to(torch.int8)
45+
biases = biases.data.mul(HIDDEN_BIAS_SCALE).round().to(torch.int32)
46+
weights = weights.data.clamp(-MAX_HIDDEN_WEIGHT, MAX_HIDDEN_WEIGHT).mul(HIDDEN_WEIGHT_SCALE).round().to(torch.int8)
4247
return (biases, weights)
4348

4449

4550
def quant_output(biases, weights):
46-
biases = biases.mul(OUTPUT_BIAS_SCALE).round().to(torch.int32)
47-
weights = weights.clamp(-MAX_OUTPUT_WEIGHT, MAX_OUTPUT_WEIGHT).mul(OUTPUT_WEIGHT_SCALE).round().to(torch.int8)
48-
return (biases, weights)
49-
50-
51-
def extract_layer(file, ninputs, size):
52-
buf = numpy.fromfile(file, numpy.float32, size)
53-
biases = torch.from_numpy(buf.astype(numpy.float32))
54-
buf = numpy.fromfile(file, numpy.float32, size*ninputs)
55-
weights = torch.from_numpy(buf.astype(numpy.float32))
51+
biases = biases.data.mul(OUTPUT_BIAS_SCALE).round().to(torch.int32)
52+
weights = weights.data.clamp(-MAX_OUTPUT_WEIGHT, MAX_OUTPUT_WEIGHT).mul(OUTPUT_WEIGHT_SCALE).round().to(torch.int8)
5653
return (biases, weights)
5754

5855

@@ -65,15 +62,15 @@ def quantization(source, target):
6562
nnue.eval()
6663

6764
# Perform quantization
68-
input = quant_input(nnue.input.weight, nnue.input.bias)
69-
linear1 = quant_linear(nnue.l1.weight, nnue.l1.bias)
70-
linear2 = quant_linear(nnue.l2.weight, nnue.l2.bias)
71-
output = quant_output(nnue.output.weight, nnue.output.bias)
65+
input = quant_input(nnue.input.bias, nnue.input.weight)
66+
linear1 = quant_linear(nnue.l1.bias, nnue.l1.weight)
67+
linear2 = quant_linear(nnue.l2.bias, nnue.l2.weight)
68+
output = quant_output(nnue.output.bias, nnue.output.weight)
7269

7370
# Write quantized layers
7471
outbuffer = bytearray()
7572
write_header(outbuffer, NNUE_FORMAT_VERSION)
76-
write_layer(outbuffer, input[0], input[1])
73+
write_input(outbuffer, input[0], input[1])
7774
write_layer(outbuffer, linear1[0], linear1[1])
7875
write_layer(outbuffer, linear2[0], linear2[1])
7976
write_layer(outbuffer, output[0], output[1])

0 commit comments

Comments
 (0)