@@ -58,10 +58,10 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
5858 nnue .eval ()
5959 with torch .no_grad ():
6060 val_loss = []
61- for k , sample in enumerate (val_data_loader ):
62- us , them , white , black , outcome , score = sample
61+ for k , batch in enumerate (val_data_loader ):
62+ us , them , white , black , outcome , score = batch
6363 pred = nnue (us , them , white , black )
64- loss = model .loss_function (wdl , pred , sample )
64+ loss = model .loss_function (wdl , pred , batch )
6565 val_loss .append (loss )
6666
6767 val_loss = torch .mean (torch .tensor (val_loss ))
@@ -70,11 +70,11 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
7070 return val_loss
7171
7272
73- def train_step (nnue , sample , optimizer , wdl , epoch , idx , num_batches ):
74- us , them , white , black , outcome , score = sample
73+ def train_step (nnue , batch , optimizer , wdl , epoch , idx , num_batches ):
74+ us , them , white , black , outcome , score = batch
7575
7676 pred = nnue (us , them , white , black )
77- loss = model .loss_function (wdl , pred , sample )
77+ loss = model .loss_function (wdl , pred , batch )
7878 loss .backward ()
7979 optimizer .step ()
8080 nnue .zero_grad ()
@@ -153,8 +153,8 @@ def main(args):
153153 best_val_loss = 1000000.0
154154 saved_models = []
155155
156- for k , sample in enumerate (train_data_loader ):
157- train_loss = train_step (nnue , sample , optimizer , args .wdl , epoch , k , num_batches )
156+ for k , batch in enumerate (train_data_loader ):
157+ train_loss = train_step (nnue , batch , optimizer , args .wdl , epoch , k , num_batches )
158158 running_train_loss += train_loss .item ()
159159
160160 if k % args .val_check_interval == (args .val_check_interval - 1 ):
0 commit comments