Skip to content

Commit 11aedb5

Browse files
authored
Merge pull request #8 from martinnovaak/logging
Enable finetuning
2 parents dbfdd8f + fcc8a4e commit 11aedb5

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

trainer/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from batchloader import BatchLoader
88
from model import PerspectiveNetwork
99
from train import train
10+
from quantize import load_quantized_net
1011

1112

1213
def main():
@@ -31,6 +32,8 @@ def load_config(config_path="config.json"):
3132

3233
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3334
model = PerspectiveNetwork(config["hidden_layer_size"]).to(device)
35+
#model = load_quantized_net("nnue.bin", config["hidden_layer_size"], 403, 64).to(device)
36+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
3437

3538
paths = [os.path.join(data_root.encode("utf-8"), file.encode("utf-8")) for file in os.listdir(data_root)]
3639

trainer/quantize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import struct
33
import numpy as np
44

5+
from model import PerspectiveNetwork
6+
57
QA = 403
68
QB = 64
79
QAB = QA * QB
@@ -31,3 +33,20 @@ def quantize(model, bin_path):
3133
bin_file.write(struct.pack('<' + 'h' * len(feature_bias_values), *feature_bias_values))
3234
bin_file.write(struct.pack('<' + 'h' * len(output_weight_values), *output_weight_values))
3335
bin_file.write(struct.pack('<' + 'h' * len(output_bias_values), *output_bias_values))
36+
37+
38+
def load_quantized_net(bin_path, hl_size, qa, qb):
39+
with open(bin_path, "rb") as bin_file:
40+
# Read feature weights
41+
feature_weights = struct.unpack(f'<{768 * hl_size}h', bin_file.read(768 * hl_size * 2))
42+
feature_bias = struct.unpack(f'<{hl_size}h', bin_file.read(hl_size * 2))
43+
output_weights = struct.unpack(f'<{2 * hl_size}h', bin_file.read(2 * hl_size * 2))
44+
output_bias = struct.unpack('<1h', bin_file.read(1 * 2))
45+
46+
model = PerspectiveNetwork(hl_size)
47+
model.feature_transformer.weight.data = torch.tensor(np.array(feature_weights).reshape(768, hl_size).T / qa, dtype=torch.float32)
48+
model.feature_transformer.bias.data = torch.tensor(np.array(feature_bias) / qa, dtype=torch.float32)
49+
model.output_layer.weight.data = torch.tensor(np.array(output_weights).reshape(1, 2 * hl_size) / qb, dtype=torch.float32)
50+
model.output_layer.bias.data = torch.tensor(np.array(output_bias) / (qa * qb), dtype=torch.float32)
51+
52+
return model

0 commit comments

Comments
 (0)