Skip to content

Commit ae5aab4

Browse files
committed
Change format for saved snapshots
Change snapshot format to Pytorch state dictionary.
1 parent ce4c2e6 commit ae5aab4

File tree

4 files changed

+16
-59
lines changed

4 files changed

+16
-59
lines changed

model.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
from torch import nn
44
import torch.nn.functional as F
55

6-
# The version of the export format
7-
EXPORT_FORMAT_VERSION = 0x00000008
8-
96
# Number of inputs
107
NUM_SQ = 64
118
NUM_PT = 12
@@ -36,31 +33,6 @@ def forward(self, us, them, w_in, b_in):
3633
return x
3734

3835

39-
def serialize_halfkx_layer(self, buf, layer):
40-
bias = layer.bias.data.cpu()
41-
buf.extend(bias.flatten().numpy().tobytes())
42-
weight = self.input.weight.data.clone().cpu()
43-
buf.extend(weight.transpose(0, 1).flatten().numpy().tobytes())
44-
45-
46-
def serialize_linear_layer(self, buf, layer):
47-
bias = layer.bias.data.cpu()
48-
buf.extend(bias.flatten().numpy().tobytes())
49-
weight = layer.weight.data.cpu()
50-
buf.extend(weight.flatten().numpy().tobytes())
51-
52-
53-
def serialize(self, buf):
54-
# Write header
55-
buf.extend(struct.pack('<i', EXPORT_FORMAT_VERSION))
56-
57-
# Write layers
58-
self.serialize_halfkx_layer(buf, self.input)
59-
self.serialize_linear_layer(buf, self.l1)
60-
self.serialize_linear_layer(buf, self.l2)
61-
self.serialize_linear_layer(buf, self.output)
62-
63-
6436
def loss_function(wdl, pred, batch):
6537
us, them, white, black, outcome, score = batch
6638

nettest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main(args):
7171
# Find all .bin files in the net folder
7272
bin_files = [f for f in os.listdir(args.net_dir)
7373
if (os.path.isfile(os.path.join(args.net_dir, f)) and
74-
os.path.splitext(f)[1] == '.bin')]
74+
os.path.splitext(f)[1] == '.pt')]
7575
time.sleep(5)
7676

7777
# Run a match with each new net

quantize.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MAX_HIDDEN_WEIGHT = MAX_QUANTIZED_ACTIVATION/HIDDEN_WEIGHT_SCALE
2020
MAX_OUTPUT_WEIGHT = MAX_QUANTIZED_ACTIVATION/OUTPUT_WEIGHT_SCALE
2121

22+
NNUE_FORMAT_VERSION = 0x00000008
2223

2324
def write_header(buf, version):
2425
buf.extend(struct.pack('<I', version))
@@ -29,7 +30,7 @@ def write_layer(buf, biases, weights):
2930
buf.extend(weights.numpy().tobytes())
3031

3132

32-
def quant_halfkx(biases, weights):
33+
def quant_input(biases, weights):
3334
biases = biases.mul(HALFKX_BIAS_SCALE).round().to(torch.int16)
3435
weights = weights.mul(HALFKX_WEIGHT_SCALE).round().to(torch.int16)
3536
return (biases, weights)
@@ -47,14 +48,7 @@ def quant_output(biases, weights):
4748
return (biases, weights)
4849

4950

50-
def read_version(file):
51-
version = struct.unpack('<I', file.read(4))[0]
52-
if version != model.EXPORT_FORMAT_VERSION:
53-
raise Exception('Model format mismatch')
54-
return version
55-
56-
57-
def read_layer(file, ninputs, size):
51+
def extract_layer(file, ninputs, size):
5852
buf = numpy.fromfile(file, numpy.float32, size)
5953
biases = torch.from_numpy(buf.astype(numpy.float32))
6054
buf = numpy.fromfile(file, numpy.float32, size*ninputs)
@@ -65,24 +59,21 @@ def read_layer(file, ninputs, size):
6559
def quantization(source, target):
6660
print('Performing quantization ...')
6761

68-
# Read all layers
69-
with open(source, 'rb') as f:
70-
version = read_version(f)
71-
halfkx = read_layer(f, model.NUM_INPUTS, model.L1)
72-
linear1 = read_layer(f, model.L1*2, model.L2)
73-
linear2 = read_layer(f, model.L2, model.L3)
74-
output = read_layer(f, model.L3, 1)
62+
# Load model
63+
nnue = model.NNUE()
64+
nnue.load_state_dict(torch.load(source, map_location=torch.device('cpu')))
65+
nnue.eval()
7566

7667
# Perform quantization
77-
halfkx = quant_halfkx(halfkx[0], halfkx[1])
78-
linear1 = quant_linear(linear1[0], linear1[1])
79-
linear2 = quant_linear(linear2[0], linear2[1])
80-
output = quant_output(output[0], output[1])
68+
input = quant_input(nnue.input.weight, nnue.input.bias)
69+
linear1 = quant_linear(nnue.l1.weight, nnue.l1.bias)
70+
linear2 = quant_linear(nnue.l2.weight, nnue.l2.bias)
71+
output = quant_output(nnue.output.weight, nnue.output.bias)
8172

8273
# Write quantized layers
8374
outbuffer = bytearray()
84-
write_header(outbuffer, version)
85-
write_layer(outbuffer, halfkx[0], halfkx[1])
75+
write_header(outbuffer, NNUE_FORMAT_VERSION)
76+
write_layer(outbuffer, input[0], input[1])
8677
write_layer(outbuffer, linear1[0], linear1[1])
8778
write_layer(outbuffer, linear2[0], linear2[1])
8879
write_layer(outbuffer, output[0], output[1])

train.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,12 @@ def filter_saved_models(saved_models, top_n):
2929

3030

3131
def write_model(nnue, path):
32-
# Serialize the model to a buffer
33-
buf = bytearray()
34-
nnue.serialize(buf)
35-
36-
# Write the buffer
37-
with open(path, 'wb') as f:
38-
f.write(buf)
32+
torch.save(nnue.state_dict(), path)
3933

4034

4135
def save_model(nnue, output_path, epoch, idx, val_loss):
4236
# Construct the full path
43-
path = f'{output_path}/epoch_{epoch}_iter_{idx+1}_loss_{val_loss:.5f}.bin.tmp'
37+
path = f'{output_path}/epoch_{epoch}_iter_{idx+1}_loss_{val_loss:.5f}.pt.tmp'
4438

4539
# Save the model
4640
write_model(nnue, path)

0 commit comments

Comments
 (0)