@@ -57,12 +57,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757 iterations = 0
5858 fens = 0
5959
60- model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
61-
6260 quantize (model , f"network/nnue_{ epoch } _scaled.bin" )
6361
6462 save_checkpoint (model , optimizer , epoch , running_loss , "checkpoint.pth" )
6563
64+ model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
65+
6666 optimizer .zero_grad ()
6767 prediction = model (batch )
6868
@@ -79,7 +79,9 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7979 epoch_time = time () - epoch_start_time
8080 formatted_fens = "{0:_}" .format (fens )
8181 formatted_speed = "{0:_}" .format (int (fens / epoch_time ))
82- print ("\r Total fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s" .format (formatted_fens ,
83- epoch_time ,
84- formatted_speed ),
85- end = '' , flush = True )
82+ print ("\r Total fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s"
83+ .format (formatted_fens , epoch_time , formatted_speed ), end = '' , flush = True )
84+
85+ if fens % 99_942_400 == 0 :
86+ print_epoch_stats (epoch , running_loss , iterations , fens , epoch_start_time , time ())
87+ model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
0 commit comments