Skip to content

Commit d1ede87

Browse files
committed
train.py: Rename sample variable to batch
Rename the sample variable to batch to match what it actually is.
1 parent ba72bdd commit d1ede87

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

train.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
5858
nnue.eval()
5959
with torch.no_grad():
6060
val_loss = []
61-
for k, sample in enumerate(val_data_loader):
62-
us, them, white, black, outcome, score = sample
61+
for k, batch in enumerate(val_data_loader):
62+
us, them, white, black, outcome, score = batch
6363
pred = nnue(us, them, white, black)
64-
loss = model.loss_function(wdl, pred, sample)
64+
loss = model.loss_function(wdl, pred, batch)
6565
val_loss.append(loss)
6666

6767
val_loss = torch.mean(torch.tensor(val_loss))
@@ -70,11 +70,11 @@ def calculate_validation_loss(nnue, val_data_loader, wdl):
7070
return val_loss
7171

7272

73-
def train_step(nnue, sample, optimizer, wdl, epoch, idx, num_batches):
74-
us, them, white, black, outcome, score = sample
73+
def train_step(nnue, batch, optimizer, wdl, epoch, idx, num_batches):
74+
us, them, white, black, outcome, score = batch
7575

7676
pred = nnue(us, them, white, black)
77-
loss = model.loss_function(wdl, pred, sample)
77+
loss = model.loss_function(wdl, pred, batch)
7878
loss.backward()
7979
optimizer.step()
8080
nnue.zero_grad()
@@ -153,8 +153,8 @@ def main(args):
153153
best_val_loss = 1000000.0
154154
saved_models = []
155155

156-
for k, sample in enumerate(train_data_loader):
157-
train_loss = train_step(nnue, sample, optimizer, args.wdl, epoch, k, num_batches)
156+
for k, batch in enumerate(train_data_loader):
157+
train_loss = train_step(nnue, batch, optimizer, args.wdl, epoch, k, num_batches)
158158
running_train_loss += train_loss.item()
159159

160160
if k%args.val_check_interval == (args.val_check_interval-1):

0 commit comments

Comments
 (0)