Skip to content

Commit e78d25e

Browse files
committed
train.py: Save the n best models for each epoch
Instead of saving the best and last model for each epoch, save the n best ones (based on validation loss).
1 parent 5566a85 commit e78d25e

File tree

1 file changed

+32
-31
lines changed

1 file changed

+32
-31
lines changed

train.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,29 @@
33
import nnue_dataset
44
import torch
55
import time
6+
import os
67
import os.path
78
from datetime import timedelta
89
from torch.utils.data import DataLoader, Dataset
910
from torch.utils.tensorboard import SummaryWriter
1011

1112
BIN_SAMPLE_SIZE = 40
1213
OUTPUT_DIR = 'output'
13-
LATEST_LAST_PATH = ''
14-
LATEST_BEST_PATH = ''
15-
LATEST_EPOCH_PATH = ''
14+
15+
16+
def filter_saved_models(saved_models, top_n):
17+
# Sore saved models based on loss
18+
saved_models.sort()
19+
20+
# Rename the best saved models (strip the .tmp extension)
21+
for idx in range(top_n):
22+
old_path = saved_models[idx][1]
23+
new_path = os.path.splitext(old_path)[0]
24+
os.rename(old_path, new_path)
25+
26+
# Remove remaining saved models for this epoch
27+
for idx in range(top_n, len(saved_models)):
28+
os.remove(saved_models[idx][1])
1629

1730

1831
def write_model(nnue, path):
@@ -25,32 +38,14 @@ def write_model(nnue, path):
2538
f.write(buf)
2639

2740

28-
def save_model(nnue, output_path, epoch, idx, val_loss, new_best, epoch_end):
29-
global LATEST_LAST_PATH
30-
global LATEST_BEST_PATH
31-
global LATEST_EPOCH_PATH
41+
def save_model(nnue, output_path, epoch, idx, val_loss):
42+
# Construct the full path
43+
path = f'{output_path}/epoch_{epoch}_iter_{idx+1}_loss_{val_loss:.5f}.bin.tmp'
3244

33-
# Save the model as the latest version
34-
if os.path.exists(LATEST_LAST_PATH):
35-
os.remove(LATEST_LAST_PATH)
36-
last_path = f'{output_path}/last_epoch_{epoch}_iter_{idx+1}_loss_{val_loss:.5f}.bin'
37-
LATEST_LAST_PATH = last_path
38-
write_model(nnue, last_path)
45+
# Save the model
46+
write_model(nnue, path)
3947

40-
# Save the model as the new best version
41-
if new_best and not epoch_end:
42-
if os.path.exists(LATEST_BEST_PATH):
43-
os.remove(LATEST_BEST_PATH)
44-
best_path = f'{output_path}/best_epoch_{epoch}_iter_{idx+1}_loss_{val_loss:.5f}.bin'
45-
LATEST_BEST_PATH = best_path
46-
write_model(nnue, best_path)
47-
48-
# Save the model as the final version for this epoch
49-
if epoch_end:
50-
epoch_path = f'{output_path}/epoch_{epoch}_loss_{val_loss:.5f}.bin'
51-
LATEST_EPOCH_PATH = epoch_path
52-
LATEST_BEST_PATH = ''
53-
write_model(nnue, epoch_path)
48+
return path
5449

5550

5651
def prepare_output_directory():
@@ -158,18 +153,18 @@ def main(args):
158153
running_train_loss = 0.0
159154
while True:
160155
best_val_loss = 1000000.0
156+
saved_models = []
161157

162158
for k, sample in enumerate(train_data_loader):
163159
train_loss = train_step(nnue, sample, optimizer, args.lambda_, epoch, k, num_batches)
164160
running_train_loss += train_loss.item()
165161

166162
if k%args.val_check_interval == (args.val_check_interval-1):
167163
val_loss = calculate_validation_loss(nnue, val_data_loader, args.lambda_)
168-
new_best = False
169164
if (val_loss < best_val_loss):
170-
new_best = True
171165
best_val_loss = val_loss
172-
save_model(nnue, output_path, epoch, k, val_loss, new_best, False)
166+
path = save_model(nnue, output_path, epoch, k, val_loss)
167+
saved_models.append((val_loss, path))
173168
if args.log:
174169
writer.add_scalar('training loss', running_train_loss/args.val_check_interval, epoch*num_batches + k)
175170
writer.add_scalar('validation loss', val_loss, epoch*num_batches + k)
@@ -180,10 +175,15 @@ def main(args):
180175
if (val_loss < best_val_loss):
181176
new_best = True
182177
best_val_loss = val_loss
183-
save_model(nnue, output_path, epoch, num_batches-1, val_loss, new_best, True)
178+
path = save_model(nnue, output_path, epoch, num_batches-1, val_loss)
179+
saved_models.append((val_loss, path))
184180
stop = time.monotonic()
185181
print(f' ({timedelta(seconds=stop-start)})')
186182

183+
# Only keep the best snapshots (based on validation loss)
184+
filter_saved_models(saved_models, args.top_n)
185+
186+
# Update learning rate scheduler
187187
scheduler.step(best_val_loss)
188188
epoch += 1
189189

@@ -196,6 +196,7 @@ def main(args):
196196
parser.add_argument('--batch-size', default=16384, type=int, help='Number of positions per batch / per iteration (default=16384)')
197197
parser.add_argument('--val-check-interval', default=2000, type=int, help='How often to check validation loss (default=2000)')
198198
parser.add_argument('--log', action='store_true', help='Enable logging during training')
199+
parser.add_argument('--top-n', default=2, type=int, help='Number of models to save for each epoch (default=2)')
199200
args = parser.parse_args()
200201

201202
main(args)

0 commit comments

Comments
 (0)