Skip to content

Commit dbfdd8f

Browse files
authored
Merge pull request #7 from martinnovaak/logging
Improve logging while training
2 parents 2c0eded + 5ded834 commit dbfdd8f

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

trainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def eval(self, fen, device):
6262
hidden_features = torch.cat((stm_perspective, nstm_perspective))
6363
hidden_features = self.screlu(hidden_features)
6464

65-
print(int((torch.special.logit(torch.sigmoid(self.output_layer(hidden_features))) * 400).item()))
65+
print(self.output_layer(hidden_features) * 400)
6666

6767
def clamp_weights(self):
6868
self.feature_transformer.weight.data.clamp_(-1.27, 1.27)

trainer/train.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def print_epoch_stats(epoch, running_loss, iterations, fens, start_time, current_time):
99
epoch_time = current_time - start_time
10-
message = ("\nepoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
10+
message = ("\nepoch {:<2} | time: {:.2f} s | epoch loss: {:.7f} | speed: {:.2f} pos/s"
1111
.format(epoch, epoch_time, running_loss.item() / iterations, fens / epoch_time))
1212
print(message)
1313

@@ -20,7 +20,7 @@ def save_checkpoint(model, optimizer, epoch, loss, filename):
2020
}
2121
torch.save(checkpoint, filename)
2222

23-
def load_checkpoint(model, optimizer, filename, resume_training=False):
23+
def load_checkpoint(model, optimizer, filename):
2424
checkpoint = torch.load(filename)
2525
model.load_state_dict(checkpoint['model_state_dict'])
2626
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
@@ -57,10 +57,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757
iterations = 0
5858
fens = 0
5959

60-
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
61-
6260
quantize(model, f"network/nnue_{epoch}_scaled.bin")
6361

62+
save_checkpoint(model, optimizer, epoch, running_loss, "checkpoint.pth")
63+
64+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
65+
6466
optimizer.zero_grad()
6567
prediction = model(batch)
6668

@@ -74,4 +76,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7476
fens += batch.size
7577

7678
if fens % 163_840 == 0:
77-
print("\rTotal fens parsed in this epoch:", fens, end='', flush=True)
79+
epoch_time = time() - epoch_start_time
80+
formatted_fens = "{0:_}".format(fens)
81+
formatted_speed = "{0:_}".format(int(fens / epoch_time))
82+
print("\rTotal fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s"
83+
.format(formatted_fens, epoch_time, formatted_speed), end='', flush=True)
84+
85+
if fens % 99_942_400 == 0:
86+
print_epoch_stats(epoch, running_loss, iterations, fens, epoch_start_time, time())
87+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)

0 commit comments

Comments
 (0)