77
88def print_epoch_stats (epoch , running_loss , iterations , fens , start_time , current_time ):
99 epoch_time = current_time - start_time
10- message = ("\n epoch {:<2} | time: {:.2f} s | epoch loss: {:.4f } | speed: {:.2f} pos/s"
10+ message = ("\n epoch {:<2} | time: {:.2f} s | epoch loss: {:.7f } | speed: {:.2f} pos/s"
1111 .format (epoch , epoch_time , running_loss .item () / iterations , fens / epoch_time ))
1212 print (message )
1313
@@ -20,7 +20,7 @@ def save_checkpoint(model, optimizer, epoch, loss, filename):
2020 }
2121 torch .save (checkpoint , filename )
2222
23- def load_checkpoint (model , optimizer , filename , resume_training = False ):
23+ def load_checkpoint (model , optimizer , filename ):
2424 checkpoint = torch .load (filename )
2525 model .load_state_dict (checkpoint ['model_state_dict' ])
2626 optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
@@ -57,10 +57,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757 iterations = 0
5858 fens = 0
5959
60- model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
61-
6260 quantize (model , f"network/nnue_{ epoch } _scaled.bin" )
6361
62+ save_checkpoint (model , optimizer , epoch , running_loss , "checkpoint.pth" )
63+
64+ model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
65+
6466 optimizer .zero_grad ()
6567 prediction = model (batch )
6668
@@ -74,4 +76,12 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7476 fens += batch .size
7577
7678 if fens % 163_840 == 0 :
77- print ("\r Total fens parsed in this epoch:" , fens , end = '' , flush = True )
79+ epoch_time = time () - epoch_start_time
80+ formatted_fens = "{0:_}" .format (fens )
81+ formatted_speed = "{0:_}" .format (int (fens / epoch_time ))
82+ print ("\r Total fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s"
83+ .format (formatted_fens , epoch_time , formatted_speed ), end = '' , flush = True )
84+
85+ if fens % 99_942_400 == 0 :
86+ print_epoch_stats (epoch , running_loss , iterations , fens , epoch_start_time , time ())
87+ model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
0 commit comments