11import argparse
2- import model as M
2+ import model
33import nnue_dataset
44import torch
55import time
@@ -61,7 +61,7 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
6161 for k , sample in enumerate (val_data_loader ):
6262 us , them , white , black , outcome , score = sample
6363 pred = nnue (us , them , white , black )
64- loss = M .loss_function (wdl , pred , sample )
64+ loss = model .loss_function (wdl , pred , sample )
6565 val_loss .append (loss )
6666
6767 val_loss = torch .mean (torch .tensor (val_loss ))
@@ -74,7 +74,7 @@ def train_step(nnue, sample, optimizer, wdl, epoch, idx, num_batches):
7474 us , them , white , black , outcome , score = sample
7575
7676 pred = nnue (us , them , white , black )
77- loss = M .loss_function (wdl , pred , sample )
77+ loss = model .loss_function (wdl , pred , sample )
7878 loss .backward ()
7979 optimizer .step ()
8080 nnue .zero_grad ()
@@ -119,6 +119,8 @@ def main(args):
119119 print (f'Batch size: { args .batch_size } ' )
120120 print (f'WDL: { args .wdl } ' )
121121 print (f'Validation check interval: { args .val_check_interval } ' )
122+ if args .resume :
123+ print (f'Resuming training from { args .resume } ' )
122124 if args .log :
123125 print (f'Logs written to: { output_path } ' )
124126 print (f'Data written to: { output_path } ' )
@@ -134,7 +136,9 @@ def main(args):
134136 train_data_loader , val_data_loader = create_data_loaders (args .train , args .val , train_size , val_size , args .batch_size , main_device )
135137
136138 # Create model
137- nnue = M .NNUE ().to (main_device )
139+ nnue = model .NNUE ().to (main_device )
140+ if args .resume :
141+ nnue .load_state_dict (torch .load (args .resume ))
138142
139143 # Configure optimizer
140144 optimizer = torch .optim .RAdam (nnue .parameters (), lr = 1e-3 , betas = (.95 , 0.999 ), eps = 1e-5 , weight_decay = 0 )
@@ -191,6 +195,8 @@ def main(args):
191195 parser .add_argument ('--val-check-interval' , default = 2000 , type = int , help = 'How often to check validation loss (default=2000)' )
192196 parser .add_argument ('--log' , action = 'store_true' , help = 'Enable logging during training' )
193197 parser .add_argument ('--top-n' , default = 2 , type = int , help = 'Number of models to save for each epoch (default=2)' )
198+ parser .add_argument ('--resume' ,
199+ help = 'Resume training from an existing snapshot' )
194200 args = parser .parse_args ()
195201
196202 main (args )
0 commit comments