Skip to content

Commit b7c5756

Browse files
committed
train.py: Add --train-size and --val-size arguments
1 parent ac8d6e1 commit b7c5756

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def main(args):
132132

133133
# Create data loaders
134134
train_size = int(os.path.getsize(args.train)/BIN_SAMPLE_SIZE)
135+
if args.train_size:
136+
train_size = min(args.train_size, train_size)
135137
val_size = int(os.path.getsize(args.val)/BIN_SAMPLE_SIZE)
138+
if args.val_size:
139+
val_size = min(args.val_size, val_size)
136140
train_data_loader, val_data_loader = create_data_loaders(args.train, args.val, train_size, val_size, args.batch_size, main_device)
137141

138142
# Create model
@@ -192,6 +196,8 @@ def main(args):
192196
parser.add_argument('val', help='Validation data (.bin)')
193197
parser.add_argument('--wdl', default=1.0, type=float, help='wdl=0.0 = train on evaluations, wdl=1.0 = train on game results, interpolates between (default=1.0)')
194198
parser.add_argument('--batch-size', default=16384, type=int, help='Number of positions per batch / per iteration (default=16384)')
199+
parser.add_argument('--train-size', type=int, help='Number of training samples to use (default=all)')
200+
parser.add_argument('--val-size', type=int, help='Number of validation samples to use (default=all)')
195201
parser.add_argument('--val-check-interval', default=2000, type=int, help='How often to check validation loss (default=2000)')
196202
parser.add_argument('--log', action='store_true', help='Enable logging during training')
197203
parser.add_argument('--top-n', default=3, type=int, help='Number of models to save for each epoch (default=3)')

0 commit comments

Comments
 (0)