77
88def print_epoch_stats (epoch , running_loss , iterations , fens , start_time , current_time ):
99 epoch_time = current_time - start_time
10- message = ("\n epoch {:<2} | time: {:.2f} s | epoch loss: {:.4f } | speed: {:.2f} pos/s"
10+ message = ("\n epoch {:<2} | time: {:.2f} s | epoch loss: {:.7f } | speed: {:.2f} pos/s"
1111 .format (epoch , epoch_time , running_loss .item () / iterations , fens / epoch_time ))
1212 print (message )
1313
@@ -20,7 +20,7 @@ def save_checkpoint(model, optimizer, epoch, loss, filename):
2020 }
2121 torch .save (checkpoint , filename )
2222
23- def load_checkpoint (model , optimizer , filename , resume_training = False ):
23+ def load_checkpoint (model , optimizer , filename ):
2424 checkpoint = torch .load (filename )
2525 model .load_state_dict (checkpoint ['model_state_dict' ])
2626 optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
@@ -61,6 +61,8 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
6161
6262 quantize (model , f"network/nnue_{ epoch } _scaled.bin" )
6363
64+ save_checkpoint (model , optimizer , epoch , running_loss , "checkpoint.pth" )
65+
6466 optimizer .zero_grad ()
6567 prediction = model (batch )
6668
@@ -74,4 +76,10 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
7476 fens += batch .size
7577
7678 if fens % 163_840 == 0 :
77- print ("\r Total fens parsed in this epoch:" , fens , end = '' , flush = True )
79+ epoch_time = time () - epoch_start_time
80+ formatted_fens = "{0:_}" .format (fens )
81+ formatted_speed = "{0:_}" .format (int (fens / epoch_time ))
82+ print ("\r Total fens parsed in this epoch: {}, Time: {:.2f} s, Speed: {} pos/s" .format (formatted_fens ,
83+ epoch_time ,
84+ formatted_speed ),
85+ end = '' , flush = True )
0 commit comments