Skip to content

Commit 98329fa

Browse files
authored
Merge branch 'main' into refactor
2 parents faacaca + 11aedb5 commit 98329fa

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

trainer/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from batchloader import BatchLoader
88
from model import PerspectiveNetwork
99
from train import train
10+
from quantize import load_quantized_net
1011

1112

1213
def main():
@@ -31,6 +32,8 @@ def load_config(config_path="config.json"):
3132

3233
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3334
model = PerspectiveNetwork(config["hidden_layer_size"]).to(device)
35+
#model = load_quantized_net("nnue.bin", config["hidden_layer_size"], 403, 64).to(device)
36+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
3437

3538
paths = [os.path.join(data_root.encode("utf-8"), file.encode("utf-8")) for file in os.listdir(data_root)]
3639

trainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def eval(self, fen, device):
6262
hidden_features = torch.cat((stm_perspective, nstm_perspective))
6363
hidden_features = self.screlu(hidden_features)
6464

65-
print(int((torch.special.logit(torch.sigmoid(self.output_layer(hidden_features))) * 400).item()))
65+
print(self.output_layer(hidden_features) * 400)
6666

6767
def clamp_weights(self):
6868
self.feature_transformer.weight.data.clamp_(-1.27, 1.27)

trainer/quantize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import struct
33
import numpy as np
44

5+
from model import PerspectiveNetwork
6+
57
QA = 403
68
QB = 64
79
QAB = QA * QB
@@ -39,4 +41,4 @@ def load_quantized_net(bin_path, hl_size, qa, qb):
3941
model.output_layer.weight.data = torch.tensor(np.array(output_weights).reshape(1, 2 * hl_size) / qb, dtype=torch.float32)
4042
model.output_layer.bias.data = torch.tensor(np.array(output_bias) / (qa * qb), dtype=torch.float32)
4143

42-
return model
44+
return model

trainer/train.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def print_epoch_stats(epoch, running_loss, iterations, fens, start_time, current_time):
99
epoch_time = current_time - start_time
10-
message = ("\nepoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
10+
message = ("\nepoch {:<2} | time: {:.2f} s | epoch loss: {:.7f} | speed: {:.2f} pos/s"
1111
.format(epoch, epoch_time, running_loss.item() / iterations, fens / epoch_time))
1212
print(message)
1313

@@ -20,7 +20,7 @@ def save_checkpoint(model, optimizer, epoch, loss, filename):
2020
}
2121
torch.save(checkpoint, filename)
2222

23-
def load_checkpoint(model, optimizer, filename, resume_training=False):
23+
def load_checkpoint(model, optimizer, filename):
2424
checkpoint = torch.load(filename)
2525
model.load_state_dict(checkpoint['model_state_dict'])
2626
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
@@ -57,10 +57,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757
iterations = 0
5858
fens = 0
5959

60-
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
61-
6260
quantize(model, f"network/nnue_{epoch}_scaled.bin")
6361

62+
save_checkpoint(model, optimizer, epoch, running_loss, "checkpoint.pth")
63+
64+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
65+
6466
optimizer.zero_grad()
6567
prediction = model(batch)
6668

@@ -74,4 +76,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7476
fens += batch.size
7577

7678
if fens % 163_840 == 0:
77-
print("\rTotal fens parsed in this epoch:", fens, end='', flush=True)
79+
epoch_time = time() - epoch_start_time
80+
formatted_fens = "{0:_}".format(fens)
81+
formatted_speed = "{0:_}".format(int(fens / epoch_time))
82+
print("\rTotal fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s"
83+
.format(formatted_fens, epoch_time, formatted_speed), end='', flush=True)
84+
85+
if fens % 99_942_400 == 0:
86+
print_epoch_stats(epoch, running_loss, iterations, fens, epoch_start_time, time())
87+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)

0 commit comments

Comments
 (0)