Skip to content

Commit 9d72441

Browse files
committed
train.py: Added the --resume option
The --resume option can be used to resume training from an existing model.
1 parent ae5aab4 commit 9d72441

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import model as M
2+
import model
33
import nnue_dataset
44
import torch
55
import time
@@ -61,7 +61,7 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
6161
for k, sample in enumerate(val_data_loader):
6262
us, them, white, black, outcome, score = sample
6363
pred = nnue(us, them, white, black)
64-
loss = M.loss_function(wdl, pred, sample)
64+
loss = model.loss_function(wdl, pred, sample)
6565
val_loss.append(loss)
6666

6767
val_loss = torch.mean(torch.tensor(val_loss))
@@ -74,7 +74,7 @@ def train_step(nnue, sample, optimizer, wdl, epoch, idx, num_batches):
7474
us, them, white, black, outcome, score = sample
7575

7676
pred = nnue(us, them, white, black)
77-
loss = M.loss_function(wdl, pred, sample)
77+
loss = model.loss_function(wdl, pred, sample)
7878
loss.backward()
7979
optimizer.step()
8080
nnue.zero_grad()
@@ -119,6 +119,8 @@ def main(args):
119119
print(f'Batch size: {args.batch_size}')
120120
print(f'WDL: {args.wdl}')
121121
print(f'Validation check interval: {args.val_check_interval}')
122+
if args.resume:
123+
print(f'Resuming training from {args.resume}')
122124
if args.log:
123125
print(f'Logs written to: {output_path}')
124126
print(f'Data written to: {output_path}')
@@ -134,7 +136,9 @@ def main(args):
134136
train_data_loader, val_data_loader = create_data_loaders(args.train, args.val, train_size, val_size, args.batch_size, main_device)
135137

136138
# Create model
137-
nnue = M.NNUE().to(main_device)
139+
nnue = model.NNUE().to(main_device)
140+
if args.resume:
141+
nnue.load_state_dict(torch.load(args.resume))
138142

139143
# Configure optimizer
140144
optimizer = torch.optim.RAdam(nnue.parameters(), lr=1e-3, betas=(.95, 0.999), eps=1e-5, weight_decay=0)
@@ -191,6 +195,8 @@ def main(args):
191195
parser.add_argument('--val-check-interval', default=2000, type=int, help='How often to check validation loss (default=2000)')
192196
parser.add_argument('--log', action='store_true', help='Enable logging during training')
193197
parser.add_argument('--top-n', default=2, type=int, help='Number of models to save for each epoch (default=2)')
198+
parser.add_argument('--resume',
199+
help='Resume training from an existing snapshot')
194200
args = parser.parse_args()
195201

196202
main(args)

0 commit comments

Comments
 (0)