77
88def print_epoch_stats (epoch , running_loss , iterations , fens , start_time , current_time ):
99 epoch_time = current_time - start_time
10- message = ("epoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
10+ message = ("\n epoch {:<2} | time: {:.2f} s | epoch loss: {:.4f} | speed: {:.2f} pos/s"
1111 .format (epoch , epoch_time , running_loss .item () / iterations , fens / epoch_time ))
1212 print (message )
1313
14- def train (model : torch .nn .Module , optimizer : torch .optim .Optimizer , dataloader : BatchLoader , epochs : int , lr_drop_steps : int , device : torch .device ):
14+ def save_checkpoint (model , optimizer , epoch , loss , filename ):
15+ checkpoint = {
16+ 'epoch' : epoch ,
17+ 'model_state_dict' : model .state_dict (),
18+ 'optimizer_state_dict' : optimizer .state_dict (),
19+ 'loss' : loss ,
20+ }
21+ torch .save (checkpoint , filename )
22+
23+ def load_checkpoint (model , optimizer , filename , resume_training = False ):
24+ checkpoint = torch .load (filename )
25+ model .load_state_dict (checkpoint ['model_state_dict' ])
26+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
27+ epoch = checkpoint ['epoch' ]
28+ loss = checkpoint ['loss' ]
29+ return model , optimizer , epoch , loss
30+
31+ def train (model : torch .nn .Module , optimizer : torch .optim .Optimizer , dataloader : BatchLoader , epochs : int , lr_drop_steps : int , device : torch .device , resume_training : bool = False ):
32+ if resume_training :
33+ model , optimizer , start_epoch , best_loss = load_checkpoint (model , optimizer , "checkpoint.pth" )
34+ else :
35+ start_epoch = 0
36+
1537 running_loss = torch .zeros (1 , device = device )
1638 epoch_start_time = time ()
1739 iterations = 0
1840 fens = 0
19- epoch = 0
41+ epoch = start_epoch
2042
2143 while epoch < epochs :
2244 new_epoch , batch = dataloader .next_batch (device )
2345 if new_epoch :
2446 epoch += 1
2547
26- if epoch % lr_drop_steps == 0 :
27- optimizer .param_groups [0 ]["lr" ] *= 0.1
28-
2948 current_time = time ()
3049 print_epoch_stats (epoch , running_loss , iterations , fens , epoch_start_time , current_time )
3150
51+ if epoch % lr_drop_steps == 0 :
52+ optimizer .param_groups [0 ]["lr" ] *= 0.1
53+ print ("LR dropped" )
54+
3255 running_loss = torch .zeros (1 , device = device )
3356 epoch_start_time = current_time
3457 iterations = 0
3558 fens = 0
3659
60+ model .eval ("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" , device )
61+
3762 quantize (model , f"network/nnue_{ epoch } _scaled.bin" )
3863
3964 optimizer .zero_grad ()
@@ -47,3 +72,6 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
4772 running_loss += loss
4873 iterations += 1
4974 fens += batch .size
75+
76+ if fens % 163_840 == 0 :
77+ print ("\r Total fens parsed in this epoch:" , fens , end = '' , flush = True )
0 commit comments