33import nnue_dataset
44import torch
55import time
6+ import os
67import os .path
78from datetime import timedelta
89from torch .utils .data import DataLoader , Dataset
910from torch .utils .tensorboard import SummaryWriter
1011
1112BIN_SAMPLE_SIZE = 40
1213OUTPUT_DIR = 'output'
13- LATEST_LAST_PATH = ''
14- LATEST_BEST_PATH = ''
15- LATEST_EPOCH_PATH = ''
14+
15+
16+ def filter_saved_models (saved_models , top_n ):
17+ # Sore saved models based on loss
18+ saved_models .sort ()
19+
20+ # Rename the best saved models (strip the .tmp extension)
21+ for idx in range (top_n ):
22+ old_path = saved_models [idx ][1 ]
23+ new_path = os .path .splitext (old_path )[0 ]
24+ os .rename (old_path , new_path )
25+
26+ # Remove remaining saved models for this epoch
27+ for idx in range (top_n , len (saved_models )):
28+ os .remove (saved_models [idx ][1 ])
1629
1730
1831def write_model (nnue , path ):
@@ -25,32 +38,14 @@ def write_model(nnue, path):
2538 f .write (buf )
2639
2740
28- def save_model (nnue , output_path , epoch , idx , val_loss , new_best , epoch_end ):
29- global LATEST_LAST_PATH
30- global LATEST_BEST_PATH
31- global LATEST_EPOCH_PATH
41+ def save_model (nnue , output_path , epoch , idx , val_loss ):
42+ # Construct the full path
43+ path = f'{ output_path } /epoch_{ epoch } _iter_{ idx + 1 } _loss_{ val_loss :.5f} .bin.tmp'
3244
33- # Save the model as the latest version
34- if os .path .exists (LATEST_LAST_PATH ):
35- os .remove (LATEST_LAST_PATH )
36- last_path = f'{ output_path } /last_epoch_{ epoch } _iter_{ idx + 1 } _loss_{ val_loss :.5f} .bin'
37- LATEST_LAST_PATH = last_path
38- write_model (nnue , last_path )
45+ # Save the model
46+ write_model (nnue , path )
3947
40- # Save the model as the new best version
41- if new_best and not epoch_end :
42- if os .path .exists (LATEST_BEST_PATH ):
43- os .remove (LATEST_BEST_PATH )
44- best_path = f'{ output_path } /best_epoch_{ epoch } _iter_{ idx + 1 } _loss_{ val_loss :.5f} .bin'
45- LATEST_BEST_PATH = best_path
46- write_model (nnue , best_path )
47-
48- # Save the model as the final version for this epoch
49- if epoch_end :
50- epoch_path = f'{ output_path } /epoch_{ epoch } _loss_{ val_loss :.5f} .bin'
51- LATEST_EPOCH_PATH = epoch_path
52- LATEST_BEST_PATH = ''
53- write_model (nnue , epoch_path )
48+ return path
5449
5550
5651def prepare_output_directory ():
@@ -158,18 +153,18 @@ def main(args):
158153 running_train_loss = 0.0
159154 while True :
160155 best_val_loss = 1000000.0
156+ saved_models = []
161157
162158 for k , sample in enumerate (train_data_loader ):
163159 train_loss = train_step (nnue , sample , optimizer , args .lambda_ , epoch , k , num_batches )
164160 running_train_loss += train_loss .item ()
165161
166162 if k % args .val_check_interval == (args .val_check_interval - 1 ):
167163 val_loss = calculate_validation_loss (nnue , val_data_loader , args .lambda_ )
168- new_best = False
169164 if (val_loss < best_val_loss ):
170- new_best = True
171165 best_val_loss = val_loss
172- save_model (nnue , output_path , epoch , k , val_loss , new_best , False )
166+ path = save_model (nnue , output_path , epoch , k , val_loss )
167+ saved_models .append ((val_loss , path ))
173168 if args .log :
174169 writer .add_scalar ('training loss' , running_train_loss / args .val_check_interval , epoch * num_batches + k )
175170 writer .add_scalar ('validation loss' , val_loss , epoch * num_batches + k )
@@ -180,10 +175,15 @@ def main(args):
180175 if (val_loss < best_val_loss ):
181176 new_best = True
182177 best_val_loss = val_loss
183- save_model (nnue , output_path , epoch , num_batches - 1 , val_loss , new_best , True )
178+ path = save_model (nnue , output_path , epoch , num_batches - 1 , val_loss )
179+ saved_models .append ((val_loss , path ))
184180 stop = time .monotonic ()
185181 print (f' ({ timedelta (seconds = stop - start )} )' )
186182
183+ # Only keep the best snapshots (based on validation loss)
184+ filter_saved_models (saved_models , args .top_n )
185+
186+ # Update learning rate scheduler
187187 scheduler .step (best_val_loss )
188188 epoch += 1
189189
@@ -196,6 +196,7 @@ def main(args):
196196 parser .add_argument ('--batch-size' , default = 16384 , type = int , help = 'Number of positions per batch / per iteration (default=16384)' )
197197 parser .add_argument ('--val-check-interval' , default = 2000 , type = int , help = 'How often to check validation loss (default=2000)' )
198198 parser .add_argument ('--log' , action = 'store_true' , help = 'Enable logging during training' )
199+ parser .add_argument ('--top-n' , default = 2 , type = int , help = 'Number of models to save for each epoch (default=2)' )
199200 args = parser .parse_args ()
200201
201202 main (args )
0 commit comments