Skip to content

Commit 5ded834

Browse files
committed
Improve printing
1 parent ebde751 commit 5ded834

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

trainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def eval(self, fen, device):
6262
hidden_features = torch.cat((stm_perspective, nstm_perspective))
6363
hidden_features = self.screlu(hidden_features)
6464

65-
print(int((torch.special.logit(torch.sigmoid(self.output_layer(hidden_features))) * 400).item()))
65+
print(self.output_layer(hidden_features) * 400)
6666

6767
def clamp_weights(self):
6868
self.feature_transformer.weight.data.clamp_(-1.27, 1.27)

trainer/train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757
iterations = 0
5858
fens = 0
5959

60-
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
61-
6260
quantize(model, f"network/nnue_{epoch}_scaled.bin")
6361

6462
save_checkpoint(model, optimizer, epoch, running_loss, "checkpoint.pth")
6563

64+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
65+
6666
optimizer.zero_grad()
6767
prediction = model(batch)
6868

@@ -79,7 +79,9 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7979
epoch_time = time() - epoch_start_time
8080
formatted_fens = "{0:_}".format(fens)
8181
formatted_speed = "{0:_}".format(int(fens / epoch_time))
82-
print("\rTotal fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s".format(formatted_fens,
83-
epoch_time,
84-
formatted_speed),
85-
end='', flush=True)
82+
print("\rTotal fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s"
83+
.format(formatted_fens, epoch_time, formatted_speed), end='', flush=True)
84+
85+
if fens % 99_942_400 == 0:
86+
print_epoch_stats(epoch, running_loss, iterations, fens, epoch_start_time, time())
87+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)

0 commit comments

Comments
 (0)