Skip to content

Commit fafae62

Browse files
authored
Merge pull request #5 from martinnovaak/resume-train
Add print of startpos eval after each epoch
2 parents da0ca67 + a59f506 commit fafae62

File tree

4 files changed

+71
-8
lines changed

4 files changed

+71
-8
lines changed

trainer/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
"wdl": 0.5,
99
"lr_drop_steps" : 4,
1010
"scale": 400,
11-
"hidden_layer_size": 16
11+
"hidden_layer_size": 16,
12+
"resume_training": false
1213
}

trainer/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def load_config(config_path="config.json"):
2727
wdl = config.get("wdl", 0.5)
2828
lr_drop_steps = config.get("lr_drop_steps", 10)
2929
scale = config.get("scale")
30+
resume_training = config.get("resume_training", False)
3031

3132
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3233
model = PerspectiveNetwork(config["hidden_layer_size"]).to(device)
@@ -39,7 +40,7 @@ def load_config(config_path="config.json"):
3940

4041
start_time = time()
4142

42-
train(model, optimizer, dataloader, epochs, lr_drop_steps, device)
43+
train(model, optimizer, dataloader, epochs, lr_drop_steps, device, resume_training)
4344

4445
end_time = time()
4546
elapsed_time = end_time - start_time

trainer/model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ def forward(self, batch: Batch):
3131

3232
return torch.sigmoid(self.output_layer(hidden_features))
3333

34+
def eval(self, fen, device):
35+
fen = fen.split(" ")[0]
36+
stm_features_dense_tensor = torch.zeros(768, device=device)
37+
nstm_features_dense_tensor = torch.zeros(768, device=device)
38+
39+
for rank_idx, rank in enumerate(fen.split('/')):
40+
file_idx = 0
41+
for char in rank:
42+
if char.isdigit():
43+
file_idx += int(char)
44+
else:
45+
sq = 8 * (7 - rank_idx) + file_idx
46+
piece_type = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5}[char.lower()]
47+
48+
is_black_piece = char.islower()
49+
piece_color = 1 if is_black_piece else 0
50+
51+
stm_features_dense_tensor[piece_color * 384 + piece_type * 64 + sq] = 1
52+
nstm_features_dense_tensor[(1 - piece_color) * 384 + piece_type * 64 + (sq ^ 56)] = 1
53+
54+
file_idx += 1
55+
56+
board_stm = stm_features_dense_tensor.to_dense()
57+
board_nstm = nstm_features_dense_tensor.to_dense()
58+
59+
stm_perspective = self.feature_transformer(board_stm)
60+
nstm_perspective = self.feature_transformer(board_nstm)
61+
62+
hidden_features = torch.cat((stm_perspective, nstm_perspective))
63+
hidden_features = self.screlu(hidden_features)
64+
65+
print(int((torch.special.logit(torch.sigmoid(self.output_layer(hidden_features))) * 400).item()))
66+
3467
def clamp_weights(self):
3568
self.feature_transformer.weight.data.clamp_(-1.27, 1.27)
3669
self.output_layer.weight.data.clamp_(-1.27, 1.27)

trainer/train.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,58 @@
77

88
def print_epoch_stats(epoch, running_loss, iterations, fens, start_time, current_time):
99
epoch_time = current_time - start_time
10-
message = ("epoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
10+
message = ("\nepoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
1111
.format(epoch, epoch_time, running_loss.item() / iterations, fens / epoch_time))
1212
print(message)
1313

14-
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: BatchLoader, epochs: int, lr_drop_steps: int, device: torch.device):
14+
def save_checkpoint(model, optimizer, epoch, loss, filename):
15+
checkpoint = {
16+
'epoch': epoch,
17+
'model_state_dict': model.state_dict(),
18+
'optimizer_state_dict': optimizer.state_dict(),
19+
'loss': loss,
20+
}
21+
torch.save(checkpoint, filename)
22+
23+
def load_checkpoint(model, optimizer, filename, resume_training=False):
24+
checkpoint = torch.load(filename)
25+
model.load_state_dict(checkpoint['model_state_dict'])
26+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
27+
epoch = checkpoint['epoch']
28+
loss = checkpoint['loss']
29+
return model, optimizer, epoch, loss
30+
31+
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: BatchLoader, epochs: int, lr_drop_steps: int, device: torch.device, resume_training: bool = False):
32+
if resume_training:
33+
model, optimizer, start_epoch, best_loss = load_checkpoint(model, optimizer, "checkpoint.pth")
34+
else:
35+
start_epoch = 0
36+
1537
running_loss = torch.zeros(1, device=device)
1638
epoch_start_time = time()
1739
iterations = 0
1840
fens = 0
19-
epoch = 0
41+
epoch = start_epoch
2042

2143
while epoch < epochs:
2244
new_epoch, batch = dataloader.next_batch(device)
2345
if new_epoch:
2446
epoch += 1
2547

26-
if epoch % lr_drop_steps == 0:
27-
optimizer.param_groups[0]["lr"] *= 0.1
28-
2948
current_time = time()
3049
print_epoch_stats(epoch, running_loss, iterations, fens, epoch_start_time, current_time)
3150

51+
if epoch % lr_drop_steps == 0:
52+
optimizer.param_groups[0]["lr"] *= 0.1
53+
print("LR dropped")
54+
3255
running_loss = torch.zeros(1, device=device)
3356
epoch_start_time = current_time
3457
iterations = 0
3558
fens = 0
3659

60+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
61+
3762
quantize(model, f"network/nnue_{epoch}_scaled.bin")
3863

3964
optimizer.zero_grad()
@@ -47,3 +72,6 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
4772
running_loss += loss
4873
iterations += 1
4974
fens += batch.size
75+
76+
if fens % 163_840 == 0:
77+
print("\rTotal fens parsed in this epoch:", fens, end='', flush=True)

0 commit comments

Comments
 (0)